Browse Source

Add labels test cases for data_import

pull/1823/head
Hironsan 2 years ago
parent
commit
993e4f655b
2 changed files with 162 additions and 24 deletions
  1. 42
      backend/data_import/pipeline/labels.py
  2. 144
      backend/data_import/tests/test_labels.py

42
backend/data_import/pipeline/labels.py

@ -8,6 +8,7 @@ from .label import Label
from .label_types import LabelTypes
from examples.models import Example
from labels.models import Category as CategoryModel
from labels.models import Label as LabelModel
from labels.models import Relation as RelationModel
from labels.models import Span as SpanModel
from labels.models import TextLabel as TextLabelModel
@ -15,10 +16,15 @@ from projects.models import Project
class Labels(abc.ABC):
label_model = LabelModel
def __init__(self, labels: List[Label], types: LabelTypes):
self.labels = labels
self.types = types
def __len__(self) -> int:
return len(self.labels)
def clean(self, project: Project):
pass
@ -34,27 +40,26 @@ class Labels(abc.ABC):
examples = Example.objects.filter(uuid__in=example_uuids)
return {example.uuid: example for example in examples}
@abc.abstractmethod
def save(self, user, **kwargs):
raise NotImplementedError()
labels = [
label.create(user, self.uuid_to_example[label.example_uuid], self.types, **kwargs) for label in self.labels
]
self.label_model.objects.bulk_create(labels)
class Categories(Labels):
label_model = CategoryModel
def clean(self, project: Project):
exclusive = getattr(project, "single_class_classification", False)
if exclusive:
groups = groupby(self.labels, lambda label: label.example_uuid)
self.labels = [next(group) for _, group in groups]
def save(self, user, **kwargs):
uuid_to_example = self.uuid_to_example
categories = [
category.create(user, uuid_to_example[category.example_uuid], self.types) for category in self.labels
]
CategoryModel.objects.bulk_create(categories)
class Spans(Labels):
label_model = SpanModel
def clean(self, project: Project):
allow_overlapping = getattr(project, "allow_overlapping", False)
if allow_overlapping:
@ -68,11 +73,6 @@ class Spans(Labels):
spans.append(label)
self.labels = spans
def save(self, user, **kwargs):
uuid_to_example = self.uuid_to_example
spans = [span.create(user, uuid_to_example[span.example_uuid], self.types) for span in self.labels]
SpanModel.objects.bulk_create(spans)
@property
def id_to_span(self) -> Dict[int, SpanModel]:
span_uuids = [str(label.uuid) for label in self.labels]
@ -82,18 +82,12 @@ class Spans(Labels):
class Texts(Labels):
def save(self, user, **kwargs):
uuid_to_example = self.uuid_to_example
texts = [text.create(user, uuid_to_example[text.example_uuid], self.types) for text in self.labels]
TextLabelModel.objects.bulk_create(texts)
label_model = TextLabelModel
class Relations(Labels):
label_model = RelationModel
def save(self, user, **kwargs):
id_to_span = kwargs["spans"].id_to_span
uuid_to_example = self.uuid_to_example
relations = [
relation.create(user, uuid_to_example[relation.example_uuid], self.types, id_to_span=id_to_span)
for relation in self.labels
]
RelationModel.objects.bulk_create(relations)
super().save(user, id_to_span=id_to_span)

144
backend/data_import/tests/test_labels.py

