You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

306 lines
11 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import json
  2. import csv
  3. from io import TextIOWrapper
  4. from django.urls import reverse
  5. from django_filters.rest_framework import DjangoFilterBackend
  6. from django.http import JsonResponse, HttpResponse, HttpResponseRedirect
  7. from django.shortcuts import render, get_object_or_404
  8. from django.views import View
  9. from django.views.generic import TemplateView
  10. from django.views.generic.list import ListView
  11. from django.views.generic.detail import DetailView
  12. from django.contrib.auth.mixins import LoginRequiredMixin
  13. from rest_framework import viewsets, filters, generics
  14. from rest_framework.decorators import action
  15. from rest_framework.response import Response
  16. from rest_framework.permissions import SAFE_METHODS, BasePermission, IsAdminUser, IsAuthenticated
  17. from .models import Label, Document, Project
  18. from .models import DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation
  19. from .serializers import LabelSerializer, ProjectSerializer
  20. class IndexView(TemplateView):
  21. template_name = 'index.html'
  22. class ProjectView(LoginRequiredMixin, TemplateView):
  23. template_name = 'annotation.html'
  24. def get_context_data(self, **kwargs):
  25. context = super().get_context_data(**kwargs)
  26. project_id = kwargs.get('project_id')
  27. project = get_object_or_404(Project, pk=project_id)
  28. self.template_name = project.get_template()
  29. return context
  30. class ProjectAdminView(LoginRequiredMixin, DetailView):
  31. model = Project
  32. template_name = 'project_admin.html'
  33. class ProjectsView(LoginRequiredMixin, ListView):
  34. model = Project
  35. paginate_by = 100
  36. template_name = 'projects.html'
  37. class DatasetView(LoginRequiredMixin, ListView):
  38. template_name = 'admin/dataset.html'
  39. context_object_name = 'documents'
  40. paginate_by = 5
  41. def get_queryset(self):
  42. project_id = self.kwargs['pk']
  43. project = get_object_or_404(Project, pk=project_id)
  44. return project.documents.all()
  45. class DatasetUpload(LoginRequiredMixin, View):
  46. model = Project
  47. def get(self, request, *args, **kwargs):
  48. return render(request, 'admin/dataset_upload.html')
  49. def post(self, request, *args, **kwargs):
  50. project = get_object_or_404(Project, pk=kwargs.get('pk'))
  51. try:
  52. form_data = TextIOWrapper(request.FILES['csv_file'].file, encoding='utf-8')
  53. reader = csv.reader(form_data)
  54. for line in reader:
  55. text = line[0]
  56. Document(text=text, project=project).save()
  57. return HttpResponseRedirect(reverse('dataset', args=[project.id]))
  58. except:
  59. print("failed")
  60. return HttpResponseRedirect(reverse('dataset-upload', args=[project.id]))
  61. class RawDataAPI(View):
  62. def post(self, request, *args, **kwargs):
  63. """Upload data."""
  64. f = request.FILES['file']
  65. content = ''.join(chunk.decode('utf-8') for chunk in f.chunks())
  66. for line in content.split('\n'):
  67. j = json.loads(line)
  68. Document(text=j['text']).save()
  69. return JsonResponse({'status': 'ok'})
  70. class DataDownload(View):
  71. def get(self, request, *args, **kwargs):
  72. project_id = self.kwargs['project_id']
  73. project = get_object_or_404(Project, pk=project_id)
  74. docs = project.get_documents(is_null=False).distinct()
  75. response = HttpResponse(content_type='text/csv')
  76. response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(project.name)
  77. writer = csv.writer(response)
  78. for d in docs:
  79. writer.writerows(d.make_dataset())
  80. return response
  81. class IsProjectUser(BasePermission):
  82. def has_permission(self, request, view):
  83. user = request.user
  84. project_id = view.kwargs.get('project_id')
  85. project = get_object_or_404(Project, pk=project_id)
  86. return user in project.users.all()
  87. class IsAdminUserAndWriteOnly(BasePermission):
  88. def has_permission(self, request, view):
  89. if request.method in SAFE_METHODS:
  90. return True
  91. return IsAdminUser().has_permission(request, view)
  92. class IsOwnAnnotation(BasePermission):
  93. def has_permission(self, request, view):
  94. user = request.user
  95. project_id = view.kwargs.get('project_id')
  96. annotation_id = view.kwargs.get('annotation_id')
  97. project = get_object_or_404(Project, pk=project_id)
  98. Annotation = project.get_annotation_class()
  99. annotation = Annotation.objects.get(id=annotation_id)
  100. return annotation.user == user
  101. class ProjectViewSet(viewsets.ModelViewSet):
  102. queryset = Project.objects.all()
  103. serializer_class = ProjectSerializer
  104. pagination_class = None
  105. permission_classes = (IsAuthenticated, IsAdminUserAndWriteOnly)
  106. def get_queryset(self):
  107. user = self.request.user
  108. queryset = self.queryset.filter(users__id__contains=user.id)
  109. return queryset
  110. @action(methods=['get'], detail=True)
  111. def progress(self, request, pk=None):
  112. project = self.get_object()
  113. docs = project.get_documents(is_null=True)
  114. total = project.documents.count()
  115. remaining = docs.count()
  116. return Response({'total': total, 'remaining': remaining})
  117. class ProjectLabelsAPI(generics.ListCreateAPIView):
  118. queryset = Label.objects.all()
  119. serializer_class = LabelSerializer
  120. pagination_class = None
  121. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  122. def get_queryset(self):
  123. project_id = self.kwargs['project_id']
  124. queryset = self.queryset.filter(project=project_id)
  125. return queryset
  126. def perform_create(self, serializer):
  127. project_id = self.kwargs['project_id']
  128. project = get_object_or_404(Project, pk=project_id)
  129. serializer.save(project=project)
  130. class ProjectLabelAPI(generics.RetrieveUpdateDestroyAPIView):
  131. queryset = Label.objects.all()
  132. serializer_class = LabelSerializer
  133. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUser)
  134. def get_queryset(self):
  135. project_id = self.kwargs['project_id']
  136. queryset = self.queryset.filter(project=project_id)
  137. return queryset
  138. def get_object(self):
  139. label_id = self.kwargs['label_id']
  140. queryset = self.filter_queryset(self.get_queryset())
  141. obj = get_object_or_404(queryset, pk=label_id)
  142. self.check_object_permissions(self.request, obj)
  143. return obj
  144. class ProjectDocsAPI(generics.ListCreateAPIView):
  145. queryset = Document.objects.all()
  146. filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter)
  147. search_fields = ('text', )
  148. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  149. def get_serializer_class(self):
  150. project_id = self.kwargs['project_id']
  151. project = get_object_or_404(Project, pk=project_id)
  152. self.serializer_class = project.get_project_serializer()
  153. return self.serializer_class
  154. def get_queryset(self):
  155. project_id = self.kwargs['project_id']
  156. queryset = self.queryset.filter(project=project_id)
  157. if not self.request.query_params.get('is_checked'):
  158. return queryset
  159. project = get_object_or_404(Project, pk=project_id)
  160. is_null = self.request.query_params.get('is_checked') == 'true'
  161. queryset = project.get_documents(is_null).distinct()
  162. return queryset
  163. class AnnotationsAPI(generics.ListCreateAPIView):
  164. pagination_class = None
  165. permission_classes = (IsAuthenticated, IsProjectUser)
  166. def get_serializer_class(self):
  167. project_id = self.kwargs['project_id']
  168. project = get_object_or_404(Project, pk=project_id)
  169. self.serializer_class = project.get_annotation_serializer()
  170. return self.serializer_class
  171. def get_queryset(self):
  172. project_id = self.kwargs['project_id']
  173. project = get_object_or_404(Project, pk=project_id)
  174. doc_id = self.kwargs['doc_id']
  175. document = get_object_or_404(Document, pk=doc_id, project=project)
  176. self.queryset = document.get_annotations()
  177. return self.queryset
  178. def post(self, request, *args, **kwargs):
  179. doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  180. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  181. self.serializer_class = project.get_annotation_serializer()
  182. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  183. label = get_object_or_404(Label, pk=request.data['label_id'])
  184. annotation = DocumentAnnotation(document=doc, label=label, manual=True,
  185. user=self.request.user)
  186. elif project.is_type_of(Project.SEQUENCE_LABELING):
  187. label = get_object_or_404(Label, pk=request.data['label_id'])
  188. annotation = SequenceAnnotation(document=doc, label=label, manual=True,
  189. user=self.request.user,
  190. start_offset=request.data['start_offset'],
  191. end_offset=request.data['end_offset'])
  192. elif project.is_type_of(Project.Seq2seq):
  193. text = request.data['text']
  194. annotation = Seq2seqAnnotation(document=doc,
  195. text=text,
  196. manual=True,
  197. user=self.request.user)
  198. annotation.save()
  199. serializer = self.serializer_class(annotation)
  200. return Response(serializer.data)
  201. class AnnotationAPI(generics.RetrieveUpdateDestroyAPIView):
  202. permission_classes = (IsAuthenticated, IsProjectUser, IsOwnAnnotation)
  203. def get_queryset(self):
  204. doc_id = self.kwargs['doc_id']
  205. document = get_object_or_404(Document, pk=doc_id)
  206. self.queryset = document.get_annotations()
  207. return self.queryset
  208. def get_object(self):
  209. annotation_id = self.kwargs['annotation_id']
  210. queryset = self.filter_queryset(self.get_queryset())
  211. obj = get_object_or_404(queryset, pk=annotation_id)
  212. self.check_object_permissions(self.request, obj)
  213. return obj
  214. def put(self, request, *args, **kwargs):
  215. doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  216. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  217. self.serializer_class = project.get_annotation_serializer()
  218. if project.is_type_of(Project.Seq2seq):
  219. text = request.data['text']
  220. annotation = get_object_or_404(Seq2seqAnnotation, pk=request.data['id'])
  221. annotation.text = text
  222. annotation.save()
  223. serializer = self.serializer_class(annotation)
  224. return Response(serializer.data)