You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
4.1 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from functools import partial
  2. from typing import Type
  3. from django.core.exceptions import ValidationError
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import generics, status
  6. from rest_framework.permissions import IsAuthenticated
  7. from rest_framework.response import Response
  8. from .permissions import CanEditLabel
  9. from .serializers import (
  10. CategorySerializer,
  11. RelationSerializer,
  12. SpanSerializer,
  13. TextLabelSerializer,
  14. )
  15. from labels.models import Category, Label, Relation, Span, TextLabel
  16. from projects.models import Project
  17. from projects.permissions import IsProjectMember
  18. class BaseListAPI(generics.ListCreateAPIView):
  19. label_class: Type[Label]
  20. pagination_class = None
  21. permission_classes = [IsAuthenticated & IsProjectMember]
  22. swagger_schema = None
  23. @property
  24. def project(self):
  25. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  26. def get_queryset(self):
  27. queryset = self.label_class.objects.filter(example=self.kwargs["example_id"])
  28. if not self.project.collaborative_annotation:
  29. queryset = queryset.filter(user=self.request.user)
  30. return queryset
  31. def create(self, request, *args, **kwargs):
  32. request.data["example"] = self.kwargs["example_id"]
  33. try:
  34. response = super().create(request, args, kwargs)
  35. except ValidationError as err:
  36. response = Response({"detail": err.messages}, status=status.HTTP_400_BAD_REQUEST)
  37. return response
  38. def perform_create(self, serializer):
  39. serializer.save(example_id=self.kwargs["example_id"], user=self.request.user)
  40. def delete(self, request, *args, **kwargs):
  41. queryset = self.get_queryset()
  42. queryset.all().delete()
  43. return Response(status=status.HTTP_204_NO_CONTENT)
  44. class BaseDetailAPI(generics.RetrieveUpdateDestroyAPIView):
  45. lookup_url_kwarg = "annotation_id"
  46. swagger_schema = None
  47. @property
  48. def project(self):
  49. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  50. def get_permissions(self):
  51. if self.project.collaborative_annotation:
  52. self.permission_classes = [IsAuthenticated & IsProjectMember]
  53. else:
  54. self.permission_classes = [IsAuthenticated & IsProjectMember & partial(CanEditLabel, self.queryset)]
  55. return super().get_permissions()
  56. class CategoryListAPI(BaseListAPI):
  57. label_class = Category
  58. serializer_class = CategorySerializer
  59. def create(self, request, *args, **kwargs):
  60. if self.project.single_class_classification:
  61. self.get_queryset().delete()
  62. return super().create(request, args, kwargs)
  63. class CategoryDetailAPI(BaseDetailAPI):
  64. queryset = Category.objects.all()
  65. serializer_class = CategorySerializer
  66. class SpanListAPI(BaseListAPI):
  67. label_class = Span
  68. serializer_class = SpanSerializer
  69. class SpanDetailAPI(BaseDetailAPI):
  70. queryset = Span.objects.all()
  71. serializer_class = SpanSerializer
  72. class TextLabelListAPI(BaseListAPI):
  73. label_class = TextLabel
  74. serializer_class = TextLabelSerializer
  75. class TextLabelDetailAPI(BaseDetailAPI):
  76. queryset = TextLabel.objects.all()
  77. serializer_class = TextLabelSerializer
  78. class RelationList(generics.ListCreateAPIView):
  79. serializer_class = RelationSerializer
  80. pagination_class = None
  81. permission_classes = [IsAuthenticated & IsProjectMember]
  82. def get_queryset(self):
  83. project = get_object_or_404(Project, pk=self.kwargs["project_id"])
  84. return project.annotation_relations
  85. def perform_create(self, serializer):
  86. project = get_object_or_404(Project, pk=self.kwargs["project_id"])
  87. serializer.save(project=project)
  88. def delete(self, request, *args, **kwargs):
  89. delete_ids = request.data["ids"]
  90. Relation.objects.filter(pk__in=delete_ids).delete()
  91. return Response(status=status.HTTP_204_NO_CONTENT)
  92. class RelationDetail(generics.RetrieveUpdateDestroyAPIView):
  93. queryset = Relation.objects.all()
  94. serializer_class = RelationSerializer
  95. lookup_url_kwarg = "annotation_id"
  96. permission_classes = [IsAuthenticated & IsProjectMember]