You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

103 lines
3.1 KiB

import abc
from typing import List, Type
from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels
from auto_labeling_pipeline.mappings import MappingTemplate
from auto_labeling_pipeline.models import RequestModelFactory
from auto_labeling_pipeline.pipeline import pipeline
from auto_labeling_pipeline.postprocessing import PostProcessor
from api.models import Example, Project, User
from api.models import CategoryType, SpanType
from api.models import Annotation, Category, Span, TextLabel
def get_label_collection(task_type: str) -> Type[Labels]:
return {
'Category': ClassificationLabels,
'Span': SequenceLabels,
'Text': Seq2seqLabels
}[task_type]
class LabelCollection(abc.ABC):
label_type = None
model = None
def __init__(self, labels):
self.labels = labels
def transform(self, project: Project, example: Example, user: User) -> List[Annotation]:
mapping = {
c.text: c for c in self.label_type.objects.filter(project=project)
}
annotations = []
for label in self.labels:
if label['label'] not in mapping:
continue
label['example'] = example
label['label'] = mapping[label['label']]
label['user'] = user
annotations.append(self.model(**label))
return annotations
def save(self, project: Project, example: Example, user: User):
labels = self.transform(project, example, user)
labels = self.model.objects.filter_annotatable_labels(labels, project)
self.model.objects.bulk_create(labels)
class Categories(LabelCollection):
label_type = CategoryType
model = Category
class Spans(LabelCollection):
label_type = SpanType
model = Span
class Texts(LabelCollection):
model = TextLabel
def transform(self, project: Project, example: Example, user: User) -> List[Annotation]:
annotations = []
for label in self.labels:
label['example'] = example
label['user'] = user
annotations.append(self.model(**label))
return annotations
def create_labels(task_type: str, labels: Labels) -> LabelCollection:
return {
'Category': Categories,
'Span': Spans,
'Text': Texts
}[task_type](labels.dict())
def execute_pipeline(text: str,
task_type: str,
model_name: str,
model_attrs: dict,
template: str,
label_mapping: dict):
label_collection = get_label_collection(task_type)
model = RequestModelFactory.create(
model_name=model_name,
attributes=model_attrs
)
template = MappingTemplate(
label_collection=label_collection,
template=template
)
post_processor = PostProcessor(label_mapping)
labels = pipeline(
text=text,
request_model=model,
mapping_template=template,
post_processing=post_processor
)
labels = create_labels(task_type, labels)
return labels