You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

82 lines
2.9 KiB

  1. import abc
  2. import dataclasses
  3. import enum
  4. import random
  5. from typing import List
  6. import numpy as np
  7. @dataclasses.dataclass
  8. class Assignment:
  9. user: int
  10. example: int
  11. class StrategyName(enum.Enum):
  12. weighted_sequential = enum.auto()
  13. weighted_random = enum.auto()
  14. sampling_without_replacement = enum.auto()
  15. def create_assignment_strategy(strategy_name: StrategyName, dataset_size: int, weights: List[int]) -> "BaseStrategy":
  16. if strategy_name == StrategyName.weighted_sequential:
  17. return WeightedSequentialStrategy(dataset_size, weights)
  18. elif strategy_name == StrategyName.weighted_random:
  19. return WeightedRandomStrategy(dataset_size, weights)
  20. elif strategy_name == StrategyName.sampling_without_replacement:
  21. return SamplingWithoutReplacementStrategy(dataset_size, weights)
  22. else:
  23. raise ValueError(f"Unknown strategy name: {strategy_name}")
  24. class BaseStrategy(abc.ABC):
  25. @abc.abstractmethod
  26. def assign(self) -> List[Assignment]:
  27. ...
  28. class WeightedSequentialStrategy(BaseStrategy):
  29. def __init__(self, dataset_size: int, weights: List[int]):
  30. if sum(weights) != 100:
  31. raise ValueError("Sum of weights must be 100")
  32. self.dataset_size = dataset_size
  33. self.weights = weights
  34. def assign(self) -> List[Assignment]:
  35. assignments = []
  36. cumsum = np.cumsum([0] + self.weights)
  37. ratio = np.round(cumsum / 100 * self.dataset_size).astype(int)
  38. for user, (start, end) in enumerate(zip(ratio, ratio[1:])): # Todo: use itertools.pairwise
  39. assignments.extend([Assignment(user=user, example=example) for example in range(start, end)])
  40. return assignments
  41. class WeightedRandomStrategy(BaseStrategy):
  42. def __init__(self, dataset_size: int, weights: List[int]):
  43. if sum(weights) != 100:
  44. raise ValueError("Sum of weights must be 100")
  45. self.dataset_size = dataset_size
  46. self.weights = weights
  47. def assign(self) -> List[Assignment]:
  48. proba = np.array(self.weights) / 100
  49. assignees = np.random.choice(range(len(self.weights)), size=self.dataset_size, p=proba)
  50. return [Assignment(user=user, example=example) for example, user in enumerate(assignees)]
  51. class SamplingWithoutReplacementStrategy(BaseStrategy):
  52. def __init__(self, dataset_size: int, weights: List[int]):
  53. if not (0 <= sum(weights) <= 100 * len(weights)):
  54. raise ValueError("Sum of weights must be between 0 and 100 x number of members")
  55. self.dataset_size = dataset_size
  56. self.weights = weights
  57. def assign(self) -> List[Assignment]:
  58. assignments = []
  59. proba = np.array(self.weights) / 100
  60. for user, p in enumerate(proba):
  61. count = int(self.dataset_size * p)
  62. examples = random.sample(range(self.dataset_size), count)
  63. assignments.extend([Assignment(user=user, example=example) for example in examples])
  64. return assignments