You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

431 lines
16 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
  1. import collections
  2. import json
  3. from django.conf import settings
  4. from django.contrib.auth.models import User
  5. from django.db import transaction
  6. from django.db.utils import IntegrityError
  7. from django.shortcuts import get_object_or_404, redirect
  8. from django_filters.rest_framework import DjangoFilterBackend
  9. from django.db.models import Count, F, Q
  10. from libcloud.base import DriverType, get_driver
  11. from libcloud.storage.types import ContainerDoesNotExistError, ObjectDoesNotExistError
  12. from rest_framework import generics, filters, status
  13. from rest_framework.exceptions import ParseError, ValidationError
  14. from rest_framework.permissions import IsAuthenticated, IsAuthenticatedOrReadOnly
  15. from rest_framework.response import Response
  16. from rest_framework.views import APIView
  17. from rest_framework.parsers import MultiPartParser
  18. from rest_framework_csv.renderers import CSVRenderer
  19. from .filters import DocumentFilter
  20. from .models import Project, Label, Document, RoleMapping, Role
  21. from .permissions import IsProjectAdmin, IsAnnotatorAndReadOnly, IsAnnotator, IsAnnotationApproverAndReadOnly, IsOwnAnnotation, IsAnnotationApprover
  22. from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer, UserSerializer
  23. from .serializers import ProjectPolymorphicSerializer, RoleMappingSerializer, RoleSerializer
  24. from .utils import CSVParser, ExcelParser, JSONParser, PlainTextParser, CoNLLParser, AudioParser, iterable_to_io
  25. from .utils import JSONLRenderer
  26. from .utils import JSONPainter, CSVPainter
  27. IsInProjectReadOnlyOrAdmin = (IsAnnotatorAndReadOnly | IsAnnotationApproverAndReadOnly | IsProjectAdmin)
  28. IsInProjectOrAdmin = (IsAnnotator | IsAnnotationApprover | IsProjectAdmin)
  29. class Health(APIView):
  30. permission_classes = (IsAuthenticatedOrReadOnly,)
  31. def get(self, request, *args, **kwargs):
  32. return Response({'status': 'green'})
  33. class Me(APIView):
  34. permission_classes = (IsAuthenticated,)
  35. def get(self, request, *args, **kwargs):
  36. serializer = UserSerializer(request.user, context={'request': request})
  37. return Response(serializer.data)
  38. class Features(APIView):
  39. permission_classes = (IsAuthenticated,)
  40. def get(self, request, *args, **kwargs):
  41. return Response({
  42. 'cloud_upload': bool(settings.CLOUD_BROWSER_APACHE_LIBCLOUD_PROVIDER),
  43. })
  44. class ProjectList(generics.ListCreateAPIView):
  45. serializer_class = ProjectPolymorphicSerializer
  46. pagination_class = None
  47. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  48. def get_queryset(self):
  49. return self.request.user.projects
  50. def perform_create(self, serializer):
  51. serializer.save(users=[self.request.user])
  52. class ProjectDetail(generics.RetrieveUpdateDestroyAPIView):
  53. queryset = Project.objects.all()
  54. serializer_class = ProjectSerializer
  55. lookup_url_kwarg = 'project_id'
  56. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  57. class StatisticsAPI(APIView):
  58. pagination_class = None
  59. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  60. def get(self, request, *args, **kwargs):
  61. p = get_object_or_404(Project, pk=self.kwargs['project_id'])
  62. include = set(request.GET.getlist('include'))
  63. response = {}
  64. if not include or 'label' in include:
  65. label_count, user_count = self.label_per_data(p)
  66. response['label'] = label_count
  67. # TODO: Make user_label count chart
  68. response['user_label'] = user_count
  69. if not include or 'total' in include or 'remaining' in include or 'user' in include:
  70. progress = self.progress(project=p)
  71. response.update(progress)
  72. if include:
  73. response = {key: value for (key, value) in response.items() if key in include}
  74. return Response(response)
  75. @staticmethod
  76. def _get_user_completion_data(annotation_class, annotation_filter):
  77. all_annotation_objects = annotation_class.objects.filter(annotation_filter)
  78. set_user_data = collections.defaultdict(set)
  79. for ind_obj in all_annotation_objects.values('user__username', 'document__id'):
  80. set_user_data[ind_obj['user__username']].add(ind_obj['document__id'])
  81. return {i: len(set_user_data[i]) for i in set_user_data}
  82. def progress(self, project):
  83. docs = project.documents
  84. annotation_class = project.get_annotation_class()
  85. total = docs.count()
  86. annotation_filter = Q(document_id__in=docs.all())
  87. user_data = self._get_user_completion_data(annotation_class, annotation_filter)
  88. if not project.collaborative_annotation:
  89. annotation_filter &= Q(user_id=self.request.user)
  90. done = annotation_class.objects.filter(annotation_filter)\
  91. .aggregate(Count('document', distinct=True))['document__count']
  92. remaining = total - done
  93. return {'total': total, 'remaining': remaining, 'user': user_data}
  94. def label_per_data(self, project):
  95. annotation_class = project.get_annotation_class()
  96. return annotation_class.objects.get_label_per_data(project=project)
  97. class ApproveLabelsAPI(APIView):
  98. permission_classes = [IsAuthenticated & (IsAnnotationApprover | IsProjectAdmin)]
  99. def post(self, request, *args, **kwargs):
  100. approved = self.request.data.get('approved', True)
  101. document = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  102. document.annotations_approved_by = self.request.user if approved else None
  103. document.save()
  104. return Response(DocumentSerializer(document).data)
  105. class LabelList(generics.ListCreateAPIView):
  106. serializer_class = LabelSerializer
  107. pagination_class = None
  108. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  109. def get_queryset(self):
  110. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  111. return project.labels
  112. def perform_create(self, serializer):
  113. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  114. serializer.save(project=project)
  115. class LabelDetail(generics.RetrieveUpdateDestroyAPIView):
  116. queryset = Label.objects.all()
  117. serializer_class = LabelSerializer
  118. lookup_url_kwarg = 'label_id'
  119. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  120. class DocumentList(generics.ListCreateAPIView):
  121. serializer_class = DocumentSerializer
  122. filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter)
  123. search_fields = ('text', )
  124. ordering_fields = ('created_at', 'updated_at', 'doc_annotations__updated_at',
  125. 'seq_annotations__updated_at', 'seq2seq_annotations__updated_at')
  126. filter_class = DocumentFilter
  127. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  128. def get_queryset(self):
  129. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  130. queryset = project.documents
  131. if project.randomize_document_order:
  132. queryset = queryset.annotate(sort_id=F('id') % self.request.user.id).order_by('sort_id')
  133. else:
  134. queryset = queryset.order_by('id')
  135. return queryset
  136. def perform_create(self, serializer):
  137. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  138. serializer.save(project=project)
  139. class DocumentDetail(generics.RetrieveUpdateDestroyAPIView):
  140. queryset = Document.objects.all()
  141. serializer_class = DocumentSerializer
  142. lookup_url_kwarg = 'doc_id'
  143. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  144. class AnnotationList(generics.ListCreateAPIView):
  145. pagination_class = None
  146. permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
  147. swagger_schema = None
  148. def get_serializer_class(self):
  149. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  150. self.serializer_class = project.get_annotation_serializer()
  151. return self.serializer_class
  152. def get_queryset(self):
  153. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  154. model = project.get_annotation_class()
  155. queryset = model.objects.filter(document=self.kwargs['doc_id'])
  156. if not project.collaborative_annotation:
  157. queryset = queryset.filter(user=self.request.user)
  158. return queryset
  159. def create(self, request, *args, **kwargs):
  160. self.check_single_class_classification(self.kwargs['project_id'], self.kwargs['doc_id'], request.user)
  161. request.data['document'] = self.kwargs['doc_id']
  162. return super().create(request, args, kwargs)
  163. def perform_create(self, serializer):
  164. serializer.save(document_id=self.kwargs['doc_id'], user=self.request.user)
  165. @staticmethod
  166. def check_single_class_classification(project_id, doc_id, user):
  167. project = get_object_or_404(Project, pk=project_id)
  168. if not project.single_class_classification:
  169. return
  170. model = project.get_annotation_class()
  171. annotations = model.objects.filter(document_id=doc_id)
  172. if not project.collaborative_annotation:
  173. annotations = annotations.filter(user=user)
  174. if annotations.exists():
  175. raise ValidationError('requested to create duplicate annotation for single-class-classification project')
  176. class AnnotationDetail(generics.RetrieveUpdateDestroyAPIView):
  177. lookup_url_kwarg = 'annotation_id'
  178. permission_classes = [IsAuthenticated & (((IsAnnotator & IsOwnAnnotation) | IsAnnotationApprover) | IsProjectAdmin)]
  179. swagger_schema = None
  180. def get_serializer_class(self):
  181. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  182. self.serializer_class = project.get_annotation_serializer()
  183. return self.serializer_class
  184. def get_queryset(self):
  185. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  186. model = project.get_annotation_class()
  187. self.queryset = model.objects.all()
  188. return self.queryset
  189. class TextUploadAPI(APIView):
  190. parser_classes = (MultiPartParser,)
  191. permission_classes = [IsAuthenticated & IsProjectAdmin]
  192. def post(self, request, *args, **kwargs):
  193. if 'file' not in request.data:
  194. raise ParseError('Empty content')
  195. self.save_file(
  196. user=request.user,
  197. file=request.data['file'],
  198. file_format=request.data['format'],
  199. project_id=kwargs['project_id'],
  200. )
  201. return Response(status=status.HTTP_201_CREATED)
  202. @classmethod
  203. def save_file(cls, user, file, file_format, project_id):
  204. project = get_object_or_404(Project, pk=project_id)
  205. parser = cls.select_parser(file_format)
  206. data = parser.parse(file)
  207. storage = project.get_storage(data)
  208. storage.save(user)
  209. @classmethod
  210. def select_parser(cls, file_format):
  211. if file_format == 'plain':
  212. return PlainTextParser()
  213. elif file_format == 'csv':
  214. return CSVParser()
  215. elif file_format == 'json':
  216. return JSONParser()
  217. elif file_format == 'conll':
  218. return CoNLLParser()
  219. elif file_format == 'excel':
  220. return ExcelParser()
  221. elif file_format == 'audio':
  222. return AudioParser()
  223. else:
  224. raise ValidationError('format {} is invalid.'.format(file_format))
  225. class CloudUploadAPI(APIView):
  226. permission_classes = TextUploadAPI.permission_classes
  227. def get(self, request, *args, **kwargs):
  228. try:
  229. project_id = request.query_params['project_id']
  230. file_format = request.query_params['upload_format']
  231. cloud_container = request.query_params['container']
  232. cloud_object = request.query_params['object']
  233. except KeyError as ex:
  234. raise ValidationError('query parameter {} is missing'.format(ex))
  235. try:
  236. cloud_file = self.get_cloud_object_as_io(cloud_container, cloud_object)
  237. except ContainerDoesNotExistError:
  238. raise ValidationError('cloud container {} does not exist'.format(cloud_container))
  239. except ObjectDoesNotExistError:
  240. raise ValidationError('cloud object {} does not exist'.format(cloud_object))
  241. TextUploadAPI.save_file(
  242. user=request.user,
  243. file=cloud_file,
  244. file_format=file_format,
  245. project_id=project_id,
  246. )
  247. next_url = request.query_params.get('next')
  248. if next_url == 'about:blank':
  249. return Response(data='', content_type='text/plain', status=status.HTTP_201_CREATED)
  250. if next_url:
  251. return redirect(next_url)
  252. return Response(status=status.HTTP_201_CREATED)
  253. @classmethod
  254. def get_cloud_object_as_io(cls, container_name, object_name):
  255. provider = settings.CLOUD_BROWSER_APACHE_LIBCLOUD_PROVIDER.lower()
  256. account = settings.CLOUD_BROWSER_APACHE_LIBCLOUD_ACCOUNT
  257. key = settings.CLOUD_BROWSER_APACHE_LIBCLOUD_SECRET_KEY
  258. driver = get_driver(DriverType.STORAGE, provider)
  259. client = driver(account, key)
  260. cloud_container = client.get_container(container_name)
  261. cloud_object = cloud_container.get_object(object_name)
  262. return iterable_to_io(cloud_object.as_stream())
  263. class TextDownloadAPI(APIView):
  264. permission_classes = TextUploadAPI.permission_classes
  265. renderer_classes = (CSVRenderer, JSONLRenderer)
  266. def get(self, request, *args, **kwargs):
  267. format = request.query_params.get('q')
  268. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  269. documents = project.documents.all()
  270. painter = self.select_painter(format)
  271. # json1 format prints text labels while json format prints annotations with label ids
  272. # json1 format - "labels": [[0, 15, "PERSON"], ..]
  273. # json format - "annotations": [{"label": 5, "start_offset": 0, "end_offset": 2, "user": 1},..]
  274. if format == "json1":
  275. labels = project.labels.all()
  276. data = JSONPainter.paint_labels(documents, labels)
  277. else:
  278. data = painter.paint(documents)
  279. return Response(data)
  280. def select_painter(self, format):
  281. if format == 'csv':
  282. return CSVPainter()
  283. elif format == 'json' or format == "json1":
  284. return JSONPainter()
  285. else:
  286. raise ValidationError('format {} is invalid.'.format(format))
  287. class Users(APIView):
  288. permission_classes = [IsAuthenticated & IsProjectAdmin]
  289. def get(self, request, *args, **kwargs):
  290. queryset = User.objects.all()
  291. serialized_data = UserSerializer(queryset, many=True).data
  292. return Response(serialized_data)
  293. class Roles(generics.ListCreateAPIView):
  294. serializer_class = RoleSerializer
  295. pagination_class = None
  296. permission_classes = [IsAuthenticated & IsProjectAdmin]
  297. queryset = Role.objects.all()
  298. class RoleMappingList(generics.ListCreateAPIView):
  299. serializer_class = RoleMappingSerializer
  300. pagination_class = None
  301. permission_classes = [IsAuthenticated & IsProjectAdmin]
  302. def get_queryset(self):
  303. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  304. return project.role_mappings
  305. def perform_create(self, serializer):
  306. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  307. serializer.save(project=project)
  308. class RoleMappingDetail(generics.RetrieveUpdateDestroyAPIView):
  309. queryset = RoleMapping.objects.all()
  310. serializer_class = RoleMappingSerializer
  311. lookup_url_kwarg = 'rolemapping_id'
  312. permission_classes = [IsAuthenticated & IsProjectAdmin]
  313. class LabelUploadAPI(APIView):
  314. parser_classes = (MultiPartParser,)
  315. permission_classes = [IsAuthenticated & IsProjectAdmin]
  316. @transaction.atomic
  317. def post(self, request, *args, **kwargs):
  318. if 'file' not in request.data:
  319. raise ParseError('Empty content')
  320. labels = json.load(request.data['file'])
  321. project = get_object_or_404(Project, pk=kwargs['project_id'])
  322. try:
  323. for label in labels:
  324. serializer = LabelSerializer(data=label)
  325. serializer.is_valid(raise_exception=True)
  326. serializer.save(project=project)
  327. return Response(status=status.HTTP_201_CREATED)
  328. except IntegrityError:
  329. content = {'error': 'IntegrityError: you cannot create a label with same name or shortkey.'}
  330. return Response(content, status=status.HTTP_400_BAD_REQUEST)