From 6410af3c3ece47ef93b4a6acab1dfa70a9142295 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Wed, 26 Jan 2022 08:15:23 +0900 Subject: [PATCH] Change execution_pipeline to pass config directly --- backend/auto_labeling/pipeline/execution.py | 22 +++++++++------------ backend/auto_labeling/views.py | 20 ++++--------------- 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/backend/auto_labeling/pipeline/execution.py b/backend/auto_labeling/pipeline/execution.py index 6bb97ae5..795c1e3d 100644 --- a/backend/auto_labeling/pipeline/execution.py +++ b/backend/auto_labeling/pipeline/execution.py @@ -7,6 +7,7 @@ from auto_labeling_pipeline.pipeline import pipeline from auto_labeling_pipeline.postprocessing import PostProcessor from .labels import create_labels +from auto_labeling.models import AutoLabelingConfig def get_label_collection(task_type: str) -> Type[Labels]: @@ -17,27 +18,22 @@ def get_label_collection(task_type: str) -> Type[Labels]: }[task_type] -def execute_pipeline(text: str, - task_type: str, - model_name: str, - model_attrs: dict, - template: str, - label_mapping: dict): - label_collection = get_label_collection(task_type) +def execute_pipeline(data: str, config: AutoLabelingConfig): + label_collection = get_label_collection(config.task_type) model = RequestModelFactory.create( - model_name=model_name, - attributes=model_attrs + model_name=config.model_name, + attributes=config.model_attrs ) template = MappingTemplate( label_collection=label_collection, - template=template + template=config.template ) - post_processor = PostProcessor(label_mapping) + post_processor = PostProcessor(config.label_mapping) labels = pipeline( - text=text, + text=data, request_model=model, mapping_template=template, post_processing=post_processor ) - labels = create_labels(task_type, labels) + labels = create_labels(config.task_type, labels) return labels diff --git a/backend/auto_labeling/views.py b/backend/auto_labeling/views.py index 4296d4e2..e357ff36 100644 --- a/backend/auto_labeling/views.py +++ b/backend/auto_labeling/views.py @@ -91,14 +91,9 @@ class FullPipelineTesting(APIView): def pass_pipeline_call(self, serializer): test_input = self.request.data['input'] - return execute_pipeline( - text=test_input, - task_type=serializer.data.get('task_type'), - model_name=serializer.data.get('model_name'), - model_attrs=serializer.data.get('model_attrs'), - template=serializer.data.get('template'), - label_mapping=serializer.data.get('label_mapping') - ) + config = AutoLabelingConfig(**serializer.data) + labels = execute_pipeline(test_input, config=config) + return labels.labels class RestAPIRequestTesting(APIView): @@ -192,13 +187,6 @@ class AutomatedLabeling(generics.CreateAPIView): example = get_object_or_404(Example, pk=self.kwargs['example_id']) configs = AutoLabelingConfig.objects.filter(project=project) for config in configs: - labels = execute_pipeline( - text=example.data, - task_type=config.task_type, - model_name=config.model_name, - model_attrs=config.model_attrs, - template=config.template, - label_mapping=config.label_mapping - ) + labels = execute_pipeline(example.data, config=config) labels.save(project, example, self.request.user) return Response({'ok': True}, status=status.HTTP_201_CREATED)