diff --git a/backend/api/tests/api/test_auto_labeling.py b/backend/api/tests/api/test_auto_labeling.py index e8920f4f..e8777744 100644 --- a/backend/api/tests/api/test_auto_labeling.py +++ b/backend/api/tests/api/test_auto_labeling.py @@ -6,9 +6,10 @@ from auto_labeling_pipeline.models import RequestModelFactory from rest_framework import status from rest_framework.reverse import reverse -from ...models import DOCUMENT_CLASSIFICATION +from ...models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION from ...views.auto_labeling import load_data_as_b64 -from .utils import CRUDMixin, prepare_project +from .utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image, + prepare_project) data_dir = pathlib.Path(__file__).parent / 'data' @@ -108,3 +109,35 @@ class TestConfigCreation(CRUDMixin): def test_create_config(self): self.assert_create(self.project.users[0], status.HTTP_201_CREATED) + + +class TestAutoLabelingText(CRUDMixin): + + def setUp(self): + self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + make_auto_labeling_config(self.project.item) + self.example = make_doc(self.project.item) + self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id]) + + @patch('api.views.auto_labeling.execute_pipeline', return_value=[]) + def test_text_task(self, mock): + self.assert_create(self.project.users[0], status.HTTP_201_CREATED) + _, kwargs = mock.call_args + self.assertEqual(kwargs['text'], self.example.text) + + +class TestAutoLabelingImage(CRUDMixin): + + def setUp(self): + self.project = prepare_project(task=IMAGE_CLASSIFICATION) + make_auto_labeling_config(self.project.item) + filepath = data_dir / 'images/1500x500.jpeg' + self.example = make_image(self.project.item, str(filepath)) + self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id]) + + @patch('api.views.auto_labeling.execute_pipeline', return_value=[]) + def test_text_task(self, mock): + self.assert_create(self.project.users[0], status.HTTP_201_CREATED) + _, kwargs = mock.call_args + expected = load_data_as_b64(str(self.example.filename)) + self.assertEqual(kwargs['text'], expected) diff --git a/backend/api/tests/api/utils.py b/backend/api/tests/api/utils.py index 147b5bd3..ce8019b4 100644 --- a/backend/api/tests/api/utils.py +++ b/backend/api/tests/api/utils.py @@ -8,8 +8,8 @@ from model_mommy import mommy from rest_framework import status from rest_framework.test import APITestCase -from ...models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, - SPEECH2TEXT, Role, RoleMapping) +from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, + SEQUENCE_LABELING, SPEECH2TEXT, Role, RoleMapping) DATA_DIR = os.path.join(os.path.dirname(__file__), '../data') @@ -61,7 +61,8 @@ def make_project( DOCUMENT_CLASSIFICATION: 'TextClassificationProject', SEQUENCE_LABELING: 'SequenceLabelingProject', SEQ2SEQ: 'Seq2seqProject', - SPEECH2TEXT: 'Speech2TextProject' + SPEECH2TEXT: 'Speech2TextProject', + IMAGE_CLASSIFICATION: 'ImageClassificationProject' }.get(task, 'Project') project = mommy.make( _model=project_model, @@ -89,11 +90,11 @@ def make_label(project): def make_doc(project): - return mommy.make('Example', project=project) + return mommy.make('Example', text='example', project=project) -def make_image(project): - return mommy.make('Example', project=project) +def make_image(project, filepath): + return mommy.make('Example', filename=filepath, project=project) def make_comment(doc, user): @@ -104,6 +105,10 @@ def make_example_state(example, user): return mommy.make('ExampleState', example=example, confirmed_by=user) +def make_auto_labeling_config(project): + return mommy.make('AutoLabelingConfig', project=project) + + def make_annotation(task, doc, user): annotation_model = { DOCUMENT_CLASSIFICATION: 'Category', diff --git a/backend/api/views/auto_labeling.py b/backend/api/views/auto_labeling.py index fd7c832a..06a8d43e 100644 --- a/backend/api/views/auto_labeling.py +++ b/backend/api/views/auto_labeling.py @@ -223,14 +223,21 @@ class AutoLabelingAnnotation(generics.CreateAPIView): def perform_create(self, serializer): serializer.save(user=self.request.user) + def get_example(self, project): + example = get_object_or_404(Example, pk=self.kwargs['doc_id']) + if project.is_task_of('text'): + return example.text + else: + return load_data_as_b64(str(example.filename)) + def extract(self): project = get_object_or_404(Project, pk=self.kwargs['project_id']) - doc = get_object_or_404(Example, pk=self.kwargs['doc_id']) + example = self.get_example(project) config = project.auto_labeling_config.first() if not config: raise AutoLabeliingPermissionDenied() return execute_pipeline( - text=doc.text, + text=example, project_type=project.project_type, model_name=config.model_name, model_attrs=config.model_attrs,