From a63a1e9751a6d064df1ab60339847fa260c6028a Mon Sep 17 00:00:00 2001 From: Hironsan Date: Tue, 17 May 2022 23:13:31 +0900 Subject: [PATCH] Add first class collection for label --- backend/data_import/pipeline/labels.py | 214 ++++++++++--------------- 1 file changed, 89 insertions(+), 125 deletions(-) diff --git a/backend/data_import/pipeline/labels.py b/backend/data_import/pipeline/labels.py index 02e64536..166ce40c 100644 --- a/backend/data_import/pipeline/labels.py +++ b/backend/data_import/pipeline/labels.py @@ -1,135 +1,99 @@ import abc -import uuid -from typing import Any, Dict, Optional - -import pydantic.error_wrappers -from pydantic import UUID4, BaseModel, validator - -from label_types.models import CategoryType, LabelType, RelationType, SpanType -from labels.models import Category -from labels.models import Label as LabelModel -from labels.models import Relation, Span -from labels.models import TextLabel as TL +from itertools import groupby +from typing import Dict, List + +from pydantic import UUID4 + +from .label import Label +from .label_types import LabelTypes +from examples.models import Example +from labels.models import Category as CategoryModel +from labels.models import Relation as RelationModel +from labels.models import Span as SpanModel +from labels.models import TextLabel as TextLabelModel from projects.models import Project -class Label(BaseModel, abc.ABC): - id: int = -1 - uuid: UUID4 +class Labels(abc.ABC): + def __init__(self, labels: List[Label], types: LabelTypes): + self.labels = labels + self.types = types - def __init__(self, **data): - data["uuid"] = uuid.uuid4() - super().__init__(**data) + def clean(self, project: Project): + pass - @classmethod - def parse(cls, obj: Any): - raise NotImplementedError() + def save_types(self, project: Project): + types = [label.create_type(project) for label in self.labels] + filtered_types = list(filter(None, types)) + self.types.save(filtered_types) + self.types.update(project) + + @property + def uuid_to_example(self) -> Dict[UUID4, Example]: + example_uuids = {str(label.example_uuid) for label in self.labels} + examples = Example.objects.filter(uuid__in=example_uuids) + return {example.uuid: example for example in examples} @abc.abstractmethod - def create_type(self, project: Project) -> Optional[LabelType]: + def save(self, user, **kwargs): raise NotImplementedError() - @abc.abstractmethod - def create(self, user, example, mapping, **kwargs) -> LabelModel: - raise NotImplementedError - - def __hash__(self): - return hash(tuple(self.dict())) - - -class CategoryLabel(Label): - label: str - - @validator("label") - def label_is_not_empty(cls, value: str): - if value: - return value - else: - raise ValueError("is not empty.") - - @classmethod - def parse(cls, obj: Any): - try: - return cls(label=obj) - except pydantic.error_wrappers.ValidationError: - return None - - def create_type(self, project: Project) -> Optional[LabelType]: - return CategoryType(text=self.label, project=project) - - def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): - return Category(uuid=self.uuid, user=user, example=example, label=mapping[self.label]) - - -class SpanLabel(Label): - label: str - start_offset: int - end_offset: int - - @classmethod - def parse(cls, obj: Any): - try: - if isinstance(obj, list) or isinstance(obj, tuple): - columns = ["start_offset", "end_offset", "label"] - obj = zip(columns, obj) - return cls.parse_obj(obj) - elif isinstance(obj, dict): - return cls.parse_obj(obj) - except pydantic.error_wrappers.ValidationError: - return None - - def create_type(self, project: Project) -> Optional[LabelType]: - return SpanType(text=self.label, project=project) - - def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): - return Span( - uuid=self.uuid, - user=user, - example=example, - start_offset=self.start_offset, - end_offset=self.end_offset, - label=mapping[self.label], - ) - - -class TextLabel(Label): - text: str - - @classmethod - def parse(cls, obj: Any): - try: - return cls(text=obj) - except pydantic.error_wrappers.ValidationError: - return None - - def create_type(self, project: Project) -> Optional[LabelType]: - return None - - def create(self, user, example, mapping, **kwargs): - return TL(uuid=self.uuid, user=user, example=example, text=self.text) - - -class RelationLabel(Label): - from_id: int - to_id: int - type: str - - @classmethod - def parse(cls, obj: Any): - try: - return cls.parse_obj(obj) - except pydantic.error_wrappers.ValidationError: - return None - - def create_type(self, project: Project) -> Optional[LabelType]: - return RelationType(text=self.type, project=project) - - def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): - return Relation( - uuid=self.uuid, - user=user, - example=example, - type=mapping[self.type], - from_id=kwargs["span_mapping"][self.from_id], - to_id=kwargs["span_mapping"][self.to_id], - ) + +class Categories(Labels): + def clean(self, project: Project): + exclusive = getattr(project, "single_class_classification", False) + if exclusive: + groups = groupby(self.labels, lambda label: label.example_uuid) + self.labels = [next(group) for _, group in groups] + + def save(self, user, **kwargs): + uuid_to_example = self.uuid_to_example + categories = [ + category.create(user, uuid_to_example[category.example_uuid], self.types) for category in self.labels + ] + CategoryModel.objects.bulk_create(categories) + + +class Spans(Labels): + def clean(self, project: Project): + allow_overlapping = getattr(project, "allow_overlapping", False) + if allow_overlapping: + return + self.labels.sort() + last_offset = -1 + spans = [] + for label in self.labels: + if getattr(label, "start_offset") >= last_offset: + last_offset = getattr(label, "end_offset") + spans.append(label) + self.labels = spans + + def save(self, user, **kwargs): + uuid_to_example = self.uuid_to_example + spans = [span.create(user, uuid_to_example[span.example_uuid], self.types) for span in self.labels] + SpanModel.objects.bulk_create(spans) + + @property + def id_to_span(self) -> Dict[int, SpanModel]: + span_uuids = [str(label.uuid) for label in self.labels] + spans = SpanModel.objects.filter(uuid__in=span_uuids) + uuid_to_span = {span.uuid: span for span in spans} + return {span.id: uuid_to_span[span.uuid] for span in self.labels} + + +class Texts(Labels): + def save(self, user, **kwargs): + uuid_to_example = self.uuid_to_example + texts = [text.create(user, uuid_to_example[text.example_uuid], self.types) for text in self.labels] + TextLabelModel.objects.bulk_create(texts) + + +class Relations(Labels): + def save(self, user, **kwargs): + id_to_span = kwargs["spans"].id_to_span + uuid_to_example = self.uuid_to_example + relations = [ + relation.create(user, uuid_to_example[relation.example_uuid], self.types, id_to_span=id_to_span) + for relation in self.labels + ] + RelationModel.objects.bulk_create(relations)