|
|
@ -1,3 +1,4 @@ |
|
|
|
import abc |
|
|
|
import json |
|
|
|
from typing import List |
|
|
|
|
|
|
@ -17,7 +18,7 @@ from rest_framework.request import Request |
|
|
|
from rest_framework.response import Response |
|
|
|
from rest_framework.views import APIView |
|
|
|
|
|
|
|
from api.models import Example, Project, Category, CategoryType |
|
|
|
from api.models import Example, Project, Category, CategoryType, Annotation |
|
|
|
from members.permissions import IsInProjectOrAdmin, IsProjectAdmin |
|
|
|
from .pipeline.execution import execute_pipeline |
|
|
|
from .exceptions import (AutoLabelingPermissionDenied, |
|
|
@ -232,15 +233,16 @@ class AutomatedDataLabeling(generics.CreateAPIView): |
|
|
|
return labels |
|
|
|
|
|
|
|
|
|
|
|
class AutomatedCategoryLabeling(generics.CreateAPIView): |
|
|
|
class AutomatedLabeling(abc.ABC, generics.CreateAPIView): |
|
|
|
permission_classes = [IsAuthenticated & IsInProjectOrAdmin] |
|
|
|
swagger_schema = None |
|
|
|
model = Category |
|
|
|
model = None |
|
|
|
task_type = None |
|
|
|
|
|
|
|
def create(self, request, *args, **kwargs): |
|
|
|
project = get_object_or_404(Project, pk=self.kwargs['project_id']) |
|
|
|
example = get_object_or_404(Example, pk=self.kwargs['example_id']) |
|
|
|
configs = AutoLabelingConfig.objects.filter(task_type='Category') |
|
|
|
configs = AutoLabelingConfig.objects.filter(task_type=self.task_type) |
|
|
|
for config in configs: |
|
|
|
labels = execute_pipeline( |
|
|
|
text=example.data, |
|
|
@ -255,6 +257,15 @@ class AutomatedCategoryLabeling(generics.CreateAPIView): |
|
|
|
self.model.objects.bulk_create(labels) |
|
|
|
return Response({'ok': True}, status=status.HTTP_201_CREATED) |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def transform(self, labels, example: Example, project: Project) -> List[Annotation]: |
|
|
|
raise NotImplementedError('Please implement this method in the subclass') |
|
|
|
|
|
|
|
|
|
|
|
class AutomatedCategoryLabeling(AutomatedLabeling): |
|
|
|
model = Category |
|
|
|
task_type = 'Category' |
|
|
|
|
|
|
|
def transform(self, labels, example: Example, project: Project) -> List[Category]: |
|
|
|
mapping = { |
|
|
|
c.text: c for c in CategoryType.objects.filter(project=project) |
|
|
|