From ad84c02704c6ef4bf1d35f924372acc2acfe32ed Mon Sep 17 00:00:00 2001 From: Hironsan Date: Tue, 8 Aug 2023 20:44:12 +0900 Subject: [PATCH] Fix weighted sequential strategy --- backend/examples/assignment/strategies.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/backend/examples/assignment/strategies.py b/backend/examples/assignment/strategies.py index 93419d46..22c7f1e0 100644 --- a/backend/examples/assignment/strategies.py +++ b/backend/examples/assignment/strategies.py @@ -45,16 +45,10 @@ class WeightedSequentialStrategy(BaseStrategy): def assign(self) -> List[Assignment]: assignments = [] - proba = np.array(self.weights) / 100 - counts = np.round(proba * self.dataset_size).astype(int) - reminder = self.dataset_size - sum(counts) - for i in np.random.choice(range(len(self.weights)), size=reminder, p=proba): - counts[i] += 1 - - start = 0 - for user, count in enumerate(counts): - assignments.extend([Assignment(user=user, example=example) for example in range(start, start + count)]) - start += count + ratio = list(np.round(np.cumsum(self.weights) / 100 * self.dataset_size).astype(int)) + ratio = [0] + ratio[:-1] + [self.dataset_size] + for user, (start, end) in enumerate(zip(ratio, ratio[1:])): # Todo: use itertools.pairwise + assignments.extend([Assignment(user=user, example=example) for example in range(start, end)]) return assignments