You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

239 lines
9.0 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import pathlib
  2. from django.test import TestCase
  3. from data_import.celery_tasks import import_dataset
  4. from examples.models import Example
  5. from label_types.models import CategoryType, SpanType
  6. from labels.models import Category, Span
  7. from projects.models import (
  8. DOCUMENT_CLASSIFICATION,
  9. IMAGE_CLASSIFICATION,
  10. INTENT_DETECTION_AND_SLOT_FILLING,
  11. SEQ2SEQ,
  12. SEQUENCE_LABELING,
  13. )
  14. from projects.tests.utils import prepare_project
  15. class TestImportData(TestCase):
  16. task = "Any"
  17. annotation_class = Category
  18. def setUp(self):
  19. self.project = prepare_project(self.task)
  20. self.user = self.project.admin
  21. self.data_path = pathlib.Path(__file__).parent / "data"
  22. def import_dataset(self, filename, file_format, kwargs=None):
  23. filenames = [str(self.data_path / filename)]
  24. kwargs = kwargs or {}
  25. return import_dataset(self.user.id, self.project.item.id, filenames, file_format, **kwargs)
  26. class TestImportClassificationData(TestImportData):
  27. task = DOCUMENT_CLASSIFICATION
  28. def assert_examples(self, dataset):
  29. self.assertEqual(Example.objects.count(), len(dataset))
  30. for text, expected_labels in dataset:
  31. example = Example.objects.get(text=text)
  32. labels = set(cat.label.text for cat in example.categories.all())
  33. self.assertEqual(labels, set(expected_labels))
  34. def assert_parse_error(self, response):
  35. self.assertGreaterEqual(len(response["error"]), 1)
  36. self.assertEqual(Example.objects.count(), 0)
  37. self.assertEqual(CategoryType.objects.count(), 0)
  38. self.assertEqual(Category.objects.count(), 0)
  39. def test_jsonl(self):
  40. filename = "text_classification/example.jsonl"
  41. file_format = "JSONL"
  42. kwargs = {"column_label": "labels"}
  43. dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
  44. self.import_dataset(filename, file_format, kwargs)
  45. self.assert_examples(dataset)
  46. def test_csv(self):
  47. filename = "text_classification/example.csv"
  48. file_format = "CSV"
  49. dataset = [("exampleA", ["positive"]), ("exampleB", [])]
  50. self.import_dataset(filename, file_format)
  51. self.assert_examples(dataset)
  52. def test_csv_out_of_order_columns(self):
  53. filename = "text_classification/example_out_of_order_columns.csv"
  54. file_format = "CSV"
  55. dataset = [("exampleA", ["positive"]), ("exampleB", [])]
  56. self.import_dataset(filename, file_format)
  57. self.assert_examples(dataset)
  58. def test_fasttext(self):
  59. filename = "text_classification/example_fasttext.txt"
  60. file_format = "fastText"
  61. dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
  62. self.import_dataset(filename, file_format)
  63. self.assert_examples(dataset)
  64. def test_excel(self):
  65. filename = "text_classification/example.xlsx"
  66. file_format = "Excel"
  67. dataset = [("exampleA", ["positive"]), ("exampleB", [])]
  68. self.import_dataset(filename, file_format)
  69. self.assert_examples(dataset)
  70. def test_json(self):
  71. filename = "text_classification/example.json"
  72. file_format = "JSON"
  73. dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
  74. self.import_dataset(filename, file_format)
  75. self.assert_examples(dataset)
  76. def test_textfile(self):
  77. filename = "example.txt"
  78. file_format = "TextFile"
  79. dataset = [("exampleA\nexampleB\n\nexampleC\n", [])]
  80. self.import_dataset(filename, file_format)
  81. self.assert_examples(dataset)
  82. def test_textline(self):
  83. filename = "example.txt"
  84. file_format = "TextLine"
  85. dataset = [("exampleA", []), ("exampleB", []), ("exampleC", [])]
  86. self.import_dataset(filename, file_format)
  87. self.assert_examples(dataset)
  88. def test_wrong_jsonl(self):
  89. filename = "text_classification/example.json"
  90. file_format = "JSONL"
  91. response = self.import_dataset(filename, file_format)
  92. self.assert_parse_error(response)
  93. def test_wrong_json(self):
  94. filename = "text_classification/example.jsonl"
  95. file_format = "JSON"
  96. response = self.import_dataset(filename, file_format)
  97. self.assert_parse_error(response)
  98. def test_wrong_excel(self):
  99. filename = "text_classification/example.jsonl"
  100. file_format = "Excel"
  101. response = self.import_dataset(filename, file_format)
  102. self.assert_parse_error(response)
  103. def test_wrong_csv(self):
  104. filename = "text_classification/example.jsonl"
  105. file_format = "CSV"
  106. response = self.import_dataset(filename, file_format)
  107. self.assert_parse_error(response)
  108. class TestImportSequenceLabelingData(TestImportData):
  109. task = SEQUENCE_LABELING
  110. def assert_examples(self, dataset):
  111. self.assertEqual(Example.objects.count(), len(dataset))
  112. for text, expected_labels in dataset:
  113. example = Example.objects.get(text=text)
  114. labels = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
  115. self.assertEqual(labels, expected_labels)
  116. def assert_parse_error(self, response):
  117. self.assertGreaterEqual(len(response["error"]), 1)
  118. self.assertEqual(Example.objects.count(), 0)
  119. self.assertEqual(SpanType.objects.count(), 0)
  120. self.assertEqual(Span.objects.count(), 0)
  121. def test_jsonl(self):
  122. filename = "sequence_labeling/example.jsonl"
  123. file_format = "JSONL"
  124. dataset = [("exampleA", [[0, 1, "LOC"]]), ("exampleB", [])]
  125. self.import_dataset(filename, file_format)
  126. self.assert_examples(dataset)
  127. def test_conll(self):
  128. filename = "sequence_labeling/example.conll"
  129. file_format = "CoNLL"
  130. dataset = [("JAPAN GET", [[0, 5, "LOC"]]), ("Nadim Ladki", [[0, 11, "PER"]])]
  131. self.import_dataset(filename, file_format)
  132. self.assert_examples(dataset)
  133. def test_wrong_conll(self):
  134. filename = "sequence_labeling/example.jsonl"
  135. file_format = "CoNLL"
  136. response = self.import_dataset(filename, file_format)
  137. self.assert_parse_error(response)
  138. def test_jsonl_with_overlapping(self):
  139. filename = "sequence_labeling/example_overlapping.jsonl"
  140. file_format = "JSONL"
  141. response = self.import_dataset(filename, file_format)
  142. self.assertEqual(len(response["error"]), 1)
  143. class TestImportSeq2seqData(TestImportData):
  144. task = SEQ2SEQ
  145. def assert_examples(self, dataset):
  146. self.assertEqual(Example.objects.count(), len(dataset))
  147. for text, expected_labels in dataset:
  148. example = Example.objects.get(text=text)
  149. labels = set(text_label.text for text_label in example.texts.all())
  150. self.assertEqual(labels, set(expected_labels))
  151. def test_jsonl(self):
  152. filename = "seq2seq/example.jsonl"
  153. file_format = "JSONL"
  154. dataset = [("exampleA", ["label1"]), ("exampleB", [])]
  155. self.import_dataset(filename, file_format)
  156. self.assert_examples(dataset)
  157. def test_json(self):
  158. filename = "seq2seq/example.json"
  159. file_format = "JSON"
  160. dataset = [("exampleA", ["label1"]), ("exampleB", [])]
  161. self.import_dataset(filename, file_format)
  162. self.assert_examples(dataset)
  163. def test_csv(self):
  164. filename = "seq2seq/example.csv"
  165. file_format = "CSV"
  166. dataset = [("exampleA", ["label1"]), ("exampleB", [])]
  167. self.import_dataset(filename, file_format)
  168. self.assert_examples(dataset)
  169. class TestImportIntentDetectionAndSlotFillingData(TestImportData):
  170. task = INTENT_DETECTION_AND_SLOT_FILLING
  171. def assert_examples(self, dataset):
  172. self.assertEqual(Example.objects.count(), len(dataset))
  173. for text, expected_labels in dataset:
  174. example = Example.objects.get(text=text)
  175. cats = set(cat.label.text for cat in example.categories.all())
  176. entities = [(span.start_offset, span.end_offset, span.label.text) for span in example.spans.all()]
  177. self.assertEqual(cats, set(expected_labels["cats"]))
  178. self.assertEqual(entities, expected_labels["entities"])
  179. def test_entities_and_cats(self):
  180. filename = "intent/example.jsonl"
  181. file_format = "JSONL"
  182. dataset = [
  183. ("exampleA", {"cats": ["positive"], "entities": [(0, 1, "LOC")]}),
  184. ("exampleB", {"cats": ["positive"], "entities": []}),
  185. ("exampleC", {"cats": [], "entities": [(0, 1, "LOC")]}),
  186. ("exampleD", {"cats": [], "entities": []}),
  187. ]
  188. self.import_dataset(filename, file_format)
  189. self.assert_examples(dataset)
  190. class TestImportImageClassificationData(TestImportData):
  191. task = IMAGE_CLASSIFICATION
  192. def test_example(self):
  193. filename = "images/1500x500.jpeg"
  194. file_format = "ImageFile"
  195. self.import_dataset(filename, file_format)
  196. self.assertEqual(Example.objects.count(), 1)