You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

266 lines
11 KiB

  1. import pathlib
  2. from unittest.mock import patch
  3. from auto_labeling_pipeline.mappings import AmazonComprehendSentimentTemplate
  4. from auto_labeling_pipeline.models import RequestModelFactory
  5. from model_mommy import mommy
  6. from rest_framework import status
  7. from rest_framework.reverse import reverse
  8. from api.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
  9. from labels.models import Category, Span, TextLabel
  10. from api.tests.api.utils import CRUDMixin, make_doc, prepare_project
  11. from auto_labeling.pipeline.labels import Categories, Spans, Texts
  12. data_dir = pathlib.Path(__file__).parent / 'data'
  13. class TestTemplateList(CRUDMixin):
  14. def setUp(self):
  15. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  16. self.url = reverse(viewname='auto_labeling_templates', args=[self.project.item.id])
  17. def test_allow_admin_to_fetch_template_list(self):
  18. self.url += '?task_name=DocumentClassification'
  19. response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
  20. self.assertIn('Custom REST Request', response.data)
  21. self.assertGreaterEqual(len(response.data), 1)
  22. def test_deny_non_admin_to_fetch_template_list(self):
  23. self.url += '?task_name=DocumentClassification'
  24. for user in self.project.users[1:]:
  25. self.assert_fetch(user, status.HTTP_403_FORBIDDEN)
  26. def test_return_only_default_template_with_empty_task_name(self):
  27. response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
  28. self.assertEqual(len(response.data), 1)
  29. self.assertIn('Custom REST Request', response.data)
  30. def test_return_only_default_template_with_wrong_task_name(self):
  31. self.url += '?task_name=foobar'
  32. response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
  33. self.assertEqual(len(response.data), 1)
  34. self.assertIn('Custom REST Request', response.data)
  35. class TestConfigParameter(CRUDMixin):
  36. def setUp(self):
  37. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  38. self.data = {
  39. 'model_name': 'GCP Entity Analysis',
  40. 'model_attrs': {'key': 'hoge', 'type': 'PLAIN_TEXT', 'language': 'en'},
  41. 'text': 'example'
  42. }
  43. self.url = reverse(viewname='auto_labeling_parameter_testing', args=[self.project.item.id])
  44. @patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={})
  45. def test_called_with_proper_model(self, mock):
  46. self.assert_create(self.project.users[0], status.HTTP_200_OK)
  47. _, kwargs = mock.call_args
  48. expected = RequestModelFactory.create(self.data['model_name'], self.data['model_attrs'])
  49. self.assertEqual(kwargs['model'], expected)
  50. @patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={})
  51. def test_called_with_text(self, mock):
  52. self.assert_create(self.project.users[0], status.HTTP_200_OK)
  53. _, kwargs = mock.call_args
  54. self.assertEqual(kwargs['example'], self.data['text'])
  55. @patch('auto_labeling.views.RestAPIRequestTesting.send_request', return_value={})
  56. def test_called_with_image(self, mock):
  57. self.data['text'] = str(data_dir / 'images/1500x500.jpeg')
  58. self.assert_create(self.project.users[0], status.HTTP_200_OK)
  59. _, kwargs = mock.call_args
  60. self.assertEqual(kwargs['example'], self.data['text'])
  61. class TestTemplateMapping(CRUDMixin):
  62. def setUp(self):
  63. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  64. self.data = {
  65. 'response': {
  66. 'Sentiment': 'NEUTRAL',
  67. 'SentimentScore': {
  68. 'Positive': 0.004438233096152544,
  69. 'Negative': 0.0005306027014739811,
  70. 'Neutral': 0.9950305223464966,
  71. 'Mixed': 5.80838445785048e-7
  72. }
  73. },
  74. 'template': AmazonComprehendSentimentTemplate().load(),
  75. 'task_type': 'Category'
  76. }
  77. self.url = reverse(viewname='auto_labeling_template_test', args=[self.project.item.id])
  78. def test_template_mapping(self):
  79. response = self.assert_create(self.project.users[0], status.HTTP_200_OK)
  80. expected = [{'label': 'NEUTRAL'}]
  81. self.assertEqual(response.json(), expected)
  82. def test_json_decode_error(self):
  83. self.data['template'] = ''
  84. self.assert_create(self.project.users[0], status.HTTP_400_BAD_REQUEST)
  85. class TestLabelMapping(CRUDMixin):
  86. def setUp(self):
  87. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  88. self.data = {
  89. 'response': [{'label': 'NEGATIVE'}],
  90. 'label_mapping': {'NEGATIVE': 'Negative'},
  91. 'task_type': 'Category'
  92. }
  93. self.url = reverse(viewname='auto_labeling_mapping_test', args=[self.project.item.id])
  94. def test_label_mapping(self):
  95. response = self.assert_create(self.project.users[0], status.HTTP_200_OK)
  96. expected = [{'label': 'Negative'}]
  97. self.assertEqual(response.json(), expected)
  98. class TestConfigCreation(CRUDMixin):
  99. def setUp(self):
  100. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  101. self.data = {
  102. 'model_name': 'Amazon Comprehend Sentiment Analysis',
  103. 'model_attrs': {
  104. 'aws_access_key': 'str',
  105. 'aws_secret_access_key': 'str',
  106. 'region_name': 'us-east-1',
  107. 'language_code': 'en'
  108. },
  109. 'template': AmazonComprehendSentimentTemplate().load(),
  110. 'label_mapping': {'NEGATIVE': 'Negative'},
  111. 'task_type': 'Category'
  112. }
  113. self.url = reverse(viewname='auto_labeling_configs', args=[self.project.item.id])
  114. def test_create_config(self):
  115. response = self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
  116. self.assertEqual(response.data['model_name'], self.data['model_name'])
  117. def test_list_config(self):
  118. mommy.make('AutoLabelingConfig', project=self.project.item)
  119. response = self.assert_fetch(self.project.users[0], status.HTTP_200_OK)
  120. self.assertEqual(len(response.data), 1)
  121. class TestAutomatedLabeling(CRUDMixin):
  122. def setUp(self):
  123. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION, single_class_classification=False)
  124. self.example = make_doc(self.project.item)
  125. self.category_pos = mommy.make(
  126. 'CategoryType', project=self.project.item, text='POS'
  127. )
  128. self.category_neg = mommy.make(
  129. 'CategoryType', project=self.project.item, text='NEG'
  130. )
  131. self.loc = mommy.make('SpanType', project=self.project.item, text='LOC')
  132. self.url = reverse(viewname='auto_labeling', args=[self.project.item.id])
  133. self.url += f'?example={self.example.id}'
  134. @patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}]))
  135. def test_category_labeling(self, mock):
  136. mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
  137. self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
  138. self.assertEqual(Category.objects.count(), 1)
  139. self.assertEqual(Category.objects.first().label, self.category_pos)
  140. @patch(
  141. 'auto_labeling.views.execute_pipeline',
  142. side_effect=[
  143. Categories([{'label': 'POS'}]),
  144. Categories([{'label': 'NEG'}])
  145. ]
  146. )
  147. def test_multiple_configs(self, mock):
  148. mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
  149. mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
  150. self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
  151. self.assertEqual(Category.objects.count(), 2)
  152. self.assertEqual(Category.objects.first().label, self.category_pos)
  153. self.assertEqual(Category.objects.last().label, self.category_neg)
  154. @patch(
  155. 'auto_labeling.views.execute_pipeline',
  156. side_effect=[
  157. Categories([{'label': 'POS'}]),
  158. Categories([{'label': 'POS'}])
  159. ]
  160. )
  161. def test_cannot_label_same_category_type(self, mock):
  162. mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
  163. mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
  164. self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
  165. self.assertEqual(Category.objects.count(), 1)
  166. @patch(
  167. 'auto_labeling.views.execute_pipeline',
  168. side_effect=[
  169. Categories([{'label': 'POS'}]),
  170. Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]),
  171. ]
  172. )
  173. def test_allow_multi_type_configs(self, mock):
  174. mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
  175. mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item)
  176. self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
  177. self.assertEqual(Category.objects.count(), 1)
  178. self.assertEqual(Span.objects.count(), 1)
  179. @patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}]))
  180. def test_cannot_use_other_project_config(self, mock):
  181. mommy.make('AutoLabelingConfig', task_type='Category')
  182. self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
  183. self.assertEqual(Category.objects.count(), 0)
  184. class TestAutomatedSpanLabeling(CRUDMixin):
  185. def setUp(self):
  186. self.project = prepare_project(task=SEQUENCE_LABELING)
  187. self.example = make_doc(self.project.item)
  188. self.loc = mommy.make('SpanType', project=self.project.item, text='LOC')
  189. self.url = reverse(viewname='auto_labeling', args=[self.project.item.id])
  190. self.url += f'?example={self.example.id}'
  191. @patch(
  192. 'auto_labeling.views.execute_pipeline',
  193. side_effect=[
  194. Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]),
  195. Spans([{'label': 'LOC', 'start_offset': 4, 'end_offset': 10}])
  196. ]
  197. )
  198. def test_cannot_label_overlapping_span(self, mock):
  199. mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item)
  200. mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item)
  201. self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
  202. self.assertEqual(Span.objects.count(), 1)
  203. class TestAutomatedTextLabeling(CRUDMixin):
  204. def setUp(self):
  205. self.project = prepare_project(task=SEQ2SEQ)
  206. self.example = make_doc(self.project.item)
  207. self.url = reverse(viewname='auto_labeling', args=[self.project.item.id])
  208. self.url += f'?example={self.example.id}'
  209. @patch(
  210. 'auto_labeling.views.execute_pipeline',
  211. side_effect=[
  212. Texts([{'text': 'foo'}]),
  213. Texts([{'text': 'foo'}])
  214. ]
  215. )
  216. def test_cannot_label_same_text(self, mock):
  217. mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item)
  218. mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item)
  219. self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
  220. self.assertEqual(TextLabel.objects.count(), 1)