@ -0,0 +1,144 @@
import uuid
from unittest.mock import MagicMock
from django.test import TestCase
from model_mommy import mommy
from data_import.models import DummyLabelType
from data_import.pipeline.label import (
CategoryLabel,
RelationLabel,
SpanLabel,
TextLabel,
)
from data_import.pipeline.label_types import LabelTypes
from data_import.pipeline.labels import Categories, Relations, Spans, Texts
from label_types.models import CategoryType, RelationType, SpanType
from labels.models import Category, Relation, Span
from labels.models import TextLabel as TextLabelModel
from projects.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING
from projects.tests.utils import prepare_project
class TestCategories(TestCase):
def setUp(self):
self.types = LabelTypes(CategoryType)
self.project = prepare_project(DOCUMENT_CLASSIFICATION)
self.user = self.project.admin
example_uuid = uuid.uuid4()
labels = [
CategoryLabel(example_uuid=example_uuid, label="A"),
CategoryLabel(example_uuid=example_uuid, label="B"),
]
mommy.make("Example", project=self.project.item, uuid=example_uuid)
self.categories = Categories(labels, self.types)
def test_clean(self):
self.categories.clean(self.project.item)
self.assertEqual(len(self.categories), 2)
def test_clean_with_exclusive_labels(self):
self.project.item.single_class_classification = True
self.project.item.save()
self.categories.clean(self.project.item)
self.assertEqual(len(self.categories), 1)
def test_save(self):
self.categories.save_types(self.project.item)
self.categories.save(self.user)
self.assertEqual(Category.objects.count(), 2)
def test_save_types(self):
self.categories.save_types(self.project.item)
self.assertEqual(CategoryType.objects.count(), 2)
class TestSpans(TestCase):
def setUp(self):
self.types = LabelTypes(SpanType)
self.project = prepare_project(SEQUENCE_LABELING, allow_overlapping=True)
self.user = self.project.admin
example_uuid = uuid.uuid4()
labels = [
SpanLabel(example_uuid=example_uuid, label="A", start_offset=0, end_offset=1),
SpanLabel(example_uuid=example_uuid, label="B", start_offset=0, end_offset=3),
SpanLabel(example_uuid=example_uuid, label="B", start_offset=3, end_offset=4),
]
mommy.make("Example", project=self.project.item, uuid=example_uuid)
self.spans = Spans(labels, self.types)
def test_clean(self):
self.project.item.allow_overlapping = False
self.project.item.save()
self.spans.clean(self.project.item)
self.assertEqual(len(self.spans), 2)
def test_clean_with_overlapping(self):
self.spans.clean(self.project.item)
self.assertEqual(len(self.spans), 3)
def test_save(self):
self.spans.save_types(self.project.item)
self.spans.save(self.user)
self.assertEqual(Span.objects.count(), 3)
def test_save_types(self):
self.spans.save_types(self.project.item)
self.assertEqual(SpanType.objects.count(), 2)
class TestTexts(TestCase):
def setUp(self):
self.types = LabelTypes(DummyLabelType)
self.project = prepare_project(SEQUENCE_LABELING)
self.user = self.project.admin
example_uuid = uuid.uuid4()
labels = [
TextLabel(example_uuid=example_uuid, text="A"),
TextLabel(example_uuid=example_uuid, text="B"),
]
mommy.make("Example", project=self.project.item, uuid=example_uuid)
self.texts = Texts(labels, self.types)
def test_clean(self):
self.texts.clean(self.project.item)
self.assertEqual(len(self.texts), 2)
def test_save(self):
self.texts.save_types(self.project.item)
self.texts.save(self.user)
self.assertEqual(TextLabelModel.objects.count(), 2)
def test_save_types(self):
# nothing happen
self.texts.save_types(self.project.item)
class TestRelations(TestCase):
def setUp(self):
self.types = LabelTypes(RelationType)
self.project = prepare_project(SEQUENCE_LABELING, use_relation=True)
self.user = self.project.admin
example_uuid = uuid.uuid4()
example = mommy.make("Example", project=self.project.item, uuid=example_uuid)
from_span = mommy.make("Span", example=example, start_offset=0, end_offset=1)
to_span = mommy.make("Span", example=example, start_offset=2, end_offset=3)
labels = [
RelationLabel(example_uuid=example_uuid, type="A", from_id=from_span.id, to_id=to_span.id),
]
self.relations = Relations(labels, self.types)
self.spans = MagicMock()
self.spans.id_to_span = {from_span.id: from_span, to_span.id: to_span}
def test_clean(self):
self.relations.clean(self.project.item)
self.assertEqual(len(self.relations), 1)
def test_save(self):
self.relations.save_types(self.project.item)
self.relations.save(self.user, spans=self.spans)
self.assertEqual(Relation.objects.count(), 1)
def test_save_types(self):
self.relations.save_types(self.project.item)
self.assertEqual(RelationType.objects.count(), 1)
Loading…
Cancel
Save