|
@ -7,10 +7,10 @@ from model_mommy import mommy |
|
|
from rest_framework import status |
|
|
from rest_framework import status |
|
|
from rest_framework.reverse import reverse |
|
|
from rest_framework.reverse import reverse |
|
|
|
|
|
|
|
|
from api.models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ |
|
|
|
|
|
|
|
|
from api.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ |
|
|
from api.models import Category, Span, TextLabel |
|
|
from api.models import Category, Span, TextLabel |
|
|
from api.tests.api.utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image, |
|
|
|
|
|
prepare_project) |
|
|
|
|
|
|
|
|
from api.tests.api.utils import CRUDMixin, make_doc, prepare_project |
|
|
|
|
|
from auto_labeling.pipeline.execution import Categories, Spans, Texts |
|
|
|
|
|
|
|
|
data_dir = pathlib.Path(__file__).parent / 'data' |
|
|
data_dir = pathlib.Path(__file__).parent / 'data' |
|
|
|
|
|
|
|
@ -148,38 +148,6 @@ class TestConfigCreation(CRUDMixin): |
|
|
self.assertEqual(len(response.data), 1) |
|
|
self.assertEqual(len(response.data), 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAutoLabelingText(CRUDMixin): |
|
|
|
|
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
|
|
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) |
|
|
|
|
|
make_auto_labeling_config(self.project.item) |
|
|
|
|
|
self.example = make_doc(self.project.item) |
|
|
|
|
|
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[]) |
|
|
|
|
|
def test_text_task(self, mock): |
|
|
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
|
|
_, kwargs = mock.call_args |
|
|
|
|
|
self.assertEqual(kwargs['text'], self.example.text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAutoLabelingImage(CRUDMixin): |
|
|
|
|
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
|
|
self.project = prepare_project(task=IMAGE_CLASSIFICATION) |
|
|
|
|
|
make_auto_labeling_config(self.project.item) |
|
|
|
|
|
filepath = data_dir / 'images/1500x500.jpeg' |
|
|
|
|
|
self.example = make_image(self.project.item, str(filepath)) |
|
|
|
|
|
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[]) |
|
|
|
|
|
def test_text_task(self, mock): |
|
|
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
|
|
_, kwargs = mock.call_args |
|
|
|
|
|
expected = str(self.example.filename) |
|
|
|
|
|
self.assertEqual(kwargs['text'], expected) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAutomatedCategoryLabeling(CRUDMixin): |
|
|
class TestAutomatedCategoryLabeling(CRUDMixin): |
|
|
|
|
|
|
|
|
def setUp(self): |
|
|
def setUp(self): |
|
@ -191,30 +159,63 @@ class TestAutomatedCategoryLabeling(CRUDMixin): |
|
|
self.category_neg = mommy.make( |
|
|
self.category_neg = mommy.make( |
|
|
'CategoryType', project=self.project.item, text='NEG' |
|
|
'CategoryType', project=self.project.item, text='NEG' |
|
|
) |
|
|
) |
|
|
self.url = reverse(viewname='automated_category_labeling', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
|
self.loc = mommy.make('SpanType', project=self.project.item, text='LOC') |
|
|
|
|
|
self.url = reverse(viewname='automated_labeling', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[{'label': 'POS'}]) |
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}])) |
|
|
def test_category_labeling(self, mock): |
|
|
def test_category_labeling(self, mock): |
|
|
mommy.make('AutoLabelingConfig', task_type='Category') |
|
|
|
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assertEqual(Category.objects.count(), 1) |
|
|
self.assertEqual(Category.objects.count(), 1) |
|
|
self.assertEqual(Category.objects.first().label, self.category_pos) |
|
|
self.assertEqual(Category.objects.first().label, self.category_pos) |
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'label': 'POS'}], [{'label': 'NEG'}]]) |
|
|
|
|
|
|
|
|
@patch( |
|
|
|
|
|
'auto_labeling.views.execute_pipeline', |
|
|
|
|
|
side_effect=[ |
|
|
|
|
|
Categories([{'label': 'POS'}]), |
|
|
|
|
|
Categories([{'label': 'NEG'}]) |
|
|
|
|
|
] |
|
|
|
|
|
) |
|
|
def test_multiple_configs(self, mock): |
|
|
def test_multiple_configs(self, mock): |
|
|
mommy.make('AutoLabelingConfig', task_type='Category') |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category') |
|
|
|
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assertEqual(Category.objects.count(), 2) |
|
|
self.assertEqual(Category.objects.count(), 2) |
|
|
self.assertEqual(Category.objects.first().label, self.category_pos) |
|
|
self.assertEqual(Category.objects.first().label, self.category_pos) |
|
|
self.assertEqual(Category.objects.last().label, self.category_neg) |
|
|
self.assertEqual(Category.objects.last().label, self.category_neg) |
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'label': 'POS'}], [{'label': 'POS'}]]) |
|
|
|
|
|
|
|
|
@patch( |
|
|
|
|
|
'auto_labeling.views.execute_pipeline', |
|
|
|
|
|
side_effect=[ |
|
|
|
|
|
Categories([{'label': 'POS'}]), |
|
|
|
|
|
Categories([{'label': 'POS'}]) |
|
|
|
|
|
] |
|
|
|
|
|
) |
|
|
def test_cannot_label_same_category_type(self, mock): |
|
|
def test_cannot_label_same_category_type(self, mock): |
|
|
mommy.make('AutoLabelingConfig', task_type='Category') |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category') |
|
|
|
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
|
|
self.assertEqual(Category.objects.count(), 1) |
|
|
|
|
|
|
|
|
|
|
|
@patch( |
|
|
|
|
|
'auto_labeling.views.execute_pipeline', |
|
|
|
|
|
side_effect=[ |
|
|
|
|
|
Categories([{'label': 'POS'}]), |
|
|
|
|
|
Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]), |
|
|
|
|
|
] |
|
|
|
|
|
) |
|
|
|
|
|
def test_allow_multi_type_configs(self, mock): |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assertEqual(Category.objects.count(), 1) |
|
|
self.assertEqual(Category.objects.count(), 1) |
|
|
|
|
|
self.assertEqual(Span.objects.count(), 1) |
|
|
|
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}])) |
|
|
|
|
|
def test_cannot_use_other_project_config(self, mock): |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category') |
|
|
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
|
|
self.assertEqual(Category.objects.count(), 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAutomatedSpanLabeling(CRUDMixin): |
|
|
class TestAutomatedSpanLabeling(CRUDMixin): |
|
@ -223,46 +224,18 @@ class TestAutomatedSpanLabeling(CRUDMixin): |
|
|
self.project = prepare_project(task=SEQUENCE_LABELING) |
|
|
self.project = prepare_project(task=SEQUENCE_LABELING) |
|
|
self.example = make_doc(self.project.item) |
|
|
self.example = make_doc(self.project.item) |
|
|
self.loc = mommy.make('SpanType', project=self.project.item, text='LOC') |
|
|
self.loc = mommy.make('SpanType', project=self.project.item, text='LOC') |
|
|
self.url = reverse(viewname='automated_span_labeling', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]) |
|
|
|
|
|
def test_span_labeling(self, mock): |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span') |
|
|
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
|
|
self.assertEqual(Span.objects.count(), 1) |
|
|
|
|
|
self.assertEqual(Span.objects.first().label, self.loc) |
|
|
|
|
|
|
|
|
|
|
|
@patch( |
|
|
|
|
|
'auto_labeling.views.execute_pipeline', |
|
|
|
|
|
side_effect=[ |
|
|
|
|
|
[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}], |
|
|
|
|
|
[{'label': 'LOC', 'start_offset': 5, 'end_offset': 10}] |
|
|
|
|
|
] |
|
|
|
|
|
) |
|
|
|
|
|
def test_multiple_configs(self, mock): |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span') |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span') |
|
|
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
|
|
expected_spans = [ |
|
|
|
|
|
{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}, |
|
|
|
|
|
{'label': 'LOC', 'start_offset': 5, 'end_offset': 10} |
|
|
|
|
|
] |
|
|
|
|
|
self.assertEqual(Span.objects.count(), len(expected_spans)) |
|
|
|
|
|
for actual, expected in zip(Span.objects.all(), expected_spans): |
|
|
|
|
|
self.assertEqual(actual.label, self.loc) |
|
|
|
|
|
self.assertEqual(actual.start_offset, expected['start_offset']) |
|
|
|
|
|
self.assertEqual(actual.end_offset, expected['end_offset']) |
|
|
|
|
|
|
|
|
self.url = reverse(viewname='automated_labeling', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
|
@patch( |
|
|
@patch( |
|
|
'auto_labeling.views.execute_pipeline', |
|
|
'auto_labeling.views.execute_pipeline', |
|
|
side_effect=[ |
|
|
side_effect=[ |
|
|
[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}], |
|
|
|
|
|
[{'label': 'LOC', 'start_offset': 4, 'end_offset': 10}] |
|
|
|
|
|
|
|
|
Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]), |
|
|
|
|
|
Spans([{'label': 'LOC', 'start_offset': 4, 'end_offset': 10}]) |
|
|
] |
|
|
] |
|
|
) |
|
|
) |
|
|
def test_cannot_label_overlapping_span(self, mock): |
|
|
def test_cannot_label_overlapping_span(self, mock): |
|
|
mommy.make('AutoLabelingConfig', task_type='Span') |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span') |
|
|
|
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item) |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assertEqual(Span.objects.count(), 1) |
|
|
self.assertEqual(Span.objects.count(), 1) |
|
|
|
|
|
|
|
@ -272,27 +245,17 @@ class TestAutomatedTextLabeling(CRUDMixin): |
|
|
def setUp(self): |
|
|
def setUp(self): |
|
|
self.project = prepare_project(task=SEQ2SEQ) |
|
|
self.project = prepare_project(task=SEQ2SEQ) |
|
|
self.example = make_doc(self.project.item) |
|
|
self.example = make_doc(self.project.item) |
|
|
self.url = reverse(viewname='automated_text_labeling', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
|
self.url = reverse(viewname='automated_labeling', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[{'text': 'foo'}]) |
|
|
|
|
|
def test_category_labeling(self, mock): |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
|
|
self.assertEqual(TextLabel.objects.count(), 1) |
|
|
|
|
|
self.assertEqual(TextLabel.objects.first().text, 'foo') |
|
|
|
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'bar'}]]) |
|
|
|
|
|
def test_multiple_configs(self, mock): |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
|
|
self.assertEqual(TextLabel.objects.count(), 2) |
|
|
|
|
|
self.assertEqual(TextLabel.objects.first().text, 'foo') |
|
|
|
|
|
self.assertEqual(TextLabel.objects.last().text, 'bar') |
|
|
|
|
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'foo'}]]) |
|
|
|
|
|
def test_cannot_label_same_category_type(self, mock): |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
|
|
|
|
|
@patch( |
|
|
|
|
|
'auto_labeling.views.execute_pipeline', |
|
|
|
|
|
side_effect=[ |
|
|
|
|
|
Texts([{'text': 'foo'}]), |
|
|
|
|
|
Texts([{'text': 'foo'}]) |
|
|
|
|
|
] |
|
|
|
|
|
) |
|
|
|
|
|
def test_cannot_label_same_text(self, mock): |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item) |
|
|
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
self.assertEqual(TextLabel.objects.count(), 1) |
|
|
self.assertEqual(TextLabel.objects.count(), 1) |