You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

98 lines
3.1 KiB

import abc
from itertools import groupby
from typing import Dict, List
from pydantic import UUID4
from .label import Label
from .label_types import LabelTypes
from examples.models import Example
from labels.models import Category as CategoryModel
from labels.models import Label as LabelModel
from labels.models import Relation as RelationModel
from labels.models import Span as SpanModel
from labels.models import TextLabel as TextLabelModel
from projects.models import Project
class Labels(abc.ABC):
label_model = LabelModel
def __init__(self, labels: List[Label], types: LabelTypes):
self.labels = labels
self.types = types
def __len__(self) -> int:
return len(self.labels)
def clean(self, project: Project):
pass
def save_types(self, project: Project):
types = [label.create_type(project) for label in self.labels]
filtered_types = list(filter(None, types))
self.types.save(filtered_types)
self.types.update(project)
@property
def uuid_to_example(self) -> Dict[UUID4, Example]:
example_uuids = {str(label.example_uuid) for label in self.labels}
examples = Example.objects.filter(uuid__in=example_uuids)
return {example.uuid: example for example in examples}
def save(self, user, **kwargs):
uuid_to_example = self.uuid_to_example
labels = [
label.create(user, uuid_to_example[label.example_uuid], self.types, **kwargs)
for label in self.labels
if label.example_uuid in uuid_to_example
]
self.label_model.objects.bulk_create(labels)
class Categories(Labels):
label_model = CategoryModel
def clean(self, project: Project):
exclusive = getattr(project, "single_class_classification", False)
if exclusive:
groups = groupby(self.labels, lambda label: label.example_uuid)
self.labels = [next(group) for _, group in groups]
class Spans(Labels):
label_model = SpanModel
def clean(self, project: Project):
allow_overlapping = getattr(project, "allow_overlapping", False)
if allow_overlapping:
return
spans = []
groups = groupby(self.labels, lambda label: label.example_uuid)
for _, group in groups:
labels = sorted(group)
last_offset = -1
for label in labels:
if getattr(label, "start_offset") >= last_offset:
last_offset = getattr(label, "end_offset")
spans.append(label)
self.labels = spans
@property
def id_to_span(self) -> Dict[int, SpanModel]:
span_uuids = [str(label.uuid) for label in self.labels]
spans = SpanModel.objects.filter(uuid__in=span_uuids)
uuid_to_span = {span.uuid: span for span in spans}
return {span.id: uuid_to_span[span.uuid] for span in self.labels}
class Texts(Labels):
label_model = TextLabelModel
class Relations(Labels):
label_model = RelationModel
def save(self, user, **kwargs):
id_to_span = kwargs["spans"].id_to_span
super().save(user, id_to_span=id_to_span)