Browse Source

Update AutoLabelingAnnotation API to handle image data

pull/1413/head
Hironsan 4 years ago
parent
commit
25989c265d
3 changed files with 55 additions and 10 deletions
  1. 37
      backend/api/tests/api/test_auto_labeling.py
  2. 17
      backend/api/tests/api/utils.py
  3. 11
      backend/api/views/auto_labeling.py

37
backend/api/tests/api/test_auto_labeling.py

@ -6,9 +6,10 @@ from auto_labeling_pipeline.models import RequestModelFactory
from rest_framework import status
from rest_framework.reverse import reverse
from ...models import DOCUMENT_CLASSIFICATION
from ...models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION
from ...views.auto_labeling import load_data_as_b64
from .utils import CRUDMixin, prepare_project
from .utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image,
prepare_project)
data_dir = pathlib.Path(__file__).parent / 'data'
@ -108,3 +109,35 @@ class TestConfigCreation(CRUDMixin):
def test_create_config(self):
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
class TestAutoLabelingText(CRUDMixin):
def setUp(self):
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
make_auto_labeling_config(self.project.item)
self.example = make_doc(self.project.item)
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
@patch('api.views.auto_labeling.execute_pipeline', return_value=[])
def test_text_task(self, mock):
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
_, kwargs = mock.call_args
self.assertEqual(kwargs['text'], self.example.text)
class TestAutoLabelingImage(CRUDMixin):
def setUp(self):
self.project = prepare_project(task=IMAGE_CLASSIFICATION)
make_auto_labeling_config(self.project.item)
filepath = data_dir / 'images/1500x500.jpeg'
self.example = make_image(self.project.item, str(filepath))
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
@patch('api.views.auto_labeling.execute_pipeline', return_value=[])
def test_text_task(self, mock):
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
_, kwargs = mock.call_args
expected = load_data_as_b64(str(self.example.filename))
self.assertEqual(kwargs['text'], expected)

17
backend/api/tests/api/utils.py

@ -8,8 +8,8 @@ from model_mommy import mommy
from rest_framework import status
from rest_framework.test import APITestCase
from ...models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
SPEECH2TEXT, Role, RoleMapping)
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT, Role, RoleMapping)
DATA_DIR = os.path.join(os.path.dirname(__file__), '../data')
@ -61,7 +61,8 @@ def make_project(
DOCUMENT_CLASSIFICATION: 'TextClassificationProject',
SEQUENCE_LABELING: 'SequenceLabelingProject',
SEQ2SEQ: 'Seq2seqProject',
SPEECH2TEXT: 'Speech2TextProject'
SPEECH2TEXT: 'Speech2TextProject',
IMAGE_CLASSIFICATION: 'ImageClassificationProject'
}.get(task, 'Project')
project = mommy.make(
_model=project_model,
@ -89,11 +90,11 @@ def make_label(project):
def make_doc(project):
return mommy.make('Example', project=project)
return mommy.make('Example', text='example', project=project)
def make_image(project):
return mommy.make('Example', project=project)
def make_image(project, filepath):
return mommy.make('Example', filename=filepath, project=project)
def make_comment(doc, user):
@ -104,6 +105,10 @@ def make_example_state(example, user):
return mommy.make('ExampleState', example=example, confirmed_by=user)
def make_auto_labeling_config(project):
return mommy.make('AutoLabelingConfig', project=project)
def make_annotation(task, doc, user):
annotation_model = {
DOCUMENT_CLASSIFICATION: 'Category',

11
backend/api/views/auto_labeling.py

@ -223,14 +223,21 @@ class AutoLabelingAnnotation(generics.CreateAPIView):
def perform_create(self, serializer):
serializer.save(user=self.request.user)
def get_example(self, project):
example = get_object_or_404(Example, pk=self.kwargs['doc_id'])
if project.is_task_of('text'):
return example.text
else:
return load_data_as_b64(str(example.filename))
def extract(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
doc = get_object_or_404(Example, pk=self.kwargs['doc_id'])
example = self.get_example(project)
config = project.auto_labeling_config.first()
if not config:
raise AutoLabeliingPermissionDenied()
return execute_pipeline(
text=doc.text,
text=example,
project_type=project.project_type,
model_name=config.model_name,
model_attrs=config.model_attrs,

Loading…
Cancel
Save