Browse Source

Fix weighted sequential strategy

pull/2261/head
Hironsan 1 year ago
parent
commit
ad84c02704
1 changed files with 4 additions and 10 deletions
  1. 14
      backend/examples/assignment/strategies.py

14
backend/examples/assignment/strategies.py

@ -45,16 +45,10 @@ class WeightedSequentialStrategy(BaseStrategy):
def assign(self) -> List[Assignment]:
assignments = []
proba = np.array(self.weights) / 100
counts = np.round(proba * self.dataset_size).astype(int)
reminder = self.dataset_size - sum(counts)
for i in np.random.choice(range(len(self.weights)), size=reminder, p=proba):
counts[i] += 1
start = 0
for user, count in enumerate(counts):
assignments.extend([Assignment(user=user, example=example) for example in range(start, start + count)])
start += count
ratio = list(np.round(np.cumsum(self.weights) / 100 * self.dataset_size).astype(int))
ratio = [0] + ratio[:-1] + [self.dataset_size]
for user, (start, end) in enumerate(zip(ratio, ratio[1:])): # Todo: use itertools.pairwise
assignments.extend([Assignment(user=user, example=example) for example in range(start, end)])
return assignments

Loading…
Cancel
Save