diff --git a/backend/auto_labeling/tests/test_views.py b/backend/auto_labeling/tests/test_views.py index 8d6eefd4..fbd141c4 100644 --- a/backend/auto_labeling/tests/test_views.py +++ b/backend/auto_labeling/tests/test_views.py @@ -3,6 +3,7 @@ from unittest.mock import patch from auto_labeling_pipeline.mappings import AmazonComprehendSentimentTemplate from auto_labeling_pipeline.models import RequestModelFactory +from model_mommy import mommy from rest_framework import status from rest_framework.reverse import reverse @@ -107,7 +108,13 @@ class TestConfigCreation(CRUDMixin): self.url = reverse(viewname='auto_labeling_configs', args=[self.project.item.id]) def test_create_config(self): - self.assert_create(self.project.users[0], status.HTTP_201_CREATED) + response = self.assert_create(self.project.users[0], status.HTTP_201_CREATED) + self.assertEqual(response.data['model_name'], self.data['model_name']) + + def test_list_config(self): + mommy.make('AutoLabelingConfig', project=self.project.item) + response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) class TestAutoLabelingText(CRUDMixin): diff --git a/backend/auto_labeling/views.py b/backend/auto_labeling/views.py index 345d2724..279758c4 100644 --- a/backend/auto_labeling/views.py +++ b/backend/auto_labeling/views.py @@ -39,8 +39,7 @@ class AutoLabelingTemplateDetailAPI(APIView): permission_classes = [IsAuthenticated & IsProjectAdmin] def get(self, request, *args, **kwargs): - option_name = self.kwargs['option_name'] - option = Options.find(option_name=option_name) + option = Options.find(option_name=self.kwargs['option_name']) return Response(option.to_dict(), status=status.HTTP_200_OK) @@ -50,12 +49,10 @@ class AutoLabelingConfigList(generics.ListCreateAPIView): permission_classes = [IsAuthenticated & IsProjectAdmin] def get_queryset(self): - project = get_object_or_404(Project, pk=self.kwargs['project_id']) - return project.auto_labeling_config + return AutoLabelingConfig.objects.filter(project=self.kwargs['project_id']) def perform_create(self, serializer): - project = get_object_or_404(Project, pk=self.kwargs['project_id']) - serializer.save(project=project) + serializer.save(project_id=self.kwargs['project_id']) class AutoLabelingConfigDetail(generics.RetrieveUpdateDestroyAPIView):