You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

266 lines
9.1 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. import pathlib
  2. from django.test import TestCase
  3. from data_import.celery_tasks import import_dataset
  4. from api.models import (DOCUMENT_CLASSIFICATION,
  5. INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
  6. SEQUENCE_LABELING, Example)
  7. from label_types.models import CategoryType, SpanType
  8. from labels.models import Category, Span
  9. from api.tests.api.utils import prepare_project
  10. class TestImportData(TestCase):
  11. task = 'Any'
  12. annotation_class = Category
  13. def setUp(self):
  14. self.project = prepare_project(self.task)
  15. self.user = self.project.users[0]
  16. self.data_path = pathlib.Path(__file__).parent / 'data'
  17. def import_dataset(self, filename, file_format, kwargs=None):
  18. filenames = [str(self.data_path / filename)]
  19. kwargs = kwargs or {}
  20. return import_dataset(self.user.id, self.project.item.id, filenames, file_format, **kwargs)
  21. class TestImportClassificationData(TestImportData):
  22. task = DOCUMENT_CLASSIFICATION
  23. def assert_examples(self, dataset):
  24. self.assertEqual(Example.objects.count(), len(dataset))
  25. for text, expected_labels in dataset:
  26. example = Example.objects.get(text=text)
  27. labels = set(cat.label.text for cat in example.categories.all())
  28. self.assertEqual(labels, set(expected_labels))
  29. def assert_parse_error(self, response):
  30. self.assertGreaterEqual(len(response['error']), 1)
  31. self.assertEqual(Example.objects.count(), 0)
  32. self.assertEqual(CategoryType.objects.count(), 0)
  33. self.assertEqual(Category.objects.count(), 0)
  34. def test_jsonl(self):
  35. filename = 'text_classification/example.jsonl'
  36. file_format = 'JSONL'
  37. kwargs = {'column_label': 'labels'}
  38. dataset = [
  39. ('exampleA', ['positive']),
  40. ('exampleB', ['positive', 'negative']),
  41. ('exampleC', [])
  42. ]
  43. self.import_dataset(filename, file_format, kwargs)
  44. self.assert_examples(dataset)
  45. def test_csv(self):
  46. filename = 'text_classification/example.csv'
  47. file_format = 'CSV'
  48. dataset = [
  49. ('exampleA', ['positive']),
  50. ('exampleB', [])
  51. ]
  52. self.import_dataset(filename, file_format)
  53. self.assert_examples(dataset)
  54. def test_csv_out_of_order_columns(self):
  55. filename = 'text_classification/example_out_of_order_columns.csv'
  56. file_format = 'CSV'
  57. dataset = [
  58. ('exampleA', ['positive']),
  59. ('exampleB', [])
  60. ]
  61. self.import_dataset(filename, file_format)
  62. self.assert_examples(dataset)
  63. def test_fasttext(self):
  64. filename = 'text_classification/example_fasttext.txt'
  65. file_format = 'fastText'
  66. dataset = [
  67. ('exampleA', ['positive']),
  68. ('exampleB', ['positive', 'negative']),
  69. ('exampleC', [])
  70. ]
  71. self.import_dataset(filename, file_format)
  72. self.assert_examples(dataset)
  73. def test_excel(self):
  74. filename = 'text_classification/example.xlsx'
  75. file_format = 'Excel'
  76. dataset = [
  77. ('exampleA', ['positive']),
  78. ('exampleB', [])
  79. ]
  80. self.import_dataset(filename, file_format)
  81. self.assert_examples(dataset)
  82. def test_json(self):
  83. filename = 'text_classification/example.json'
  84. file_format = 'JSON'
  85. dataset = [
  86. ('exampleA', ['positive']),
  87. ('exampleB', ['positive', 'negative']),
  88. ('exampleC', [])
  89. ]
  90. self.import_dataset(filename, file_format)
  91. self.assert_examples(dataset)
  92. def test_textfile(self):
  93. filename = 'example.txt'
  94. file_format = 'TextFile'
  95. dataset = [
  96. ('exampleA\nexampleB\n\nexampleC\n', [])
  97. ]
  98. self.import_dataset(filename, file_format)
  99. self.assert_examples(dataset)
  100. def test_textline(self):
  101. filename = 'example.txt'
  102. file_format = 'TextLine'
  103. dataset = [
  104. ('exampleA', []),
  105. ('exampleB', []),
  106. ('exampleC', [])
  107. ]
  108. self.import_dataset(filename, file_format)
  109. self.assert_examples(dataset)
  110. def test_wrong_jsonl(self):
  111. filename = 'text_classification/example.json'
  112. file_format = 'JSONL'
  113. response = self.import_dataset(filename, file_format)
  114. self.assert_parse_error(response)
  115. def test_wrong_json(self):
  116. filename = 'text_classification/example.jsonl'
  117. file_format = 'JSON'
  118. response = self.import_dataset(filename, file_format)
  119. self.assert_parse_error(response)
  120. def test_wrong_excel(self):
  121. filename = 'text_classification/example.jsonl'
  122. file_format = 'Excel'
  123. response = self.import_dataset(filename, file_format)
  124. self.assert_parse_error(response)
  125. def test_wrong_csv(self):
  126. filename = 'text_classification/example.jsonl'
  127. file_format = 'CSV'
  128. response = self.import_dataset(filename, file_format)
  129. self.assert_parse_error(response)
  130. class TestImportSequenceLabelingData(TestImportData):
  131. task = SEQUENCE_LABELING
  132. def assert_examples(self, dataset):
  133. self.assertEqual(Example.objects.count(), len(dataset))
  134. for text, expected_labels in dataset:
  135. example = Example.objects.get(text=text)
  136. labels = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
  137. self.assertEqual(labels, expected_labels)
  138. def assert_parse_error(self, response):
  139. self.assertGreaterEqual(len(response['error']), 1)
  140. self.assertEqual(Example.objects.count(), 0)
  141. self.assertEqual(SpanType.objects.count(), 0)
  142. self.assertEqual(Span.objects.count(), 0)
  143. def test_jsonl(self):
  144. filename = 'sequence_labeling/example.jsonl'
  145. file_format = 'JSONL'
  146. dataset = [
  147. ('exampleA', [[0, 1, 'LOC']]),
  148. ('exampleB', [])
  149. ]
  150. self.import_dataset(filename, file_format)
  151. self.assert_examples(dataset)
  152. def test_conll(self):
  153. filename = 'sequence_labeling/example.conll'
  154. file_format = 'CoNLL'
  155. dataset = [
  156. ('JAPAN GET', [[0, 5, 'LOC']]),
  157. ('Nadim Ladki', [[0, 11, 'PER']])
  158. ]
  159. self.import_dataset(filename, file_format)
  160. self.assert_examples(dataset)
  161. def test_wrong_conll(self):
  162. filename = 'sequence_labeling/example.jsonl'
  163. file_format = 'CoNLL'
  164. response = self.import_dataset(filename, file_format)
  165. self.assert_parse_error(response)
  166. def test_jsonl_with_overlapping(self):
  167. filename = 'sequence_labeling/example_overlapping.jsonl'
  168. file_format = 'JSONL'
  169. response = self.import_dataset(filename, file_format)
  170. self.assertEqual(len(response['error']), 1)
  171. class TestImportSeq2seqData(TestImportData):
  172. task = SEQ2SEQ
  173. def assert_examples(self, dataset):
  174. self.assertEqual(Example.objects.count(), len(dataset))
  175. for text, expected_labels in dataset:
  176. example = Example.objects.get(text=text)
  177. labels = set(text_label.text for text_label in example.texts.all())
  178. self.assertEqual(labels, set(expected_labels))
  179. def test_jsonl(self):
  180. filename = 'seq2seq/example.jsonl'
  181. file_format = 'JSONL'
  182. dataset = [
  183. ('exampleA', ['label1']),
  184. ('exampleB', [])
  185. ]
  186. self.import_dataset(filename, file_format)
  187. self.assert_examples(dataset)
  188. def test_json(self):
  189. filename = 'seq2seq/example.json'
  190. file_format = 'JSON'
  191. dataset = [
  192. ('exampleA', ['label1']),
  193. ('exampleB', [])
  194. ]
  195. self.import_dataset(filename, file_format)
  196. self.assert_examples(dataset)
  197. def test_csv(self):
  198. filename = 'seq2seq/example.csv'
  199. file_format = 'CSV'
  200. dataset = [
  201. ('exampleA', ['label1']),
  202. ('exampleB', [])
  203. ]
  204. self.import_dataset(filename, file_format)
  205. self.assert_examples(dataset)
  206. class TextImportIntentDetectionAndSlotFillingData(TestImportData):
  207. task = INTENT_DETECTION_AND_SLOT_FILLING
  208. def assert_examples(self, dataset):
  209. self.assertEqual(Example.objects.count(), len(dataset))
  210. for text, expected_labels in dataset:
  211. example = Example.objects.get(text=text)
  212. cats = set(cat.label.text for cat in example.categories.all())
  213. entities = [(span.start_offset, span.end_offset, span.label.text) for span in example.spans.all()]
  214. self.assertEqual(cats, set(expected_labels['cats']))
  215. self.assertEqual(entities, expected_labels['entities'])
  216. def test_entities_and_cats(self):
  217. filename = 'intent/example.jsonl'
  218. file_format = 'JSONL'
  219. dataset = [
  220. ('exampleA', {'cats': ['positive'], 'entities': [(0, 1, 'LOC')]}),
  221. ('exampleB', {'cats': ['positive'], 'entities': []}),
  222. ('exampleC', {'cats': [], 'entities': [(0, 1, 'LOC')]}),
  223. ('exampleD', {'cats': [], 'entities': []}),
  224. ]
  225. self.import_dataset(filename, file_format)
  226. self.assert_examples(dataset)