mirror of https://github.com/doccano/doccano.git
pythondatasetsactive-learningtext-annotationdatasetnatural-language-processingdata-labelingmachine-learningannotation-tool
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
267 lines
11 KiB
267 lines
11 KiB
import pathlib
|
|
from unittest.mock import patch
|
|
|
|
from auto_labeling_pipeline.mappings import AmazonComprehendSentimentTemplate
|
|
from auto_labeling_pipeline.models import RequestModelFactory
|
|
from model_mommy import mommy
|
|
from rest_framework import status
|
|
from rest_framework.reverse import reverse
|
|
|
|
from api.models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQUENCE_LABELING
|
|
from api.models import Category, Span
|
|
from api.tests.api.utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image,
|
|
prepare_project)
|
|
|
|
data_dir = pathlib.Path(__file__).parent / 'data'
|
|
|
|
|
|
class TestTemplateList(CRUDMixin):
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
|
|
self.url = reverse(viewname='auto_labeling_templates', args=[self.project.item.id])
|
|
|
|
def test_allow_admin_to_fetch_template_list(self):
|
|
self.url += '?task_name=DocumentClassification'
|
|
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
|
|
self.assertIn('Custom REST Request', response.data)
|
|
self.assertGreaterEqual(len(response.data), 1)
|
|
|
|
def test_deny_non_admin_to_fetch_template_list(self):
|
|
self.url += '?task_name=DocumentClassification'
|
|
for user in self.project.users[1:]:
|
|
self.assert_fetch(user, status.HTTP_403_FORBIDDEN)
|
|
|
|
def test_return_only_default_template_with_empty_task_name(self):
|
|
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
|
|
self.assertEqual(len(response.data), 1)
|
|
self.assertIn('Custom REST Request', response.data)
|
|
|
|
def test_return_only_default_template_with_wrong_task_name(self):
|
|
self.url += '?task_name=foobar'
|
|
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
|
|
self.assertEqual(len(response.data), 1)
|
|
self.assertIn('Custom REST Request', response.data)
|
|
|
|
|
|
class TestConfigParameter(CRUDMixin):
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
|
|
self.data = {
|
|
'model_name': 'GCP Entity Analysis',
|
|
'model_attrs': {'key': 'hoge', 'type': 'PLAIN_TEXT', 'language': 'en'},
|
|
'text': 'example'
|
|
}
|
|
self.url = reverse(viewname='auto_labeling_parameter_testing', args=[self.project.item.id])
|
|
|
|
@patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={})
|
|
def test_called_with_proper_model(self, mock):
|
|
self.assert_create(self.project.users[0], status.HTTP_200_OK)
|
|
_, kwargs = mock.call_args
|
|
expected = RequestModelFactory.create(self.data['model_name'], self.data['model_attrs'])
|
|
self.assertEqual(kwargs['model'], expected)
|
|
|
|
@patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={})
|
|
def test_called_with_text(self, mock):
|
|
self.assert_create(self.project.users[0], status.HTTP_200_OK)
|
|
_, kwargs = mock.call_args
|
|
self.assertEqual(kwargs['example'], self.data['text'])
|
|
|
|
@patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={})
|
|
def test_called_with_image(self, mock):
|
|
self.data['text'] = str(data_dir / 'images/1500x500.jpeg')
|
|
self.assert_create(self.project.users[0], status.HTTP_200_OK)
|
|
_, kwargs = mock.call_args
|
|
self.assertEqual(kwargs['example'], self.data['text'])
|
|
|
|
|
|
class TestTemplateMapping(CRUDMixin):
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
|
|
self.data = {
|
|
'response': {
|
|
'Sentiment': 'NEUTRAL',
|
|
'SentimentScore': {
|
|
'Positive': 0.004438233096152544,
|
|
'Negative': 0.0005306027014739811,
|
|
'Neutral': 0.9950305223464966,
|
|
'Mixed': 5.80838445785048e-7
|
|
}
|
|
},
|
|
'template': AmazonComprehendSentimentTemplate().load()
|
|
}
|
|
self.url = reverse(viewname='auto_labeling_template_test', args=[self.project.item.id])
|
|
|
|
def test_template_mapping(self):
|
|
response = self.assert_create(self.project.users[0], status.HTTP_200_OK)
|
|
expected = [{'label': 'NEUTRAL'}]
|
|
self.assertEqual(response.json(), expected)
|
|
|
|
def test_json_decode_error(self):
|
|
self.data['template'] = ''
|
|
self.assert_create(self.project.users[0], status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
class TestLabelMapping(CRUDMixin):
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
|
|
self.data = {
|
|
'response': [{'label': 'NEGATIVE'}],
|
|
'label_mapping': {'NEGATIVE': 'Negative'}
|
|
}
|
|
self.url = reverse(viewname='auto_labeling_mapping_test', args=[self.project.item.id])
|
|
|
|
def test_label_mapping(self):
|
|
response = self.assert_create(self.project.users[0], status.HTTP_200_OK)
|
|
expected = [{'label': 'Negative'}]
|
|
self.assertEqual(response.json(), expected)
|
|
|
|
|
|
class TestConfigCreation(CRUDMixin):
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
|
|
self.data = {
|
|
'model_name': 'Amazon Comprehend Sentiment Analysis',
|
|
'model_attrs': {
|
|
'aws_access_key': 'str',
|
|
'aws_secret_access_key': 'str',
|
|
'region_name': 'us-east-1',
|
|
'language_code': 'en'
|
|
},
|
|
'template': AmazonComprehendSentimentTemplate().load(),
|
|
'label_mapping': {'NEGATIVE': 'Negative'},
|
|
'task_type': 'Category'
|
|
}
|
|
self.url = reverse(viewname='auto_labeling_configs', args=[self.project.item.id])
|
|
|
|
def test_create_config(self):
|
|
response = self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
|
|
self.assertEqual(response.data['model_name'], self.data['model_name'])
|
|
|
|
def test_list_config(self):
|
|
mommy.make('AutoLabelingConfig', project=self.project.item)
|
|
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
|
|
self.assertEqual(len(response.data), 1)
|
|
|
|
|
|
class TestAutoLabelingText(CRUDMixin):
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
|
|
make_auto_labeling_config(self.project.item)
|
|
self.example = make_doc(self.project.item)
|
|
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[])
|
|
def test_text_task(self, mock):
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
|
|
_, kwargs = mock.call_args
|
|
self.assertEqual(kwargs['text'], self.example.text)
|
|
|
|
|
|
class TestAutoLabelingImage(CRUDMixin):
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(task=IMAGE_CLASSIFICATION)
|
|
make_auto_labeling_config(self.project.item)
|
|
filepath = data_dir / 'images/1500x500.jpeg'
|
|
self.example = make_image(self.project.item, str(filepath))
|
|
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[])
|
|
def test_text_task(self, mock):
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
|
|
_, kwargs = mock.call_args
|
|
expected = str(self.example.filename)
|
|
self.assertEqual(kwargs['text'], expected)
|
|
|
|
|
|
class TestAutomatedCategoryLabeling(CRUDMixin):
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION, single_class_classification=False)
|
|
self.example = make_doc(self.project.item)
|
|
self.category_pos = mommy.make(
|
|
'CategoryType', project=self.project.item, text='POS'
|
|
)
|
|
self.category_neg = mommy.make(
|
|
'CategoryType', project=self.project.item, text='NEG'
|
|
)
|
|
self.url = reverse(viewname='automated_category_labeling', args=[self.project.item.id, self.example.id])
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[{'label': 'POS'}])
|
|
def test_category_labeling(self, mock):
|
|
mommy.make('AutoLabelingConfig', task_type='Category')
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
|
|
self.assertEqual(Category.objects.count(), 1)
|
|
self.assertEqual(Category.objects.first().label, self.category_pos)
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'label': 'POS'}], [{'label': 'NEG'}]])
|
|
def test_multiple_configs(self, mock):
|
|
mommy.make('AutoLabelingConfig', task_type='Category')
|
|
mommy.make('AutoLabelingConfig', task_type='Category')
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
|
|
self.assertEqual(Category.objects.count(), 2)
|
|
self.assertEqual(Category.objects.first().label, self.category_pos)
|
|
self.assertEqual(Category.objects.last().label, self.category_neg)
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'label': 'POS'}], [{'label': 'POS'}]])
|
|
def test_cannot_label_same_category_type(self, mock):
|
|
mommy.make('AutoLabelingConfig', task_type='Category')
|
|
mommy.make('AutoLabelingConfig', task_type='Category')
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
|
|
self.assertEqual(Category.objects.count(), 1)
|
|
|
|
|
|
class TestAutomatedSpanLabeling(CRUDMixin):
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(task=SEQUENCE_LABELING)
|
|
self.example = make_doc(self.project.item)
|
|
self.loc = mommy.make('SpanType', project=self.project.item, text='LOC')
|
|
self.url = reverse(viewname='automated_span_labeling', args=[self.project.item.id, self.example.id])
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}])
|
|
def test_span_labeling(self, mock):
|
|
mommy.make('AutoLabelingConfig', task_type='Span')
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
|
|
self.assertEqual(Span.objects.count(), 1)
|
|
self.assertEqual(Span.objects.first().label, self.loc)
|
|
|
|
@patch(
|
|
'auto_labeling.views.execute_pipeline',
|
|
side_effect=[
|
|
[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}],
|
|
[{'label': 'LOC', 'start_offset': 5, 'end_offset': 10}]
|
|
]
|
|
)
|
|
def test_multiple_configs(self, mock):
|
|
mommy.make('AutoLabelingConfig', task_type='Span')
|
|
mommy.make('AutoLabelingConfig', task_type='Span')
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
|
|
expected_spans = [
|
|
{'label': 'LOC', 'start_offset': 0, 'end_offset': 5},
|
|
{'label': 'LOC', 'start_offset': 5, 'end_offset': 10}
|
|
]
|
|
self.assertEqual(Span.objects.count(), len(expected_spans))
|
|
for actual, expected in zip(Span.objects.all(), expected_spans):
|
|
self.assertEqual(actual.label, self.loc)
|
|
self.assertEqual(actual.start_offset, expected['start_offset'])
|
|
self.assertEqual(actual.end_offset, expected['end_offset'])
|
|
|
|
@patch(
|
|
'auto_labeling.views.execute_pipeline',
|
|
side_effect=[
|
|
[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}],
|
|
[{'label': 'LOC', 'start_offset': 4, 'end_offset': 10}]
|
|
]
|
|
)
|
|
def test_cannot_label_overlapping_span(self, mock):
|
|
mommy.make('AutoLabelingConfig', task_type='Span')
|
|
mommy.make('AutoLabelingConfig', task_type='Span')
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
|
|
self.assertEqual(Span.objects.count(), 1)
|