Browse Source

Add first class collection for label

pull/1823/head
Hironsan 2 years ago
parent
commit
a63a1e9751
1 changed files with 89 additions and 125 deletions
  1. 214
      backend/data_import/pipeline/labels.py

214
backend/data_import/pipeline/labels.py

@ -1,135 +1,99 @@
import abc
import uuid
from typing import Any, Dict, Optional
import pydantic.error_wrappers
from pydantic import UUID4, BaseModel, validator
from label_types.models import CategoryType, LabelType, RelationType, SpanType
from labels.models import Category
from labels.models import Label as LabelModel
from labels.models import Relation, Span
from labels.models import TextLabel as TL
from itertools import groupby
from typing import Dict, List
from pydantic import UUID4
from .label import Label
from .label_types import LabelTypes
from examples.models import Example
from labels.models import Category as CategoryModel
from labels.models import Relation as RelationModel
from labels.models import Span as SpanModel
from labels.models import TextLabel as TextLabelModel
from projects.models import Project
class Label(BaseModel, abc.ABC):
id: int = -1
uuid: UUID4
class Labels(abc.ABC):
def __init__(self, labels: List[Label], types: LabelTypes):
self.labels = labels
self.types = types
def __init__(self, **data):
data["uuid"] = uuid.uuid4()
super().__init__(**data)
def clean(self, project: Project):
pass
@classmethod
def parse(cls, obj: Any):
raise NotImplementedError()
def save_types(self, project: Project):
types = [label.create_type(project) for label in self.labels]
filtered_types = list(filter(None, types))
self.types.save(filtered_types)
self.types.update(project)
@property
def uuid_to_example(self) -> Dict[UUID4, Example]:
example_uuids = {str(label.example_uuid) for label in self.labels}
examples = Example.objects.filter(uuid__in=example_uuids)
return {example.uuid: example for example in examples}
@abc.abstractmethod
def create_type(self, project: Project) -> Optional[LabelType]:
def save(self, user, **kwargs):
raise NotImplementedError()
@abc.abstractmethod
def create(self, user, example, mapping, **kwargs) -> LabelModel:
raise NotImplementedError
def __hash__(self):
return hash(tuple(self.dict()))
class CategoryLabel(Label):
label: str
@validator("label")
def label_is_not_empty(cls, value: str):
if value:
return value
else:
raise ValueError("is not empty.")
@classmethod
def parse(cls, obj: Any):
try:
return cls(label=obj)
except pydantic.error_wrappers.ValidationError:
return None
def create_type(self, project: Project) -> Optional[LabelType]:
return CategoryType(text=self.label, project=project)
def create(self, user, example, mapping: Dict[str, LabelType], **kwargs):
return Category(uuid=self.uuid, user=user, example=example, label=mapping[self.label])
class SpanLabel(Label):
label: str
start_offset: int
end_offset: int
@classmethod
def parse(cls, obj: Any):
try:
if isinstance(obj, list) or isinstance(obj, tuple):
columns = ["start_offset", "end_offset", "label"]
obj = zip(columns, obj)
return cls.parse_obj(obj)
elif isinstance(obj, dict):
return cls.parse_obj(obj)
except pydantic.error_wrappers.ValidationError:
return None
def create_type(self, project: Project) -> Optional[LabelType]:
return SpanType(text=self.label, project=project)
def create(self, user, example, mapping: Dict[str, LabelType], **kwargs):
return Span(
uuid=self.uuid,
user=user,
example=example,
start_offset=self.start_offset,
end_offset=self.end_offset,
label=mapping[self.label],
)
class TextLabel(Label):
text: str
@classmethod
def parse(cls, obj: Any):
try:
return cls(text=obj)
except pydantic.error_wrappers.ValidationError:
return None
def create_type(self, project: Project) -> Optional[LabelType]:
return None
def create(self, user, example, mapping, **kwargs):
return TL(uuid=self.uuid, user=user, example=example, text=self.text)
class RelationLabel(Label):
from_id: int
to_id: int
type: str
@classmethod
def parse(cls, obj: Any):
try:
return cls.parse_obj(obj)
except pydantic.error_wrappers.ValidationError:
return None
def create_type(self, project: Project) -> Optional[LabelType]:
return RelationType(text=self.type, project=project)
def create(self, user, example, mapping: Dict[str, LabelType], **kwargs):
return Relation(
uuid=self.uuid,
user=user,
example=example,
type=mapping[self.type],
from_id=kwargs["span_mapping"][self.from_id],
to_id=kwargs["span_mapping"][self.to_id],
)
class Categories(Labels):
def clean(self, project: Project):
exclusive = getattr(project, "single_class_classification", False)
if exclusive:
groups = groupby(self.labels, lambda label: label.example_uuid)
self.labels = [next(group) for _, group in groups]
def save(self, user, **kwargs):
uuid_to_example = self.uuid_to_example
categories = [
category.create(user, uuid_to_example[category.example_uuid], self.types) for category in self.labels
]
CategoryModel.objects.bulk_create(categories)
class Spans(Labels):
def clean(self, project: Project):
allow_overlapping = getattr(project, "allow_overlapping", False)
if allow_overlapping:
return
self.labels.sort()
last_offset = -1
spans = []
for label in self.labels:
if getattr(label, "start_offset") >= last_offset:
last_offset = getattr(label, "end_offset")
spans.append(label)
self.labels = spans
def save(self, user, **kwargs):
uuid_to_example = self.uuid_to_example
spans = [span.create(user, uuid_to_example[span.example_uuid], self.types) for span in self.labels]
SpanModel.objects.bulk_create(spans)
@property
def id_to_span(self) -> Dict[int, SpanModel]:
span_uuids = [str(label.uuid) for label in self.labels]
spans = SpanModel.objects.filter(uuid__in=span_uuids)
uuid_to_span = {span.uuid: span for span in spans}
return {span.id: uuid_to_span[span.uuid] for span in self.labels}
class Texts(Labels):
def save(self, user, **kwargs):
uuid_to_example = self.uuid_to_example
texts = [text.create(user, uuid_to_example[text.example_uuid], self.types) for text in self.labels]
TextLabelModel.objects.bulk_create(texts)
class Relations(Labels):
def save(self, user, **kwargs):
id_to_span = kwargs["spans"].id_to_span
uuid_to_example = self.uuid_to_example
relations = [
relation.create(user, uuid_to_example[relation.example_uuid], self.types, id_to_span=id_to_span)
for relation in self.labels
]
RelationModel.objects.bulk_create(relations)
Loading…
Cancel
Save