You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
4.0 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from functools import partial
  2. from typing import Type
  3. from django.core.exceptions import ValidationError
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import generics, status
  6. from rest_framework.permissions import IsAuthenticated
  7. from rest_framework.response import Response
  8. from .permissions import CanEditLabel
  9. from .serializers import (
  10. BoundingBoxSerializer,
  11. CategorySerializer,
  12. RelationSerializer,
  13. SegmentationSerializer,
  14. SpanSerializer,
  15. TextLabelSerializer,
  16. )
  17. from labels.models import (
  18. BoundingBox,
  19. Category,
  20. Label,
  21. Relation,
  22. Segmentation,
  23. Span,
  24. TextLabel,
  25. )
  26. from projects.models import Project
  27. from projects.permissions import IsProjectMember
  28. class BaseListAPI(generics.ListCreateAPIView):
  29. label_class: Type[Label]
  30. pagination_class = None
  31. permission_classes = [IsAuthenticated & IsProjectMember]
  32. swagger_schema = None
  33. @property
  34. def project(self):
  35. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  36. def get_queryset(self):
  37. queryset = self.label_class.objects.filter(example=self.kwargs["example_id"])
  38. if not self.project.collaborative_annotation:
  39. queryset = queryset.filter(user=self.request.user)
  40. return queryset
  41. def create(self, request, *args, **kwargs):
  42. request.data["example"] = self.kwargs["example_id"]
  43. try:
  44. response = super().create(request, args, kwargs)
  45. except ValidationError as err:
  46. response = Response({"detail": err.messages}, status=status.HTTP_400_BAD_REQUEST)
  47. return response
  48. def perform_create(self, serializer):
  49. serializer.save(example_id=self.kwargs["example_id"], user=self.request.user)
  50. def delete(self, request, *args, **kwargs):
  51. queryset = self.get_queryset()
  52. queryset.all().delete()
  53. return Response(status=status.HTTP_204_NO_CONTENT)
  54. class BaseDetailAPI(generics.RetrieveUpdateDestroyAPIView):
  55. lookup_url_kwarg = "annotation_id"
  56. swagger_schema = None
  57. @property
  58. def project(self):
  59. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  60. def get_permissions(self):
  61. if self.project.collaborative_annotation:
  62. self.permission_classes = [IsAuthenticated & IsProjectMember]
  63. else:
  64. self.permission_classes = [IsAuthenticated & IsProjectMember & partial(CanEditLabel, self.queryset)]
  65. return super().get_permissions()
  66. class CategoryListAPI(BaseListAPI):
  67. label_class = Category
  68. serializer_class = CategorySerializer
  69. def create(self, request, *args, **kwargs):
  70. if self.project.single_class_classification:
  71. self.get_queryset().delete()
  72. return super().create(request, args, kwargs)
  73. class CategoryDetailAPI(BaseDetailAPI):
  74. queryset = Category.objects.all()
  75. serializer_class = CategorySerializer
  76. class SpanListAPI(BaseListAPI):
  77. label_class = Span
  78. serializer_class = SpanSerializer
  79. class SpanDetailAPI(BaseDetailAPI):
  80. queryset = Span.objects.all()
  81. serializer_class = SpanSerializer
  82. class TextLabelListAPI(BaseListAPI):
  83. label_class = TextLabel
  84. serializer_class = TextLabelSerializer
  85. class TextLabelDetailAPI(BaseDetailAPI):
  86. queryset = TextLabel.objects.all()
  87. serializer_class = TextLabelSerializer
  88. class RelationList(BaseListAPI):
  89. label_class = Relation
  90. serializer_class = RelationSerializer
  91. class RelationDetail(BaseDetailAPI):
  92. queryset = Relation.objects.all()
  93. serializer_class = RelationSerializer
  94. class BoundingBoxListAPI(BaseListAPI):
  95. label_class = BoundingBox
  96. serializer_class = BoundingBoxSerializer
  97. class BoundingBoxDetailAPI(BaseDetailAPI):
  98. queryset = BoundingBox.objects.all()
  99. serializer_class = BoundingBoxSerializer
  100. class SegmentationListAPI(BaseListAPI):
  101. label_class = Segmentation
  102. serializer_class = SegmentationSerializer
  103. class SegmentationDetailAPI(BaseDetailAPI):
  104. queryset = Segmentation.objects.all()
  105. serializer_class = SegmentationSerializer