mirror of https://github.com/doccano/doccano.git
pythondatasetsactive-learningtext-annotationdatasetnatural-language-processingdata-labelingmachine-learningannotation-tool
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
3.1 KiB
103 lines
3.1 KiB
import abc
|
|
from typing import List, Type
|
|
|
|
from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels
|
|
from auto_labeling_pipeline.mappings import MappingTemplate
|
|
from auto_labeling_pipeline.models import RequestModelFactory
|
|
from auto_labeling_pipeline.pipeline import pipeline
|
|
from auto_labeling_pipeline.postprocessing import PostProcessor
|
|
|
|
from api.models import Example, Project, User
|
|
from api.models import CategoryType, SpanType
|
|
from api.models import Annotation, Category, Span, TextLabel
|
|
|
|
|
|
def get_label_collection(task_type: str) -> Type[Labels]:
|
|
return {
|
|
'Category': ClassificationLabels,
|
|
'Span': SequenceLabels,
|
|
'Text': Seq2seqLabels
|
|
}[task_type]
|
|
|
|
|
|
class LabelCollection(abc.ABC):
|
|
label_type = None
|
|
model = None
|
|
|
|
def __init__(self, labels):
|
|
self.labels = labels
|
|
|
|
def transform(self, project: Project, example: Example, user: User) -> List[Annotation]:
|
|
mapping = {
|
|
c.text: c for c in self.label_type.objects.filter(project=project)
|
|
}
|
|
annotations = []
|
|
for label in self.labels:
|
|
if label['label'] not in mapping:
|
|
continue
|
|
label['example'] = example
|
|
label['label'] = mapping[label['label']]
|
|
label['user'] = user
|
|
annotations.append(self.model(**label))
|
|
return annotations
|
|
|
|
def save(self, project: Project, example: Example, user: User):
|
|
labels = self.transform(project, example, user)
|
|
labels = self.model.objects.filter_annotatable_labels(labels, project)
|
|
self.model.objects.bulk_create(labels)
|
|
|
|
|
|
class Categories(LabelCollection):
|
|
label_type = CategoryType
|
|
model = Category
|
|
|
|
|
|
class Spans(LabelCollection):
|
|
label_type = SpanType
|
|
model = Span
|
|
|
|
|
|
class Texts(LabelCollection):
|
|
model = TextLabel
|
|
|
|
def transform(self, project: Project, example: Example, user: User) -> List[Annotation]:
|
|
annotations = []
|
|
for label in self.labels:
|
|
label['example'] = example
|
|
label['user'] = user
|
|
annotations.append(self.model(**label))
|
|
return annotations
|
|
|
|
|
|
def create_labels(task_type: str, labels: Labels) -> LabelCollection:
|
|
return {
|
|
'Category': Categories,
|
|
'Span': Spans,
|
|
'Text': Texts
|
|
}[task_type](labels.dict())
|
|
|
|
|
|
def execute_pipeline(text: str,
|
|
task_type: str,
|
|
model_name: str,
|
|
model_attrs: dict,
|
|
template: str,
|
|
label_mapping: dict):
|
|
label_collection = get_label_collection(task_type)
|
|
model = RequestModelFactory.create(
|
|
model_name=model_name,
|
|
attributes=model_attrs
|
|
)
|
|
template = MappingTemplate(
|
|
label_collection=label_collection,
|
|
template=template
|
|
)
|
|
post_processor = PostProcessor(label_mapping)
|
|
labels = pipeline(
|
|
text=text,
|
|
request_model=model,
|
|
mapping_template=template,
|
|
post_processing=post_processor
|
|
)
|
|
labels = create_labels(task_type, labels)
|
|
return labels
|