You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

39 lines
1.3 KiB

  1. from typing import Type
  2. from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels
  3. from auto_labeling_pipeline.mappings import MappingTemplate
  4. from auto_labeling_pipeline.models import RequestModelFactory
  5. from auto_labeling_pipeline.pipeline import pipeline
  6. from auto_labeling_pipeline.postprocessing import PostProcessor
  7. from .labels import create_labels
  8. from auto_labeling.models import AutoLabelingConfig
  9. def get_label_collection(task_type: str) -> Type[Labels]:
  10. return {
  11. 'Category': ClassificationLabels,
  12. 'Span': SequenceLabels,
  13. 'Text': Seq2seqLabels
  14. }[task_type]
  15. def execute_pipeline(data: str, config: AutoLabelingConfig):
  16. label_collection = get_label_collection(config.task_type)
  17. model = RequestModelFactory.create(
  18. model_name=config.model_name,
  19. attributes=config.model_attrs
  20. )
  21. template = MappingTemplate(
  22. label_collection=label_collection,
  23. template=config.template
  24. )
  25. post_processor = PostProcessor(config.label_mapping)
  26. labels = pipeline(
  27. text=data,
  28. request_model=model,
  29. mapping_template=template,
  30. post_processing=post_processor
  31. )
  32. labels = create_labels(config.task_type, labels)
  33. return labels