You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 lines
5.4 KiB

5 years ago
  1. import io
  2. from django.test import TestCase
  3. from seqeval.metrics.sequence_labeling import get_entities
  4. from ..exceptions import FileParseException
  5. from ..models import Document, Label
  6. from ..utils import (AudioParser, BaseStorage, ClassificationStorage,
  7. CoNLLParser, Seq2seqStorage, SequenceLabelingStorage,
  8. iterable_to_io)
  9. class TestBaseStorage(TestCase):
  10. def test_extract_label(self):
  11. data = [{'labels': ['positive']}, {'labels': ['negative']}]
  12. actual = BaseStorage.extract_label(data)
  13. self.assertEqual(actual, [['positive'], ['negative']])
  14. def test_exclude_created_labels(self):
  15. labels = ['positive', 'negative']
  16. created = {'positive': Label(text='positive')}
  17. actual = BaseStorage.exclude_created_labels(labels, created)
  18. self.assertEqual(actual, ['negative'])
  19. def test_to_serializer_format(self):
  20. labels = ['positive']
  21. created = {}
  22. actual = BaseStorage.to_serializer_format(labels, created)
  23. self.assertEqual(len(actual), 1)
  24. self.assertEqual(actual[0]['text'], 'positive')
  25. self.assertIsNone(actual[0]['prefix_key'])
  26. self.assertEqual(actual[0]['suffix_key'], 'p')
  27. self.assertIsNotNone(actual[0]['background_color'])
  28. self.assertIsNotNone(actual[0]['text_color'])
  29. def test_get_shortkey_without_existing_shortkey(self):
  30. label = 'positive'
  31. created = {}
  32. actual = BaseStorage.get_shortkey(label, created)
  33. self.assertEqual(actual, ('p', None))
  34. def test_get_shortkey_with_existing_shortkey(self):
  35. label = 'positive'
  36. created = {('p', None)}
  37. actual = BaseStorage.get_shortkey(label, created)
  38. self.assertEqual(actual, ('p', 'ctrl'))
  39. def test_update_saved_labels(self):
  40. saved = {'positive': Label(text='positive', text_color='#000000')}
  41. new = [Label(text='positive', text_color='#ffffff')]
  42. actual = BaseStorage.update_saved_labels(saved, new)
  43. self.assertEqual(actual['positive'].text_color, '#ffffff')
  44. class TestClassificationStorage(TestCase):
  45. def test_extract_unique_labels(self):
  46. labels = [['positive'], ['positive', 'negative'], ['negative']]
  47. actual = ClassificationStorage.extract_unique_labels(labels)
  48. self.assertCountEqual(actual, ['positive', 'negative'])
  49. def test_make_annotations(self):
  50. docs = [Document(text='a', id=1), Document(text='b', id=2), Document(text='c', id=3)]
  51. labels = [['positive'], ['positive', 'negative'], ['negative']]
  52. saved_labels = {'positive': Label(text='positive', id=1), 'negative': Label(text='negative', id=2)}
  53. actual = ClassificationStorage.make_annotations(docs, labels, saved_labels)
  54. self.assertCountEqual(actual, [
  55. {'document': 1, 'label': 1},
  56. {'document': 2, 'label': 1},
  57. {'document': 2, 'label': 2},
  58. {'document': 3, 'label': 2},
  59. ])
  60. class TestSequenceLabelingStorage(TestCase):
  61. def test_extract_unique_labels(self):
  62. labels = [[[0, 1, 'LOC']], [[3, 4, 'ORG']]]
  63. actual = SequenceLabelingStorage.extract_unique_labels(labels)
  64. self.assertCountEqual(actual, ['LOC', 'ORG'])
  65. def test_make_annotations(self):
  66. docs = [Document(text='a', id=1), Document(text='b', id=2)]
  67. labels = [[[0, 1, 'LOC']], [[3, 4, 'ORG']]]
  68. saved_labels = {'LOC': Label(text='LOC', id=1), 'ORG': Label(text='ORG', id=2)}
  69. actual = SequenceLabelingStorage.make_annotations(docs, labels, saved_labels)
  70. self.assertEqual(actual, [
  71. {'document': 1, 'label': 1, 'start_offset': 0, 'end_offset': 1},
  72. {'document': 2, 'label': 2, 'start_offset': 3, 'end_offset': 4},
  73. ])
  74. class TestSeq2seqStorage(TestCase):
  75. def test_make_annotations(self):
  76. docs = [Document(text='a', id=1), Document(text='b', id=2)]
  77. labels = [['Hello!'], ['How are you?', "What's up?"]]
  78. actual = Seq2seqStorage.make_annotations(docs, labels)
  79. self.assertEqual(actual, [
  80. {'document': 1, 'text': 'Hello!'},
  81. {'document': 2, 'text': 'How are you?'},
  82. {'document': 2, 'text': "What's up?"},
  83. ])
  84. class TestCoNLLParser(TestCase):
  85. def test_calc_char_offset(self):
  86. f = io.BytesIO(
  87. b"EU\tORG\n"
  88. b"rejects\t_\n"
  89. b"German\tMISC\n"
  90. b"call\t_\n"
  91. )
  92. actual = next(CoNLLParser().parse(f))[0]
  93. self.assertEqual(actual, {
  94. 'text': 'EU rejects German call',
  95. 'labels': [[0, 2, 'ORG'], [11, 17, 'MISC']]
  96. })
  97. class TestAudioParser(TestCase):
  98. def test_parse_mp3(self):
  99. f = io.BytesIO(b'...')
  100. f.name = 'test.mp3'
  101. actual = next(AudioParser().parse(f))
  102. self.assertEqual(actual, [{
  103. 'audio': 'data:audio/mpeg;base64,Li4u',
  104. 'meta': '{"filename": "test.mp3"}',
  105. }])
  106. def test_parse_unknown(self):
  107. f = io.BytesIO(b'...')
  108. f.name = 'unknown.unknown'
  109. with self.assertRaises(FileParseException):
  110. next(AudioParser().parse(f))
  111. class TestIterableToIO(TestCase):
  112. def test(self):
  113. def iterable():
  114. yield b'fo'
  115. yield b'o\nbar\n'
  116. yield b'baz\nrest'
  117. stream = iterable_to_io(iterable())
  118. stream = io.TextIOWrapper(stream)
  119. self.assertEqual(stream.readlines(), ['foo\n', 'bar\n', 'baz\n', 'rest'])