You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

143 lines
4.0 KiB

from functools import partial
from typing import Type
from django.core.exceptions import ValidationError
from django.shortcuts import get_object_or_404
from rest_framework import generics, status
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from .permissions import CanEditLabel
from .serializers import (
BoundingBoxSerializer,
CategorySerializer,
RelationSerializer,
SegmentationSerializer,
SpanSerializer,
TextLabelSerializer,
)
from labels.models import (
BoundingBox,
Category,
Label,
Relation,
Segmentation,
Span,
TextLabel,
)
from projects.models import Project
from projects.permissions import IsProjectMember
class BaseListAPI(generics.ListCreateAPIView):
label_class: Type[Label]
pagination_class = None
permission_classes = [IsAuthenticated & IsProjectMember]
swagger_schema = None
@property
def project(self):
return get_object_or_404(Project, pk=self.kwargs["project_id"])
def get_queryset(self):
queryset = self.label_class.objects.filter(example=self.kwargs["example_id"])
if not self.project.collaborative_annotation:
queryset = queryset.filter(user=self.request.user)
return queryset
def create(self, request, *args, **kwargs):
request.data["example"] = self.kwargs["example_id"]
try:
response = super().create(request, args, kwargs)
except ValidationError as err:
response = Response({"detail": err.messages}, status=status.HTTP_400_BAD_REQUEST)
return response
def perform_create(self, serializer):
serializer.save(example_id=self.kwargs["example_id"], user=self.request.user)
def delete(self, request, *args, **kwargs):
queryset = self.get_queryset()
queryset.all().delete()
return Response(status=status.HTTP_204_NO_CONTENT)
class BaseDetailAPI(generics.RetrieveUpdateDestroyAPIView):
lookup_url_kwarg = "annotation_id"
swagger_schema = None
@property
def project(self):
return get_object_or_404(Project, pk=self.kwargs["project_id"])
def get_permissions(self):
if self.project.collaborative_annotation:
self.permission_classes = [IsAuthenticated & IsProjectMember]
else:
self.permission_classes = [IsAuthenticated & IsProjectMember & partial(CanEditLabel, self.queryset)]
return super().get_permissions()
class CategoryListAPI(BaseListAPI):
label_class = Category
serializer_class = CategorySerializer
def create(self, request, *args, **kwargs):
if self.project.single_class_classification:
self.get_queryset().delete()
return super().create(request, args, kwargs)
class CategoryDetailAPI(BaseDetailAPI):
queryset = Category.objects.all()
serializer_class = CategorySerializer
class SpanListAPI(BaseListAPI):
label_class = Span
serializer_class = SpanSerializer
class SpanDetailAPI(BaseDetailAPI):
queryset = Span.objects.all()
serializer_class = SpanSerializer
class TextLabelListAPI(BaseListAPI):
label_class = TextLabel
serializer_class = TextLabelSerializer
class TextLabelDetailAPI(BaseDetailAPI):
queryset = TextLabel.objects.all()
serializer_class = TextLabelSerializer
class RelationList(BaseListAPI):
label_class = Relation
serializer_class = RelationSerializer
class RelationDetail(BaseDetailAPI):
queryset = Relation.objects.all()
serializer_class = RelationSerializer
class BoundingBoxListAPI(BaseListAPI):
label_class = BoundingBox
serializer_class = BoundingBoxSerializer
class BoundingBoxDetailAPI(BaseDetailAPI):
queryset = BoundingBox.objects.all()
serializer_class = BoundingBoxSerializer
class SegmentationListAPI(BaseListAPI):
label_class = Segmentation
serializer_class = SegmentationSerializer
class SegmentationDetailAPI(BaseDetailAPI):
queryset = Segmentation.objects.all()
serializer_class = SegmentationSerializer