from typing import Type from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels from auto_labeling_pipeline.mappings import MappingTemplate from auto_labeling_pipeline.models import RequestModelFactory from auto_labeling_pipeline.pipeline import pipeline from auto_labeling_pipeline.postprocessing import PostProcessor from .labels import create_labels from auto_labeling.models import AutoLabelingConfig def get_label_collection(task_type: str) -> Type[Labels]: return { 'Category': ClassificationLabels, 'Span': SequenceLabels, 'Text': Seq2seqLabels }[task_type] def execute_pipeline(data: str, config: AutoLabelingConfig): label_collection = get_label_collection(config.task_type) model = RequestModelFactory.create( model_name=config.model_name, attributes=config.model_attrs ) template = MappingTemplate( label_collection=label_collection, template=config.template ) post_processor = PostProcessor(config.label_mapping) labels = pipeline( text=data, request_model=model, mapping_template=template, post_processing=post_processor ) labels = create_labels(config.task_type, labels) return labels