|
|
@ -18,6 +18,27 @@ class BaseStrategy(abc.ABC): |
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
class WeightedSequentialStrategy(BaseStrategy): |
|
|
|
def __init__(self, dataset_size: int, weights: List[int]): |
|
|
|
assert sum(weights) == 100 |
|
|
|
self.dataset_size = dataset_size |
|
|
|
self.weights = weights |
|
|
|
|
|
|
|
def assign(self) -> List[Assignment]: |
|
|
|
assignments = [] |
|
|
|
proba = np.array(self.weights) / 100 |
|
|
|
counts = np.round(proba * self.dataset_size).astype(int) |
|
|
|
reminder = self.dataset_size - sum(counts) |
|
|
|
for i in np.random.choice(range(len(self.weights)), size=reminder, p=proba): |
|
|
|
counts[i] += 1 |
|
|
|
|
|
|
|
start = 0 |
|
|
|
for user, count in enumerate(counts): |
|
|
|
assignments.extend([Assignment(user=user, example=example) for example in range(start, start + count)]) |
|
|
|
start += count |
|
|
|
return assignments |
|
|
|
|
|
|
|
|
|
|
|
class WeightedRandomStrategy(BaseStrategy): |
|
|
|
def __init__(self, dataset_size: int, weights: List[int]): |
|
|
|
assert sum(weights) == 100 |
|
|
|