diff --git a/backend/api/tests/api/test_auto_labeling.py b/backend/api/tests/api/test_auto_labeling.py index 87dd4749..8d89dc55 100644 --- a/backend/api/tests/api/test_auto_labeling.py +++ b/backend/api/tests/api/test_auto_labeling.py @@ -1,25 +1,27 @@ +import pathlib from unittest.mock import patch from auto_labeling_pipeline.models import RequestModelFactory from rest_framework import status from rest_framework.reverse import reverse -from ...models import DOCUMENT_CLASSIFICATION, Category -from .utils import (CRUDMixin, make_annotation, make_doc, - make_user, prepare_project) +from ...models import DOCUMENT_CLASSIFICATION +from ...views.auto_labeling import load_data_as_b64 +from .utils import CRUDMixin, prepare_project + +data_dir = pathlib.Path(__file__).parent / 'data' class TestConfigParameter(CRUDMixin): - @classmethod - def setUpTestData(cls): - cls.project = prepare_project(task=DOCUMENT_CLASSIFICATION) - cls.data = { + def setUp(self): + self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.data = { 'model_name': 'GCP Entity Analysis', 'model_attrs': {'key': 'hoge', 'type': 'PLAIN_TEXT', 'language': 'en'}, 'text': 'example' } - cls.url = reverse(viewname='auto_labeling_parameter_testing', args=[cls.project.item.id]) + self.url = reverse(viewname='auto_labeling_parameter_testing', args=[self.project.item.id]) @patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={}) def test_called_with_proper_model(self, mock): @@ -36,6 +38,7 @@ class TestConfigParameter(CRUDMixin): @patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={}) def test_called_with_image(self, mock): + self.data['text'] = load_data_as_b64(data_dir / 'images/1500x500.jpeg') self.assert_create(self.project.users[0], status.HTTP_200_OK) _, kwargs = mock.call_args self.assertEqual(kwargs['example'], self.data['text']) diff --git a/backend/api/views/auto_labeling.py b/backend/api/views/auto_labeling.py index 6b2cafec..81b2a0a0 100644 --- a/backend/api/views/auto_labeling.py +++ b/backend/api/views/auto_labeling.py @@ -1,3 +1,5 @@ +import base64 + import botocore.exceptions import requests from auto_labeling_pipeline.mappings import MappingTemplate @@ -22,6 +24,12 @@ from ..serializers import (AutoLabelingConfigSerializer, get_annotation_serializer) +def load_data_as_b64(filepath): + with open(filepath, 'rb') as f: + byte_str = base64.b64encode(f.read()) + return byte_str.decode('utf-8') + + class AutoLabelingTemplateListAPI(APIView): permission_classes = [IsAuthenticated & IsProjectAdmin] @@ -137,13 +145,7 @@ class AutoLabelingConfigParameterTest(APIView): raise e def prepare_example(self): - if self.project.is_task_of('text'): - return self.request.data['text'] - elif self.project.is_task_of('image'): - return '' - elif self.project.is_task_of('speech'): - raise NotImplementedError('can not handle speech data now.') - raise NotImplementedError('The project type is unknown.') + return self.request.data['text'] def post(self, *args, **kwargs): model = self.create_model()