import abc from itertools import groupby from typing import Dict, List, Tuple from .examples import Examples from .label import Label from .label_types import LabelTypes from labels.models import Category as CategoryModel from labels.models import Label as LabelModel from labels.models import Relation as RelationModel from labels.models import Span as SpanModel from labels.models import TextLabel as TextLabelModel from projects.models import Project class Labels(abc.ABC): label_model = LabelModel def __init__(self, labels: List[Label], types: LabelTypes): self.labels = labels self.types = types def __len__(self) -> int: return len(self.labels) def clean(self, project: Project): pass def save_types(self, project: Project): types = [label.create_type(project) for label in self.labels] filtered_types = list(filter(None, types)) self.types.save(filtered_types) self.types.update(project) def save(self, user, examples: Examples, **kwargs): labels = [ label.create(user, examples[label.example_uuid], self.types, **kwargs) for label in self.labels if label.example_uuid in examples ] self.label_model.objects.bulk_create(labels) class Categories(Labels): label_model = CategoryModel def clean(self, project: Project): exclusive = getattr(project, "single_class_classification", False) if exclusive: groups = groupby(self.labels, lambda label: label.example_uuid) self.labels = [next(group) for _, group in groups] class Spans(Labels): label_model = SpanModel def clean(self, project: Project): allow_overlapping = getattr(project, "allow_overlapping", False) if allow_overlapping: return spans = [] groups = groupby(self.labels, lambda label: label.example_uuid) for _, group in groups: labels = sorted(group) last_offset = -1 for label in labels: if getattr(label, "start_offset") >= last_offset: last_offset = getattr(label, "end_offset") spans.append(label) self.labels = spans @property def id_to_span(self) -> Dict[Tuple[int, str], SpanModel]: uuids = [str(span.uuid) for span in self.labels] spans = SpanModel.objects.filter(uuid__in=uuids) uuid_to_span = {span.uuid: span for span in spans} return {(span.id, str(span.example_uuid)): uuid_to_span[span.uuid] for span in self.labels} class Texts(Labels): label_model = TextLabelModel class Relations(Labels): label_model = RelationModel def save(self, user, examples: Examples, **kwargs): id_to_span = kwargs["spans"].id_to_span super().save(user, examples, id_to_span=id_to_span)