Browse Source

Change execution_pipeline to pass config directly

pull/1650/head
Hironsan 3 years ago
parent
commit
6410af3c3e
2 changed files with 13 additions and 29 deletions
  1. 22
      backend/auto_labeling/pipeline/execution.py
  2. 20
      backend/auto_labeling/views.py

22
backend/auto_labeling/pipeline/execution.py

@ -7,6 +7,7 @@ from auto_labeling_pipeline.pipeline import pipeline
from auto_labeling_pipeline.postprocessing import PostProcessor
from .labels import create_labels
from auto_labeling.models import AutoLabelingConfig
def get_label_collection(task_type: str) -> Type[Labels]:
@ -17,27 +18,22 @@ def get_label_collection(task_type: str) -> Type[Labels]:
}[task_type]
def execute_pipeline(text: str,
task_type: str,
model_name: str,
model_attrs: dict,
template: str,
label_mapping: dict):
label_collection = get_label_collection(task_type)
def execute_pipeline(data: str, config: AutoLabelingConfig):
label_collection = get_label_collection(config.task_type)
model = RequestModelFactory.create(
model_name=model_name,
attributes=model_attrs
model_name=config.model_name,
attributes=config.model_attrs
)
template = MappingTemplate(
label_collection=label_collection,
template=template
template=config.template
)
post_processor = PostProcessor(label_mapping)
post_processor = PostProcessor(config.label_mapping)
labels = pipeline(
text=text,
text=data,
request_model=model,
mapping_template=template,
post_processing=post_processor
)
labels = create_labels(task_type, labels)
labels = create_labels(config.task_type, labels)
return labels

20
backend/auto_labeling/views.py

@ -91,14 +91,9 @@ class FullPipelineTesting(APIView):
def pass_pipeline_call(self, serializer):
test_input = self.request.data['input']
return execute_pipeline(
text=test_input,
task_type=serializer.data.get('task_type'),
model_name=serializer.data.get('model_name'),
model_attrs=serializer.data.get('model_attrs'),
template=serializer.data.get('template'),
label_mapping=serializer.data.get('label_mapping')
)
config = AutoLabelingConfig(**serializer.data)
labels = execute_pipeline(test_input, config=config)
return labels.labels
class RestAPIRequestTesting(APIView):
@ -192,13 +187,6 @@ class AutomatedLabeling(generics.CreateAPIView):
example = get_object_or_404(Example, pk=self.kwargs['example_id'])
configs = AutoLabelingConfig.objects.filter(project=project)
for config in configs:
labels = execute_pipeline(
text=example.data,
task_type=config.task_type,
model_name=config.model_name,
model_attrs=config.model_attrs,
template=config.template,
label_mapping=config.label_mapping
)
labels = execute_pipeline(example.data, config=config)
labels.save(project, example, self.request.user)
return Response({'ok': True}, status=status.HTTP_201_CREATED)
Loading…
Cancel
Save