You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

267 lines
9.1 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. import pathlib
  2. from django.test import TestCase
  3. from data_import.celery_tasks import import_dataset
  4. from api.models import (DOCUMENT_CLASSIFICATION,
  5. INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
  6. SEQUENCE_LABELING)
  7. from examples.models import Example
  8. from label_types.models import CategoryType, SpanType
  9. from labels.models import Category, Span
  10. from api.tests.api.utils import prepare_project
  11. class TestImportData(TestCase):
  12. task = 'Any'
  13. annotation_class = Category
  14. def setUp(self):
  15. self.project = prepare_project(self.task)
  16. self.user = self.project.users[0]
  17. self.data_path = pathlib.Path(__file__).parent / 'data'
  18. def import_dataset(self, filename, file_format, kwargs=None):
  19. filenames = [str(self.data_path / filename)]
  20. kwargs = kwargs or {}
  21. return import_dataset(self.user.id, self.project.item.id, filenames, file_format, **kwargs)
  22. class TestImportClassificationData(TestImportData):
  23. task = DOCUMENT_CLASSIFICATION
  24. def assert_examples(self, dataset):
  25. self.assertEqual(Example.objects.count(), len(dataset))
  26. for text, expected_labels in dataset:
  27. example = Example.objects.get(text=text)
  28. labels = set(cat.label.text for cat in example.categories.all())
  29. self.assertEqual(labels, set(expected_labels))
  30. def assert_parse_error(self, response):
  31. self.assertGreaterEqual(len(response['error']), 1)
  32. self.assertEqual(Example.objects.count(), 0)
  33. self.assertEqual(CategoryType.objects.count(), 0)
  34. self.assertEqual(Category.objects.count(), 0)
  35. def test_jsonl(self):
  36. filename = 'text_classification/example.jsonl'
  37. file_format = 'JSONL'
  38. kwargs = {'column_label': 'labels'}
  39. dataset = [
  40. ('exampleA', ['positive']),
  41. ('exampleB', ['positive', 'negative']),
  42. ('exampleC', [])
  43. ]
  44. self.import_dataset(filename, file_format, kwargs)
  45. self.assert_examples(dataset)
  46. def test_csv(self):
  47. filename = 'text_classification/example.csv'
  48. file_format = 'CSV'
  49. dataset = [
  50. ('exampleA', ['positive']),
  51. ('exampleB', [])
  52. ]
  53. self.import_dataset(filename, file_format)
  54. self.assert_examples(dataset)
  55. def test_csv_out_of_order_columns(self):
  56. filename = 'text_classification/example_out_of_order_columns.csv'
  57. file_format = 'CSV'
  58. dataset = [
  59. ('exampleA', ['positive']),
  60. ('exampleB', [])
  61. ]
  62. self.import_dataset(filename, file_format)
  63. self.assert_examples(dataset)
  64. def test_fasttext(self):
  65. filename = 'text_classification/example_fasttext.txt'
  66. file_format = 'fastText'
  67. dataset = [
  68. ('exampleA', ['positive']),
  69. ('exampleB', ['positive', 'negative']),
  70. ('exampleC', [])
  71. ]
  72. self.import_dataset(filename, file_format)
  73. self.assert_examples(dataset)
  74. def test_excel(self):
  75. filename = 'text_classification/example.xlsx'
  76. file_format = 'Excel'
  77. dataset = [
  78. ('exampleA', ['positive']),
  79. ('exampleB', [])
  80. ]
  81. self.import_dataset(filename, file_format)
  82. self.assert_examples(dataset)
  83. def test_json(self):
  84. filename = 'text_classification/example.json'
  85. file_format = 'JSON'
  86. dataset = [
  87. ('exampleA', ['positive']),
  88. ('exampleB', ['positive', 'negative']),
  89. ('exampleC', [])
  90. ]
  91. self.import_dataset(filename, file_format)
  92. self.assert_examples(dataset)
  93. def test_textfile(self):
  94. filename = 'example.txt'
  95. file_format = 'TextFile'
  96. dataset = [
  97. ('exampleA\nexampleB\n\nexampleC\n', [])
  98. ]
  99. self.import_dataset(filename, file_format)
  100. self.assert_examples(dataset)
  101. def test_textline(self):
  102. filename = 'example.txt'
  103. file_format = 'TextLine'
  104. dataset = [
  105. ('exampleA', []),
  106. ('exampleB', []),
  107. ('exampleC', [])
  108. ]
  109. self.import_dataset(filename, file_format)
  110. self.assert_examples(dataset)
  111. def test_wrong_jsonl(self):
  112. filename = 'text_classification/example.json'
  113. file_format = 'JSONL'
  114. response = self.import_dataset(filename, file_format)
  115. self.assert_parse_error(response)
  116. def test_wrong_json(self):
  117. filename = 'text_classification/example.jsonl'
  118. file_format = 'JSON'
  119. response = self.import_dataset(filename, file_format)
  120. self.assert_parse_error(response)
  121. def test_wrong_excel(self):
  122. filename = 'text_classification/example.jsonl'
  123. file_format = 'Excel'
  124. response = self.import_dataset(filename, file_format)
  125. self.assert_parse_error(response)
  126. def test_wrong_csv(self):
  127. filename = 'text_classification/example.jsonl'
  128. file_format = 'CSV'
  129. response = self.import_dataset(filename, file_format)
  130. self.assert_parse_error(response)
  131. class TestImportSequenceLabelingData(TestImportData):
  132. task = SEQUENCE_LABELING
  133. def assert_examples(self, dataset):
  134. self.assertEqual(Example.objects.count(), len(dataset))
  135. for text, expected_labels in dataset:
  136. example = Example.objects.get(text=text)
  137. labels = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
  138. self.assertEqual(labels, expected_labels)
  139. def assert_parse_error(self, response):
  140. self.assertGreaterEqual(len(response['error']), 1)
  141. self.assertEqual(Example.objects.count(), 0)
  142. self.assertEqual(SpanType.objects.count(), 0)
  143. self.assertEqual(Span.objects.count(), 0)
  144. def test_jsonl(self):
  145. filename = 'sequence_labeling/example.jsonl'
  146. file_format = 'JSONL'
  147. dataset = [
  148. ('exampleA', [[0, 1, 'LOC']]),
  149. ('exampleB', [])
  150. ]
  151. self.import_dataset(filename, file_format)
  152. self.assert_examples(dataset)
  153. def test_conll(self):
  154. filename = 'sequence_labeling/example.conll'
  155. file_format = 'CoNLL'
  156. dataset = [
  157. ('JAPAN GET', [[0, 5, 'LOC']]),
  158. ('Nadim Ladki', [[0, 11, 'PER']])
  159. ]
  160. self.import_dataset(filename, file_format)
  161. self.assert_examples(dataset)
  162. def test_wrong_conll(self):
  163. filename = 'sequence_labeling/example.jsonl'
  164. file_format = 'CoNLL'
  165. response = self.import_dataset(filename, file_format)
  166. self.assert_parse_error(response)
  167. def test_jsonl_with_overlapping(self):
  168. filename = 'sequence_labeling/example_overlapping.jsonl'
  169. file_format = 'JSONL'
  170. response = self.import_dataset(filename, file_format)
  171. self.assertEqual(len(response['error']), 1)
  172. class TestImportSeq2seqData(TestImportData):
  173. task = SEQ2SEQ
  174. def assert_examples(self, dataset):
  175. self.assertEqual(Example.objects.count(), len(dataset))
  176. for text, expected_labels in dataset:
  177. example = Example.objects.get(text=text)
  178. labels = set(text_label.text for text_label in example.texts.all())
  179. self.assertEqual(labels, set(expected_labels))
  180. def test_jsonl(self):
  181. filename = 'seq2seq/example.jsonl'
  182. file_format = 'JSONL'
  183. dataset = [
  184. ('exampleA', ['label1']),
  185. ('exampleB', [])
  186. ]
  187. self.import_dataset(filename, file_format)
  188. self.assert_examples(dataset)
  189. def test_json(self):
  190. filename = 'seq2seq/example.json'
  191. file_format = 'JSON'
  192. dataset = [
  193. ('exampleA', ['label1']),
  194. ('exampleB', [])
  195. ]
  196. self.import_dataset(filename, file_format)
  197. self.assert_examples(dataset)
  198. def test_csv(self):
  199. filename = 'seq2seq/example.csv'
  200. file_format = 'CSV'
  201. dataset = [
  202. ('exampleA', ['label1']),
  203. ('exampleB', [])
  204. ]
  205. self.import_dataset(filename, file_format)
  206. self.assert_examples(dataset)
  207. class TextImportIntentDetectionAndSlotFillingData(TestImportData):
  208. task = INTENT_DETECTION_AND_SLOT_FILLING
  209. def assert_examples(self, dataset):
  210. self.assertEqual(Example.objects.count(), len(dataset))
  211. for text, expected_labels in dataset:
  212. example = Example.objects.get(text=text)
  213. cats = set(cat.label.text for cat in example.categories.all())
  214. entities = [(span.start_offset, span.end_offset, span.label.text) for span in example.spans.all()]
  215. self.assertEqual(cats, set(expected_labels['cats']))
  216. self.assertEqual(entities, expected_labels['entities'])
  217. def test_entities_and_cats(self):
  218. filename = 'intent/example.jsonl'
  219. file_format = 'JSONL'
  220. dataset = [
  221. ('exampleA', {'cats': ['positive'], 'entities': [(0, 1, 'LOC')]}),
  222. ('exampleB', {'cats': ['positive'], 'entities': []}),
  223. ('exampleC', {'cats': [], 'entities': [(0, 1, 'LOC')]}),
  224. ('exampleD', {'cats': [], 'entities': []}),
  225. ]
  226. self.import_dataset(filename, file_format)
  227. self.assert_examples(dataset)