|
|
import pathlib from unittest.mock import patch
from auto_labeling_pipeline.mappings import AmazonComprehendSentimentTemplate from auto_labeling_pipeline.models import RequestModelFactory from model_mommy import mommy from rest_framework import status from rest_framework.reverse import reverse
from api.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ from labels.models import Category, Span, TextLabel from api.tests.api.utils import CRUDMixin, make_doc, prepare_project from auto_labeling.pipeline.labels import Categories, Spans, Texts
data_dir = pathlib.Path(__file__).parent / 'data'
class TestTemplateList(CRUDMixin):
def setUp(self): self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) self.url = reverse(viewname='auto_labeling_templates', args=[self.project.item.id])
def test_allow_admin_to_fetch_template_list(self): self.url += '?task_name=DocumentClassification' response = self.assert_fetch(self.project.admin, status.HTTP_200_OK) self.assertIn('Custom REST Request', response.data) self.assertGreaterEqual(len(response.data), 1)
def test_deny_project_staff_to_fetch_template_list(self): self.url += '?task_name=DocumentClassification' for user in self.project.staffs: self.assert_fetch(user, status.HTTP_403_FORBIDDEN)
def test_return_only_default_template_with_empty_task_name(self): response = self.assert_fetch(self.project.admin, status.HTTP_200_OK) self.assertEqual(len(response.data), 1) self.assertIn('Custom REST Request', response.data)
def test_return_only_default_template_with_wrong_task_name(self): self.url += '?task_name=foobar' response = self.assert_fetch(self.project.admin, status.HTTP_200_OK) self.assertEqual(len(response.data), 1) self.assertIn('Custom REST Request', response.data)
class TestConfigParameter(CRUDMixin):
def setUp(self): self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) self.data = { 'model_name': 'GCP Entity Analysis', 'model_attrs': {'key': 'hoge', 'type': 'PLAIN_TEXT', 'language': 'en'}, 'text': 'example' } self.url = reverse(viewname='auto_labeling_parameter_testing', args=[self.project.item.id])
@patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={}) def test_called_with_proper_model(self, mock): self.assert_create(self.project.admin, status.HTTP_200_OK) _, kwargs = mock.call_args expected = RequestModelFactory.create(self.data['model_name'], self.data['model_attrs']) self.assertEqual(kwargs['model'], expected)
@patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={}) def test_called_with_text(self, mock): self.assert_create(self.project.admin, status.HTTP_200_OK) _, kwargs = mock.call_args self.assertEqual(kwargs['example'], self.data['text'])
@patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={}) def test_called_with_image(self, mock): self.data['text'] = str(data_dir / 'images/1500x500.jpeg') self.assert_create(self.project.admin, status.HTTP_200_OK) _, kwargs = mock.call_args self.assertEqual(kwargs['example'], self.data['text'])
class TestTemplateMapping(CRUDMixin):
def setUp(self): self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) self.data = { 'response': { 'Sentiment': 'NEUTRAL', 'SentimentScore': { 'Positive': 0.004438233096152544, 'Negative': 0.0005306027014739811, 'Neutral': 0.9950305223464966, 'Mixed': 5.80838445785048e-7 } }, 'template': AmazonComprehendSentimentTemplate().load(), 'task_type': 'Category' } self.url = reverse(viewname='auto_labeling_template_test', args=[self.project.item.id])
def test_template_mapping(self): response = self.assert_create(self.project.admin, status.HTTP_200_OK) expected = [{'label': 'NEUTRAL'}] self.assertEqual(response.json(), expected)
def test_json_decode_error(self): self.data['template'] = '' self.assert_create(self.project.admin, status.HTTP_400_BAD_REQUEST)
class TestLabelMapping(CRUDMixin):
def setUp(self): self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) self.data = { 'response': [{'label': 'NEGATIVE'}], 'label_mapping': {'NEGATIVE': 'Negative'}, 'task_type': 'Category' } self.url = reverse(viewname='auto_labeling_mapping_test', args=[self.project.item.id])
def test_label_mapping(self): response = self.assert_create(self.project.admin, status.HTTP_200_OK) expected = [{'label': 'Negative'}] self.assertEqual(response.json(), expected)
class TestConfigCreation(CRUDMixin):
def setUp(self): self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) self.data = { 'model_name': 'Amazon Comprehend Sentiment Analysis', 'model_attrs': { 'aws_access_key': 'str', 'aws_secret_access_key': 'str', 'region_name': 'us-east-1', 'language_code': 'en' }, 'template': AmazonComprehendSentimentTemplate().load(), 'label_mapping': {'NEGATIVE': 'Negative'}, 'task_type': 'Category' } self.url = reverse(viewname='auto_labeling_configs', args=[self.project.item.id])
def test_create_config(self): response = self.assert_create(self.project.admin, status.HTTP_201_CREATED) self.assertEqual(response.data['model_name'], self.data['model_name'])
def test_list_config(self): mommy.make('AutoLabelingConfig', project=self.project.item) response = self.assert_fetch(self.project.admin, status.HTTP_200_OK) self.assertEqual(len(response.data), 1)
class TestAutomatedLabeling(CRUDMixin):
def setUp(self): self.project = prepare_project(task=DOCUMENT_CLASSIFICATION, single_class_classification=False) self.example = make_doc(self.project.item) self.category_pos = mommy.make( 'CategoryType', project=self.project.item, text='POS' ) self.category_neg = mommy.make( 'CategoryType', project=self.project.item, text='NEG' ) self.loc = mommy.make('SpanType', project=self.project.item, text='LOC') self.url = reverse(viewname='auto_labeling', args=[self.project.item.id]) self.url += f'?example={self.example.id}'
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}])) def test_category_labeling(self, mock): mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) self.assert_create(self.project.admin, status.HTTP_201_CREATED) self.assertEqual(Category.objects.count(), 1) self.assertEqual(Category.objects.first().label, self.category_pos)
@patch( 'auto_labeling.views.execute_pipeline', side_effect=[ Categories([{'label': 'POS'}]), Categories([{'label': 'NEG'}]) ] ) def test_multiple_configs(self, mock): mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) self.assert_create(self.project.admin, status.HTTP_201_CREATED) self.assertEqual(Category.objects.count(), 2) self.assertEqual(Category.objects.first().label, self.category_pos) self.assertEqual(Category.objects.last().label, self.category_neg)
@patch( 'auto_labeling.views.execute_pipeline', side_effect=[ Categories([{'label': 'POS'}]), Categories([{'label': 'POS'}]) ] ) def test_cannot_label_same_category_type(self, mock): mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) self.assert_create(self.project.admin, status.HTTP_201_CREATED) self.assertEqual(Category.objects.count(), 1)
@patch( 'auto_labeling.views.execute_pipeline', side_effect=[ Categories([{'label': 'POS'}]), Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]), ] ) def test_allow_multi_type_configs(self, mock): mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item) self.assert_create(self.project.admin, status.HTTP_201_CREATED) self.assertEqual(Category.objects.count(), 1) self.assertEqual(Span.objects.count(), 1)
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}])) def test_cannot_use_other_project_config(self, mock): mommy.make('AutoLabelingConfig', task_type='Category') self.assert_create(self.project.admin, status.HTTP_201_CREATED) self.assertEqual(Category.objects.count(), 0)
class TestAutomatedSpanLabeling(CRUDMixin):
def setUp(self): self.project = prepare_project(task=SEQUENCE_LABELING) self.example = make_doc(self.project.item) self.loc = mommy.make('SpanType', project=self.project.item, text='LOC') self.url = reverse(viewname='auto_labeling', args=[self.project.item.id]) self.url += f'?example={self.example.id}'
@patch( 'auto_labeling.views.execute_pipeline', side_effect=[ Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]), Spans([{'label': 'LOC', 'start_offset': 4, 'end_offset': 10}]) ] ) def test_cannot_label_overlapping_span(self, mock): mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item) mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item) self.assert_create(self.project.admin, status.HTTP_201_CREATED) self.assertEqual(Span.objects.count(), 1)
class TestAutomatedTextLabeling(CRUDMixin):
def setUp(self): self.project = prepare_project(task=SEQ2SEQ) self.example = make_doc(self.project.item) self.url = reverse(viewname='auto_labeling', args=[self.project.item.id]) self.url += f'?example={self.example.id}'
@patch( 'auto_labeling.views.execute_pipeline', side_effect=[ Texts([{'text': 'foo'}]), Texts([{'text': 'foo'}]) ] ) def test_cannot_label_same_text(self, mock): mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item) mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item) self.assert_create(self.project.admin, status.HTTP_201_CREATED) self.assertEqual(TextLabel.objects.count(), 1)
|