You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
6.6 KiB

  1. import uuid
  2. from unittest.mock import MagicMock
  3. from django.test import TestCase
  4. from model_mommy import mommy
  5. from data_import.models import DummyLabelType
  6. from data_import.pipeline.label import (
  7. CategoryLabel,
  8. RelationLabel,
  9. SpanLabel,
  10. TextLabel,
  11. )
  12. from data_import.pipeline.label_types import LabelTypes
  13. from data_import.pipeline.labels import Categories, Relations, Spans, Texts
  14. from label_types.models import CategoryType, RelationType, SpanType
  15. from labels.models import Category, Relation, Span
  16. from labels.models import TextLabel as TextLabelModel
  17. from projects.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING
  18. from projects.tests.utils import prepare_project
  19. class TestCategories(TestCase):
  20. def setUp(self):
  21. self.types = LabelTypes(CategoryType)
  22. self.project = prepare_project(DOCUMENT_CLASSIFICATION)
  23. self.user = self.project.admin
  24. example_uuid = uuid.uuid4()
  25. labels = [
  26. CategoryLabel(example_uuid=example_uuid, label="A"),
  27. CategoryLabel(example_uuid=example_uuid, label="B"),
  28. ]
  29. example = mommy.make("Example", project=self.project.item, uuid=example_uuid)
  30. self.examples = MagicMock()
  31. self.examples.__getitem__.return_value = example
  32. self.examples.__contains__.return_value = True
  33. self.categories = Categories(labels, self.types)
  34. def test_clean(self):
  35. self.categories.clean(self.project.item)
  36. self.assertEqual(len(self.categories), 2)
  37. def test_clean_with_exclusive_labels(self):
  38. self.project.item.single_class_classification = True
  39. self.project.item.save()
  40. self.categories.clean(self.project.item)
  41. self.assertEqual(len(self.categories), 1)
  42. def test_save(self):
  43. self.categories.save_types(self.project.item)
  44. self.categories.save(self.user, self.examples)
  45. self.assertEqual(Category.objects.count(), 2)
  46. def test_save_types(self):
  47. self.categories.save_types(self.project.item)
  48. self.assertEqual(CategoryType.objects.count(), 2)
  49. class TestSpans(TestCase):
  50. def setUp(self):
  51. self.types = LabelTypes(SpanType)
  52. self.project = prepare_project(SEQUENCE_LABELING, allow_overlapping=True)
  53. self.user = self.project.admin
  54. example_uuid = uuid.uuid4()
  55. labels = [
  56. SpanLabel(example_uuid=example_uuid, label="A", start_offset=0, end_offset=1),
  57. SpanLabel(example_uuid=example_uuid, label="B", start_offset=0, end_offset=3),
  58. SpanLabel(example_uuid=example_uuid, label="B", start_offset=3, end_offset=4),
  59. ]
  60. example = mommy.make("Example", project=self.project.item, uuid=example_uuid)
  61. self.examples = MagicMock()
  62. self.examples.__getitem__.return_value = example
  63. self.examples.__contains__.return_value = True
  64. self.spans = Spans(labels, self.types)
  65. def disable_overlapping(self):
  66. self.project.item.allow_overlapping = False
  67. self.project.item.save()
  68. def test_clean(self):
  69. self.disable_overlapping()
  70. self.spans.clean(self.project.item)
  71. self.assertEqual(len(self.spans), 2)
  72. def test_clean_with_overlapping(self):
  73. self.spans.clean(self.project.item)
  74. self.assertEqual(len(self.spans), 3)
  75. def test_clean_with_multiple_examples(self):
  76. self.disable_overlapping()
  77. example_uuid1 = uuid.uuid4()
  78. example_uuid2 = uuid.uuid4()
  79. labels = [
  80. SpanLabel(example_uuid=example_uuid1, label="A", start_offset=0, end_offset=1),
  81. SpanLabel(example_uuid=example_uuid2, label="B", start_offset=0, end_offset=3),
  82. ]
  83. mommy.make("Example", project=self.project.item, uuid=example_uuid1)
  84. mommy.make("Example", project=self.project.item, uuid=example_uuid2)
  85. spans = Spans(labels, self.types)
  86. spans.clean(self.project.item)
  87. self.assertEqual(len(spans), 2)
  88. def test_save(self):
  89. self.spans.save_types(self.project.item)
  90. self.spans.save(self.user, self.examples)
  91. self.assertEqual(Span.objects.count(), 3)
  92. def test_save_types(self):
  93. self.spans.save_types(self.project.item)
  94. self.assertEqual(SpanType.objects.count(), 2)
  95. class TestTexts(TestCase):
  96. def setUp(self):
  97. self.types = LabelTypes(DummyLabelType)
  98. self.project = prepare_project(SEQUENCE_LABELING)
  99. self.user = self.project.admin
  100. example_uuid = uuid.uuid4()
  101. labels = [
  102. TextLabel(example_uuid=example_uuid, text="A"),
  103. TextLabel(example_uuid=example_uuid, text="B"),
  104. ]
  105. example = mommy.make("Example", project=self.project.item, uuid=example_uuid)
  106. self.examples = MagicMock()
  107. self.examples.__getitem__.return_value = example
  108. self.examples.__contains__.return_value = True
  109. self.texts = Texts(labels, self.types)
  110. def test_clean(self):
  111. self.texts.clean(self.project.item)
  112. self.assertEqual(len(self.texts), 2)
  113. def test_save(self):
  114. self.texts.save_types(self.project.item)
  115. self.texts.save(self.user, self.examples)
  116. self.assertEqual(TextLabelModel.objects.count(), 2)
  117. def test_save_types(self):
  118. # nothing happen
  119. self.texts.save_types(self.project.item)
  120. class TestRelations(TestCase):
  121. def setUp(self):
  122. self.types = LabelTypes(RelationType)
  123. self.project = prepare_project(SEQUENCE_LABELING, use_relation=True)
  124. self.user = self.project.admin
  125. example_uuid = uuid.uuid4()
  126. example = mommy.make("Example", project=self.project.item, uuid=example_uuid)
  127. from_span = mommy.make("Span", example=example, start_offset=0, end_offset=1)
  128. to_span = mommy.make("Span", example=example, start_offset=2, end_offset=3)
  129. labels = [
  130. RelationLabel(example_uuid=example_uuid, type="A", from_id=from_span.id, to_id=to_span.id),
  131. ]
  132. self.relations = Relations(labels, self.types)
  133. self.spans = MagicMock()
  134. self.spans.id_to_span = {from_span.id: from_span, to_span.id: to_span}
  135. self.examples = MagicMock()
  136. self.examples.__getitem__.return_value = example
  137. self.examples.__contains__.return_value = True
  138. def test_clean(self):
  139. self.relations.clean(self.project.item)
  140. self.assertEqual(len(self.relations), 1)
  141. def test_save(self):
  142. self.relations.save_types(self.project.item)
  143. self.relations.save(self.user, self.examples, spans=self.spans)
  144. self.assertEqual(Relation.objects.count(), 1)
  145. def test_save_types(self):
  146. self.relations.save_types(self.project.item)
  147. self.assertEqual(RelationType.objects.count(), 1)