from typing import Type

from auto_labeling_pipeline.labels import (
    ClassificationLabels,
    Labels,
    Seq2seqLabels,
    SequenceLabels,
)
from auto_labeling_pipeline.mappings import MappingTemplate
from auto_labeling_pipeline.models import RequestModelFactory
from auto_labeling_pipeline.pipeline import pipeline
from auto_labeling_pipeline.postprocessing import PostProcessor

from .labels import create_labels
from auto_labeling.models import AutoLabelingConfig


def get_label_collection(task_type: str) -> Type[Labels]:
    return {"Category": ClassificationLabels, "Span": SequenceLabels, "Text": Seq2seqLabels}[task_type]


def execute_pipeline(data: str, config: AutoLabelingConfig):
    label_collection = get_label_collection(config.task_type)
    model = RequestModelFactory.create(model_name=config.model_name, attributes=config.model_attrs)
    template = MappingTemplate(label_collection=label_collection, template=config.template)
    post_processor = PostProcessor(config.label_mapping)
    labels = pipeline(text=data, request_model=model, mapping_template=template, post_processing=post_processor)
    labels = create_labels(config.task_type, labels)
    return labels