From 4d5b1858b41020e01246c648a52b4da204bd4edb Mon Sep 17 00:00:00 2001 From: Hironsan Date: Mon, 31 Jul 2023 14:11:59 +0900 Subject: [PATCH] Extract bulk_assign function --- backend/examples/assignment/strategies.py | 9 ++++-- backend/examples/assignment/usecase.py | 30 +++++++++++++++++++ backend/examples/views/assignment.py | 35 +++++++---------------- 3 files changed, 47 insertions(+), 27 deletions(-) create mode 100644 backend/examples/assignment/usecase.py diff --git a/backend/examples/assignment/strategies.py b/backend/examples/assignment/strategies.py index 2913eff4..93419d46 100644 --- a/backend/examples/assignment/strategies.py +++ b/backend/examples/assignment/strategies.py @@ -38,7 +38,8 @@ class BaseStrategy(abc.ABC): class WeightedSequentialStrategy(BaseStrategy): def __init__(self, dataset_size: int, weights: List[int]): - assert sum(weights) == 100 + if sum(weights) != 100: + raise ValueError("Sum of weights must be 100") self.dataset_size = dataset_size self.weights = weights @@ -59,7 +60,8 @@ class WeightedSequentialStrategy(BaseStrategy): class WeightedRandomStrategy(BaseStrategy): def __init__(self, dataset_size: int, weights: List[int]): - assert sum(weights) == 100 + if sum(weights) != 100: + raise ValueError("Sum of weights must be 100") self.dataset_size = dataset_size self.weights = weights @@ -71,7 +73,8 @@ class WeightedRandomStrategy(BaseStrategy): class SamplingWithoutReplacementStrategy(BaseStrategy): def __init__(self, dataset_size: int, weights: List[int]): - assert 0 <= sum(weights) <= 100 * len(weights) + if not (0 <= sum(weights) <= 100 * len(weights)): + raise ValueError("Sum of weights must be between 0 and 100 x number of members") self.dataset_size = dataset_size self.weights = weights diff --git a/backend/examples/assignment/usecase.py b/backend/examples/assignment/usecase.py new file mode 100644 index 00000000..8b6aa046 --- /dev/null +++ b/backend/examples/assignment/usecase.py @@ -0,0 +1,30 @@ +from django.shortcuts import get_object_or_404 + +from examples.assignment.strategies import StrategyName, create_assignment_strategy +from examples.assignment.workload import WorkloadAllocation +from examples.models import Assignment +from projects.models import Member, Project + + +def bulk_assign(project_id: int, workload_allocation: WorkloadAllocation, strategy_name: StrategyName) -> None: + project = get_object_or_404(Project, pk=project_id) + members = Member.objects.filter(project=project, pk__in=workload_allocation.member_ids) + if len(members) != len(workload_allocation.member_ids): + raise ValueError("Invalid member ids") + # Sort members by workload_allocation.member_ids + members = sorted(members, key=lambda m: workload_allocation.member_ids.index(m.id)) + + dataset_size = project.examples.count() # Todo: unassigned examples + + strategy = create_assignment_strategy(strategy_name, dataset_size, workload_allocation.weights) + assignments = strategy.assign() + examples = project.examples.all() + assignments = [ + Assignment( + project=project, + example=examples[assignment.example], + assignee=members[assignment.user].user, + ) + for assignment in assignments + ] + Assignment.objects.bulk_create(assignments) diff --git a/backend/examples/views/assignment.py b/backend/examples/views/assignment.py index f0b50efb..f4256d43 100644 --- a/backend/examples/views/assignment.py +++ b/backend/examples/views/assignment.py @@ -5,11 +5,12 @@ from rest_framework import filters, generics, status from rest_framework.permissions import IsAuthenticated from rest_framework.views import APIView, Response -from examples.assignment.strategies import StrategyName, create_assignment_strategy +from examples.assignment.strategies import StrategyName +from examples.assignment.usecase import bulk_assign from examples.assignment.workload import WorkloadAllocation from examples.models import Assignment from examples.serializers import AssignmentSerializer -from projects.models import Member, Project +from projects.models import Project from projects.permissions import IsProjectAdmin, IsProjectStaffAndReadOnly @@ -72,29 +73,15 @@ class BulkAssignment(APIView): status=status.HTTP_400_BAD_REQUEST, ) - project = get_object_or_404(Project, pk=self.kwargs["project_id"]) - members = Member.objects.filter(project=project, pk__in=workload_allocation.member_ids) - if len(members) != len(workload_allocation.member_ids): + try: + bulk_assign( + project_id=self.kwargs["project_id"], + workload_allocation=workload_allocation, + strategy_name=strategy_name, + ) + except ValueError as e: return Response( - {"detail": "Invalid member ids"}, + {"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST, ) - # Sort members by workload_allocation.member_ids - members = sorted(members, key=lambda m: workload_allocation.member_ids.index(m.id)) - - dataset_size = project.examples.count() # Todo: unassigned examples - strategy = create_assignment_strategy( - strategy_name, dataset_size, workload_allocation.weights - ) # Todo: raise 400 if weights are not valid - assignments = strategy.assign() - examples = project.examples.all() - assignments = [ - Assignment( - project=project, - example=examples[assignment.example], - assignee=members[assignment.user].user, - ) - for assignment in assignments - ] - Assignment.objects.bulk_create(assignments) return Response(status=status.HTTP_201_CREATED)