diff --git a/backend/data_import/pipeline/label.py b/backend/data_import/pipeline/label.py new file mode 100644 index 00000000..5e8113de --- /dev/null +++ b/backend/data_import/pipeline/label.py @@ -0,0 +1,155 @@ +import abc +import uuid +from typing import Any, Optional + +import pydantic.error_wrappers +from pydantic import UUID4, BaseModel, validator + +from .label_types import LabelTypes +from examples.models import Example +from label_types.models import CategoryType, LabelType, RelationType, SpanType +from labels.models import Category as CategoryModel +from labels.models import Label as LabelModel +from labels.models import Relation as RelationModel +from labels.models import Span as SpanModel +from labels.models import TextLabel as TextLabelModel +from projects.models import Project + + +class Label(BaseModel, abc.ABC): + id: int = -1 + uuid: UUID4 + example_uuid: UUID4 + + def __init__(self, **data): + data["uuid"] = uuid.uuid4() + super().__init__(**data) + + @abc.abstractmethod + def __lt__(self, other): + raise NotImplementedError() + + @classmethod + def parse(cls, example_uuid: UUID4, obj: Any): + raise NotImplementedError() + + @abc.abstractmethod + def create_type(self, project: Project) -> Optional[LabelType]: + raise NotImplementedError() + + @abc.abstractmethod + def create(self, user, example: Example, types: LabelTypes, **kwargs) -> LabelModel: + raise NotImplementedError + + def __hash__(self): + return hash(tuple(self.dict())) + + +class CategoryLabel(Label): + label: str + + def __lt__(self, other): + return self.label < other.label + + @validator("label") + def label_is_not_empty(cls, value: str): + if value: + return value + else: + raise ValueError("is not empty.") + + @classmethod + def parse(cls, example_uuid: UUID4, obj: Any): + try: + return cls(example_uuid=example_uuid, label=obj) + except pydantic.error_wrappers.ValidationError: + return None + + def create_type(self, project: Project) -> Optional[LabelType]: + return CategoryType(text=self.label, project=project) + + def create(self, user, example: Example, types: LabelTypes, **kwargs): + return CategoryModel(uuid=self.uuid, user=user, example=example, label=types.get_by_text(self.label)) + + +class SpanLabel(Label): + label: str + start_offset: int + end_offset: int + + def __lt__(self, other): + return self.start_offset < other.start_offset + + @classmethod + def parse(cls, example_uuid: UUID4, obj: Any): + try: + if isinstance(obj, list) or isinstance(obj, tuple): + columns = ["start_offset", "end_offset", "label"] + obj = zip(columns, obj) + return cls(example_uuid=example_uuid, **dict(obj)) + elif isinstance(obj, dict): + return cls(example_uuid=example_uuid, **obj) + except pydantic.error_wrappers.ValidationError: + return None + + def create_type(self, project: Project) -> Optional[LabelType]: + return SpanType(text=self.label, project=project) + + def create(self, user, example: Example, types: LabelTypes, **kwargs): + return SpanModel( + uuid=self.uuid, + user=user, + example=example, + start_offset=self.start_offset, + end_offset=self.end_offset, + label=types.get_by_text(self.label), + ) + + +class TextLabel(Label): + text: str + + def __lt__(self, other): + return self.text < other.text + + @classmethod + def parse(cls, example_uuid: UUID4, obj: Any): + try: + return cls(example_uuid=example_uuid, text=obj) + except pydantic.error_wrappers.ValidationError: + return None + + def create_type(self, project: Project) -> Optional[LabelType]: + return None + + def create(self, user, example: Example, types: LabelTypes, **kwargs): + return TextLabelModel(uuid=self.uuid, user=user, example=example, text=self.text) + + +class RelationLabel(Label): + from_id: int + to_id: int + type: str + + def __lt__(self, other): + return self.from_id < other.from_id + + @classmethod + def parse(cls, example_uuid: UUID4, obj: Any): + try: + return cls(example_uuid=example_uuid, **obj) + except pydantic.error_wrappers.ValidationError: + return None + + def create_type(self, project: Project) -> Optional[LabelType]: + return RelationType(text=self.type, project=project) + + def create(self, user, example: Example, types: LabelTypes, **kwargs): + return RelationModel( + uuid=self.uuid, + user=user, + example=example, + type=types.get_by_text(self.type), + from_id=kwargs["id_to_span"][self.from_id], + to_id=kwargs["id_to_span"][self.to_id], + ) diff --git a/backend/data_import/tests/test_label.py b/backend/data_import/tests/test_label.py new file mode 100644 index 00000000..a68efd4e --- /dev/null +++ b/backend/data_import/tests/test_label.py @@ -0,0 +1,169 @@ +import uuid +from unittest.mock import MagicMock + +from django.test import TestCase +from model_mommy import mommy + +from data_import.pipeline.label import ( + CategoryLabel, + RelationLabel, + SpanLabel, + TextLabel, +) +from label_types.models import CategoryType, RelationType, SpanType +from labels.models import Category as CategoryModel +from labels.models import Relation as RelationModel +from labels.models import Span as SpanModel +from labels.models import TextLabel as TextModel +from projects.models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING +from projects.tests.utils import prepare_project + + +class TestLabel(TestCase): + task = "Any" + + def setUp(self): + self.project = prepare_project(self.task) + self.user = self.project.admin + self.example = mommy.make("Example", project=self.project.item) + + +class TestCategoryLabel(TestLabel): + task = DOCUMENT_CLASSIFICATION + + def test_comparison(self): + category1 = CategoryLabel(label="A", example_uuid=uuid.uuid4()) + category2 = CategoryLabel(label="B", example_uuid=uuid.uuid4()) + self.assertLess(category1, category2) + + def test_empty_label_raises_value_error(self): + with self.assertRaises(ValueError): + CategoryLabel(label="", example_uuid=uuid.uuid4()) + + def test_parse(self): + example_uuid = uuid.uuid4() + category = CategoryLabel.parse(example_uuid, obj="A") + self.assertEqual(category.label, "A") + self.assertEqual(category.example_uuid, example_uuid) + + def test_create_type(self): + category = CategoryLabel(label="A", example_uuid=uuid.uuid4()) + category_type = category.create_type(self.project.item) + self.assertIsInstance(category_type, CategoryType) + self.assertEqual(category_type.text, "A") + + def test_create(self): + category = CategoryLabel(label="A", example_uuid=uuid.uuid4()) + types = MagicMock() + types.get_by_text.return_value = mommy.make(CategoryType, project=self.project.item) + category_model = category.create(self.user, self.example, types) + self.assertIsInstance(category_model, CategoryModel) + + +class TestSpanLabel(TestLabel): + task = SEQUENCE_LABELING + + def test_comparison(self): + span1 = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) + span2 = SpanLabel(label="A", start_offset=1, end_offset=1, example_uuid=uuid.uuid4()) + self.assertLess(span1, span2) + + def test_parse_tuple(self): + example_uuid = uuid.uuid4() + span = SpanLabel.parse(example_uuid, obj=(0, 1, "A")) + self.assertEqual(span.label, "A") + self.assertEqual(span.start_offset, 0) + self.assertEqual(span.end_offset, 1) + + def test_parse_dict(self): + example_uuid = uuid.uuid4() + span = SpanLabel.parse(example_uuid, obj={"label": "A", "start_offset": 0, "end_offset": 1}) + self.assertEqual(span.label, "A") + self.assertEqual(span.start_offset, 0) + self.assertEqual(span.end_offset, 1) + + def test_parse_invalid_dict(self): + example_uuid = uuid.uuid4() + span = SpanLabel.parse(example_uuid, obj={"label": "A", "start_offset": 0}) + self.assertEqual(span, None) + + def test_create_type(self): + span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) + span_type = span.create_type(self.project.item) + self.assertIsInstance(span_type, SpanType) + self.assertEqual(span_type.text, "A") + + def test_create(self): + span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) + types = MagicMock() + types.get_by_text.return_value = mommy.make(SpanType, project=self.project.item) + span_model = span.create(self.user, self.example, types) + self.assertIsInstance(span_model, SpanModel) + + +class TestTextLabel(TestLabel): + task = SEQ2SEQ + + def test_comparison(self): + text1 = TextLabel(text="A", example_uuid=uuid.uuid4()) + text2 = TextLabel(text="B", example_uuid=uuid.uuid4()) + self.assertLess(text1, text2) + + def test_parse(self): + example_uuid = uuid.uuid4() + text = TextLabel.parse(example_uuid, obj="A") + self.assertEqual(text.text, "A") + + def test_parse_invalid_data(self): + example_uuid = uuid.uuid4() + text = TextLabel.parse(example_uuid, obj=[]) + self.assertEqual(text, None) + + def test_create_type(self): + text = TextLabel(text="A", example_uuid=uuid.uuid4()) + text_type = text.create_type(self.project.item) + self.assertEqual(text_type, None) + + def test_create(self): + text = TextLabel(text="A", example_uuid=uuid.uuid4()) + types = MagicMock() + text_model = text.create(self.user, self.example, types) + self.assertIsInstance(text_model, TextModel) + + +class TestRelationLabel(TestLabel): + task = SEQUENCE_LABELING + + def test_comparison(self): + relation1 = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) + relation2 = RelationLabel(type="A", from_id=1, to_id=1, example_uuid=uuid.uuid4()) + self.assertLess(relation1, relation2) + + def test_parse(self): + example_uuid = uuid.uuid4() + relation = RelationLabel.parse(example_uuid, obj={"type": "A", "from_id": 0, "to_id": 1}) + self.assertEqual(relation.type, "A") + self.assertEqual(relation.from_id, 0) + self.assertEqual(relation.to_id, 1) + + def test_parse_invalid_data(self): + example_uuid = uuid.uuid4() + relation = RelationLabel.parse(example_uuid, obj={"type": "A", "from_id": 0}) + self.assertEqual(relation, None) + + def test_create_type(self): + relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) + relation_type = relation.create_type(self.project.item) + self.assertIsInstance(relation_type, RelationType) + self.assertEqual(relation_type.text, "A") + + def test_create(self): + relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) + types = MagicMock() + types.get_by_text.return_value = mommy.make(RelationType, project=self.project.item) + id_to_span = { + 0: mommy.make(SpanModel, start_offset=0, end_offset=1), + 1: mommy.make(SpanModel, start_offset=2, end_offset=3), + } + relation_model = relation.create(self.user, self.example, types, id_to_span=id_to_span) + self.assertIsInstance(relation_model, RelationModel)