You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

413 lines
15 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
  1. import collections
  2. import json
  3. from django.conf import settings
  4. from django.contrib.auth.models import User
  5. from django.db import transaction
  6. from django.db.utils import IntegrityError
  7. from django.shortcuts import get_object_or_404, redirect
  8. from django_filters.rest_framework import DjangoFilterBackend
  9. from django.db.models import Count, F, Q
  10. from libcloud.base import DriverType, get_driver
  11. from libcloud.storage.types import ContainerDoesNotExistError, ObjectDoesNotExistError
  12. from rest_framework import generics, filters, status
  13. from rest_framework.exceptions import ParseError, ValidationError
  14. from rest_framework.permissions import IsAuthenticated, IsAuthenticatedOrReadOnly
  15. from rest_framework.response import Response
  16. from rest_framework.views import APIView
  17. from rest_framework.parsers import MultiPartParser
  18. from rest_framework_csv.renderers import CSVRenderer
  19. from .filters import DocumentFilter
  20. from .models import Project, Label, Document, RoleMapping, Role
  21. from .permissions import IsProjectAdmin, IsAnnotatorAndReadOnly, IsAnnotator, IsAnnotationApproverAndReadOnly, IsOwnAnnotation, IsAnnotationApprover
  22. from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer, UserSerializer
  23. from .serializers import ProjectPolymorphicSerializer, RoleMappingSerializer, RoleSerializer
  24. from .utils import CSVParser, ExcelParser, JSONParser, PlainTextParser, CoNLLParser, iterable_to_io
  25. from .utils import JSONLRenderer
  26. from .utils import JSONPainter, CSVPainter
  27. IsInProjectReadOnlyOrAdmin = (IsAnnotatorAndReadOnly | IsAnnotationApproverAndReadOnly | IsProjectAdmin)
  28. IsInProjectOrAdmin = (IsAnnotator | IsAnnotationApprover | IsProjectAdmin)
  29. class Health(APIView):
  30. permission_classes = (IsAuthenticatedOrReadOnly,)
  31. def get(self, request, *args, **kwargs):
  32. return Response({'status': 'green'})
  33. class Me(APIView):
  34. permission_classes = (IsAuthenticated,)
  35. def get(self, request, *args, **kwargs):
  36. serializer = UserSerializer(request.user, context={'request': request})
  37. return Response(serializer.data)
  38. class Features(APIView):
  39. permission_classes = (IsAuthenticated,)
  40. def get(self, request, *args, **kwargs):
  41. return Response({
  42. 'cloud_upload': bool(settings.CLOUD_BROWSER_APACHE_LIBCLOUD_PROVIDER),
  43. })
  44. class ProjectList(generics.ListCreateAPIView):
  45. serializer_class = ProjectPolymorphicSerializer
  46. pagination_class = None
  47. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  48. def get_queryset(self):
  49. return self.request.user.projects
  50. def perform_create(self, serializer):
  51. serializer.save(users=[self.request.user])
  52. class ProjectDetail(generics.RetrieveUpdateDestroyAPIView):
  53. queryset = Project.objects.all()
  54. serializer_class = ProjectSerializer
  55. lookup_url_kwarg = 'project_id'
  56. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  57. class StatisticsAPI(APIView):
  58. pagination_class = None
  59. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  60. def get(self, request, *args, **kwargs):
  61. p = get_object_or_404(Project, pk=self.kwargs['project_id'])
  62. include = set(request.GET.getlist('include'))
  63. response = {}
  64. if not include or 'label' in include:
  65. label_count, user_count = self.label_per_data(p)
  66. response['label'] = label_count
  67. # TODO: Make user_label count chart
  68. response['user_label'] = user_count
  69. if not include or 'total' in include or 'remaining' in include or 'user' in include:
  70. progress = self.progress(project=p)
  71. response.update(progress)
  72. if include:
  73. response = {key: value for (key, value) in response.items() if key in include}
  74. return Response(response)
  75. @staticmethod
  76. def _get_user_completion_data(annotation_class, annotation_filter):
  77. all_annotation_objects = annotation_class.objects.filter(annotation_filter)
  78. set_user_data = collections.defaultdict(set)
  79. for ind_obj in all_annotation_objects.values('user__username', 'document__id'):
  80. set_user_data[ind_obj['user__username']].add(ind_obj['document__id'])
  81. return {i: len(set_user_data[i]) for i in set_user_data}
  82. def progress(self, project):
  83. docs = project.documents
  84. annotation_class = project.get_annotation_class()
  85. total = docs.count()
  86. annotation_filter = Q(document_id__in=docs.all())
  87. user_data = self._get_user_completion_data(annotation_class, annotation_filter)
  88. if not project.collaborative_annotation:
  89. annotation_filter &= Q(user_id=self.request.user)
  90. done = annotation_class.objects.filter(annotation_filter)\
  91. .aggregate(Count('document', distinct=True))['document__count']
  92. remaining = total - done
  93. return {'total': total, 'remaining': remaining, 'user': user_data}
  94. def label_per_data(self, project):
  95. annotation_class = project.get_annotation_class()
  96. return annotation_class.objects.get_label_per_data(project=project)
  97. class ApproveLabelsAPI(APIView):
  98. permission_classes = [IsAuthenticated & (IsAnnotationApprover | IsProjectAdmin)]
  99. def post(self, request, *args, **kwargs):
  100. approved = self.request.data.get('approved', True)
  101. document = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  102. document.annotations_approved_by = self.request.user if approved else None
  103. document.save()
  104. return Response(DocumentSerializer(document).data)
  105. class LabelList(generics.ListCreateAPIView):
  106. serializer_class = LabelSerializer
  107. pagination_class = None
  108. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  109. def get_queryset(self):
  110. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  111. return project.labels
  112. def perform_create(self, serializer):
  113. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  114. serializer.save(project=project)
  115. class LabelDetail(generics.RetrieveUpdateDestroyAPIView):
  116. queryset = Label.objects.all()
  117. serializer_class = LabelSerializer
  118. lookup_url_kwarg = 'label_id'
  119. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  120. class DocumentList(generics.ListCreateAPIView):
  121. serializer_class = DocumentSerializer
  122. filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter)
  123. search_fields = ('text', )
  124. ordering_fields = ('created_at', 'updated_at', 'doc_annotations__updated_at',
  125. 'seq_annotations__updated_at', 'seq2seq_annotations__updated_at')
  126. filter_class = DocumentFilter
  127. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  128. def get_queryset(self):
  129. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  130. queryset = project.documents
  131. if project.randomize_document_order:
  132. queryset = queryset.annotate(sort_id=F('id') % self.request.user.id).order_by('sort_id')
  133. else:
  134. queryset = queryset.order_by('id')
  135. return queryset
  136. def perform_create(self, serializer):
  137. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  138. serializer.save(project=project)
  139. class DocumentDetail(generics.RetrieveUpdateDestroyAPIView):
  140. queryset = Document.objects.all()
  141. serializer_class = DocumentSerializer
  142. lookup_url_kwarg = 'doc_id'
  143. permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
  144. class AnnotationList(generics.ListCreateAPIView):
  145. pagination_class = None
  146. permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
  147. swagger_schema = None
  148. def get_serializer_class(self):
  149. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  150. self.serializer_class = project.get_annotation_serializer()
  151. return self.serializer_class
  152. def get_queryset(self):
  153. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  154. model = project.get_annotation_class()
  155. queryset = model.objects.filter(document=self.kwargs['doc_id'])
  156. if not project.collaborative_annotation:
  157. queryset = queryset.filter(user=self.request.user)
  158. return queryset
  159. def create(self, request, *args, **kwargs):
  160. request.data['document'] = self.kwargs['doc_id']
  161. return super().create(request, args, kwargs)
  162. def perform_create(self, serializer):
  163. serializer.save(document_id=self.kwargs['doc_id'], user=self.request.user)
  164. class AnnotationDetail(generics.RetrieveUpdateDestroyAPIView):
  165. lookup_url_kwarg = 'annotation_id'
  166. permission_classes = [IsAuthenticated & (((IsAnnotator & IsOwnAnnotation) | IsAnnotationApprover) | IsProjectAdmin)]
  167. swagger_schema = None
  168. def get_serializer_class(self):
  169. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  170. self.serializer_class = project.get_annotation_serializer()
  171. return self.serializer_class
  172. def get_queryset(self):
  173. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  174. model = project.get_annotation_class()
  175. self.queryset = model.objects.all()
  176. return self.queryset
  177. class TextUploadAPI(APIView):
  178. parser_classes = (MultiPartParser,)
  179. permission_classes = [IsAuthenticated & IsProjectAdmin]
  180. def post(self, request, *args, **kwargs):
  181. if 'file' not in request.data:
  182. raise ParseError('Empty content')
  183. self.save_file(
  184. user=request.user,
  185. file=request.data['file'],
  186. file_format=request.data['format'],
  187. project_id=kwargs['project_id'],
  188. )
  189. return Response(status=status.HTTP_201_CREATED)
  190. @classmethod
  191. def save_file(cls, user, file, file_format, project_id):
  192. project = get_object_or_404(Project, pk=project_id)
  193. parser = cls.select_parser(file_format)
  194. data = parser.parse(file)
  195. storage = project.get_storage(data)
  196. storage.save(user)
  197. @classmethod
  198. def select_parser(cls, file_format):
  199. if file_format == 'plain':
  200. return PlainTextParser()
  201. elif file_format == 'csv':
  202. return CSVParser()
  203. elif file_format == 'json':
  204. return JSONParser()
  205. elif file_format == 'conll':
  206. return CoNLLParser()
  207. elif file_format == 'excel':
  208. return ExcelParser()
  209. else:
  210. raise ValidationError('format {} is invalid.'.format(file_format))
  211. class CloudUploadAPI(APIView):
  212. permission_classes = TextUploadAPI.permission_classes
  213. def get(self, request, *args, **kwargs):
  214. try:
  215. project_id = request.query_params['project_id']
  216. file_format = request.query_params['upload_format']
  217. cloud_container = request.query_params['container']
  218. cloud_object = request.query_params['object']
  219. except KeyError as ex:
  220. raise ValidationError('query parameter {} is missing'.format(ex))
  221. try:
  222. cloud_file = self.get_cloud_object_as_io(cloud_container, cloud_object)
  223. except ContainerDoesNotExistError:
  224. raise ValidationError('cloud container {} does not exist'.format(cloud_container))
  225. except ObjectDoesNotExistError:
  226. raise ValidationError('cloud object {} does not exist'.format(cloud_object))
  227. TextUploadAPI.save_file(
  228. user=request.user,
  229. file=cloud_file,
  230. file_format=file_format,
  231. project_id=project_id,
  232. )
  233. next_url = request.query_params.get('next')
  234. if next_url == 'about:blank':
  235. return Response(data='', content_type='text/plain', status=status.HTTP_201_CREATED)
  236. if next_url:
  237. return redirect(next_url)
  238. return Response(status=status.HTTP_201_CREATED)
  239. @classmethod
  240. def get_cloud_object_as_io(cls, container_name, object_name):
  241. provider = settings.CLOUD_BROWSER_APACHE_LIBCLOUD_PROVIDER.lower()
  242. account = settings.CLOUD_BROWSER_APACHE_LIBCLOUD_ACCOUNT
  243. key = settings.CLOUD_BROWSER_APACHE_LIBCLOUD_SECRET_KEY
  244. driver = get_driver(DriverType.STORAGE, provider)
  245. client = driver(account, key)
  246. cloud_container = client.get_container(container_name)
  247. cloud_object = cloud_container.get_object(object_name)
  248. return iterable_to_io(cloud_object.as_stream())
  249. class TextDownloadAPI(APIView):
  250. permission_classes = TextUploadAPI.permission_classes
  251. renderer_classes = (CSVRenderer, JSONLRenderer)
  252. def get(self, request, *args, **kwargs):
  253. format = request.query_params.get('q')
  254. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  255. documents = project.documents.all()
  256. painter = self.select_painter(format)
  257. # json1 format prints text labels while json format prints annotations with label ids
  258. # json1 format - "labels": [[0, 15, "PERSON"], ..]
  259. # json format - "annotations": [{"label": 5, "start_offset": 0, "end_offset": 2, "user": 1},..]
  260. if format == "json1":
  261. labels = project.labels.all()
  262. data = JSONPainter.paint_labels(documents, labels)
  263. else:
  264. data = painter.paint(documents)
  265. return Response(data)
  266. def select_painter(self, format):
  267. if format == 'csv':
  268. return CSVPainter()
  269. elif format == 'json' or format == "json1":
  270. return JSONPainter()
  271. else:
  272. raise ValidationError('format {} is invalid.'.format(format))
  273. class Users(APIView):
  274. permission_classes = [IsAuthenticated & IsProjectAdmin]
  275. def get(self, request, *args, **kwargs):
  276. queryset = User.objects.all()
  277. serialized_data = UserSerializer(queryset, many=True).data
  278. return Response(serialized_data)
  279. class Roles(generics.ListCreateAPIView):
  280. serializer_class = RoleSerializer
  281. pagination_class = None
  282. permission_classes = [IsAuthenticated & IsProjectAdmin]
  283. queryset = Role.objects.all()
  284. class RoleMappingList(generics.ListCreateAPIView):
  285. serializer_class = RoleMappingSerializer
  286. pagination_class = None
  287. permission_classes = [IsAuthenticated & IsProjectAdmin]
  288. def get_queryset(self):
  289. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  290. return project.role_mappings
  291. def perform_create(self, serializer):
  292. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  293. serializer.save(project=project)
  294. class RoleMappingDetail(generics.RetrieveUpdateDestroyAPIView):
  295. queryset = RoleMapping.objects.all()
  296. serializer_class = RoleMappingSerializer
  297. lookup_url_kwarg = 'rolemapping_id'
  298. permission_classes = [IsAuthenticated & IsProjectAdmin]
  299. class LabelUploadAPI(APIView):
  300. parser_classes = (MultiPartParser,)
  301. permission_classes = [IsAuthenticated & IsProjectAdmin]
  302. @transaction.atomic
  303. def post(self, request, *args, **kwargs):
  304. if 'file' not in request.data:
  305. raise ParseError('Empty content')
  306. labels = json.load(request.data['file'])
  307. project = get_object_or_404(Project, pk=kwargs['project_id'])
  308. try:
  309. for label in labels:
  310. serializer = LabelSerializer(data=label)
  311. serializer.is_valid(raise_exception=True)
  312. serializer.save(project=project)
  313. return Response(status=status.HTTP_201_CREATED)
  314. except IntegrityError:
  315. content = {'error': 'IntegrityError: you cannot create a label with same name or shortkey.'}
  316. return Response(content, status=status.HTTP_400_BAD_REQUEST)