You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
3.4 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from functools import partial
  2. from typing import Type
  3. from django.core.exceptions import ValidationError
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import generics, status
  6. from rest_framework.permissions import IsAuthenticated
  7. from rest_framework.response import Response
  8. from .permissions import CanEditLabel
  9. from .serializers import (
  10. CategorySerializer,
  11. RelationSerializer,
  12. SpanSerializer,
  13. TextLabelSerializer,
  14. )
  15. from labels.models import Category, Label, Relation, Span, TextLabel
  16. from projects.models import Project
  17. from projects.permissions import IsProjectMember
  18. class BaseListAPI(generics.ListCreateAPIView):
  19. label_class: Type[Label]
  20. pagination_class = None
  21. permission_classes = [IsAuthenticated & IsProjectMember]
  22. swagger_schema = None
  23. @property
  24. def project(self):
  25. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  26. def get_queryset(self):
  27. queryset = self.label_class.objects.filter(example=self.kwargs["example_id"])
  28. if not self.project.collaborative_annotation:
  29. queryset = queryset.filter(user=self.request.user)
  30. return queryset
  31. def create(self, request, *args, **kwargs):
  32. request.data["example"] = self.kwargs["example_id"]
  33. try:
  34. response = super().create(request, args, kwargs)
  35. except ValidationError as err:
  36. response = Response({"detail": err.messages}, status=status.HTTP_400_BAD_REQUEST)
  37. return response
  38. def perform_create(self, serializer):
  39. serializer.save(example_id=self.kwargs["example_id"], user=self.request.user)
  40. def delete(self, request, *args, **kwargs):
  41. queryset = self.get_queryset()
  42. queryset.all().delete()
  43. return Response(status=status.HTTP_204_NO_CONTENT)
  44. class BaseDetailAPI(generics.RetrieveUpdateDestroyAPIView):
  45. lookup_url_kwarg = "annotation_id"
  46. swagger_schema = None
  47. @property
  48. def project(self):
  49. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  50. def get_permissions(self):
  51. if self.project.collaborative_annotation:
  52. self.permission_classes = [IsAuthenticated & IsProjectMember]
  53. else:
  54. self.permission_classes = [IsAuthenticated & IsProjectMember & partial(CanEditLabel, self.queryset)]
  55. return super().get_permissions()
  56. class CategoryListAPI(BaseListAPI):
  57. label_class = Category
  58. serializer_class = CategorySerializer
  59. def create(self, request, *args, **kwargs):
  60. if self.project.single_class_classification:
  61. self.get_queryset().delete()
  62. return super().create(request, args, kwargs)
  63. class CategoryDetailAPI(BaseDetailAPI):
  64. queryset = Category.objects.all()
  65. serializer_class = CategorySerializer
  66. class SpanListAPI(BaseListAPI):
  67. label_class = Span
  68. serializer_class = SpanSerializer
  69. class SpanDetailAPI(BaseDetailAPI):
  70. queryset = Span.objects.all()
  71. serializer_class = SpanSerializer
  72. class TextLabelListAPI(BaseListAPI):
  73. label_class = TextLabel
  74. serializer_class = TextLabelSerializer
  75. class TextLabelDetailAPI(BaseDetailAPI):
  76. queryset = TextLabel.objects.all()
  77. serializer_class = TextLabelSerializer
  78. class RelationList(BaseListAPI):
  79. label_class = Relation
  80. serializer_class = RelationSerializer
  81. class RelationDetail(BaseDetailAPI):
  82. queryset = Relation.objects.all()
  83. serializer_class = RelationSerializer