You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

85 lines
2.9 KiB

import abc
import dataclasses
import enum
import random
from typing import List
import numpy as np
@dataclasses.dataclass
class Assignment:
user: int
example: int
class StrategyName(enum.Enum):
weighted_sequential = enum.auto()
weighted_random = enum.auto()
sampling_without_replacement = enum.auto()
def create_assignment_strategy(strategy_name: StrategyName, dataset_size: int, weights: List[int]) -> "BaseStrategy":
if strategy_name == StrategyName.weighted_sequential:
return WeightedSequentialStrategy(dataset_size, weights)
elif strategy_name == StrategyName.weighted_random:
return WeightedRandomStrategy(dataset_size, weights)
elif strategy_name == StrategyName.sampling_without_replacement:
return SamplingWithoutReplacementStrategy(dataset_size, weights)
else:
raise ValueError(f"Unknown strategy name: {strategy_name}")
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def assign(self) -> List[Assignment]:
...
class WeightedSequentialStrategy(BaseStrategy):
def __init__(self, dataset_size: int, weights: List[int]):
assert sum(weights) == 100
self.dataset_size = dataset_size
self.weights = weights
def assign(self) -> List[Assignment]:
assignments = []
proba = np.array(self.weights) / 100
counts = np.round(proba * self.dataset_size).astype(int)
reminder = self.dataset_size - sum(counts)
for i in np.random.choice(range(len(self.weights)), size=reminder, p=proba):
counts[i] += 1
start = 0
for user, count in enumerate(counts):
assignments.extend([Assignment(user=user, example=example) for example in range(start, start + count)])
start += count
return assignments
class WeightedRandomStrategy(BaseStrategy):
def __init__(self, dataset_size: int, weights: List[int]):
assert sum(weights) == 100
self.dataset_size = dataset_size
self.weights = weights
def assign(self) -> List[Assignment]:
proba = np.array(self.weights) / 100
assignees = np.random.choice(range(len(self.weights)), size=self.dataset_size, p=proba)
return [Assignment(user=user, example=example) for example, user in enumerate(assignees)]
class SamplingWithoutReplacementStrategy(BaseStrategy):
def __init__(self, dataset_size: int, weights: List[int]):
assert 0 <= sum(weights) <= 100 * len(weights)
self.dataset_size = dataset_size
self.weights = weights
def assign(self) -> List[Assignment]:
assignments = []
proba = np.array(self.weights) / 100
for user, p in enumerate(proba):
count = int(self.dataset_size * p)
examples = random.sample(range(self.dataset_size), count)
assignments.extend([Assignment(user=user, example=example) for example in examples])
return assignments