You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

347 lines
12 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import json
  2. import csv
  3. from itertools import chain
  4. from collections import Counter
  5. from io import TextIOWrapper
  6. from django import forms
  7. from django.contrib.auth.mixins import UserPassesTestMixin
  8. from django.urls import reverse
  9. from django_filters.rest_framework import DjangoFilterBackend
  10. from django.http import JsonResponse, HttpResponse, HttpResponseRedirect
  11. from django.shortcuts import render, get_object_or_404
  12. from django.views import View
  13. from django.views.generic import TemplateView
  14. from django.views.generic.list import ListView
  15. from django.views.generic.detail import DetailView
  16. from django.contrib.auth.mixins import LoginRequiredMixin
  17. from rest_framework import viewsets, filters, generics
  18. from rest_framework.views import APIView
  19. from rest_framework.decorators import action
  20. from rest_framework.response import Response
  21. from rest_framework.permissions import SAFE_METHODS, BasePermission, IsAdminUser, IsAuthenticated
  22. from .models import Label, Document, Project
  23. from .models import DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation
  24. from .serializers import LabelSerializer, ProjectSerializer
  25. class SuperUserMixin(UserPassesTestMixin):
  26. def test_func(self):
  27. return self.request.user.is_superuser
  28. class IndexView(TemplateView):
  29. template_name = 'index.html'
  30. class ProjectView(LoginRequiredMixin, TemplateView):
  31. template_name = 'annotation.html'
  32. def get_context_data(self, **kwargs):
  33. context = super().get_context_data(**kwargs)
  34. project_id = kwargs.get('project_id')
  35. project = get_object_or_404(Project, pk=project_id)
  36. self.template_name = project.get_template()
  37. return context
  38. class ProjectForm(forms.ModelForm):
  39. class Meta:
  40. model = Project
  41. fields = ('name', 'description', 'project_type', 'users')
  42. class ProjectsView(LoginRequiredMixin, TemplateView):
  43. model = Project
  44. paginate_by = 100
  45. template_name = 'projects.html'
  46. def get(self, request, *args, **kwargs):
  47. form = ProjectForm()
  48. return render(request, self.template_name, {'form': form})
  49. def post(self, request, *args, **kwargs):
  50. form = ProjectForm(request.POST)
  51. if form.is_valid():
  52. project = form.save()
  53. return HttpResponseRedirect(reverse('upload', args=[project.id]))
  54. else:
  55. return render(request, self.template_name, {'form': form})
  56. class DatasetView(SuperUserMixin, LoginRequiredMixin, ListView):
  57. template_name = 'admin/dataset.html'
  58. context_object_name = 'documents'
  59. paginate_by = 5
  60. def get_queryset(self):
  61. project_id = self.kwargs['project_id']
  62. project = get_object_or_404(Project, pk=project_id)
  63. return project.documents.all()
  64. class LabelView(SuperUserMixin, LoginRequiredMixin, TemplateView):
  65. template_name = 'admin/label.html'
  66. class StatsView(SuperUserMixin, LoginRequiredMixin, TemplateView):
  67. template_name = 'admin/stats.html'
  68. class DatasetUpload(SuperUserMixin, LoginRequiredMixin, TemplateView):
  69. model = Project
  70. template_name = 'admin/dataset_upload.html'
  71. def post(self, request, *args, **kwargs):
  72. project = get_object_or_404(Project, pk=kwargs.get('project_id'))
  73. try:
  74. form_data = TextIOWrapper(request.FILES['csv_file'].file, encoding='utf-8')
  75. reader = csv.reader(form_data)
  76. for line in reader:
  77. text = line[0]
  78. Document(text=text, project=project).save()
  79. return HttpResponseRedirect(reverse('dataset', args=[project.id]))
  80. except:
  81. print("failed")
  82. return HttpResponseRedirect(reverse('dataset-upload', args=[project.id]))
  83. class DataDownload(SuperUserMixin, View):
  84. def get(self, request, *args, **kwargs):
  85. project_id = self.kwargs['project_id']
  86. project = get_object_or_404(Project, pk=project_id)
  87. docs = project.get_documents(is_null=False).distinct()
  88. filename = '_'.join(project.name.lower().split())
  89. response = HttpResponse(content_type='text/csv')
  90. response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename)
  91. writer = csv.writer(response)
  92. for d in docs:
  93. writer.writerows(d.make_dataset())
  94. return response
  95. class IsProjectUser(BasePermission):
  96. def has_permission(self, request, view):
  97. user = request.user
  98. project_id = view.kwargs.get('project_id')
  99. project = get_object_or_404(Project, pk=project_id)
  100. return user in project.users.all()
  101. class IsAdminUserAndWriteOnly(BasePermission):
  102. def has_permission(self, request, view):
  103. if request.method in SAFE_METHODS:
  104. return True
  105. return IsAdminUser().has_permission(request, view)
  106. class IsOwnAnnotation(BasePermission):
  107. def has_permission(self, request, view):
  108. user = request.user
  109. project_id = view.kwargs.get('project_id')
  110. annotation_id = view.kwargs.get('annotation_id')
  111. project = get_object_or_404(Project, pk=project_id)
  112. Annotation = project.get_annotation_class()
  113. annotation = Annotation.objects.get(id=annotation_id)
  114. return annotation.user == user
  115. class ProjectViewSet(viewsets.ModelViewSet):
  116. queryset = Project.objects.all()
  117. serializer_class = ProjectSerializer
  118. pagination_class = None
  119. permission_classes = (IsAuthenticated, IsAdminUserAndWriteOnly)
  120. def get_queryset(self):
  121. user = self.request.user
  122. queryset = self.queryset.filter(users__id__contains=user.id)
  123. return queryset
  124. @action(methods=['get'], detail=True)
  125. def progress(self, request, pk=None):
  126. project = self.get_object()
  127. docs = project.get_documents(is_null=True)
  128. total = project.documents.count()
  129. remaining = docs.count()
  130. return Response({'total': total, 'remaining': remaining})
  131. class ProjectLabelsAPI(generics.ListCreateAPIView):
  132. queryset = Label.objects.all()
  133. serializer_class = LabelSerializer
  134. pagination_class = None
  135. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  136. def get_queryset(self):
  137. project_id = self.kwargs['project_id']
  138. queryset = self.queryset.filter(project=project_id)
  139. return queryset
  140. def perform_create(self, serializer):
  141. project_id = self.kwargs['project_id']
  142. project = get_object_or_404(Project, pk=project_id)
  143. serializer.save(project=project)
  144. class ProjectStatsAPI(APIView):
  145. pagination_class = None
  146. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  147. def get(self, request, *args, **kwargs):
  148. p = get_object_or_404(Project, pk=self.kwargs['project_id'])
  149. labels = [label.text for label in p.labels.all()]
  150. users = [user.username for user in p.users.all()]
  151. docs = [doc for doc in p.documents.all()]
  152. nested_labels = [[a.label.text for a in doc.get_annotations()] for doc in docs]
  153. nested_users = [[a.user.username for a in doc.get_annotations()] for doc in docs]
  154. label_count = Counter(chain(*nested_labels))
  155. label_data = [label_count[name] for name in labels]
  156. user_count = Counter(chain(*nested_users))
  157. user_data = [user_count[name] for name in users]
  158. response = {'label': {'labels': labels, 'data': label_data},
  159. 'user': {'users': users, 'data': user_data}}
  160. return Response(response)
  161. class ProjectLabelAPI(generics.RetrieveUpdateDestroyAPIView):
  162. queryset = Label.objects.all()
  163. serializer_class = LabelSerializer
  164. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUser)
  165. def get_queryset(self):
  166. project_id = self.kwargs['project_id']
  167. queryset = self.queryset.filter(project=project_id)
  168. return queryset
  169. def get_object(self):
  170. label_id = self.kwargs['label_id']
  171. queryset = self.filter_queryset(self.get_queryset())
  172. obj = get_object_or_404(queryset, pk=label_id)
  173. self.check_object_permissions(self.request, obj)
  174. return obj
  175. class ProjectDocsAPI(generics.ListCreateAPIView):
  176. queryset = Document.objects.all()
  177. filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter)
  178. search_fields = ('text', )
  179. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  180. def get_serializer_class(self):
  181. project_id = self.kwargs['project_id']
  182. project = get_object_or_404(Project, pk=project_id)
  183. self.serializer_class = project.get_project_serializer()
  184. return self.serializer_class
  185. def get_queryset(self):
  186. project_id = self.kwargs['project_id']
  187. queryset = self.queryset.filter(project=project_id)
  188. if not self.request.query_params.get('is_checked'):
  189. return queryset
  190. project = get_object_or_404(Project, pk=project_id)
  191. is_null = self.request.query_params.get('is_checked') == 'true'
  192. queryset = project.get_documents(is_null).distinct()
  193. return queryset
  194. class AnnotationsAPI(generics.ListCreateAPIView):
  195. pagination_class = None
  196. permission_classes = (IsAuthenticated, IsProjectUser)
  197. def get_serializer_class(self):
  198. project_id = self.kwargs['project_id']
  199. project = get_object_or_404(Project, pk=project_id)
  200. self.serializer_class = project.get_annotation_serializer()
  201. return self.serializer_class
  202. def get_queryset(self):
  203. project_id = self.kwargs['project_id']
  204. project = get_object_or_404(Project, pk=project_id)
  205. doc_id = self.kwargs['doc_id']
  206. document = get_object_or_404(Document, pk=doc_id, project=project)
  207. self.queryset = document.get_annotations()
  208. return self.queryset
  209. def post(self, request, *args, **kwargs):
  210. doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  211. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  212. self.serializer_class = project.get_annotation_serializer()
  213. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  214. label = get_object_or_404(Label, pk=request.data['label_id'])
  215. annotation = DocumentAnnotation(document=doc, label=label, manual=True,
  216. user=self.request.user)
  217. elif project.is_type_of(Project.SEQUENCE_LABELING):
  218. label = get_object_or_404(Label, pk=request.data['label_id'])
  219. annotation = SequenceAnnotation(document=doc, label=label, manual=True,
  220. user=self.request.user,
  221. start_offset=request.data['start_offset'],
  222. end_offset=request.data['end_offset'])
  223. elif project.is_type_of(Project.Seq2seq):
  224. text = request.data['text']
  225. annotation = Seq2seqAnnotation(document=doc,
  226. text=text,
  227. manual=True,
  228. user=self.request.user)
  229. annotation.save()
  230. serializer = self.serializer_class(annotation)
  231. return Response(serializer.data)
  232. class AnnotationAPI(generics.RetrieveUpdateDestroyAPIView):
  233. permission_classes = (IsAuthenticated, IsProjectUser, IsOwnAnnotation)
  234. def get_queryset(self):
  235. doc_id = self.kwargs['doc_id']
  236. document = get_object_or_404(Document, pk=doc_id)
  237. self.queryset = document.get_annotations()
  238. return self.queryset
  239. def get_object(self):
  240. annotation_id = self.kwargs['annotation_id']
  241. queryset = self.filter_queryset(self.get_queryset())
  242. obj = get_object_or_404(queryset, pk=annotation_id)
  243. self.check_object_permissions(self.request, obj)
  244. return obj
  245. def put(self, request, *args, **kwargs):
  246. doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  247. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  248. self.serializer_class = project.get_annotation_serializer()
  249. if project.is_type_of(Project.Seq2seq):
  250. text = request.data['text']
  251. annotation = get_object_or_404(Seq2seqAnnotation, pk=request.data['id'])
  252. annotation.text = text
  253. annotation.save()
  254. serializer = self.serializer_class(annotation)
  255. return Response(serializer.data)