You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
5.4 KiB

5 years ago
  1. import io
  2. from django.test import TestCase
  3. from seqeval.metrics.sequence_labeling import get_entities
  4. from ..models import Label, Document
  5. from ..utils import BaseStorage, ClassificationStorage, SequenceLabelingStorage, Seq2seqStorage, CoNLLParser
  6. from ..utils import Color, iterable_to_io
  7. class TestColor(TestCase):
  8. def test_random_color(self):
  9. color = Color.random()
  10. self.assertTrue(0 <= color.red <= 255)
  11. self.assertTrue(0 <= color.green <= 255)
  12. self.assertTrue(0 <= color.blue <= 255)
  13. def test_hex(self):
  14. color = Color(red=255, green=192, blue=203)
  15. self.assertEqual(color.hex, '#ffc0cb')
  16. def test_contrast_color(self):
  17. color = Color(red=255, green=192, blue=203)
  18. self.assertEqual(color.contrast_color.hex, '#000000')
  19. color = Color(red=199, green=21, blue=133)
  20. self.assertEqual(color.contrast_color.hex, '#ffffff')
  21. class TestBaseStorage(TestCase):
  22. def test_extract_label(self):
  23. data = [{'labels': ['positive']}, {'labels': ['negative']}]
  24. actual = BaseStorage.extract_label(data)
  25. self.assertEqual(actual, [['positive'], ['negative']])
  26. def test_exclude_created_labels(self):
  27. labels = ['positive', 'negative']
  28. created = {'positive': Label(text='positive')}
  29. actual = BaseStorage.exclude_created_labels(labels, created)
  30. self.assertEqual(actual, ['negative'])
  31. def test_to_serializer_format(self):
  32. labels = ['positive']
  33. created = {}
  34. actual = BaseStorage.to_serializer_format(labels, created, random_seed=123)
  35. self.assertEqual(actual, [{
  36. 'text': 'positive',
  37. 'prefix_key': None,
  38. 'suffix_key': 'p',
  39. 'background_color': '#0d1668',
  40. 'text_color': '#ffffff',
  41. }])
  42. def test_get_shortkey_without_existing_shortkey(self):
  43. label = 'positive'
  44. created = {}
  45. actual = BaseStorage.get_shortkey(label, created)
  46. self.assertEqual(actual, ('p', None))
  47. def test_get_shortkey_with_existing_shortkey(self):
  48. label = 'positive'
  49. created = {('p', None)}
  50. actual = BaseStorage.get_shortkey(label, created)
  51. self.assertEqual(actual, ('p', 'ctrl'))
  52. def test_update_saved_labels(self):
  53. saved = {'positive': Label(text='positive', text_color='#000000')}
  54. new = [Label(text='positive', text_color='#ffffff')]
  55. actual = BaseStorage.update_saved_labels(saved, new)
  56. self.assertEqual(actual['positive'].text_color, '#ffffff')
  57. class TestClassificationStorage(TestCase):
  58. def test_extract_unique_labels(self):
  59. labels = [['positive'], ['positive', 'negative'], ['negative']]
  60. actual = ClassificationStorage.extract_unique_labels(labels)
  61. self.assertCountEqual(actual, ['positive', 'negative'])
  62. def test_make_annotations(self):
  63. docs = [Document(text='a', id=1), Document(text='b', id=2), Document(text='c', id=3)]
  64. labels = [['positive'], ['positive', 'negative'], ['negative']]
  65. saved_labels = {'positive': Label(text='positive', id=1), 'negative': Label(text='negative', id=2)}
  66. actual = ClassificationStorage.make_annotations(docs, labels, saved_labels)
  67. self.assertCountEqual(actual, [
  68. {'document': 1, 'label': 1},
  69. {'document': 2, 'label': 1},
  70. {'document': 2, 'label': 2},
  71. {'document': 3, 'label': 2},
  72. ])
  73. class TestSequenceLabelingStorage(TestCase):
  74. def test_extract_unique_labels(self):
  75. labels = [[[0, 1, 'LOC']], [[3, 4, 'ORG']]]
  76. actual = SequenceLabelingStorage.extract_unique_labels(labels)
  77. self.assertCountEqual(actual, ['LOC', 'ORG'])
  78. def test_make_annotations(self):
  79. docs = [Document(text='a', id=1), Document(text='b', id=2)]
  80. labels = [[[0, 1, 'LOC']], [[3, 4, 'ORG']]]
  81. saved_labels = {'LOC': Label(text='LOC', id=1), 'ORG': Label(text='ORG', id=2)}
  82. actual = SequenceLabelingStorage.make_annotations(docs, labels, saved_labels)
  83. self.assertEqual(actual, [
  84. {'document': 1, 'label': 1, 'start_offset': 0, 'end_offset': 1},
  85. {'document': 2, 'label': 2, 'start_offset': 3, 'end_offset': 4},
  86. ])
  87. class TestSeq2seqStorage(TestCase):
  88. def test_make_annotations(self):
  89. docs = [Document(text='a', id=1), Document(text='b', id=2)]
  90. labels = [['Hello!'], ['How are you?', "What's up?"]]
  91. actual = Seq2seqStorage.make_annotations(docs, labels)
  92. self.assertEqual(actual, [
  93. {'document': 1, 'text': 'Hello!'},
  94. {'document': 2, 'text': 'How are you?'},
  95. {'document': 2, 'text': "What's up?"},
  96. ])
  97. class TestCoNLLParser(TestCase):
  98. def test_calc_char_offset(self):
  99. f = io.BytesIO(
  100. b"EU\tORG\n"
  101. b"rejects\t_\n"
  102. b"German\tMISC\n"
  103. b"call\t_\n"
  104. )
  105. actual = next(CoNLLParser().parse(f))[0]
  106. self.assertEqual(actual, {
  107. 'text': 'EU rejects German call',
  108. 'labels': [[0, 2, 'ORG'], [11, 17, 'MISC']]
  109. })
  110. class TestIterableToIO(TestCase):
  111. def test(self):
  112. def iterable():
  113. yield b'fo'
  114. yield b'o\nbar\n'
  115. yield b'baz\nrest'
  116. stream = iterable_to_io(iterable())
  117. stream = io.TextIOWrapper(stream)
  118. self.assertEqual(stream.readlines(), ['foo\n', 'bar\n', 'baz\n', 'rest'])