Browse Source

Extract AutomatedLabeling class from AutomatedCategoryLabeling

pull/1650/head
Hironsan 2 years ago
parent
commit
c6fd2ea494
1 changed files with 15 additions and 4 deletions
  1. 19
      backend/auto_labeling/views.py

19
backend/auto_labeling/views.py

@ -1,3 +1,4 @@
import abc
import json import json
from typing import List from typing import List
@ -17,7 +18,7 @@ from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from api.models import Example, Project, Category, CategoryType
from api.models import Example, Project, Category, CategoryType, Annotation
from members.permissions import IsInProjectOrAdmin, IsProjectAdmin from members.permissions import IsInProjectOrAdmin, IsProjectAdmin
from .pipeline.execution import execute_pipeline from .pipeline.execution import execute_pipeline
from .exceptions import (AutoLabelingPermissionDenied, from .exceptions import (AutoLabelingPermissionDenied,
@ -232,15 +233,16 @@ class AutomatedDataLabeling(generics.CreateAPIView):
return labels return labels
class AutomatedCategoryLabeling(generics.CreateAPIView):
class AutomatedLabeling(abc.ABC, generics.CreateAPIView):
permission_classes = [IsAuthenticated & IsInProjectOrAdmin] permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
swagger_schema = None swagger_schema = None
model = Category
model = None
task_type = None
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
project = get_object_or_404(Project, pk=self.kwargs['project_id']) project = get_object_or_404(Project, pk=self.kwargs['project_id'])
example = get_object_or_404(Example, pk=self.kwargs['example_id']) example = get_object_or_404(Example, pk=self.kwargs['example_id'])
configs = AutoLabelingConfig.objects.filter(task_type='Category')
configs = AutoLabelingConfig.objects.filter(task_type=self.task_type)
for config in configs: for config in configs:
labels = execute_pipeline( labels = execute_pipeline(
text=example.data, text=example.data,
@ -255,6 +257,15 @@ class AutomatedCategoryLabeling(generics.CreateAPIView):
self.model.objects.bulk_create(labels) self.model.objects.bulk_create(labels)
return Response({'ok': True}, status=status.HTTP_201_CREATED) return Response({'ok': True}, status=status.HTTP_201_CREATED)
@abc.abstractmethod
def transform(self, labels, example: Example, project: Project) -> List[Annotation]:
raise NotImplementedError('Please implement this method in the subclass')
class AutomatedCategoryLabeling(AutomatedLabeling):
model = Category
task_type = 'Category'
def transform(self, labels, example: Example, project: Project) -> List[Category]: def transform(self, labels, example: Example, project: Project) -> List[Category]:
mapping = { mapping = {
c.text: c for c in CategoryType.objects.filter(project=project) c.text: c for c in CategoryType.objects.filter(project=project)

Loading…
Cancel
Save