From 82e7289bb7ddb5a9997b2ff209a9515eb5841b6c Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 27 Jul 2023 13:29:50 +0900 Subject: [PATCH] Add bulk assignment API --- backend/examples/assignment/strategies.py | 18 ++++++++ backend/examples/assignment/workload.py | 20 +++++++++ backend/examples/urls.py | 8 +++- backend/examples/views/assignment.py | 52 ++++++++++++++++++++++- 4 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 backend/examples/assignment/workload.py diff --git a/backend/examples/assignment/strategies.py b/backend/examples/assignment/strategies.py index 0d72033c..2913eff4 100644 --- a/backend/examples/assignment/strategies.py +++ b/backend/examples/assignment/strategies.py @@ -1,5 +1,6 @@ import abc import dataclasses +import enum import random from typing import List @@ -12,6 +13,23 @@ class Assignment: example: int +class StrategyName(enum.Enum): + weighted_sequential = enum.auto() + weighted_random = enum.auto() + sampling_without_replacement = enum.auto() + + +def create_assignment_strategy(strategy_name: StrategyName, dataset_size: int, weights: List[int]) -> "BaseStrategy": + if strategy_name == StrategyName.weighted_sequential: + return WeightedSequentialStrategy(dataset_size, weights) + elif strategy_name == StrategyName.weighted_random: + return WeightedRandomStrategy(dataset_size, weights) + elif strategy_name == StrategyName.sampling_without_replacement: + return SamplingWithoutReplacementStrategy(dataset_size, weights) + else: + raise ValueError(f"Unknown strategy name: {strategy_name}") + + class BaseStrategy(abc.ABC): @abc.abstractmethod def assign(self) -> List[Assignment]: diff --git a/backend/examples/assignment/workload.py b/backend/examples/assignment/workload.py new file mode 100644 index 00000000..a21c260c --- /dev/null +++ b/backend/examples/assignment/workload.py @@ -0,0 +1,20 @@ +from typing import List + +from pydantic import BaseModel, NonNegativeInt + + +class Workload(BaseModel): + weight: NonNegativeInt + member_id: int + + +class WorkloadAllocation(BaseModel): + workloads: List[Workload] + + @property + def member_ids(self): + return [w.member_id for w in self.workloads] + + @property + def weights(self): + return [w.weight for w in self.workloads] diff --git a/backend/examples/urls.py b/backend/examples/urls.py index 82f4b398..56c6cbdb 100644 --- a/backend/examples/urls.py +++ b/backend/examples/urls.py @@ -1,6 +1,11 @@ from django.urls import path -from .views.assignment import AssignmentDetail, AssignmentList, ResetAssignment +from .views.assignment import ( + AssignmentDetail, + AssignmentList, + BulkAssignment, + ResetAssignment, +) from .views.comment import CommentDetail, CommentList from .views.example import ExampleDetail, ExampleList from .views.example_state import ExampleStateList @@ -9,6 +14,7 @@ urlpatterns = [ path(route="assignments", view=AssignmentList.as_view(), name="assignment_list"), path(route="assignments/", view=AssignmentDetail.as_view(), name="assignment_detail"), path(route="assignments/reset", view=ResetAssignment.as_view(), name="assignment_reset"), + path(route="assignments/bulk_assign", view=BulkAssignment.as_view(), name="bulk_assignment"), path(route="examples", view=ExampleList.as_view(), name="example_list"), path(route="examples/", view=ExampleDetail.as_view(), name="example_detail"), path(route="comments", view=CommentList.as_view(), name="comment_list"), diff --git a/backend/examples/views/assignment.py b/backend/examples/views/assignment.py index 3f9767a9..11fd0a7e 100644 --- a/backend/examples/views/assignment.py +++ b/backend/examples/views/assignment.py @@ -1,12 +1,15 @@ from django.shortcuts import get_object_or_404 from django_filters.rest_framework import DjangoFilterBackend +from pydantic import ValidationError from rest_framework import filters, generics, status from rest_framework.permissions import IsAuthenticated from rest_framework.views import APIView, Response +from examples.assignment.strategies import StrategyName, create_assignment_strategy +from examples.assignment.workload import WorkloadAllocation from examples.models import Assignment from examples.serializers import AssignmentSerializer -from projects.models import Project +from projects.models import Member, Project from projects.permissions import IsProjectAdmin, IsProjectStaffAndReadOnly @@ -46,3 +49,50 @@ class ResetAssignment(APIView): def delete(self, *args, **kwargs): Assignment.objects.filter(project=self.project).delete() return Response(status=status.HTTP_204_NO_CONTENT) + + +class BulkAssignment(APIView): + serializer_class = AssignmentSerializer + permission_classes = [IsAuthenticated & IsProjectAdmin] + + def post(self, *args, **kwargs): + try: + strategy_name = StrategyName[self.request.data["strategy_name"]] + except KeyError: + return Response( + {"detail": "Invalid strategy name"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + workload_allocation = WorkloadAllocation(workloads=self.request.data["workloads"]) + except ValidationError as e: + return Response( + {"detail": e.errors()}, + status=status.HTTP_400_BAD_REQUEST, + ) + + project = get_object_or_404(Project, pk=self.kwargs["project_id"]) + members = Member.objects.filter(project=project, pk__in=workload_allocation.member_ids) + if len(members) != len(workload_allocation.member_ids): + return Response( + {"detail": "Invalid member ids"}, + status=status.HTTP_400_BAD_REQUEST, + ) + # Sort members by workload_allocation.member_ids + members = sorted(members, key=lambda m: workload_allocation.member_ids.index(m.id)) + + dataset_size = project.examples.count() # Todo: unassigned examples + strategy = create_assignment_strategy(strategy_name, dataset_size, workload_allocation.weights) + assignments = strategy.assign() + example_ids = project.examples.values_list("pk", flat=True) + assignments = [ + Assignment( + project=project, + example=example_ids[assignment.example], + assignee=members[assignment.user].user, + ) + for assignment in assignments + ] + Assignment.objects.bulk_create(assignments) + return Response(status=status.HTTP_201_CREATED)