You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

265 lines
9.1 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. import pathlib
  2. from django.test import TestCase
  3. from data_import.celery_tasks import import_dataset
  4. from api.models import (DOCUMENT_CLASSIFICATION,
  5. INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
  6. SEQUENCE_LABELING, Category, CategoryType, Example, Span,
  7. SpanType)
  8. from api.tests.api.utils import prepare_project
  9. class TestImportData(TestCase):
  10. task = 'Any'
  11. annotation_class = Category
  12. def setUp(self):
  13. self.project = prepare_project(self.task)
  14. self.user = self.project.users[0]
  15. self.data_path = pathlib.Path(__file__).parent / 'data'
  16. def import_dataset(self, filename, file_format, kwargs=None):
  17. filenames = [str(self.data_path / filename)]
  18. kwargs = kwargs or {}
  19. return import_dataset(self.user.id, self.project.item.id, filenames, file_format, **kwargs)
  20. class TestImportClassificationData(TestImportData):
  21. task = DOCUMENT_CLASSIFICATION
  22. def assert_examples(self, dataset):
  23. self.assertEqual(Example.objects.count(), len(dataset))
  24. for text, expected_labels in dataset:
  25. example = Example.objects.get(text=text)
  26. labels = set(cat.label.text for cat in example.categories.all())
  27. self.assertEqual(labels, set(expected_labels))
  28. def assert_parse_error(self, response):
  29. self.assertGreaterEqual(len(response['error']), 1)
  30. self.assertEqual(Example.objects.count(), 0)
  31. self.assertEqual(CategoryType.objects.count(), 0)
  32. self.assertEqual(Category.objects.count(), 0)
  33. def test_jsonl(self):
  34. filename = 'text_classification/example.jsonl'
  35. file_format = 'JSONL'
  36. kwargs = {'column_label': 'labels'}
  37. dataset = [
  38. ('exampleA', ['positive']),
  39. ('exampleB', ['positive', 'negative']),
  40. ('exampleC', [])
  41. ]
  42. self.import_dataset(filename, file_format, kwargs)
  43. self.assert_examples(dataset)
  44. def test_csv(self):
  45. filename = 'text_classification/example.csv'
  46. file_format = 'CSV'
  47. dataset = [
  48. ('exampleA', ['positive']),
  49. ('exampleB', [])
  50. ]
  51. self.import_dataset(filename, file_format)
  52. self.assert_examples(dataset)
  53. def test_csv_out_of_order_columns(self):
  54. filename = 'text_classification/example_out_of_order_columns.csv'
  55. file_format = 'CSV'
  56. dataset = [
  57. ('exampleA', ['positive']),
  58. ('exampleB', [])
  59. ]
  60. self.import_dataset(filename, file_format)
  61. self.assert_examples(dataset)
  62. def test_fasttext(self):
  63. filename = 'text_classification/example_fasttext.txt'
  64. file_format = 'fastText'
  65. dataset = [
  66. ('exampleA', ['positive']),
  67. ('exampleB', ['positive', 'negative']),
  68. ('exampleC', [])
  69. ]
  70. self.import_dataset(filename, file_format)
  71. self.assert_examples(dataset)
  72. def test_excel(self):
  73. filename = 'text_classification/example.xlsx'
  74. file_format = 'Excel'
  75. dataset = [
  76. ('exampleA', ['positive']),
  77. ('exampleB', [])
  78. ]
  79. self.import_dataset(filename, file_format)
  80. self.assert_examples(dataset)
  81. def test_json(self):
  82. filename = 'text_classification/example.json'
  83. file_format = 'JSON'
  84. dataset = [
  85. ('exampleA', ['positive']),
  86. ('exampleB', ['positive', 'negative']),
  87. ('exampleC', [])
  88. ]
  89. self.import_dataset(filename, file_format)
  90. self.assert_examples(dataset)
  91. def test_textfile(self):
  92. filename = 'example.txt'
  93. file_format = 'TextFile'
  94. dataset = [
  95. ('exampleA\nexampleB\n\nexampleC\n', [])
  96. ]
  97. self.import_dataset(filename, file_format)
  98. self.assert_examples(dataset)
  99. def test_textline(self):
  100. filename = 'example.txt'
  101. file_format = 'TextLine'
  102. dataset = [
  103. ('exampleA', []),
  104. ('exampleB', []),
  105. ('exampleC', [])
  106. ]
  107. self.import_dataset(filename, file_format)
  108. self.assert_examples(dataset)
  109. def test_wrong_jsonl(self):
  110. filename = 'text_classification/example.json'
  111. file_format = 'JSONL'
  112. response = self.import_dataset(filename, file_format)
  113. self.assert_parse_error(response)
  114. def test_wrong_json(self):
  115. filename = 'text_classification/example.jsonl'
  116. file_format = 'JSON'
  117. response = self.import_dataset(filename, file_format)
  118. self.assert_parse_error(response)
  119. def test_wrong_excel(self):
  120. filename = 'text_classification/example.jsonl'
  121. file_format = 'Excel'
  122. response = self.import_dataset(filename, file_format)
  123. self.assert_parse_error(response)
  124. def test_wrong_csv(self):
  125. filename = 'text_classification/example.jsonl'
  126. file_format = 'CSV'
  127. response = self.import_dataset(filename, file_format)
  128. self.assert_parse_error(response)
  129. class TestImportSequenceLabelingData(TestImportData):
  130. task = SEQUENCE_LABELING
  131. def assert_examples(self, dataset):
  132. self.assertEqual(Example.objects.count(), len(dataset))
  133. for text, expected_labels in dataset:
  134. example = Example.objects.get(text=text)
  135. labels = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
  136. self.assertEqual(labels, expected_labels)
  137. def assert_parse_error(self, response):
  138. self.assertGreaterEqual(len(response['error']), 1)
  139. self.assertEqual(Example.objects.count(), 0)
  140. self.assertEqual(SpanType.objects.count(), 0)
  141. self.assertEqual(Span.objects.count(), 0)
  142. def test_jsonl(self):
  143. filename = 'sequence_labeling/example.jsonl'
  144. file_format = 'JSONL'
  145. dataset = [
  146. ('exampleA', [[0, 1, 'LOC']]),
  147. ('exampleB', [])
  148. ]
  149. self.import_dataset(filename, file_format)
  150. self.assert_examples(dataset)
  151. def test_conll(self):
  152. filename = 'sequence_labeling/example.conll'
  153. file_format = 'CoNLL'
  154. dataset = [
  155. ('JAPAN GET', [[0, 5, 'LOC']]),
  156. ('Nadim Ladki', [[0, 11, 'PER']])
  157. ]
  158. self.import_dataset(filename, file_format)
  159. self.assert_examples(dataset)
  160. def test_wrong_conll(self):
  161. filename = 'sequence_labeling/example.jsonl'
  162. file_format = 'CoNLL'
  163. response = self.import_dataset(filename, file_format)
  164. self.assert_parse_error(response)
  165. def test_jsonl_with_overlapping(self):
  166. filename = 'sequence_labeling/example_overlapping.jsonl'
  167. file_format = 'JSONL'
  168. response = self.import_dataset(filename, file_format)
  169. self.assertEqual(len(response['error']), 1)
  170. class TestImportSeq2seqData(TestImportData):
  171. task = SEQ2SEQ
  172. def assert_examples(self, dataset):
  173. self.assertEqual(Example.objects.count(), len(dataset))
  174. for text, expected_labels in dataset:
  175. example = Example.objects.get(text=text)
  176. labels = set(text_label.text for text_label in example.texts.all())
  177. self.assertEqual(labels, set(expected_labels))
  178. def test_jsonl(self):
  179. filename = 'seq2seq/example.jsonl'
  180. file_format = 'JSONL'
  181. dataset = [
  182. ('exampleA', ['label1']),
  183. ('exampleB', [])
  184. ]
  185. self.import_dataset(filename, file_format)
  186. self.assert_examples(dataset)
  187. def test_json(self):
  188. filename = 'seq2seq/example.json'
  189. file_format = 'JSON'
  190. dataset = [
  191. ('exampleA', ['label1']),
  192. ('exampleB', [])
  193. ]
  194. self.import_dataset(filename, file_format)
  195. self.assert_examples(dataset)
  196. def test_csv(self):
  197. filename = 'seq2seq/example.csv'
  198. file_format = 'CSV'
  199. dataset = [
  200. ('exampleA', ['label1']),
  201. ('exampleB', [])
  202. ]
  203. self.import_dataset(filename, file_format)
  204. self.assert_examples(dataset)
  205. class TextImportIntentDetectionAndSlotFillingData(TestImportData):
  206. task = INTENT_DETECTION_AND_SLOT_FILLING
  207. def assert_examples(self, dataset):
  208. self.assertEqual(Example.objects.count(), len(dataset))
  209. for text, expected_labels in dataset:
  210. example = Example.objects.get(text=text)
  211. cats = set(cat.label.text for cat in example.categories.all())
  212. entities = [(span.start_offset, span.end_offset, span.label.text) for span in example.spans.all()]
  213. self.assertEqual(cats, set(expected_labels['cats']))
  214. self.assertEqual(entities, expected_labels['entities'])
  215. def test_entities_and_cats(self):
  216. filename = 'intent/example.jsonl'
  217. file_format = 'JSONL'
  218. dataset = [
  219. ('exampleA', {'cats': ['positive'], 'entities': [(0, 1, 'LOC')]}),
  220. ('exampleB', {'cats': ['positive'], 'entities': []}),
  221. ('exampleC', {'cats': [], 'entities': [(0, 1, 'LOC')]}),
  222. ('exampleD', {'cats': [], 'entities': []}),
  223. ]
  224. self.import_dataset(filename, file_format)
  225. self.assert_examples(dataset)