From d04d6571ffc0779e11bd0f3a0b4a8bead38e52e0 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 10 Jun 2021 16:28:17 +0900 Subject: [PATCH] Update AutoLabelingConfigParameterTest --- backend/api/models.py | 19 ++++++++++ backend/api/tests/api/test_auto_labeling.py | 41 +++++++++++++++++++++ backend/api/urls.py | 10 ++--- backend/api/views/auto_labeling.py | 29 +++++++++++++-- 4 files changed, 90 insertions(+), 9 deletions(-) create mode 100644 backend/api/tests/api/test_auto_labeling.py diff --git a/backend/api/models.py b/backend/api/models.py index feda70cc..18d52f45 100644 --- a/backend/api/models.py +++ b/backend/api/models.py @@ -1,4 +1,5 @@ import string +from typing import Literal from auto_labeling_pipeline.models import RequestModelFactory from django.contrib.auth.models import User @@ -38,6 +39,9 @@ class Project(PolymorphicModel): def get_annotation_class(self): raise NotImplementedError() + def is_task_of(self, task: Literal['text', 'image', 'speech']): + raise NotImplementedError() + def __str__(self): return self.name @@ -47,30 +51,45 @@ class TextClassificationProject(Project): def get_annotation_class(self): return Category + def is_task_of(self, task: Literal['text', 'image', 'speech']): + return task == 'text' + class SequenceLabelingProject(Project): def get_annotation_class(self): return Span + def is_task_of(self, task: Literal['text', 'image', 'speech']): + return task == 'text' + class Seq2seqProject(Project): def get_annotation_class(self): return TextLabel + def is_task_of(self, task: Literal['text', 'image', 'speech']): + return task == 'text' + class Speech2textProject(Project): def get_annotation_class(self): return TextLabel + def is_task_of(self, task: Literal['text', 'image', 'speech']): + return task == 'speech' + class ImageClassificationProject(Project): def get_annotation_class(self): return Category + def is_task_of(self, task: Literal['text', 'image', 'speech']): + return task == 'image' + class Label(models.Model): text = models.CharField(max_length=100) diff --git a/backend/api/tests/api/test_auto_labeling.py b/backend/api/tests/api/test_auto_labeling.py new file mode 100644 index 00000000..87dd4749 --- /dev/null +++ b/backend/api/tests/api/test_auto_labeling.py @@ -0,0 +1,41 @@ +from unittest.mock import patch + +from auto_labeling_pipeline.models import RequestModelFactory +from rest_framework import status +from rest_framework.reverse import reverse + +from ...models import DOCUMENT_CLASSIFICATION, Category +from .utils import (CRUDMixin, make_annotation, make_doc, + make_user, prepare_project) + + +class TestConfigParameter(CRUDMixin): + + @classmethod + def setUpTestData(cls): + cls.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + cls.data = { + 'model_name': 'GCP Entity Analysis', + 'model_attrs': {'key': 'hoge', 'type': 'PLAIN_TEXT', 'language': 'en'}, + 'text': 'example' + } + cls.url = reverse(viewname='auto_labeling_parameter_testing', args=[cls.project.item.id]) + + @patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={}) + def test_called_with_proper_model(self, mock): + self.assert_create(self.project.users[0], status.HTTP_200_OK) + _, kwargs = mock.call_args + expected = RequestModelFactory.create(self.data['model_name'], self.data['model_attrs']) + self.assertEqual(kwargs['model'], expected) + + @patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={}) + def test_called_with_text(self, mock): + self.assert_create(self.project.users[0], status.HTTP_200_OK) + _, kwargs = mock.call_args + self.assertEqual(kwargs['example'], self.data['text']) + + @patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={}) + def test_called_with_image(self, mock): + self.assert_create(self.project.users[0], status.HTTP_200_OK) + _, kwargs = mock.call_args + self.assertEqual(kwargs['example'], self.data['text']) diff --git a/backend/api/urls.py b/backend/api/urls.py index f3ddd21c..ce91d9fc 100644 --- a/backend/api/urls.py +++ b/backend/api/urls.py @@ -179,6 +179,11 @@ urlpatterns_project = [ view=views.AutoLabelingAnnotation.as_view(), name='auto_labeling_annotation' ), + path( + route='auto-labeling-parameter-testing', + view=views.AutoLabelingConfigParameterTest.as_view(), + name='auto_labeling_parameter_testing' + ), path( route='auto-labeling-template-testing', view=views.AutoLabelingTemplateTest.as_view(), @@ -224,11 +229,6 @@ urlpatterns = [ view=views.Roles.as_view(), name='roles' ), - path( - route='auto-labeling-parameter-testing', - view=views.AutoLabelingConfigParameterTest.as_view(), - name='auto_labeling_parameter_testing' - ), path( route='tasks/status/', view=views.TaskStatus.as_view(), diff --git a/backend/api/views/auto_labeling.py b/backend/api/views/auto_labeling.py index ae660f79..6b2cafec 100644 --- a/backend/api/views/auto_labeling.py +++ b/backend/api/views/auto_labeling.py @@ -106,12 +106,16 @@ class AutoLabelingConfigTest(APIView): class AutoLabelingConfigParameterTest(APIView): permission_classes = [IsAuthenticated & IsProjectAdmin] - def post(self, *args, **kwargs): + @property + def project(self): + return get_object_or_404(Project, pk=self.kwargs['project_id']) + + def create_model(self): model_name = self.request.data['model_name'] model_attrs = self.request.data['model_attrs'] - sample_text = self.request.data['text'] try: model = RequestModelFactory.create(model_name, model_attrs) + return model except Exception: model = RequestModelFactory.find(model_name) schema = model.schema() @@ -120,9 +124,11 @@ class AutoLabelingConfigParameterTest(APIView): 'The attributes does not match the model.' 'You need to correctly specify the required fields: {}'.format(required_fields) ) + + def send_request(self, model, example): try: - response = model.send(text=sample_text) - return Response(response, status=status.HTTP_200_OK) + response = model.send(example) + return response except requests.exceptions.ConnectionError: raise URLConnectionError except botocore.exceptions.ClientError: @@ -130,6 +136,21 @@ class AutoLabelingConfigParameterTest(APIView): except Exception as e: raise e + def prepare_example(self): + if self.project.is_task_of('text'): + return self.request.data['text'] + elif self.project.is_task_of('image'): + return '' + elif self.project.is_task_of('speech'): + raise NotImplementedError('can not handle speech data now.') + raise NotImplementedError('The project type is unknown.') + + def post(self, *args, **kwargs): + model = self.create_model() + example = self.prepare_example() + response = self.send_request(model=model, example=example) + return Response(response, status=status.HTTP_200_OK) + class AutoLabelingTemplateTest(APIView): permission_classes = [IsAuthenticated & IsProjectAdmin]