Browse Source

Update AutoLabelingConfigParameterTest

pull/1413/head
Hironsan 3 years ago
parent
commit
d04d6571ff
4 changed files with 90 additions and 9 deletions
  1. 19
      backend/api/models.py
  2. 41
      backend/api/tests/api/test_auto_labeling.py
  3. 10
      backend/api/urls.py
  4. 29
      backend/api/views/auto_labeling.py

19
backend/api/models.py

@ -1,4 +1,5 @@
import string
from typing import Literal
from auto_labeling_pipeline.models import RequestModelFactory
from django.contrib.auth.models import User
@ -38,6 +39,9 @@ class Project(PolymorphicModel):
def get_annotation_class(self):
raise NotImplementedError()
def is_task_of(self, task: Literal['text', 'image', 'speech']):
raise NotImplementedError()
def __str__(self):
return self.name
@ -47,30 +51,45 @@ class TextClassificationProject(Project):
def get_annotation_class(self):
return Category
def is_task_of(self, task: Literal['text', 'image', 'speech']):
return task == 'text'
class SequenceLabelingProject(Project):
def get_annotation_class(self):
return Span
def is_task_of(self, task: Literal['text', 'image', 'speech']):
return task == 'text'
class Seq2seqProject(Project):
def get_annotation_class(self):
return TextLabel
def is_task_of(self, task: Literal['text', 'image', 'speech']):
return task == 'text'
class Speech2textProject(Project):
def get_annotation_class(self):
return TextLabel
def is_task_of(self, task: Literal['text', 'image', 'speech']):
return task == 'speech'
class ImageClassificationProject(Project):
def get_annotation_class(self):
return Category
def is_task_of(self, task: Literal['text', 'image', 'speech']):
return task == 'image'
class Label(models.Model):
text = models.CharField(max_length=100)

41
backend/api/tests/api/test_auto_labeling.py

@ -0,0 +1,41 @@
from unittest.mock import patch
from auto_labeling_pipeline.models import RequestModelFactory
from rest_framework import status
from rest_framework.reverse import reverse
from ...models import DOCUMENT_CLASSIFICATION, Category
from .utils import (CRUDMixin, make_annotation, make_doc,
make_user, prepare_project)
class TestConfigParameter(CRUDMixin):
@classmethod
def setUpTestData(cls):
cls.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
cls.data = {
'model_name': 'GCP Entity Analysis',
'model_attrs': {'key': 'hoge', 'type': 'PLAIN_TEXT', 'language': 'en'},
'text': 'example'
}
cls.url = reverse(viewname='auto_labeling_parameter_testing', args=[cls.project.item.id])
@patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={})
def test_called_with_proper_model(self, mock):
self.assert_create(self.project.users[0], status.HTTP_200_OK)
_, kwargs = mock.call_args
expected = RequestModelFactory.create(self.data['model_name'], self.data['model_attrs'])
self.assertEqual(kwargs['model'], expected)
@patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={})
def test_called_with_text(self, mock):
self.assert_create(self.project.users[0], status.HTTP_200_OK)
_, kwargs = mock.call_args
self.assertEqual(kwargs['example'], self.data['text'])
@patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={})
def test_called_with_image(self, mock):
self.assert_create(self.project.users[0], status.HTTP_200_OK)
_, kwargs = mock.call_args
self.assertEqual(kwargs['example'], self.data['text'])

10
backend/api/urls.py

@ -179,6 +179,11 @@ urlpatterns_project = [
view=views.AutoLabelingAnnotation.as_view(),
name='auto_labeling_annotation'
),
path(
route='auto-labeling-parameter-testing',
view=views.AutoLabelingConfigParameterTest.as_view(),
name='auto_labeling_parameter_testing'
),
path(
route='auto-labeling-template-testing',
view=views.AutoLabelingTemplateTest.as_view(),
@ -224,11 +229,6 @@ urlpatterns = [
view=views.Roles.as_view(),
name='roles'
),
path(
route='auto-labeling-parameter-testing',
view=views.AutoLabelingConfigParameterTest.as_view(),
name='auto_labeling_parameter_testing'
),
path(
route='tasks/status/<task_id>',
view=views.TaskStatus.as_view(),

29
backend/api/views/auto_labeling.py

@ -106,12 +106,16 @@ class AutoLabelingConfigTest(APIView):
class AutoLabelingConfigParameterTest(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
def post(self, *args, **kwargs):
@property
def project(self):
return get_object_or_404(Project, pk=self.kwargs['project_id'])
def create_model(self):
model_name = self.request.data['model_name']
model_attrs = self.request.data['model_attrs']
sample_text = self.request.data['text']
try:
model = RequestModelFactory.create(model_name, model_attrs)
return model
except Exception:
model = RequestModelFactory.find(model_name)
schema = model.schema()
@ -120,9 +124,11 @@ class AutoLabelingConfigParameterTest(APIView):
'The attributes does not match the model.'
'You need to correctly specify the required fields: {}'.format(required_fields)
)
def send_request(self, model, example):
try:
response = model.send(text=sample_text)
return Response(response, status=status.HTTP_200_OK)
response = model.send(example)
return response
except requests.exceptions.ConnectionError:
raise URLConnectionError
except botocore.exceptions.ClientError:
@ -130,6 +136,21 @@ class AutoLabelingConfigParameterTest(APIView):
except Exception as e:
raise e
def prepare_example(self):
if self.project.is_task_of('text'):
return self.request.data['text']
elif self.project.is_task_of('image'):
return ''
elif self.project.is_task_of('speech'):
raise NotImplementedError('can not handle speech data now.')
raise NotImplementedError('The project type is unknown.')
def post(self, *args, **kwargs):
model = self.create_model()
example = self.prepare_example()
response = self.send_request(model=model, example=example)
return Response(response, status=status.HTTP_200_OK)
class AutoLabelingTemplateTest(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]

Loading…
Cancel
Save