Browse Source

Add test cases for TemplateList

pull/1650/head
Hironsan 3 years ago
parent
commit
5159ff61c7
2 changed files with 34 additions and 5 deletions
  1. 29
      backend/auto_labeling/tests/test_views.py
  2. 10
      backend/auto_labeling/views.py

29
backend/auto_labeling/tests/test_views.py

@ -14,6 +14,35 @@ from api.tests.api.utils import (CRUDMixin, make_auto_labeling_config, make_doc,
data_dir = pathlib.Path(__file__).parent / 'data'
class TestTemplateList(CRUDMixin):
def setUp(self):
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
self.url = reverse(viewname='auto_labeling_templates', args=[self.project.item.id])
def test_allow_admin_to_fetch_template_list(self):
self.url += '?task_name=DocumentClassification'
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
self.assertIn('Custom REST Request', response.data)
self.assertGreaterEqual(len(response.data), 1)
def test_deny_non_admin_to_fetch_template_list(self):
self.url += '?task_name=DocumentClassification'
for user in self.project.users[1:]:
self.assert_fetch(user, status.HTTP_403_FORBIDDEN)
def test_return_only_default_template_with_empty_task_name(self):
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
self.assertIn('Custom REST Request', response.data)
def test_return_only_default_template_with_wrong_task_name(self):
self.url += '?task_name=foobar'
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
self.assertIn('Custom REST Request', response.data)
class TestConfigParameter(CRUDMixin):
def setUp(self):

10
backend/auto_labeling/views.py

@ -13,6 +13,7 @@ from django_drf_filepond.models import TemporaryUpload
from rest_framework import generics, status
from rest_framework.exceptions import ValidationError
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
@ -28,9 +29,9 @@ from .serializers import (AutoLabelingConfigSerializer, get_annotation_serialize
class TemplateListAPI(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
def get(self, request, *args, **kwargs):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
options = Options.filter_by_task(task_name=project.project_type)
def get(self, request: Request, *args, **kwargs):
task_name = request.query_params.get('task_name')
options = Options.filter_by_task(task_name=task_name)
option_names = [o.name for o in options]
return Response(option_names, status=status.HTTP_200_OK)
@ -127,8 +128,7 @@ class RestAPIRequestTesting(APIView):
def send_request(self, model, example):
try:
response = model.send(example)
return response
return model.send(example)
except requests.exceptions.ConnectionError:
raise URLConnectionError
except botocore.exceptions.ClientError:

Loading…
Cancel
Save