Browse Source

Add bulk assignment API

pull/2261/head
Hironsan 1 year ago
parent
commit
82e7289bb7
4 changed files with 96 additions and 2 deletions
  1. 18
      backend/examples/assignment/strategies.py
  2. 20
      backend/examples/assignment/workload.py
  3. 8
      backend/examples/urls.py
  4. 52
      backend/examples/views/assignment.py

18
backend/examples/assignment/strategies.py

@ -1,5 +1,6 @@
import abc
import dataclasses
import enum
import random
from typing import List
@ -12,6 +13,23 @@ class Assignment:
example: int
class StrategyName(enum.Enum):
weighted_sequential = enum.auto()
weighted_random = enum.auto()
sampling_without_replacement = enum.auto()
def create_assignment_strategy(strategy_name: StrategyName, dataset_size: int, weights: List[int]) -> "BaseStrategy":
if strategy_name == StrategyName.weighted_sequential:
return WeightedSequentialStrategy(dataset_size, weights)
elif strategy_name == StrategyName.weighted_random:
return WeightedRandomStrategy(dataset_size, weights)
elif strategy_name == StrategyName.sampling_without_replacement:
return SamplingWithoutReplacementStrategy(dataset_size, weights)
else:
raise ValueError(f"Unknown strategy name: {strategy_name}")
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def assign(self) -> List[Assignment]:

20
backend/examples/assignment/workload.py

@ -0,0 +1,20 @@
from typing import List
from pydantic import BaseModel, NonNegativeInt
class Workload(BaseModel):
weight: NonNegativeInt
member_id: int
class WorkloadAllocation(BaseModel):
workloads: List[Workload]
@property
def member_ids(self):
return [w.member_id for w in self.workloads]
@property
def weights(self):
return [w.weight for w in self.workloads]

8
backend/examples/urls.py

@ -1,6 +1,11 @@
from django.urls import path
from .views.assignment import AssignmentDetail, AssignmentList, ResetAssignment
from .views.assignment import (
AssignmentDetail,
AssignmentList,
BulkAssignment,
ResetAssignment,
)
from .views.comment import CommentDetail, CommentList
from .views.example import ExampleDetail, ExampleList
from .views.example_state import ExampleStateList
@ -9,6 +14,7 @@ urlpatterns = [
path(route="assignments", view=AssignmentList.as_view(), name="assignment_list"),
path(route="assignments/<uuid:assignment_id>", view=AssignmentDetail.as_view(), name="assignment_detail"),
path(route="assignments/reset", view=ResetAssignment.as_view(), name="assignment_reset"),
path(route="assignments/bulk_assign", view=BulkAssignment.as_view(), name="bulk_assignment"),
path(route="examples", view=ExampleList.as_view(), name="example_list"),
path(route="examples/<int:example_id>", view=ExampleDetail.as_view(), name="example_detail"),
path(route="comments", view=CommentList.as_view(), name="comment_list"),

52
backend/examples/views/assignment.py

@ -1,12 +1,15 @@
from django.shortcuts import get_object_or_404
from django_filters.rest_framework import DjangoFilterBackend
from pydantic import ValidationError
from rest_framework import filters, generics, status
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView, Response
from examples.assignment.strategies import StrategyName, create_assignment_strategy
from examples.assignment.workload import WorkloadAllocation
from examples.models import Assignment
from examples.serializers import AssignmentSerializer
from projects.models import Project
from projects.models import Member, Project
from projects.permissions import IsProjectAdmin, IsProjectStaffAndReadOnly
@ -46,3 +49,50 @@ class ResetAssignment(APIView):
def delete(self, *args, **kwargs):
Assignment.objects.filter(project=self.project).delete()
return Response(status=status.HTTP_204_NO_CONTENT)
class BulkAssignment(APIView):
serializer_class = AssignmentSerializer
permission_classes = [IsAuthenticated & IsProjectAdmin]
def post(self, *args, **kwargs):
try:
strategy_name = StrategyName[self.request.data["strategy_name"]]
except KeyError:
return Response(
{"detail": "Invalid strategy name"},
status=status.HTTP_400_BAD_REQUEST,
)
try:
workload_allocation = WorkloadAllocation(workloads=self.request.data["workloads"])
except ValidationError as e:
return Response(
{"detail": e.errors()},
status=status.HTTP_400_BAD_REQUEST,
)
project = get_object_or_404(Project, pk=self.kwargs["project_id"])
members = Member.objects.filter(project=project, pk__in=workload_allocation.member_ids)
if len(members) != len(workload_allocation.member_ids):
return Response(
{"detail": "Invalid member ids"},
status=status.HTTP_400_BAD_REQUEST,
)
# Sort members by workload_allocation.member_ids
members = sorted(members, key=lambda m: workload_allocation.member_ids.index(m.id))
dataset_size = project.examples.count() # Todo: unassigned examples
strategy = create_assignment_strategy(strategy_name, dataset_size, workload_allocation.weights)
assignments = strategy.assign()
example_ids = project.examples.values_list("pk", flat=True)
assignments = [
Assignment(
project=project,
example=example_ids[assignment.example],
assignee=members[assignment.user].user,
)
for assignment in assignments
]
Assignment.objects.bulk_create(assignments)
return Response(status=status.HTTP_201_CREATED)
Loading…
Cancel
Save