You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
11 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import pathlib
  2. from unittest.mock import patch
  3. from auto_labeling_pipeline.mappings import AmazonComprehendSentimentTemplate
  4. from auto_labeling_pipeline.models import RequestModelFactory
  5. from model_mommy import mommy
  6. from rest_framework import status
  7. from rest_framework.reverse import reverse
  8. from api.tests.utils import CRUDMixin
  9. from auto_labeling.pipeline.labels import Categories, Spans, Texts
  10. from examples.tests.utils import make_doc
  11. from labels.models import Category, Span, TextLabel
  12. from projects.models import ProjectType
  13. from projects.tests.utils import prepare_project
  14. data_dir = pathlib.Path(__file__).parent / "data"
  15. class TestTemplateList(CRUDMixin):
  16. def setUp(self):
  17. self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
  18. self.url = reverse(viewname="auto_labeling_templates", args=[self.project.item.id])
  19. def test_allow_admin_to_fetch_template_list(self):
  20. self.url += "?task_name=DocumentClassification"
  21. response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
  22. self.assertIn("Custom REST Request", response.data)
  23. self.assertGreaterEqual(len(response.data), 1)
  24. def test_deny_project_staff_to_fetch_template_list(self):
  25. self.url += "?task_name=DocumentClassification"
  26. for user in self.project.staffs:
  27. self.assert_fetch(user, status.HTTP_403_FORBIDDEN)
  28. def test_return_only_default_template_with_empty_task_name(self):
  29. response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
  30. self.assertEqual(len(response.data), 1)
  31. self.assertIn("Custom REST Request", response.data)
  32. def test_return_only_default_template_with_wrong_task_name(self):
  33. self.url += "?task_name=foobar"
  34. response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
  35. self.assertEqual(len(response.data), 1)
  36. self.assertIn("Custom REST Request", response.data)
  37. class TestConfigParameter(CRUDMixin):
  38. def setUp(self):
  39. self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
  40. self.data = {
  41. "model_name": "GCP Entity Analysis",
  42. "model_attrs": {"key": "hoge", "type": "PLAIN_TEXT", "language": "en"},
  43. "text": "example",
  44. }
  45. self.url = reverse(viewname="auto_labeling_parameter_testing", args=[self.project.item.id])
  46. @patch("auto_labeling.views.RestAPIRequestTesting.send_request", return_value={})
  47. def test_called_with_proper_model(self, mock):
  48. self.assert_create(self.project.admin, status.HTTP_200_OK)
  49. _, kwargs = mock.call_args
  50. expected = RequestModelFactory.create(self.data["model_name"], self.data["model_attrs"])
  51. self.assertEqual(kwargs["model"], expected)
  52. @patch("auto_labeling.views.RestAPIRequestTesting.send_request", return_value={})
  53. def test_called_with_text(self, mock):
  54. self.assert_create(self.project.admin, status.HTTP_200_OK)
  55. _, kwargs = mock.call_args
  56. self.assertEqual(kwargs["example"], self.data["text"])
  57. @patch("auto_labeling.views.RestAPIRequestTesting.send_request", return_value={})
  58. def test_called_with_image(self, mock):
  59. self.data["text"] = str(data_dir / "images/1500x500.jpeg")
  60. self.assert_create(self.project.admin, status.HTTP_200_OK)
  61. _, kwargs = mock.call_args
  62. self.assertEqual(kwargs["example"], self.data["text"])
  63. class TestTemplateMapping(CRUDMixin):
  64. def setUp(self):
  65. self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
  66. self.data = {
  67. "response": {
  68. "Sentiment": "NEUTRAL",
  69. "SentimentScore": {
  70. "Positive": 0.004438233096152544,
  71. "Negative": 0.0005306027014739811,
  72. "Neutral": 0.9950305223464966,
  73. "Mixed": 5.80838445785048e-7,
  74. },
  75. },
  76. "template": AmazonComprehendSentimentTemplate().load(),
  77. "task_type": "Category",
  78. }
  79. self.url = reverse(viewname="auto_labeling_template_test", args=[self.project.item.id])
  80. def test_template_mapping(self):
  81. response = self.assert_create(self.project.admin, status.HTTP_200_OK)
  82. expected = [{"label": "NEUTRAL"}]
  83. self.assertEqual(response.json(), expected)
  84. def test_json_decode_error(self):
  85. self.data["template"] = ""
  86. self.assert_create(self.project.admin, status.HTTP_400_BAD_REQUEST)
  87. class TestLabelMapping(CRUDMixin):
  88. def setUp(self):
  89. self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
  90. self.data = {
  91. "response": [{"label": "NEGATIVE"}],
  92. "label_mapping": {"NEGATIVE": "Negative"},
  93. "task_type": "Category",
  94. }
  95. self.url = reverse(viewname="auto_labeling_mapping_test", args=[self.project.item.id])
  96. def test_label_mapping(self):
  97. response = self.assert_create(self.project.admin, status.HTTP_200_OK)
  98. expected = [{"label": "Negative"}]
  99. self.assertEqual(response.json(), expected)
  100. class TestConfigCreation(CRUDMixin):
  101. def setUp(self):
  102. self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
  103. self.data = {
  104. "model_name": "Amazon Comprehend Sentiment Analysis",
  105. "model_attrs": {
  106. "aws_access_key": "str",
  107. "aws_secret_access_key": "str",
  108. "region_name": "us-east-1",
  109. "language_code": "en",
  110. },
  111. "template": AmazonComprehendSentimentTemplate().load(),
  112. "label_mapping": {"NEGATIVE": "Negative"},
  113. "task_type": "Category",
  114. }
  115. self.url = reverse(viewname="auto_labeling_configs", args=[self.project.item.id])
  116. def test_create_config(self):
  117. response = self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  118. self.assertEqual(response.data["model_name"], self.data["model_name"])
  119. def test_list_config(self):
  120. mommy.make("AutoLabelingConfig", project=self.project.item)
  121. response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
  122. self.assertEqual(len(response.data), 1)
  123. class TestAutomatedLabeling(CRUDMixin):
  124. def setUp(self):
  125. self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION, single_class_classification=False)
  126. self.example = make_doc(self.project.item)
  127. self.category_pos = mommy.make("CategoryType", project=self.project.item, text="POS")
  128. self.category_neg = mommy.make("CategoryType", project=self.project.item, text="NEG")
  129. self.loc = mommy.make("SpanType", project=self.project.item, text="LOC")
  130. self.url = reverse(viewname="auto_labeling", args=[self.project.item.id])
  131. self.url += f"?example={self.example.id}"
  132. @patch("auto_labeling.views.execute_pipeline", return_value=Categories([{"label": "POS"}]))
  133. def test_category_labeling(self, mock):
  134. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  135. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  136. self.assertEqual(Category.objects.count(), 1)
  137. self.assertEqual(Category.objects.first().label, self.category_pos)
  138. @patch("auto_labeling.views.execute_pipeline", return_value=Categories([{"label": "NEUTRAL"}]))
  139. def test_nonexistent_category(self, mock):
  140. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  141. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  142. self.assertEqual(Category.objects.count(), 0)
  143. @patch(
  144. "auto_labeling.views.execute_pipeline",
  145. side_effect=[Categories([{"label": "POS"}]), Categories([{"label": "NEG"}])],
  146. )
  147. def test_multiple_configs(self, mock):
  148. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  149. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  150. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  151. self.assertEqual(Category.objects.count(), 2)
  152. self.assertEqual(Category.objects.first().label, self.category_pos)
  153. self.assertEqual(Category.objects.last().label, self.category_neg)
  154. @patch(
  155. "auto_labeling.views.execute_pipeline",
  156. side_effect=[Categories([{"label": "POS"}]), Categories([{"label": "POS"}])],
  157. )
  158. def test_cannot_label_same_category_type(self, mock):
  159. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  160. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  161. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  162. self.assertEqual(Category.objects.count(), 1)
  163. @patch(
  164. "auto_labeling.views.execute_pipeline",
  165. side_effect=[
  166. Categories([{"label": "POS"}]),
  167. Spans([{"label": "LOC", "start_offset": 0, "end_offset": 5}]),
  168. ],
  169. )
  170. def test_allow_multi_type_configs(self, mock):
  171. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  172. mommy.make("AutoLabelingConfig", task_type="Span", project=self.project.item)
  173. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  174. self.assertEqual(Category.objects.count(), 1)
  175. self.assertEqual(Span.objects.count(), 1)
  176. @patch("auto_labeling.views.execute_pipeline", return_value=Categories([{"label": "POS"}]))
  177. def test_cannot_use_other_project_config(self, mock):
  178. mommy.make("AutoLabelingConfig", task_type="Category")
  179. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  180. self.assertEqual(Category.objects.count(), 0)
  181. class TestAutomatedSpanLabeling(CRUDMixin):
  182. def setUp(self):
  183. self.project = prepare_project(task=ProjectType.SEQUENCE_LABELING)
  184. self.example = make_doc(self.project.item)
  185. self.loc = mommy.make("SpanType", project=self.project.item, text="LOC")
  186. self.url = reverse(viewname="auto_labeling", args=[self.project.item.id])
  187. self.url += f"?example={self.example.id}"
  188. @patch(
  189. "auto_labeling.views.execute_pipeline",
  190. side_effect=[
  191. Spans([{"label": "LOC", "start_offset": 0, "end_offset": 5}]),
  192. Spans([{"label": "LOC", "start_offset": 4, "end_offset": 10}]),
  193. ],
  194. )
  195. def test_cannot_label_overlapping_span(self, mock):
  196. mommy.make("AutoLabelingConfig", task_type="Span", project=self.project.item)
  197. mommy.make("AutoLabelingConfig", task_type="Span", project=self.project.item)
  198. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  199. self.assertEqual(Span.objects.count(), 1)
  200. class TestAutomatedTextLabeling(CRUDMixin):
  201. def setUp(self):
  202. self.project = prepare_project(task=ProjectType.SEQ2SEQ)
  203. self.example = make_doc(self.project.item)
  204. self.url = reverse(viewname="auto_labeling", args=[self.project.item.id])
  205. self.url += f"?example={self.example.id}"
  206. @patch("auto_labeling.views.execute_pipeline", side_effect=[Texts([{"text": "foo"}]), Texts([{"text": "foo"}])])
  207. def test_cannot_label_same_text(self, mock):
  208. mommy.make("AutoLabelingConfig", task_type="Text", project=self.project.item)
  209. mommy.make("AutoLabelingConfig", task_type="Text", project=self.project.item)
  210. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  211. self.assertEqual(TextLabel.objects.count(), 1)