From 993e4f655b06eaf0f8690101ce736684f2f4fde9 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Wed, 18 May 2022 08:03:38 +0900 Subject: [PATCH] Add labels test cases for data_import --- backend/data_import/pipeline/labels.py | 42 +++---- backend/data_import/tests/test_labels.py | 144 +++++++++++++++++++++++ 2 files changed, 162 insertions(+), 24 deletions(-) create mode 100644 backend/data_import/tests/test_labels.py diff --git a/backend/data_import/pipeline/labels.py b/backend/data_import/pipeline/labels.py index 166ce40c..a2d3a5b2 100644 --- a/backend/data_import/pipeline/labels.py +++ b/backend/data_import/pipeline/labels.py @@ -8,6 +8,7 @@ from .label import Label from .label_types import LabelTypes from examples.models import Example from labels.models import Category as CategoryModel +from labels.models import Label as LabelModel from labels.models import Relation as RelationModel from labels.models import Span as SpanModel from labels.models import TextLabel as TextLabelModel @@ -15,10 +16,15 @@ from projects.models import Project class Labels(abc.ABC): + label_model = LabelModel + def __init__(self, labels: List[Label], types: LabelTypes): self.labels = labels self.types = types + def __len__(self) -> int: + return len(self.labels) + def clean(self, project: Project): pass @@ -34,27 +40,26 @@ class Labels(abc.ABC): examples = Example.objects.filter(uuid__in=example_uuids) return {example.uuid: example for example in examples} - @abc.abstractmethod def save(self, user, **kwargs): - raise NotImplementedError() + labels = [ + label.create(user, self.uuid_to_example[label.example_uuid], self.types, **kwargs) for label in self.labels + ] + self.label_model.objects.bulk_create(labels) class Categories(Labels): + label_model = CategoryModel + def clean(self, project: Project): exclusive = getattr(project, "single_class_classification", False) if exclusive: groups = groupby(self.labels, lambda label: label.example_uuid) self.labels = [next(group) for _, group in groups] - def save(self, user, **kwargs): - uuid_to_example = self.uuid_to_example - categories = [ - category.create(user, uuid_to_example[category.example_uuid], self.types) for category in self.labels - ] - CategoryModel.objects.bulk_create(categories) - class Spans(Labels): + label_model = SpanModel + def clean(self, project: Project): allow_overlapping = getattr(project, "allow_overlapping", False) if allow_overlapping: @@ -68,11 +73,6 @@ class Spans(Labels): spans.append(label) self.labels = spans - def save(self, user, **kwargs): - uuid_to_example = self.uuid_to_example - spans = [span.create(user, uuid_to_example[span.example_uuid], self.types) for span in self.labels] - SpanModel.objects.bulk_create(spans) - @property def id_to_span(self) -> Dict[int, SpanModel]: span_uuids = [str(label.uuid) for label in self.labels] @@ -82,18 +82,12 @@ class Spans(Labels): class Texts(Labels): - def save(self, user, **kwargs): - uuid_to_example = self.uuid_to_example - texts = [text.create(user, uuid_to_example[text.example_uuid], self.types) for text in self.labels] - TextLabelModel.objects.bulk_create(texts) + label_model = TextLabelModel class Relations(Labels): + label_model = RelationModel + def save(self, user, **kwargs): id_to_span = kwargs["spans"].id_to_span - uuid_to_example = self.uuid_to_example - relations = [ - relation.create(user, uuid_to_example[relation.example_uuid], self.types, id_to_span=id_to_span) - for relation in self.labels - ] - RelationModel.objects.bulk_create(relations) + super().save(user, id_to_span=id_to_span) diff --git a/backend/data_import/tests/test_labels.py b/backend/data_import/tests/test_labels.py new file mode 100644 index 00000000..4c71c71b --- /dev/null +++ b/backend/data_import/tests/test_labels.py @@ -0,0 +1,144 @@ +import uuid +from unittest.mock import MagicMock + +from django.test import TestCase +from model_mommy import mommy + +from data_import.models import DummyLabelType +from data_import.pipeline.label import ( + CategoryLabel, + RelationLabel, + SpanLabel, + TextLabel, +) +from data_import.pipeline.label_types import LabelTypes +from data_import.pipeline.labels import Categories, Relations, Spans, Texts +from label_types.models import CategoryType, RelationType, SpanType +from labels.models import Category, Relation, Span +from labels.models import TextLabel as TextLabelModel +from projects.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING +from projects.tests.utils import prepare_project + + +class TestCategories(TestCase): + def setUp(self): + self.types = LabelTypes(CategoryType) + self.project = prepare_project(DOCUMENT_CLASSIFICATION) + self.user = self.project.admin + example_uuid = uuid.uuid4() + labels = [ + CategoryLabel(example_uuid=example_uuid, label="A"), + CategoryLabel(example_uuid=example_uuid, label="B"), + ] + mommy.make("Example", project=self.project.item, uuid=example_uuid) + self.categories = Categories(labels, self.types) + + def test_clean(self): + self.categories.clean(self.project.item) + self.assertEqual(len(self.categories), 2) + + def test_clean_with_exclusive_labels(self): + self.project.item.single_class_classification = True + self.project.item.save() + self.categories.clean(self.project.item) + self.assertEqual(len(self.categories), 1) + + def test_save(self): + self.categories.save_types(self.project.item) + self.categories.save(self.user) + self.assertEqual(Category.objects.count(), 2) + + def test_save_types(self): + self.categories.save_types(self.project.item) + self.assertEqual(CategoryType.objects.count(), 2) + + +class TestSpans(TestCase): + def setUp(self): + self.types = LabelTypes(SpanType) + self.project = prepare_project(SEQUENCE_LABELING, allow_overlapping=True) + self.user = self.project.admin + example_uuid = uuid.uuid4() + labels = [ + SpanLabel(example_uuid=example_uuid, label="A", start_offset=0, end_offset=1), + SpanLabel(example_uuid=example_uuid, label="B", start_offset=0, end_offset=3), + SpanLabel(example_uuid=example_uuid, label="B", start_offset=3, end_offset=4), + ] + mommy.make("Example", project=self.project.item, uuid=example_uuid) + self.spans = Spans(labels, self.types) + + def test_clean(self): + self.project.item.allow_overlapping = False + self.project.item.save() + self.spans.clean(self.project.item) + self.assertEqual(len(self.spans), 2) + + def test_clean_with_overlapping(self): + self.spans.clean(self.project.item) + self.assertEqual(len(self.spans), 3) + + def test_save(self): + self.spans.save_types(self.project.item) + self.spans.save(self.user) + self.assertEqual(Span.objects.count(), 3) + + def test_save_types(self): + self.spans.save_types(self.project.item) + self.assertEqual(SpanType.objects.count(), 2) + + +class TestTexts(TestCase): + def setUp(self): + self.types = LabelTypes(DummyLabelType) + self.project = prepare_project(SEQUENCE_LABELING) + self.user = self.project.admin + example_uuid = uuid.uuid4() + labels = [ + TextLabel(example_uuid=example_uuid, text="A"), + TextLabel(example_uuid=example_uuid, text="B"), + ] + mommy.make("Example", project=self.project.item, uuid=example_uuid) + self.texts = Texts(labels, self.types) + + def test_clean(self): + self.texts.clean(self.project.item) + self.assertEqual(len(self.texts), 2) + + def test_save(self): + self.texts.save_types(self.project.item) + self.texts.save(self.user) + self.assertEqual(TextLabelModel.objects.count(), 2) + + def test_save_types(self): + # nothing happen + self.texts.save_types(self.project.item) + + +class TestRelations(TestCase): + def setUp(self): + self.types = LabelTypes(RelationType) + self.project = prepare_project(SEQUENCE_LABELING, use_relation=True) + self.user = self.project.admin + example_uuid = uuid.uuid4() + example = mommy.make("Example", project=self.project.item, uuid=example_uuid) + from_span = mommy.make("Span", example=example, start_offset=0, end_offset=1) + to_span = mommy.make("Span", example=example, start_offset=2, end_offset=3) + labels = [ + RelationLabel(example_uuid=example_uuid, type="A", from_id=from_span.id, to_id=to_span.id), + ] + self.relations = Relations(labels, self.types) + self.spans = MagicMock() + self.spans.id_to_span = {from_span.id: from_span, to_span.id: to_span} + + def test_clean(self): + self.relations.clean(self.project.item) + self.assertEqual(len(self.relations), 1) + + def test_save(self): + self.relations.save_types(self.project.item) + self.relations.save(self.user, spans=self.spans) + self.assertEqual(Relation.objects.count(), 1) + + def test_save_types(self): + self.relations.save_types(self.project.item) + self.assertEqual(RelationType.objects.count(), 1)