Browse Source

Extract AutomatedLabeling class from AutomatedCategoryLabeling

pull/1650/head
Hironsan 2 years ago
parent
commit
c6fd2ea494
1 changed files with 15 additions and 4 deletions
  1. 19
      backend/auto_labeling/views.py

19
backend/auto_labeling/views.py

@ -1,3 +1,4 @@
import abc
import json
from typing import List
@ -17,7 +18,7 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from api.models import Example, Project, Category, CategoryType
from api.models import Example, Project, Category, CategoryType, Annotation
from members.permissions import IsInProjectOrAdmin, IsProjectAdmin
from .pipeline.execution import execute_pipeline
from .exceptions import (AutoLabelingPermissionDenied,
@ -232,15 +233,16 @@ class AutomatedDataLabeling(generics.CreateAPIView):
return labels
class AutomatedCategoryLabeling(generics.CreateAPIView):
class AutomatedLabeling(abc.ABC, generics.CreateAPIView):
permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
swagger_schema = None
model = Category
model = None
task_type = None
def create(self, request, *args, **kwargs):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
example = get_object_or_404(Example, pk=self.kwargs['example_id'])
configs = AutoLabelingConfig.objects.filter(task_type='Category')
configs = AutoLabelingConfig.objects.filter(task_type=self.task_type)
for config in configs:
labels = execute_pipeline(
text=example.data,
@ -255,6 +257,15 @@ class AutomatedCategoryLabeling(generics.CreateAPIView):
self.model.objects.bulk_create(labels)
return Response({'ok': True}, status=status.HTTP_201_CREATED)
@abc.abstractmethod
def transform(self, labels, example: Example, project: Project) -> List[Annotation]:
raise NotImplementedError('Please implement this method in the subclass')
class AutomatedCategoryLabeling(AutomatedLabeling):
model = Category
task_type = 'Category'
def transform(self, labels, example: Example, project: Project) -> List[Category]:
mapping = {
c.text: c for c in CategoryType.objects.filter(project=project)

Loading…
Cancel
Save