|
|
@ -23,23 +23,23 @@ class TestTemplateList(CRUDMixin): |
|
|
|
|
|
|
|
def test_allow_admin_to_fetch_template_list(self): |
|
|
|
self.url += '?task_name=DocumentClassification' |
|
|
|
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK) |
|
|
|
response = self.assert_fetch(self.project.admin, status.HTTP_200_OK) |
|
|
|
self.assertIn('Custom REST Request', response.data) |
|
|
|
self.assertGreaterEqual(len(response.data), 1) |
|
|
|
|
|
|
|
def test_deny_non_admin_to_fetch_template_list(self): |
|
|
|
def test_deny_project_staff_to_fetch_template_list(self): |
|
|
|
self.url += '?task_name=DocumentClassification' |
|
|
|
for user in self.project.users[1:]: |
|
|
|
for user in self.project.staffs: |
|
|
|
self.assert_fetch(user, status.HTTP_403_FORBIDDEN) |
|
|
|
|
|
|
|
def test_return_only_default_template_with_empty_task_name(self): |
|
|
|
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK) |
|
|
|
response = self.assert_fetch(self.project.admin, status.HTTP_200_OK) |
|
|
|
self.assertEqual(len(response.data), 1) |
|
|
|
self.assertIn('Custom REST Request', response.data) |
|
|
|
|
|
|
|
def test_return_only_default_template_with_wrong_task_name(self): |
|
|
|
self.url += '?task_name=foobar' |
|
|
|
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK) |
|
|
|
response = self.assert_fetch(self.project.admin, status.HTTP_200_OK) |
|
|
|
self.assertEqual(len(response.data), 1) |
|
|
|
self.assertIn('Custom REST Request', response.data) |
|
|
|
|
|
|
@ -57,21 +57,21 @@ class TestConfigParameter(CRUDMixin): |
|
|
|
|
|
|
|
@patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={}) |
|
|
|
def test_called_with_proper_model(self, mock): |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_200_OK) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_200_OK) |
|
|
|
_, kwargs = mock.call_args |
|
|
|
expected = RequestModelFactory.create(self.data['model_name'], self.data['model_attrs']) |
|
|
|
self.assertEqual(kwargs['model'], expected) |
|
|
|
|
|
|
|
@patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={}) |
|
|
|
def test_called_with_text(self, mock): |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_200_OK) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_200_OK) |
|
|
|
_, kwargs = mock.call_args |
|
|
|
self.assertEqual(kwargs['example'], self.data['text']) |
|
|
|
|
|
|
|
@patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={}) |
|
|
|
def test_called_with_image(self, mock): |
|
|
|
self.data['text'] = str(data_dir / 'images/1500x500.jpeg') |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_200_OK) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_200_OK) |
|
|
|
_, kwargs = mock.call_args |
|
|
|
self.assertEqual(kwargs['example'], self.data['text']) |
|
|
|
|
|
|
@ -96,13 +96,13 @@ class TestTemplateMapping(CRUDMixin): |
|
|
|
self.url = reverse(viewname='auto_labeling_template_test', args=[self.project.item.id]) |
|
|
|
|
|
|
|
def test_template_mapping(self): |
|
|
|
response = self.assert_create(self.project.users[0], status.HTTP_200_OK) |
|
|
|
response = self.assert_create(self.project.admin, status.HTTP_200_OK) |
|
|
|
expected = [{'label': 'NEUTRAL'}] |
|
|
|
self.assertEqual(response.json(), expected) |
|
|
|
|
|
|
|
def test_json_decode_error(self): |
|
|
|
self.data['template'] = '' |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_400_BAD_REQUEST) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_400_BAD_REQUEST) |
|
|
|
|
|
|
|
|
|
|
|
class TestLabelMapping(CRUDMixin): |
|
|
@ -117,7 +117,7 @@ class TestLabelMapping(CRUDMixin): |
|
|
|
self.url = reverse(viewname='auto_labeling_mapping_test', args=[self.project.item.id]) |
|
|
|
|
|
|
|
def test_label_mapping(self): |
|
|
|
response = self.assert_create(self.project.users[0], status.HTTP_200_OK) |
|
|
|
response = self.assert_create(self.project.admin, status.HTTP_200_OK) |
|
|
|
expected = [{'label': 'Negative'}] |
|
|
|
self.assertEqual(response.json(), expected) |
|
|
|
|
|
|
@ -141,12 +141,12 @@ class TestConfigCreation(CRUDMixin): |
|
|
|
self.url = reverse(viewname='auto_labeling_configs', args=[self.project.item.id]) |
|
|
|
|
|
|
|
def test_create_config(self): |
|
|
|
response = self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
response = self.assert_create(self.project.admin, status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(response.data['model_name'], self.data['model_name']) |
|
|
|
|
|
|
|
def test_list_config(self): |
|
|
|
mommy.make('AutoLabelingConfig', project=self.project.item) |
|
|
|
response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK) |
|
|
|
response = self.assert_fetch(self.project.admin, status.HTTP_200_OK) |
|
|
|
self.assertEqual(len(response.data), 1) |
|
|
|
|
|
|
|
|
|
|
@ -168,7 +168,7 @@ class TestAutomatedLabeling(CRUDMixin): |
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}])) |
|
|
|
def test_category_labeling(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(Category.objects.count(), 1) |
|
|
|
self.assertEqual(Category.objects.first().label, self.category_pos) |
|
|
|
|
|
|
@ -182,7 +182,7 @@ class TestAutomatedLabeling(CRUDMixin): |
|
|
|
def test_multiple_configs(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(Category.objects.count(), 2) |
|
|
|
self.assertEqual(Category.objects.first().label, self.category_pos) |
|
|
|
self.assertEqual(Category.objects.last().label, self.category_neg) |
|
|
@ -197,7 +197,7 @@ class TestAutomatedLabeling(CRUDMixin): |
|
|
|
def test_cannot_label_same_category_type(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(Category.objects.count(), 1) |
|
|
|
|
|
|
|
@patch( |
|
|
@ -210,14 +210,14 @@ class TestAutomatedLabeling(CRUDMixin): |
|
|
|
def test_allow_multi_type_configs(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item) |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item) |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(Category.objects.count(), 1) |
|
|
|
self.assertEqual(Span.objects.count(), 1) |
|
|
|
|
|
|
|
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}])) |
|
|
|
def test_cannot_use_other_project_config(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Category') |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(Category.objects.count(), 0) |
|
|
|
|
|
|
|
|
|
|
@ -240,7 +240,7 @@ class TestAutomatedSpanLabeling(CRUDMixin): |
|
|
|
def test_cannot_label_overlapping_span(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item) |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item) |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(Span.objects.count(), 1) |
|
|
|
|
|
|
|
|
|
|
@ -262,5 +262,5 @@ class TestAutomatedTextLabeling(CRUDMixin): |
|
|
|
def test_cannot_label_same_text(self, mock): |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item) |
|
|
|
mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item) |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED) |
|
|
|
self.assertEqual(TextLabel.objects.count(), 1) |