You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

29 lines
1.2 KiB

2 years ago
2 years ago
2 years ago
2 years ago
  1. from typing import Type
  2. from auto_labeling_pipeline.labels import (
  3. ClassificationLabels,
  4. Labels,
  5. Seq2seqLabels,
  6. SequenceLabels,
  7. )
  8. from auto_labeling_pipeline.mappings import MappingTemplate
  9. from auto_labeling_pipeline.models import RequestModelFactory
  10. from auto_labeling_pipeline.pipeline import pipeline
  11. from auto_labeling_pipeline.postprocessing import PostProcessor
  12. from .labels import create_labels
  13. from auto_labeling.models import AutoLabelingConfig
  14. def get_label_collection(task_type: str) -> Type[Labels]:
  15. return {"Category": ClassificationLabels, "Span": SequenceLabels, "Text": Seq2seqLabels}[task_type]
  16. def execute_pipeline(data: str, config: AutoLabelingConfig):
  17. label_collection = get_label_collection(config.task_type)
  18. model = RequestModelFactory.create(model_name=config.model_name, attributes=config.model_attrs)
  19. template = MappingTemplate(label_collection=label_collection, template=config.template)
  20. post_processor = PostProcessor(config.label_mapping)
  21. labels = pipeline(text=data, request_model=model, mapping_template=template, post_processing=post_processor)
  22. labels = create_labels(config.task_type, labels)
  23. return labels