|
|
@ -7,8 +7,8 @@ from model_mommy import mommy |
|
|
|
from rest_framework import status |
|
|
|
from rest_framework.reverse import reverse |
|
|
|
|
|
|
|
from api.models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQUENCE_LABELING |
|
|
|
from api.models import Category, Span |
|
|
|
from api.models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ |
|
|
|
from api.models import Category, Span, TextLabel |
|
|
|
from api.tests.api.utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image, |
|
|
|
prepare_project) |
|
|
|
|
|
|
@ -265,3 +265,34 @@ class TestAutomatedSpanLabeling(CRUDMixin): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span') |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(Span.objects.count(), 1) |
|
|
|
|
|
|
|
|
|
|
|
class TestAutomatedTextLabeling(CRUDMixin): |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
self.project = prepare_project(task=SEQ2SEQ) |
|
|
|
self.example = make_doc(self.project.item) |
|
|
|
self.url = reverse(viewname='automated_text_labeling', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[{'text': 'foo'}]) |
|
|
|
def test_category_labeling(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(TextLabel.objects.count(), 1) |
|
|
|
self.assertEqual(TextLabel.objects.first().text, 'foo') |
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'bar'}]]) |
|
|
|
def test_multiple_configs(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(TextLabel.objects.count(), 2) |
|
|
|
self.assertEqual(TextLabel.objects.first().text, 'foo') |
|
|
|
self.assertEqual(TextLabel.objects.last().text, 'bar') |
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'foo'}]]) |
|
|
|
def test_cannot_label_same_category_type(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text') |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(TextLabel.objects.count(), 1) |