You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

39 lines
1.3 KiB

from typing import Type
from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels
from auto_labeling_pipeline.mappings import MappingTemplate
from auto_labeling_pipeline.models import RequestModelFactory
from auto_labeling_pipeline.pipeline import pipeline
from auto_labeling_pipeline.postprocessing import PostProcessor
from .labels import create_labels
from auto_labeling.models import AutoLabelingConfig
def get_label_collection(task_type: str) -> Type[Labels]:
return {
'Category': ClassificationLabels,
'Span': SequenceLabels,
'Text': Seq2seqLabels
}[task_type]
def execute_pipeline(data: str, config: AutoLabelingConfig):
label_collection = get_label_collection(config.task_type)
model = RequestModelFactory.create(
model_name=config.model_name,
attributes=config.model_attrs
)
template = MappingTemplate(
label_collection=label_collection,
template=config.template
)
post_processor = PostProcessor(config.label_mapping)
labels = pipeline(
text=data,
request_model=model,
mapping_template=template,
post_processing=post_processor
)
labels = create_labels(config.task_type, labels)
return labels