You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

177 lines
6.8 KiB

  1. import uuid
  2. from unittest.mock import MagicMock
  3. from django.test import TestCase
  4. from model_mommy import mommy
  5. from data_import.pipeline.label import (
  6. CategoryLabel,
  7. RelationLabel,
  8. SpanLabel,
  9. TextLabel,
  10. )
  11. from label_types.models import CategoryType, RelationType, SpanType
  12. from labels.models import Category as CategoryModel
  13. from labels.models import Relation as RelationModel
  14. from labels.models import Span as SpanModel
  15. from labels.models import TextLabel as TextModel
  16. from projects.models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING
  17. from projects.tests.utils import prepare_project
  18. class TestLabel(TestCase):
  19. task = "Any"
  20. def setUp(self):
  21. self.project = prepare_project(self.task)
  22. self.user = self.project.admin
  23. self.example = mommy.make("Example", project=self.project.item)
  24. class TestCategoryLabel(TestLabel):
  25. task = DOCUMENT_CLASSIFICATION
  26. def test_comparison(self):
  27. category1 = CategoryLabel(label="A", example_uuid=uuid.uuid4())
  28. category2 = CategoryLabel(label="B", example_uuid=uuid.uuid4())
  29. self.assertLess(category1, category2)
  30. def test_empty_label_raises_value_error(self):
  31. with self.assertRaises(ValueError):
  32. CategoryLabel(label="", example_uuid=uuid.uuid4())
  33. def test_parse(self):
  34. example_uuid = uuid.uuid4()
  35. category = CategoryLabel.parse(example_uuid, obj="A")
  36. self.assertEqual(category.label, "A")
  37. self.assertEqual(category.example_uuid, example_uuid)
  38. def test_create_type(self):
  39. category = CategoryLabel(label="A", example_uuid=uuid.uuid4())
  40. category_type = category.create_type(self.project.item)
  41. self.assertIsInstance(category_type, CategoryType)
  42. self.assertEqual(category_type.text, "A")
  43. def test_create(self):
  44. category = CategoryLabel(label="A", example_uuid=uuid.uuid4())
  45. types = MagicMock()
  46. types.__getitem__.return_value = mommy.make(CategoryType, project=self.project.item)
  47. category_model = category.create(self.user, self.example, types)
  48. self.assertIsInstance(category_model, CategoryModel)
  49. class TestSpanLabel(TestLabel):
  50. task = SEQUENCE_LABELING
  51. def test_comparison(self):
  52. span1 = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
  53. span2 = SpanLabel(label="A", start_offset=1, end_offset=2, example_uuid=uuid.uuid4())
  54. self.assertLess(span1, span2)
  55. def test_parse_tuple(self):
  56. example_uuid = uuid.uuid4()
  57. span = SpanLabel.parse(example_uuid, obj=(0, 1, "A"))
  58. self.assertEqual(span.label, "A")
  59. self.assertEqual(span.start_offset, 0)
  60. self.assertEqual(span.end_offset, 1)
  61. def test_parse_dict(self):
  62. example_uuid = uuid.uuid4()
  63. span = SpanLabel.parse(example_uuid, obj={"label": "A", "start_offset": 0, "end_offset": 1})
  64. self.assertEqual(span.label, "A")
  65. self.assertEqual(span.start_offset, 0)
  66. self.assertEqual(span.end_offset, 1)
  67. def test_invalid_negative_offset(self):
  68. with self.assertRaises(ValueError):
  69. SpanLabel(label="A", start_offset=-1, end_offset=1, example_uuid=uuid.uuid4())
  70. def test_invalid_offset(self):
  71. with self.assertRaises(ValueError):
  72. SpanLabel(label="A", start_offset=1, end_offset=0, example_uuid=uuid.uuid4())
  73. def test_parse_invalid_dict(self):
  74. example_uuid = uuid.uuid4()
  75. with self.assertRaises(ValueError):
  76. SpanLabel.parse(example_uuid, obj={"label": "A", "start_offset": 0})
  77. def test_create_type(self):
  78. span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
  79. span_type = span.create_type(self.project.item)
  80. self.assertIsInstance(span_type, SpanType)
  81. self.assertEqual(span_type.text, "A")
  82. def test_create(self):
  83. span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
  84. types = MagicMock()
  85. types.__getitem__.return_value = mommy.make(SpanType, project=self.project.item)
  86. span_model = span.create(self.user, self.example, types)
  87. self.assertIsInstance(span_model, SpanModel)
  88. class TestTextLabel(TestLabel):
  89. task = SEQ2SEQ
  90. def test_comparison(self):
  91. text1 = TextLabel(text="A", example_uuid=uuid.uuid4())
  92. text2 = TextLabel(text="B", example_uuid=uuid.uuid4())
  93. self.assertLess(text1, text2)
  94. def test_parse(self):
  95. example_uuid = uuid.uuid4()
  96. text = TextLabel.parse(example_uuid, obj="A")
  97. self.assertEqual(text.text, "A")
  98. def test_parse_invalid_data(self):
  99. example_uuid = uuid.uuid4()
  100. with self.assertRaises(ValueError):
  101. TextLabel.parse(example_uuid, obj=[])
  102. def test_create_type(self):
  103. text = TextLabel(text="A", example_uuid=uuid.uuid4())
  104. text_type = text.create_type(self.project.item)
  105. self.assertEqual(text_type, None)
  106. def test_create(self):
  107. text = TextLabel(text="A", example_uuid=uuid.uuid4())
  108. types = MagicMock()
  109. text_model = text.create(self.user, self.example, types)
  110. self.assertIsInstance(text_model, TextModel)
  111. class TestRelationLabel(TestLabel):
  112. task = SEQUENCE_LABELING
  113. def test_comparison(self):
  114. relation1 = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4())
  115. relation2 = RelationLabel(type="A", from_id=1, to_id=1, example_uuid=uuid.uuid4())
  116. self.assertLess(relation1, relation2)
  117. def test_parse(self):
  118. example_uuid = uuid.uuid4()
  119. relation = RelationLabel.parse(example_uuid, obj={"type": "A", "from_id": 0, "to_id": 1})
  120. self.assertEqual(relation.type, "A")
  121. self.assertEqual(relation.from_id, 0)
  122. self.assertEqual(relation.to_id, 1)
  123. def test_parse_invalid_data(self):
  124. example_uuid = uuid.uuid4()
  125. with self.assertRaises(ValueError):
  126. RelationLabel.parse(example_uuid, obj={"type": "A", "from_id": 0})
  127. def test_create_type(self):
  128. relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4())
  129. relation_type = relation.create_type(self.project.item)
  130. self.assertIsInstance(relation_type, RelationType)
  131. self.assertEqual(relation_type.text, "A")
  132. def test_create(self):
  133. relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4())
  134. types = MagicMock()
  135. types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item)
  136. id_to_span = {
  137. 0: mommy.make(SpanModel, start_offset=0, end_offset=1),
  138. 1: mommy.make(SpanModel, start_offset=2, end_offset=3),
  139. }
  140. relation_model = relation.create(self.user, self.example, types, id_to_span=id_to_span)
  141. self.assertIsInstance(relation_model, RelationModel)