diff --git a/backend/examples/assignment/strategies.py b/backend/examples/assignment/strategies.py index eb3e8095..0d72033c 100644 --- a/backend/examples/assignment/strategies.py +++ b/backend/examples/assignment/strategies.py @@ -18,6 +18,27 @@ class BaseStrategy(abc.ABC): ... +class WeightedSequentialStrategy(BaseStrategy): + def __init__(self, dataset_size: int, weights: List[int]): + assert sum(weights) == 100 + self.dataset_size = dataset_size + self.weights = weights + + def assign(self) -> List[Assignment]: + assignments = [] + proba = np.array(self.weights) / 100 + counts = np.round(proba * self.dataset_size).astype(int) + reminder = self.dataset_size - sum(counts) + for i in np.random.choice(range(len(self.weights)), size=reminder, p=proba): + counts[i] += 1 + + start = 0 + for user, count in enumerate(counts): + assignments.extend([Assignment(user=user, example=example) for example in range(start, start + count)]) + start += count + return assignments + + class WeightedRandomStrategy(BaseStrategy): def __init__(self, dataset_size: int, weights: List[int]): assert sum(weights) == 100