mirror of https://github.com/doccano/doccano.git
pythonannotation-tooldatasetsactive-learningtext-annotationdatasetnatural-language-processingdata-labelingmachine-learning
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
89 lines
2.8 KiB
89 lines
2.8 KiB
import abc
|
|
from itertools import groupby
|
|
from typing import Dict, List, Tuple
|
|
|
|
from .examples import Examples
|
|
from .label import Label
|
|
from .label_types import LabelTypes
|
|
from labels.models import Category as CategoryModel
|
|
from labels.models import Label as LabelModel
|
|
from labels.models import Relation as RelationModel
|
|
from labels.models import Span as SpanModel
|
|
from labels.models import TextLabel as TextLabelModel
|
|
from projects.models import Project
|
|
|
|
|
|
class Labels(abc.ABC):
|
|
label_model = LabelModel
|
|
|
|
def __init__(self, labels: List[Label], types: LabelTypes):
|
|
self.labels = labels
|
|
self.types = types
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.labels)
|
|
|
|
def clean(self, project: Project):
|
|
pass
|
|
|
|
def save_types(self, project: Project):
|
|
types = [label.create_type(project) for label in self.labels]
|
|
filtered_types = list(filter(None, types))
|
|
self.types.save(filtered_types)
|
|
self.types.update(project)
|
|
|
|
def save(self, user, examples: Examples, **kwargs):
|
|
labels = [
|
|
label.create(user, examples[label.example_uuid], self.types, **kwargs)
|
|
for label in self.labels
|
|
if label.example_uuid in examples
|
|
]
|
|
self.label_model.objects.bulk_create(labels)
|
|
|
|
|
|
class Categories(Labels):
|
|
label_model = CategoryModel
|
|
|
|
def clean(self, project: Project):
|
|
exclusive = getattr(project, "single_class_classification", False)
|
|
if exclusive:
|
|
groups = groupby(self.labels, lambda label: label.example_uuid)
|
|
self.labels = [next(group) for _, group in groups]
|
|
|
|
|
|
class Spans(Labels):
|
|
label_model = SpanModel
|
|
|
|
def clean(self, project: Project):
|
|
allow_overlapping = getattr(project, "allow_overlapping", False)
|
|
if allow_overlapping:
|
|
return
|
|
spans = []
|
|
groups = groupby(self.labels, lambda label: label.example_uuid)
|
|
for _, group in groups:
|
|
labels = sorted(group)
|
|
last_offset = -1
|
|
for label in labels:
|
|
if getattr(label, "start_offset") >= last_offset:
|
|
last_offset = getattr(label, "end_offset")
|
|
spans.append(label)
|
|
self.labels = spans
|
|
|
|
@property
|
|
def id_to_span(self) -> Dict[Tuple[int, str], SpanModel]:
|
|
uuids = [str(span.uuid) for span in self.labels]
|
|
spans = SpanModel.objects.filter(uuid__in=uuids)
|
|
uuid_to_span = {span.uuid: span for span in spans}
|
|
return {(span.id, str(span.example_uuid)): uuid_to_span[span.uuid] for span in self.labels}
|
|
|
|
|
|
class Texts(Labels):
|
|
label_model = TextLabelModel
|
|
|
|
|
|
class Relations(Labels):
|
|
label_model = RelationModel
|
|
|
|
def save(self, user, examples: Examples, **kwargs):
|
|
id_to_span = kwargs["spans"].id_to_span
|
|
super().save(user, examples, id_to_span=id_to_span)
|