From 1f168b78af78ff6cdbdc254f5bf1b92ac888526e Mon Sep 17 00:00:00 2001 From: Hironsan Date: Mon, 31 Jul 2023 16:23:13 +0900 Subject: [PATCH] Add test cases for bulk_assign --- backend/examples/assignment/usecase.py | 15 +++++------ backend/examples/assignment/workload.py | 4 +-- backend/examples/tests/test_usecase.py | 33 +++++++++++++++++++++++++ backend/examples/views/assignment.py | 3 ++- 4 files changed, 45 insertions(+), 10 deletions(-) create mode 100644 backend/examples/tests/test_usecase.py diff --git a/backend/examples/assignment/usecase.py b/backend/examples/assignment/usecase.py index a93be052..00408e48 100644 --- a/backend/examples/assignment/usecase.py +++ b/backend/examples/assignment/usecase.py @@ -1,23 +1,24 @@ +from typing import List + from django.shortcuts import get_object_or_404 from examples.assignment.strategies import StrategyName, create_assignment_strategy -from examples.assignment.workload import WorkloadAllocation from examples.models import Assignment, Example from projects.models import Member, Project -def bulk_assign(project_id: int, workload_allocation: WorkloadAllocation, strategy_name: StrategyName) -> None: +def bulk_assign(project_id: int, strategy_name: StrategyName, member_ids: List[int], weights: List[int]) -> None: project = get_object_or_404(Project, pk=project_id) - members = Member.objects.filter(project=project, pk__in=workload_allocation.member_ids) - if len(members) != len(workload_allocation.member_ids): + members = Member.objects.filter(project=project, pk__in=member_ids) + if len(members) != len(member_ids): raise ValueError("Invalid member ids") - # Sort members by workload_allocation.member_ids - members = sorted(members, key=lambda m: workload_allocation.member_ids.index(m.id)) + # Sort members by member_ids + members = sorted(members, key=lambda m: member_ids.index(m.id)) unassigned_examples = Example.objects.filter(project=project, assignments__isnull=True) dataset_size = unassigned_examples.count() - strategy = create_assignment_strategy(strategy_name, dataset_size, workload_allocation.weights) + strategy = create_assignment_strategy(strategy_name, dataset_size, weights) assignments = strategy.assign() assignments = [ Assignment( diff --git a/backend/examples/assignment/workload.py b/backend/examples/assignment/workload.py index a21c260c..26eb6a04 100644 --- a/backend/examples/assignment/workload.py +++ b/backend/examples/assignment/workload.py @@ -12,9 +12,9 @@ class WorkloadAllocation(BaseModel): workloads: List[Workload] @property - def member_ids(self): + def member_ids(self) -> List[int]: return [w.member_id for w in self.workloads] @property - def weights(self): + def weights(self) -> List[int]: return [w.weight for w in self.workloads] diff --git a/backend/examples/tests/test_usecase.py b/backend/examples/tests/test_usecase.py new file mode 100644 index 00000000..3bd3e3b6 --- /dev/null +++ b/backend/examples/tests/test_usecase.py @@ -0,0 +1,33 @@ +from django.test import TestCase +from model_mommy import mommy + +from examples.assignment.usecase import StrategyName, bulk_assign +from projects.models import Member, ProjectType +from projects.tests.utils import prepare_project + + +class TestBulkAssignment(TestCase): + def setUp(self): + self.project = prepare_project(ProjectType.SEQUENCE_LABELING) + self.member_ids = list(Member.objects.values_list("id", flat=True)) + self.example = mommy.make("Example", project=self.project.item) + + def test_raise_error_if_weights_is_invalid(self): + with self.assertRaises(ValueError): + bulk_assign( + self.project.item.id, StrategyName.weighted_sequential, self.member_ids, [0] * len(self.member_ids) + ) + + def test_raise_error_if_passing_wrong_member_ids(self): + with self.assertRaises(ValueError): + bulk_assign( + self.project.item.id, + StrategyName.weighted_sequential, + self.member_ids + [100], + [0] * len(self.member_ids), + ) + + def test_assign_examples(self): + bulk_assign(self.project.item.id, StrategyName.weighted_sequential, self.member_ids, [100, 0, 0]) + self.assertEqual(self.example.assignments.count(), 1) + self.assertEqual(self.example.assignments.first().assignee, self.project.admin) diff --git a/backend/examples/views/assignment.py b/backend/examples/views/assignment.py index f4256d43..08c10873 100644 --- a/backend/examples/views/assignment.py +++ b/backend/examples/views/assignment.py @@ -76,8 +76,9 @@ class BulkAssignment(APIView): try: bulk_assign( project_id=self.kwargs["project_id"], - workload_allocation=workload_allocation, strategy_name=strategy_name, + member_ids=workload_allocation.member_ids, + weights=workload_allocation.weights, ) except ValueError as e: return Response(