You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
4.1 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from functools import partial
  2. from typing import Type
  3. from django.core.exceptions import ValidationError
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import generics, status
  6. from rest_framework.permissions import IsAuthenticated
  7. from rest_framework.response import Response
  8. from projects.models import Project
  9. from labels.models import Category, Label, Span, TextLabel, Relation
  10. from projects.permissions import IsProjectMember
  11. from .permissions import CanEditLabel
  12. from .serializers import CategorySerializer, SpanSerializer, TextLabelSerializer, RelationSerializer
  13. class BaseListAPI(generics.ListCreateAPIView):
  14. label_class: Type[Label]
  15. pagination_class = None
  16. permission_classes = [IsAuthenticated & IsProjectMember]
  17. swagger_schema = None
  18. @property
  19. def project(self):
  20. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  21. def get_queryset(self):
  22. queryset = self.label_class.objects.filter(example=self.kwargs["example_id"])
  23. if not self.project.collaborative_annotation:
  24. queryset = queryset.filter(user=self.request.user)
  25. return queryset
  26. def create(self, request, *args, **kwargs):
  27. request.data["example"] = self.kwargs["example_id"]
  28. try:
  29. response = super().create(request, args, kwargs)
  30. except ValidationError as err:
  31. response = Response({"detail": err.messages}, status=status.HTTP_400_BAD_REQUEST)
  32. return response
  33. def perform_create(self, serializer):
  34. serializer.save(example_id=self.kwargs["example_id"], user=self.request.user)
  35. def delete(self, request, *args, **kwargs):
  36. queryset = self.get_queryset()
  37. queryset.all().delete()
  38. return Response(status=status.HTTP_204_NO_CONTENT)
  39. class BaseDetailAPI(generics.RetrieveUpdateDestroyAPIView):
  40. lookup_url_kwarg = "annotation_id"
  41. swagger_schema = None
  42. @property
  43. def project(self):
  44. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  45. def get_permissions(self):
  46. if self.project.collaborative_annotation:
  47. self.permission_classes = [IsAuthenticated & IsProjectMember]
  48. else:
  49. self.permission_classes = [IsAuthenticated & IsProjectMember & partial(CanEditLabel, self.queryset)]
  50. return super().get_permissions()
  51. class CategoryListAPI(BaseListAPI):
  52. label_class = Category
  53. serializer_class = CategorySerializer
  54. def create(self, request, *args, **kwargs):
  55. if self.project.single_class_classification:
  56. self.get_queryset().delete()
  57. return super().create(request, args, kwargs)
  58. class CategoryDetailAPI(BaseDetailAPI):
  59. queryset = Category.objects.all()
  60. serializer_class = CategorySerializer
  61. class SpanListAPI(BaseListAPI):
  62. label_class = Span
  63. serializer_class = SpanSerializer
  64. class SpanDetailAPI(BaseDetailAPI):
  65. queryset = Span.objects.all()
  66. serializer_class = SpanSerializer
  67. class TextLabelListAPI(BaseListAPI):
  68. label_class = TextLabel
  69. serializer_class = TextLabelSerializer
  70. class TextLabelDetailAPI(BaseDetailAPI):
  71. queryset = TextLabel.objects.all()
  72. serializer_class = TextLabelSerializer
  73. class RelationList(generics.ListCreateAPIView):
  74. serializer_class = RelationSerializer
  75. pagination_class = None
  76. permission_classes = [IsAuthenticated & IsProjectMember]
  77. def get_queryset(self):
  78. project = get_object_or_404(Project, pk=self.kwargs["project_id"])
  79. return project.annotation_relations
  80. def perform_create(self, serializer):
  81. project = get_object_or_404(Project, pk=self.kwargs["project_id"])
  82. serializer.save(project=project)
  83. def delete(self, request, *args, **kwargs):
  84. delete_ids = request.data["ids"]
  85. Relation.objects.filter(pk__in=delete_ids).delete()
  86. return Response(status=status.HTTP_204_NO_CONTENT)
  87. class RelationDetail(generics.RetrieveUpdateDestroyAPIView):
  88. queryset = Relation.objects.all()
  89. serializer_class = RelationSerializer
  90. lookup_url_kwarg = "annotation_id"
  91. permission_classes = [IsAuthenticated & IsProjectMember]