You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

244 lines
11 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import pathlib
  2. from unittest.mock import patch
  3. from auto_labeling_pipeline.mappings import AmazonComprehendSentimentTemplate
  4. from auto_labeling_pipeline.models import RequestModelFactory
  5. from model_mommy import mommy
  6. from rest_framework import status
  7. from rest_framework.reverse import reverse
  8. from api.tests.utils import CRUDMixin
  9. from auto_labeling.pipeline.labels import Categories, Spans, Texts
  10. from examples.tests.utils import make_doc
  11. from labels.models import Category, Span, TextLabel
  12. from projects.models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING
  13. from projects.tests.utils import prepare_project
  14. data_dir = pathlib.Path(__file__).parent / "data"
  15. class TestTemplateList(CRUDMixin):
  16. def setUp(self):
  17. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  18. self.url = reverse(viewname="auto_labeling_templates", args=[self.project.item.id])
  19. def test_allow_admin_to_fetch_template_list(self):
  20. self.url += "?task_name=DocumentClassification"
  21. response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
  22. self.assertIn("Custom REST Request", response.data)
  23. self.assertGreaterEqual(len(response.data), 1)
  24. def test_deny_project_staff_to_fetch_template_list(self):
  25. self.url += "?task_name=DocumentClassification"
  26. for user in self.project.staffs:
  27. self.assert_fetch(user, status.HTTP_403_FORBIDDEN)
  28. def test_return_only_default_template_with_empty_task_name(self):
  29. response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
  30. self.assertEqual(len(response.data), 1)
  31. self.assertIn("Custom REST Request", response.data)
  32. def test_return_only_default_template_with_wrong_task_name(self):
  33. self.url += "?task_name=foobar"
  34. response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
  35. self.assertEqual(len(response.data), 1)
  36. self.assertIn("Custom REST Request", response.data)
  37. class TestConfigParameter(CRUDMixin):
  38. def setUp(self):
  39. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  40. self.data = {
  41. "model_name": "GCP Entity Analysis",
  42. "model_attrs": {"key": "hoge", "type": "PLAIN_TEXT", "language": "en"},
  43. "text": "example",
  44. }
  45. self.url = reverse(viewname="auto_labeling_parameter_testing", args=[self.project.item.id])
  46. @patch("auto_labeling.views.RestAPIRequestTesting.send_request", return_value={})
  47. def test_called_with_proper_model(self, mock):
  48. self.assert_create(self.project.admin, status.HTTP_200_OK)
  49. _, kwargs = mock.call_args
  50. expected = RequestModelFactory.create(self.data["model_name"], self.data["model_attrs"])
  51. self.assertEqual(kwargs["model"], expected)
  52. @patch("auto_labeling.views.RestAPIRequestTesting.send_request", return_value={})
  53. def test_called_with_text(self, mock):
  54. self.assert_create(self.project.admin, status.HTTP_200_OK)
  55. _, kwargs = mock.call_args
  56. self.assertEqual(kwargs["example"], self.data["text"])
  57. @patch("auto_labeling.views.RestAPIRequestTesting.send_request", return_value={})
  58. def test_called_with_image(self, mock):
  59. self.data["text"] = str(data_dir / "images/1500x500.jpeg")
  60. self.assert_create(self.project.admin, status.HTTP_200_OK)
  61. _, kwargs = mock.call_args
  62. self.assertEqual(kwargs["example"], self.data["text"])
  63. class TestTemplateMapping(CRUDMixin):
  64. def setUp(self):
  65. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  66. self.data = {
  67. "response": {
  68. "Sentiment": "NEUTRAL",
  69. "SentimentScore": {
  70. "Positive": 0.004438233096152544,
  71. "Negative": 0.0005306027014739811,
  72. "Neutral": 0.9950305223464966,
  73. "Mixed": 5.80838445785048e-7,
  74. },
  75. },
  76. "template": AmazonComprehendSentimentTemplate().load(),
  77. "task_type": "Category",
  78. }
  79. self.url = reverse(viewname="auto_labeling_template_test", args=[self.project.item.id])
  80. def test_template_mapping(self):
  81. response = self.assert_create(self.project.admin, status.HTTP_200_OK)
  82. expected = [{"label": "NEUTRAL"}]
  83. self.assertEqual(response.json(), expected)
  84. def test_json_decode_error(self):
  85. self.data["template"] = ""
  86. self.assert_create(self.project.admin, status.HTTP_400_BAD_REQUEST)
  87. class TestLabelMapping(CRUDMixin):
  88. def setUp(self):
  89. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  90. self.data = {
  91. "response": [{"label": "NEGATIVE"}],
  92. "label_mapping": {"NEGATIVE": "Negative"},
  93. "task_type": "Category",
  94. }
  95. self.url = reverse(viewname="auto_labeling_mapping_test", args=[self.project.item.id])
  96. def test_label_mapping(self):
  97. response = self.assert_create(self.project.admin, status.HTTP_200_OK)
  98. expected = [{"label": "Negative"}]
  99. self.assertEqual(response.json(), expected)
  100. class TestConfigCreation(CRUDMixin):
  101. def setUp(self):
  102. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
  103. self.data = {
  104. "model_name": "Amazon Comprehend Sentiment Analysis",
  105. "model_attrs": {
  106. "aws_access_key": "str",
  107. "aws_secret_access_key": "str",
  108. "region_name": "us-east-1",
  109. "language_code": "en",
  110. },
  111. "template": AmazonComprehendSentimentTemplate().load(),
  112. "label_mapping": {"NEGATIVE": "Negative"},
  113. "task_type": "Category",
  114. }
  115. self.url = reverse(viewname="auto_labeling_configs", args=[self.project.item.id])
  116. def test_create_config(self):
  117. response = self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  118. self.assertEqual(response.data["model_name"], self.data["model_name"])
  119. def test_list_config(self):
  120. mommy.make("AutoLabelingConfig", project=self.project.item)
  121. response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
  122. self.assertEqual(len(response.data), 1)
  123. class TestAutomatedLabeling(CRUDMixin):
  124. def setUp(self):
  125. self.project = prepare_project(task=DOCUMENT_CLASSIFICATION, single_class_classification=False)
  126. self.example = make_doc(self.project.item)
  127. self.category_pos = mommy.make("CategoryType", project=self.project.item, text="POS")
  128. self.category_neg = mommy.make("CategoryType", project=self.project.item, text="NEG")
  129. self.loc = mommy.make("SpanType", project=self.project.item, text="LOC")
  130. self.url = reverse(viewname="auto_labeling", args=[self.project.item.id])
  131. self.url += f"?example={self.example.id}"
  132. @patch("auto_labeling.views.execute_pipeline", return_value=Categories([{"label": "POS"}]))
  133. def test_category_labeling(self, mock):
  134. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  135. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  136. self.assertEqual(Category.objects.count(), 1)
  137. self.assertEqual(Category.objects.first().label, self.category_pos)
  138. @patch(
  139. "auto_labeling.views.execute_pipeline",
  140. side_effect=[Categories([{"label": "POS"}]), Categories([{"label": "NEG"}])],
  141. )
  142. def test_multiple_configs(self, mock):
  143. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  144. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  145. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  146. self.assertEqual(Category.objects.count(), 2)
  147. self.assertEqual(Category.objects.first().label, self.category_pos)
  148. self.assertEqual(Category.objects.last().label, self.category_neg)
  149. @patch(
  150. "auto_labeling.views.execute_pipeline",
  151. side_effect=[Categories([{"label": "POS"}]), Categories([{"label": "POS"}])],
  152. )
  153. def test_cannot_label_same_category_type(self, mock):
  154. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  155. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  156. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  157. self.assertEqual(Category.objects.count(), 1)
  158. @patch(
  159. "auto_labeling.views.execute_pipeline",
  160. side_effect=[
  161. Categories([{"label": "POS"}]),
  162. Spans([{"label": "LOC", "start_offset": 0, "end_offset": 5}]),
  163. ],
  164. )
  165. def test_allow_multi_type_configs(self, mock):
  166. mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
  167. mommy.make("AutoLabelingConfig", task_type="Span", project=self.project.item)
  168. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  169. self.assertEqual(Category.objects.count(), 1)
  170. self.assertEqual(Span.objects.count(), 1)
  171. @patch("auto_labeling.views.execute_pipeline", return_value=Categories([{"label": "POS"}]))
  172. def test_cannot_use_other_project_config(self, mock):
  173. mommy.make("AutoLabelingConfig", task_type="Category")
  174. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  175. self.assertEqual(Category.objects.count(), 0)
  176. class TestAutomatedSpanLabeling(CRUDMixin):
  177. def setUp(self):
  178. self.project = prepare_project(task=SEQUENCE_LABELING)
  179. self.example = make_doc(self.project.item)
  180. self.loc = mommy.make("SpanType", project=self.project.item, text="LOC")
  181. self.url = reverse(viewname="auto_labeling", args=[self.project.item.id])
  182. self.url += f"?example={self.example.id}"
  183. @patch(
  184. "auto_labeling.views.execute_pipeline",
  185. side_effect=[
  186. Spans([{"label": "LOC", "start_offset": 0, "end_offset": 5}]),
  187. Spans([{"label": "LOC", "start_offset": 4, "end_offset": 10}]),
  188. ],
  189. )
  190. def test_cannot_label_overlapping_span(self, mock):
  191. mommy.make("AutoLabelingConfig", task_type="Span", project=self.project.item)
  192. mommy.make("AutoLabelingConfig", task_type="Span", project=self.project.item)
  193. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  194. self.assertEqual(Span.objects.count(), 1)
  195. class TestAutomatedTextLabeling(CRUDMixin):
  196. def setUp(self):
  197. self.project = prepare_project(task=SEQ2SEQ)
  198. self.example = make_doc(self.project.item)
  199. self.url = reverse(viewname="auto_labeling", args=[self.project.item.id])
  200. self.url += f"?example={self.example.id}"
  201. @patch("auto_labeling.views.execute_pipeline", side_effect=[Texts([{"text": "foo"}]), Texts([{"text": "foo"}])])
  202. def test_cannot_label_same_text(self, mock):
  203. mommy.make("AutoLabelingConfig", task_type="Text", project=self.project.item)
  204. mommy.make("AutoLabelingConfig", task_type="Text", project=self.project.item)
  205. self.assert_create(self.project.admin, status.HTTP_201_CREATED)
  206. self.assertEqual(TextLabel.objects.count(), 1)