mirror of https://github.com/doccano/doccano.git
2 changed files with 324 additions and 0 deletions
Split View
Diff Options
@ -0,0 +1,155 @@ |
|||
import abc |
|||
import uuid |
|||
from typing import Any, Optional |
|||
|
|||
import pydantic.error_wrappers |
|||
from pydantic import UUID4, BaseModel, validator |
|||
|
|||
from .label_types import LabelTypes |
|||
from examples.models import Example |
|||
from label_types.models import CategoryType, LabelType, RelationType, SpanType |
|||
from labels.models import Category as CategoryModel |
|||
from labels.models import Label as LabelModel |
|||
from labels.models import Relation as RelationModel |
|||
from labels.models import Span as SpanModel |
|||
from labels.models import TextLabel as TextLabelModel |
|||
from projects.models import Project |
|||
|
|||
|
|||
class Label(BaseModel, abc.ABC): |
|||
id: int = -1 |
|||
uuid: UUID4 |
|||
example_uuid: UUID4 |
|||
|
|||
def __init__(self, **data): |
|||
data["uuid"] = uuid.uuid4() |
|||
super().__init__(**data) |
|||
|
|||
@abc.abstractmethod |
|||
def __lt__(self, other): |
|||
raise NotImplementedError() |
|||
|
|||
@classmethod |
|||
def parse(cls, example_uuid: UUID4, obj: Any): |
|||
raise NotImplementedError() |
|||
|
|||
@abc.abstractmethod |
|||
def create_type(self, project: Project) -> Optional[LabelType]: |
|||
raise NotImplementedError() |
|||
|
|||
@abc.abstractmethod |
|||
def create(self, user, example: Example, types: LabelTypes, **kwargs) -> LabelModel: |
|||
raise NotImplementedError |
|||
|
|||
def __hash__(self): |
|||
return hash(tuple(self.dict())) |
|||
|
|||
|
|||
class CategoryLabel(Label): |
|||
label: str |
|||
|
|||
def __lt__(self, other): |
|||
return self.label < other.label |
|||
|
|||
@validator("label") |
|||
def label_is_not_empty(cls, value: str): |
|||
if value: |
|||
return value |
|||
else: |
|||
raise ValueError("is not empty.") |
|||
|
|||
@classmethod |
|||
def parse(cls, example_uuid: UUID4, obj: Any): |
|||
try: |
|||
return cls(example_uuid=example_uuid, label=obj) |
|||
except pydantic.error_wrappers.ValidationError: |
|||
return None |
|||
|
|||
def create_type(self, project: Project) -> Optional[LabelType]: |
|||
return CategoryType(text=self.label, project=project) |
|||
|
|||
def create(self, user, example: Example, types: LabelTypes, **kwargs): |
|||
return CategoryModel(uuid=self.uuid, user=user, example=example, label=types.get_by_text(self.label)) |
|||
|
|||
|
|||
class SpanLabel(Label): |
|||
label: str |
|||
start_offset: int |
|||
end_offset: int |
|||
|
|||
def __lt__(self, other): |
|||
return self.start_offset < other.start_offset |
|||
|
|||
@classmethod |
|||
def parse(cls, example_uuid: UUID4, obj: Any): |
|||
try: |
|||
if isinstance(obj, list) or isinstance(obj, tuple): |
|||
columns = ["start_offset", "end_offset", "label"] |
|||
obj = zip(columns, obj) |
|||
return cls(example_uuid=example_uuid, **dict(obj)) |
|||
elif isinstance(obj, dict): |
|||
return cls(example_uuid=example_uuid, **obj) |
|||
except pydantic.error_wrappers.ValidationError: |
|||
return None |
|||
|
|||
def create_type(self, project: Project) -> Optional[LabelType]: |
|||
return SpanType(text=self.label, project=project) |
|||
|
|||
def create(self, user, example: Example, types: LabelTypes, **kwargs): |
|||
return SpanModel( |
|||
uuid=self.uuid, |
|||
user=user, |
|||
example=example, |
|||
start_offset=self.start_offset, |
|||
end_offset=self.end_offset, |
|||
label=types.get_by_text(self.label), |
|||
) |
|||
|
|||
|
|||
class TextLabel(Label): |
|||
text: str |
|||
|
|||
def __lt__(self, other): |
|||
return self.text < other.text |
|||
|
|||
@classmethod |
|||
def parse(cls, example_uuid: UUID4, obj: Any): |
|||
try: |
|||
return cls(example_uuid=example_uuid, text=obj) |
|||
except pydantic.error_wrappers.ValidationError: |
|||
return None |
|||
|
|||
def create_type(self, project: Project) -> Optional[LabelType]: |
|||
return None |
|||
|
|||
def create(self, user, example: Example, types: LabelTypes, **kwargs): |
|||
return TextLabelModel(uuid=self.uuid, user=user, example=example, text=self.text) |
|||
|
|||
|
|||
class RelationLabel(Label): |
|||
from_id: int |
|||
to_id: int |
|||
type: str |
|||
|
|||
def __lt__(self, other): |
|||
return self.from_id < other.from_id |
|||
|
|||
@classmethod |
|||
def parse(cls, example_uuid: UUID4, obj: Any): |
|||
try: |
|||
return cls(example_uuid=example_uuid, **obj) |
|||
except pydantic.error_wrappers.ValidationError: |
|||
return None |
|||
|
|||
def create_type(self, project: Project) -> Optional[LabelType]: |
|||
return RelationType(text=self.type, project=project) |
|||
|
|||
def create(self, user, example: Example, types: LabelTypes, **kwargs): |
|||
return RelationModel( |
|||
uuid=self.uuid, |
|||
user=user, |
|||
example=example, |
|||
type=types.get_by_text(self.type), |
|||
from_id=kwargs["id_to_span"][self.from_id], |
|||
to_id=kwargs["id_to_span"][self.to_id], |
|||
) |
@ -0,0 +1,169 @@ |
|||
import uuid |
|||
from unittest.mock import MagicMock |
|||
|
|||
from django.test import TestCase |
|||
from model_mommy import mommy |
|||
|
|||
from data_import.pipeline.label import ( |
|||
CategoryLabel, |
|||
RelationLabel, |
|||
SpanLabel, |
|||
TextLabel, |
|||
) |
|||
from label_types.models import CategoryType, RelationType, SpanType |
|||
from labels.models import Category as CategoryModel |
|||
from labels.models import Relation as RelationModel |
|||
from labels.models import Span as SpanModel |
|||
from labels.models import TextLabel as TextModel |
|||
from projects.models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING |
|||
from projects.tests.utils import prepare_project |
|||
|
|||
|
|||
class TestLabel(TestCase): |
|||
task = "Any" |
|||
|
|||
def setUp(self): |
|||
self.project = prepare_project(self.task) |
|||
self.user = self.project.admin |
|||
self.example = mommy.make("Example", project=self.project.item) |
|||
|
|||
|
|||
class TestCategoryLabel(TestLabel): |
|||
task = DOCUMENT_CLASSIFICATION |
|||
|
|||
def test_comparison(self): |
|||
category1 = CategoryLabel(label="A", example_uuid=uuid.uuid4()) |
|||
category2 = CategoryLabel(label="B", example_uuid=uuid.uuid4()) |
|||
self.assertLess(category1, category2) |
|||
|
|||
def test_empty_label_raises_value_error(self): |
|||
with self.assertRaises(ValueError): |
|||
CategoryLabel(label="", example_uuid=uuid.uuid4()) |
|||
|
|||
def test_parse(self): |
|||
example_uuid = uuid.uuid4() |
|||
category = CategoryLabel.parse(example_uuid, obj="A") |
|||
self.assertEqual(category.label, "A") |
|||
self.assertEqual(category.example_uuid, example_uuid) |
|||
|
|||
def test_create_type(self): |
|||
category = CategoryLabel(label="A", example_uuid=uuid.uuid4()) |
|||
category_type = category.create_type(self.project.item) |
|||
self.assertIsInstance(category_type, CategoryType) |
|||
self.assertEqual(category_type.text, "A") |
|||
|
|||
def test_create(self): |
|||
category = CategoryLabel(label="A", example_uuid=uuid.uuid4()) |
|||
types = MagicMock() |
|||
types.get_by_text.return_value = mommy.make(CategoryType, project=self.project.item) |
|||
category_model = category.create(self.user, self.example, types) |
|||
self.assertIsInstance(category_model, CategoryModel) |
|||
|
|||
|
|||
class TestSpanLabel(TestLabel): |
|||
task = SEQUENCE_LABELING |
|||
|
|||
def test_comparison(self): |
|||
span1 = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) |
|||
span2 = SpanLabel(label="A", start_offset=1, end_offset=1, example_uuid=uuid.uuid4()) |
|||
self.assertLess(span1, span2) |
|||
|
|||
def test_parse_tuple(self): |
|||
example_uuid = uuid.uuid4() |
|||
span = SpanLabel.parse(example_uuid, obj=(0, 1, "A")) |
|||
self.assertEqual(span.label, "A") |
|||
self.assertEqual(span.start_offset, 0) |
|||
self.assertEqual(span.end_offset, 1) |
|||
|
|||
def test_parse_dict(self): |
|||
example_uuid = uuid.uuid4() |
|||
span = SpanLabel.parse(example_uuid, obj={"label": "A", "start_offset": 0, "end_offset": 1}) |
|||
self.assertEqual(span.label, "A") |
|||
self.assertEqual(span.start_offset, 0) |
|||
self.assertEqual(span.end_offset, 1) |
|||
|
|||
def test_parse_invalid_dict(self): |
|||
example_uuid = uuid.uuid4() |
|||
span = SpanLabel.parse(example_uuid, obj={"label": "A", "start_offset": 0}) |
|||
self.assertEqual(span, None) |
|||
|
|||
def test_create_type(self): |
|||
span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) |
|||
span_type = span.create_type(self.project.item) |
|||
self.assertIsInstance(span_type, SpanType) |
|||
self.assertEqual(span_type.text, "A") |
|||
|
|||
def test_create(self): |
|||
span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) |
|||
types = MagicMock() |
|||
types.get_by_text.return_value = mommy.make(SpanType, project=self.project.item) |
|||
span_model = span.create(self.user, self.example, types) |
|||
self.assertIsInstance(span_model, SpanModel) |
|||
|
|||
|
|||
class TestTextLabel(TestLabel): |
|||
task = SEQ2SEQ |
|||
|
|||
def test_comparison(self): |
|||
text1 = TextLabel(text="A", example_uuid=uuid.uuid4()) |
|||
text2 = TextLabel(text="B", example_uuid=uuid.uuid4()) |
|||
self.assertLess(text1, text2) |
|||
|
|||
def test_parse(self): |
|||
example_uuid = uuid.uuid4() |
|||
text = TextLabel.parse(example_uuid, obj="A") |
|||
self.assertEqual(text.text, "A") |
|||
|
|||
def test_parse_invalid_data(self): |
|||
example_uuid = uuid.uuid4() |
|||
text = TextLabel.parse(example_uuid, obj=[]) |
|||
self.assertEqual(text, None) |
|||
|
|||
def test_create_type(self): |
|||
text = TextLabel(text="A", example_uuid=uuid.uuid4()) |
|||
text_type = text.create_type(self.project.item) |
|||
self.assertEqual(text_type, None) |
|||
|
|||
def test_create(self): |
|||
text = TextLabel(text="A", example_uuid=uuid.uuid4()) |
|||
types = MagicMock() |
|||
text_model = text.create(self.user, self.example, types) |
|||
self.assertIsInstance(text_model, TextModel) |
|||
|
|||
|
|||
class TestRelationLabel(TestLabel): |
|||
task = SEQUENCE_LABELING |
|||
|
|||
def test_comparison(self): |
|||
relation1 = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) |
|||
relation2 = RelationLabel(type="A", from_id=1, to_id=1, example_uuid=uuid.uuid4()) |
|||
self.assertLess(relation1, relation2) |
|||
|
|||
def test_parse(self): |
|||
example_uuid = uuid.uuid4() |
|||
relation = RelationLabel.parse(example_uuid, obj={"type": "A", "from_id": 0, "to_id": 1}) |
|||
self.assertEqual(relation.type, "A") |
|||
self.assertEqual(relation.from_id, 0) |
|||
self.assertEqual(relation.to_id, 1) |
|||
|
|||
def test_parse_invalid_data(self): |
|||
example_uuid = uuid.uuid4() |
|||
relation = RelationLabel.parse(example_uuid, obj={"type": "A", "from_id": 0}) |
|||
self.assertEqual(relation, None) |
|||
|
|||
def test_create_type(self): |
|||
relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) |
|||
relation_type = relation.create_type(self.project.item) |
|||
self.assertIsInstance(relation_type, RelationType) |
|||
self.assertEqual(relation_type.text, "A") |
|||
|
|||
def test_create(self): |
|||
relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) |
|||
types = MagicMock() |
|||
types.get_by_text.return_value = mommy.make(RelationType, project=self.project.item) |
|||
id_to_span = { |
|||
0: mommy.make(SpanModel, start_offset=0, end_offset=1), |
|||
1: mommy.make(SpanModel, start_offset=2, end_offset=3), |
|||
} |
|||
relation_model = relation.create(self.user, self.example, types, id_to_span=id_to_span) |
|||
self.assertIsInstance(relation_model, RelationModel) |
Write
Preview
Loading…
Cancel
Save