Browse Source

Add get_label_collection function

pull/1650/head
Hironsan 2 years ago
parent
commit
1a9a3b56c1
2 changed files with 40 additions and 0 deletions
  1. 0
      backend/auto_labeling/pipeline/__init__.py
  2. 40
      backend/auto_labeling/pipeline/execution.py

0
backend/auto_labeling/pipeline/__init__.py

40
backend/auto_labeling/pipeline/execution.py

@ -0,0 +1,40 @@
from typing import Type
from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels
from auto_labeling_pipeline.mappings import MappingTemplate
from auto_labeling_pipeline.models import RequestModelFactory
from auto_labeling_pipeline.pipeline import pipeline
from auto_labeling_pipeline.postprocessing import PostProcessor
def get_label_collection(task_type: str) -> Type[Labels]:
return {
'Category': ClassificationLabels,
'Span': SequenceLabels,
'Text': Seq2seqLabels
}[task_type]
def execute_pipeline(text: str,
task_type: str,
model_name: str,
model_attrs: dict,
template: str,
label_mapping: dict):
label_collection = get_label_collection(task_type)
model = RequestModelFactory.create(
model_name=model_name,
attributes=model_attrs
)
template = MappingTemplate(
label_collection=label_collection,
template=template
)
post_processor = PostProcessor(label_mapping)
labels = pipeline(
text=text,
request_model=model,
mapping_template=template,
post_processing=post_processor
)
return labels.dict()
Loading…
Cancel
Save