From c6fd2ea494d92368bd8f12bb287fdb0cde84e0e1 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Tue, 25 Jan 2022 14:28:58 +0900 Subject: [PATCH] Extract AutomatedLabeling class from AutomatedCategoryLabeling --- backend/auto_labeling/views.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/backend/auto_labeling/views.py b/backend/auto_labeling/views.py index 63515752..6e608573 100644 --- a/backend/auto_labeling/views.py +++ b/backend/auto_labeling/views.py @@ -1,3 +1,4 @@ +import abc import json from typing import List @@ -17,7 +18,7 @@ from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView -from api.models import Example, Project, Category, CategoryType +from api.models import Example, Project, Category, CategoryType, Annotation from members.permissions import IsInProjectOrAdmin, IsProjectAdmin from .pipeline.execution import execute_pipeline from .exceptions import (AutoLabelingPermissionDenied, @@ -232,15 +233,16 @@ class AutomatedDataLabeling(generics.CreateAPIView): return labels -class AutomatedCategoryLabeling(generics.CreateAPIView): +class AutomatedLabeling(abc.ABC, generics.CreateAPIView): permission_classes = [IsAuthenticated & IsInProjectOrAdmin] swagger_schema = None - model = Category + model = None + task_type = None def create(self, request, *args, **kwargs): project = get_object_or_404(Project, pk=self.kwargs['project_id']) example = get_object_or_404(Example, pk=self.kwargs['example_id']) - configs = AutoLabelingConfig.objects.filter(task_type='Category') + configs = AutoLabelingConfig.objects.filter(task_type=self.task_type) for config in configs: labels = execute_pipeline( text=example.data, @@ -255,6 +257,15 @@ class AutomatedCategoryLabeling(generics.CreateAPIView): self.model.objects.bulk_create(labels) return Response({'ok': True}, status=status.HTTP_201_CREATED) + @abc.abstractmethod + def transform(self, labels, example: Example, project: Project) -> List[Annotation]: + raise NotImplementedError('Please implement this method in the subclass') + + +class AutomatedCategoryLabeling(AutomatedLabeling): + model = Category + task_type = 'Category' + def transform(self, labels, example: Example, project: Project) -> List[Category]: mapping = { c.text: c for c in CategoryType.objects.filter(project=project)