diff --git a/backend/auto_labeling/tests/test_views.py b/backend/auto_labeling/tests/test_views.py index 890dded4..5d3b564d 100644 --- a/backend/auto_labeling/tests/test_views.py +++ b/backend/auto_labeling/tests/test_views.py @@ -14,6 +14,35 @@ from api.tests.api.utils import (CRUDMixin, make_auto_labeling_config, make_doc, data_dir = pathlib.Path(__file__).parent / 'data' +class TestTemplateList(CRUDMixin): + + def setUp(self): + self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.url = reverse(viewname='auto_labeling_templates', args=[self.project.item.id]) + + def test_allow_admin_to_fetch_template_list(self): + self.url += '?task_name=DocumentClassification' + response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK) + self.assertIn('Custom REST Request', response.data) + self.assertGreaterEqual(len(response.data), 1) + + def test_deny_non_admin_to_fetch_template_list(self): + self.url += '?task_name=DocumentClassification' + for user in self.project.users[1:]: + self.assert_fetch(user, status.HTTP_403_FORBIDDEN) + + def test_return_only_default_template_with_empty_task_name(self): + response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertIn('Custom REST Request', response.data) + + def test_return_only_default_template_with_wrong_task_name(self): + self.url += '?task_name=foobar' + response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertIn('Custom REST Request', response.data) + + class TestConfigParameter(CRUDMixin): def setUp(self): diff --git a/backend/auto_labeling/views.py b/backend/auto_labeling/views.py index d91048be..96d27633 100644 --- a/backend/auto_labeling/views.py +++ b/backend/auto_labeling/views.py @@ -13,6 +13,7 @@ from django_drf_filepond.models import TemporaryUpload from rest_framework import generics, status from rest_framework.exceptions import ValidationError from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView @@ -28,9 +29,9 @@ from .serializers import (AutoLabelingConfigSerializer, get_annotation_serialize class TemplateListAPI(APIView): permission_classes = [IsAuthenticated & IsProjectAdmin] - def get(self, request, *args, **kwargs): - project = get_object_or_404(Project, pk=self.kwargs['project_id']) - options = Options.filter_by_task(task_name=project.project_type) + def get(self, request: Request, *args, **kwargs): + task_name = request.query_params.get('task_name') + options = Options.filter_by_task(task_name=task_name) option_names = [o.name for o in options] return Response(option_names, status=status.HTTP_200_OK) @@ -127,8 +128,7 @@ class RestAPIRequestTesting(APIView): def send_request(self, model, example): try: - response = model.send(example) - return response + return model.send(example) except requests.exceptions.ConnectionError: raise URLConnectionError except botocore.exceptions.ClientError: