diff --git a/backend/auto_labeling/pipeline/__init__.py b/backend/auto_labeling/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/auto_labeling/pipeline/execution.py b/backend/auto_labeling/pipeline/execution.py new file mode 100644 index 00000000..7401a83b --- /dev/null +++ b/backend/auto_labeling/pipeline/execution.py @@ -0,0 +1,40 @@ +from typing import Type + +from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels +from auto_labeling_pipeline.mappings import MappingTemplate +from auto_labeling_pipeline.models import RequestModelFactory +from auto_labeling_pipeline.pipeline import pipeline +from auto_labeling_pipeline.postprocessing import PostProcessor + + +def get_label_collection(task_type: str) -> Type[Labels]: + return { + 'Category': ClassificationLabels, + 'Span': SequenceLabels, + 'Text': Seq2seqLabels + }[task_type] + + +def execute_pipeline(text: str, + task_type: str, + model_name: str, + model_attrs: dict, + template: str, + label_mapping: dict): + label_collection = get_label_collection(task_type) + model = RequestModelFactory.create( + model_name=model_name, + attributes=model_attrs + ) + template = MappingTemplate( + label_collection=label_collection, + template=template + ) + post_processor = PostProcessor(label_mapping) + labels = pipeline( + text=text, + request_model=model, + mapping_template=template, + post_processing=post_processor + ) + return labels.dict()