Browse Source

Update AutoLabelingConfigParameterTest

pull/1413/head
Hironsan 3 years ago
parent
commit
6c3925fc9f
2 changed files with 20 additions and 15 deletions
  1. 19
      backend/api/tests/api/test_auto_labeling.py
  2. 16
      backend/api/views/auto_labeling.py

19
backend/api/tests/api/test_auto_labeling.py

@ -1,25 +1,27 @@
import pathlib
from unittest.mock import patch
from auto_labeling_pipeline.models import RequestModelFactory
from rest_framework import status
from rest_framework.reverse import reverse
from ...models import DOCUMENT_CLASSIFICATION, Category
from .utils import (CRUDMixin, make_annotation, make_doc,
make_user, prepare_project)
from ...models import DOCUMENT_CLASSIFICATION
from ...views.auto_labeling import load_data_as_b64
from .utils import CRUDMixin, prepare_project
data_dir = pathlib.Path(__file__).parent / 'data'
class TestConfigParameter(CRUDMixin):
@classmethod
def setUpTestData(cls):
cls.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
cls.data = {
def setUp(self):
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
self.data = {
'model_name': 'GCP Entity Analysis',
'model_attrs': {'key': 'hoge', 'type': 'PLAIN_TEXT', 'language': 'en'},
'text': 'example'
}
cls.url = reverse(viewname='auto_labeling_parameter_testing', args=[cls.project.item.id])
self.url = reverse(viewname='auto_labeling_parameter_testing', args=[self.project.item.id])
@patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={})
def test_called_with_proper_model(self, mock):
@ -36,6 +38,7 @@ class TestConfigParameter(CRUDMixin):
@patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={})
def test_called_with_image(self, mock):
self.data['text'] = load_data_as_b64(data_dir / 'images/1500x500.jpeg')
self.assert_create(self.project.users[0], status.HTTP_200_OK)
_, kwargs = mock.call_args
self.assertEqual(kwargs['example'], self.data['text'])

16
backend/api/views/auto_labeling.py

@ -1,3 +1,5 @@
import base64
import botocore.exceptions
import requests
from auto_labeling_pipeline.mappings import MappingTemplate
@ -22,6 +24,12 @@ from ..serializers import (AutoLabelingConfigSerializer,
get_annotation_serializer)
def load_data_as_b64(filepath):
with open(filepath, 'rb') as f:
byte_str = base64.b64encode(f.read())
return byte_str.decode('utf-8')
class AutoLabelingTemplateListAPI(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
@ -137,13 +145,7 @@ class AutoLabelingConfigParameterTest(APIView):
raise e
def prepare_example(self):
if self.project.is_task_of('text'):
return self.request.data['text']
elif self.project.is_task_of('image'):
return ''
elif self.project.is_task_of('speech'):
raise NotImplementedError('can not handle speech data now.')
raise NotImplementedError('The project type is unknown.')
return self.request.data['text']
def post(self, *args, **kwargs):
model = self.create_model()

Loading…
Cancel
Save