You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

337 lines
13 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import os
  2. import pathlib
  3. import shutil
  4. from django.core.files import File
  5. from django.test import TestCase, override_settings
  6. from django_drf_filepond.models import StoredUpload, TemporaryUpload
  7. from django_drf_filepond.utils import _get_file_id
  8. from data_import.celery_tasks import import_dataset
  9. from data_import.pipeline.catalog import RELATION_EXTRACTION
  10. from examples.models import Example
  11. from label_types.models import SpanType
  12. from labels.models import Category, Span
  13. from projects.models import (
  14. DOCUMENT_CLASSIFICATION,
  15. IMAGE_CLASSIFICATION,
  16. INTENT_DETECTION_AND_SLOT_FILLING,
  17. SEQ2SEQ,
  18. SEQUENCE_LABELING,
  19. )
  20. from projects.tests.utils import prepare_project
  21. @override_settings(MEDIA_ROOT=os.path.join(os.path.dirname(__file__), "data"))
  22. class TestImportData(TestCase):
  23. task = "Any"
  24. annotation_class = Category
  25. def setUp(self):
  26. self.project = prepare_project(self.task)
  27. self.user = self.project.admin
  28. self.data_path = pathlib.Path(__file__).parent / "data"
  29. self.upload_id = _get_file_id()
  30. def tearDown(self):
  31. try:
  32. su = StoredUpload.objects.get(upload_id=self.upload_id)
  33. directory = pathlib.Path(su.get_absolute_file_path()).parent
  34. shutil.rmtree(directory)
  35. except StoredUpload.DoesNotExist:
  36. pass
  37. def import_dataset(self, filename, file_format, task, kwargs=None):
  38. file_path = str(self.data_path / filename)
  39. TemporaryUpload.objects.create(
  40. upload_id=self.upload_id,
  41. file_id="1",
  42. file=File(open(file_path, mode="rb"), filename.split("/")[-1]),
  43. upload_name=filename,
  44. upload_type="F",
  45. )
  46. upload_ids = [self.upload_id]
  47. kwargs = kwargs or {}
  48. return import_dataset(self.user.id, self.project.item.id, file_format, upload_ids, task, **kwargs)
  49. @override_settings(MAX_UPLOAD_SIZE=0)
  50. class TestMaxFileSize(TestImportData):
  51. task = DOCUMENT_CLASSIFICATION
  52. def test_jsonl(self):
  53. filename = "text_classification/example.jsonl"
  54. file_format = "JSONL"
  55. kwargs = {"column_label": "labels"}
  56. response = self.import_dataset(filename, file_format, self.task, kwargs)
  57. self.assertEqual(len(response["error"]), 1)
  58. self.assertIn("maximum file size", response["error"][0]["message"])
  59. class TestInvalidFileFormat(TestImportData):
  60. task = DOCUMENT_CLASSIFICATION
  61. def test_invalid_file_format(self):
  62. filename = "text_classification/example.csv"
  63. file_format = "INVALID_FORMAT"
  64. response = self.import_dataset(filename, file_format, self.task)
  65. self.assertEqual(len(response["error"]), 1)
  66. class TestImportClassificationData(TestImportData):
  67. task = DOCUMENT_CLASSIFICATION
  68. def assert_examples(self, dataset):
  69. with self.subTest():
  70. self.assertEqual(Example.objects.count(), len(dataset))
  71. for text, expected_labels in dataset:
  72. example = Example.objects.get(text=text)
  73. labels = set(cat.label.text for cat in example.categories.all())
  74. self.assertEqual(labels, set(expected_labels))
  75. def assert_parse_error(self, response):
  76. with self.subTest():
  77. self.assertGreaterEqual(len(response["error"]), 1)
  78. self.assertEqual(Example.objects.count(), 0)
  79. self.assertEqual(Category.objects.count(), 0)
  80. def test_jsonl(self):
  81. filename = "text_classification/example.jsonl"
  82. file_format = "JSONL"
  83. kwargs = {"column_label": "labels"}
  84. dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
  85. self.import_dataset(filename, file_format, self.task, kwargs)
  86. self.assert_examples(dataset)
  87. def test_csv(self):
  88. filename = "text_classification/example.csv"
  89. file_format = "CSV"
  90. dataset = [("exampleA", ["positive"]), ("exampleB", [])]
  91. self.import_dataset(filename, file_format, self.task)
  92. self.assert_examples(dataset)
  93. def test_csv_out_of_order_columns(self):
  94. filename = "text_classification/example_out_of_order_columns.csv"
  95. file_format = "CSV"
  96. dataset = [("exampleA", ["positive"]), ("exampleB", [])]
  97. self.import_dataset(filename, file_format, self.task)
  98. self.assert_examples(dataset)
  99. def test_fasttext(self):
  100. filename = "text_classification/example_fasttext.txt"
  101. file_format = "fastText"
  102. dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
  103. self.import_dataset(filename, file_format, self.task)
  104. self.assert_examples(dataset)
  105. def test_excel(self):
  106. filename = "text_classification/example.xlsx"
  107. file_format = "Excel"
  108. dataset = [("exampleA", ["positive"]), ("exampleB", [])]
  109. self.import_dataset(filename, file_format, self.task)
  110. self.assert_examples(dataset)
  111. def test_json(self):
  112. filename = "text_classification/example.json"
  113. file_format = "JSON"
  114. dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
  115. self.import_dataset(filename, file_format, self.task)
  116. self.assert_examples(dataset)
  117. def test_textfile(self):
  118. filename = "example.txt"
  119. file_format = "TextFile"
  120. dataset = [("exampleA\nexampleB\n\nexampleC\n", [])]
  121. response = self.import_dataset(filename, file_format, self.task)
  122. self.assert_examples(dataset)
  123. self.assertEqual(len(response["error"]), 0)
  124. def test_textline(self):
  125. filename = "example.txt"
  126. file_format = "TextLine"
  127. dataset = [("exampleA", []), ("exampleB", []), ("exampleC", [])]
  128. response = self.import_dataset(filename, file_format, self.task)
  129. self.assert_examples(dataset)
  130. self.assertEqual(len(response["error"]), 1)
  131. def test_wrong_jsonl(self):
  132. filename = "text_classification/example.json"
  133. file_format = "JSONL"
  134. response = self.import_dataset(filename, file_format, self.task)
  135. self.assert_parse_error(response)
  136. def test_wrong_json(self):
  137. filename = "text_classification/example.jsonl"
  138. file_format = "JSON"
  139. response = self.import_dataset(filename, file_format, self.task)
  140. self.assert_parse_error(response)
  141. def test_wrong_excel(self):
  142. filename = "text_classification/example.jsonl"
  143. file_format = "Excel"
  144. response = self.import_dataset(filename, file_format, self.task)
  145. self.assert_parse_error(response)
  146. def test_wrong_csv(self):
  147. filename = "text_classification/example.jsonl"
  148. file_format = "CSV"
  149. response = self.import_dataset(filename, file_format, self.task)
  150. self.assert_parse_error(response)
  151. class TestImportSequenceLabelingData(TestImportData):
  152. task = SEQUENCE_LABELING
  153. def assert_examples(self, dataset):
  154. self.assertEqual(Example.objects.count(), len(dataset))
  155. for text, expected_labels in dataset:
  156. example = Example.objects.get(text=text)
  157. labels = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
  158. self.assertEqual(labels, expected_labels)
  159. def assert_parse_error(self, response):
  160. self.assertGreaterEqual(len(response["error"]), 1)
  161. self.assertEqual(Example.objects.count(), 0)
  162. self.assertEqual(SpanType.objects.count(), 0)
  163. self.assertEqual(Span.objects.count(), 0)
  164. def test_jsonl(self):
  165. filename = "sequence_labeling/example.jsonl"
  166. file_format = "JSONL"
  167. dataset = [("exampleA", [[0, 1, "LOC"]]), ("exampleB", [])]
  168. self.import_dataset(filename, file_format, self.task)
  169. self.assert_examples(dataset)
  170. def test_conll(self):
  171. filename = "sequence_labeling/example.conll"
  172. file_format = "CoNLL"
  173. dataset = [("JAPAN GET", [[0, 5, "LOC"]]), ("Nadim Ladki", [[0, 11, "PER"]])]
  174. self.import_dataset(filename, file_format, self.task)
  175. self.assert_examples(dataset)
  176. def test_wrong_conll(self):
  177. filename = "sequence_labeling/example.jsonl"
  178. file_format = "CoNLL"
  179. response = self.import_dataset(filename, file_format, self.task)
  180. self.assert_parse_error(response)
  181. def test_jsonl_with_overlapping(self):
  182. filename = "sequence_labeling/example_overlapping.jsonl"
  183. file_format = "JSONL"
  184. response = self.import_dataset(filename, file_format, self.task)
  185. self.assertEqual(len(response["error"]), 0)
  186. class TestImportRelationExtractionData(TestImportData):
  187. task = SEQUENCE_LABELING
  188. def setUp(self):
  189. self.project = prepare_project(self.task, use_relation=True)
  190. self.user = self.project.admin
  191. self.data_path = pathlib.Path(__file__).parent / "data"
  192. self.upload_id = _get_file_id()
  193. def assert_examples(self, dataset):
  194. self.assertEqual(Example.objects.count(), len(dataset))
  195. for text, expected_spans in dataset:
  196. example = Example.objects.get(text=text)
  197. spans = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
  198. self.assertEqual(spans, expected_spans)
  199. self.assertEqual(example.relations.count(), 3)
  200. def assert_parse_error(self, response):
  201. self.assertGreaterEqual(len(response["error"]), 1)
  202. self.assertEqual(Example.objects.count(), 0)
  203. self.assertEqual(SpanType.objects.count(), 0)
  204. self.assertEqual(Span.objects.count(), 0)
  205. def test_jsonl(self):
  206. filename = "relation_extraction/example.jsonl"
  207. file_format = "JSONL"
  208. dataset = [
  209. (
  210. "Google was founded on September 4, 1998, by Larry Page and Sergey Brin.",
  211. [[0, 6, "ORG"], [22, 39, "DATE"], [44, 54, "PERSON"], [59, 70, "PERSON"]],
  212. ),
  213. ]
  214. self.import_dataset(filename, file_format, RELATION_EXTRACTION)
  215. self.assert_examples(dataset)
  216. class TestImportSeq2seqData(TestImportData):
  217. task = SEQ2SEQ
  218. def assert_examples(self, dataset):
  219. self.assertEqual(Example.objects.count(), len(dataset))
  220. for text, expected_labels in dataset:
  221. example = Example.objects.get(text=text)
  222. labels = set(text_label.text for text_label in example.texts.all())
  223. self.assertEqual(labels, set(expected_labels))
  224. def test_jsonl(self):
  225. filename = "seq2seq/example.jsonl"
  226. file_format = "JSONL"
  227. dataset = [("exampleA", ["label1"]), ("exampleB", [])]
  228. self.import_dataset(filename, file_format, self.task)
  229. self.assert_examples(dataset)
  230. def test_json(self):
  231. filename = "seq2seq/example.json"
  232. file_format = "JSON"
  233. dataset = [("exampleA", ["label1"]), ("exampleB", [])]
  234. self.import_dataset(filename, file_format, self.task)
  235. self.assert_examples(dataset)
  236. def test_csv(self):
  237. filename = "seq2seq/example.csv"
  238. file_format = "CSV"
  239. dataset = [("exampleA", ["label1"]), ("exampleB", [])]
  240. self.import_dataset(filename, file_format, self.task)
  241. self.assert_examples(dataset)
  242. class TestImportIntentDetectionAndSlotFillingData(TestImportData):
  243. task = INTENT_DETECTION_AND_SLOT_FILLING
  244. def assert_examples(self, dataset):
  245. self.assertEqual(Example.objects.count(), len(dataset))
  246. for text, expected_labels in dataset:
  247. example = Example.objects.get(text=text)
  248. cats = set(cat.label.text for cat in example.categories.all())
  249. entities = [(span.start_offset, span.end_offset, span.label.text) for span in example.spans.all()]
  250. self.assertEqual(cats, set(expected_labels["cats"]))
  251. self.assertEqual(entities, expected_labels["entities"])
  252. def test_entities_and_cats(self):
  253. filename = "intent/example.jsonl"
  254. file_format = "JSONL"
  255. dataset = [
  256. ("exampleA", {"cats": ["positive"], "entities": [(0, 1, "LOC")]}),
  257. ("exampleB", {"cats": ["positive"], "entities": []}),
  258. ("exampleC", {"cats": [], "entities": [(0, 1, "LOC")]}),
  259. ("exampleD", {"cats": [], "entities": []}),
  260. ]
  261. self.import_dataset(filename, file_format, self.task)
  262. self.assert_examples(dataset)
  263. class TestImportImageClassificationData(TestImportData):
  264. task = IMAGE_CLASSIFICATION
  265. def test_example(self):
  266. filename = "images/1500x500.jpeg"
  267. file_format = "ImageFile"
  268. self.import_dataset(filename, file_format, self.task)
  269. self.assertEqual(Example.objects.count(), 1)
  270. @override_settings(ENABLE_FILE_TYPE_CHECK=True)
  271. class TestFileTypeChecking(TestImportData):
  272. task = IMAGE_CLASSIFICATION
  273. def test_example(self):
  274. filename = "images/example.ico"
  275. file_format = "ImageFile"
  276. response = self.import_dataset(filename, file_format, self.task)
  277. self.assertEqual(len(response["error"]), 1)
  278. self.assertIn("unexpected", response["error"][0]["message"])