from functools import partial from typing import Type from django.core.exceptions import ValidationError from django.shortcuts import get_object_or_404 from rest_framework import generics, status from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from .permissions import CanEditLabel from .serializers import ( BoundingBoxSerializer, CategorySerializer, RelationSerializer, SegmentationSerializer, SpanSerializer, TextLabelSerializer, ) from labels.models import ( BoundingBox, Category, Label, Relation, Segmentation, Span, TextLabel, ) from projects.models import Project from projects.permissions import IsProjectMember class BaseListAPI(generics.ListCreateAPIView): label_class: Type[Label] pagination_class = None permission_classes = [IsAuthenticated & IsProjectMember] swagger_schema = None @property def project(self): return get_object_or_404(Project, pk=self.kwargs["project_id"]) def get_queryset(self): queryset = self.label_class.objects.filter(example=self.kwargs["example_id"]) if not self.project.collaborative_annotation: queryset = queryset.filter(user=self.request.user) return queryset def create(self, request, *args, **kwargs): request.data["example"] = self.kwargs["example_id"] try: response = super().create(request, args, kwargs) except ValidationError as err: response = Response({"detail": err.messages}, status=status.HTTP_400_BAD_REQUEST) return response def perform_create(self, serializer): serializer.save(example_id=self.kwargs["example_id"], user=self.request.user) def delete(self, request, *args, **kwargs): queryset = self.get_queryset() queryset.all().delete() return Response(status=status.HTTP_204_NO_CONTENT) class BaseDetailAPI(generics.RetrieveUpdateDestroyAPIView): lookup_url_kwarg = "annotation_id" swagger_schema = None @property def project(self): return get_object_or_404(Project, pk=self.kwargs["project_id"]) def get_permissions(self): if self.project.collaborative_annotation: self.permission_classes = [IsAuthenticated & IsProjectMember] else: self.permission_classes = [IsAuthenticated & IsProjectMember & partial(CanEditLabel, self.queryset)] return super().get_permissions() class CategoryListAPI(BaseListAPI): label_class = Category serializer_class = CategorySerializer def create(self, request, *args, **kwargs): if self.project.single_class_classification: self.get_queryset().delete() return super().create(request, args, kwargs) class CategoryDetailAPI(BaseDetailAPI): queryset = Category.objects.all() serializer_class = CategorySerializer class SpanListAPI(BaseListAPI): label_class = Span serializer_class = SpanSerializer class SpanDetailAPI(BaseDetailAPI): queryset = Span.objects.all() serializer_class = SpanSerializer class TextLabelListAPI(BaseListAPI): label_class = TextLabel serializer_class = TextLabelSerializer class TextLabelDetailAPI(BaseDetailAPI): queryset = TextLabel.objects.all() serializer_class = TextLabelSerializer class RelationList(BaseListAPI): label_class = Relation serializer_class = RelationSerializer class RelationDetail(BaseDetailAPI): queryset = Relation.objects.all() serializer_class = RelationSerializer class BoundingBoxListAPI(BaseListAPI): label_class = BoundingBox serializer_class = BoundingBoxSerializer class BoundingBoxDetailAPI(BaseDetailAPI): queryset = BoundingBox.objects.all() serializer_class = BoundingBoxSerializer class SegmentationListAPI(BaseListAPI): label_class = Segmentation serializer_class = SegmentationSerializer class SegmentationDetailAPI(BaseDetailAPI): queryset = Segmentation.objects.all() serializer_class = SegmentationSerializer