You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

24 lines
1.1 KiB

2 years ago
2 years ago
2 years ago
  1. from typing import Type
  2. from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels
  3. from auto_labeling_pipeline.mappings import MappingTemplate
  4. from auto_labeling_pipeline.models import RequestModelFactory
  5. from auto_labeling_pipeline.pipeline import pipeline
  6. from auto_labeling_pipeline.postprocessing import PostProcessor
  7. from .labels import create_labels
  8. from auto_labeling.models import AutoLabelingConfig
  9. def get_label_collection(task_type: str) -> Type[Labels]:
  10. return {"Category": ClassificationLabels, "Span": SequenceLabels, "Text": Seq2seqLabels}[task_type]
  11. def execute_pipeline(data: str, config: AutoLabelingConfig):
  12. label_collection = get_label_collection(config.task_type)
  13. model = RequestModelFactory.create(model_name=config.model_name, attributes=config.model_attrs)
  14. template = MappingTemplate(label_collection=label_collection, template=config.template)
  15. post_processor = PostProcessor(config.label_mapping)
  16. labels = pipeline(text=data, request_model=model, mapping_template=template, post_processing=post_processor)
  17. labels = create_labels(config.task_type, labels)
  18. return labels