You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
9.6 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import json
  2. from django_filters.rest_framework import DjangoFilterBackend
  3. from django.http import JsonResponse, HttpResponse
  4. from django.shortcuts import render, get_object_or_404
  5. from django.views import View
  6. from django.views.generic import TemplateView
  7. from django.views.generic.list import ListView
  8. from django.views.generic.detail import DetailView
  9. from django.contrib.auth.mixins import LoginRequiredMixin
  10. from django.core.paginator import Paginator
  11. from rest_framework import viewsets, filters, generics
  12. from rest_framework.decorators import action
  13. from rest_framework.response import Response
  14. from rest_framework.permissions import IsAdminUser
  15. from django.db.models.query import QuerySet
  16. from .models import Label, Document, Project
  17. from .models import DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation
  18. from .serializers import LabelSerializer, ProjectSerializer, DocumentSerializer, DocumentAnnotationSerializer
  19. from .serializers import SequenceSerializer, SequenceAnnotationSerializer
  20. from .serializers import Seq2seqSerializer, Seq2seqAnnotationSerializer
  21. class IndexView(TemplateView):
  22. template_name = 'index.html'
  23. class ProjectView(LoginRequiredMixin, TemplateView):
  24. template_name = 'annotation.html'
  25. def get_context_data(self, **kwargs):
  26. context = super().get_context_data(**kwargs)
  27. project_id = kwargs.get('project_id')
  28. project = get_object_or_404(Project, pk=project_id)
  29. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  30. self.template_name = 'annotation/document_classification.html'
  31. elif project.is_type_of(Project.SEQUENCE_LABELING):
  32. self.template_name = 'annotation/sequence_labeling.html'
  33. elif project.is_type_of(Project.Seq2seq):
  34. self.template_name = 'annotation/seq2seq.html'
  35. else:
  36. pass
  37. return context
  38. class ProjectsView(LoginRequiredMixin, ListView):
  39. model = Project
  40. paginate_by = 100
  41. template_name = 'projects.html'
  42. class ProjectAdminView(LoginRequiredMixin, DetailView):
  43. model = Project
  44. template_name = 'project_admin.html'
  45. class RawDataAPI(View):
  46. def post(self, request, *args, **kwargs):
  47. """Upload data."""
  48. f = request.FILES['file']
  49. content = ''.join(chunk.decode('utf-8') for chunk in f.chunks())
  50. for line in content.split('\n'):
  51. j = json.loads(line)
  52. Document(text=j['text']).save()
  53. return JsonResponse({'status': 'ok'})
  54. class DataDownloadAPI(View):
  55. def get(self, request, *args, **kwargs):
  56. annotated_docs = [a.as_dict() for a in Annotation.objects.filter(manual=True)]
  57. json_str = json.dumps(annotated_docs)
  58. response = HttpResponse(json_str, content_type='application/json')
  59. response['Content-Disposition'] = 'attachment; filename=annotation_data.json'
  60. return response
  61. class ProjectViewSet(viewsets.ModelViewSet):
  62. queryset = Project.objects.all()
  63. serializer_class = ProjectSerializer
  64. @action(methods=['get'], detail=True)
  65. def progress(self, request, pk=None):
  66. project = self.get_object()
  67. docs = project.documents.all()
  68. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  69. remaining = docs.filter(doc_annotations__isnull=True).count()
  70. elif project.is_type_of(Project.SEQUENCE_LABELING):
  71. remaining = docs.filter(seq_annotations__isnull=True).count()
  72. elif project.is_type_of(Project.Seq2seq):
  73. remaining = docs.filter(seq2seq_annotations__isnull=True).count()
  74. else:
  75. remaining = 0
  76. return Response({'total': docs.count(), 'remaining': remaining})
  77. class ProjectLabelsAPI(generics.ListCreateAPIView):
  78. queryset = Label.objects.all()
  79. serializer_class = LabelSerializer
  80. pagination_class = None
  81. def get_queryset(self):
  82. project_id = self.kwargs['project_id']
  83. queryset = self.queryset.filter(project=project_id)
  84. return queryset
  85. def perform_create(self, serializer):
  86. project_id = self.kwargs['project_id']
  87. project = get_object_or_404(Project, pk=project_id)
  88. serializer.save(project=project)
  89. class ProjectLabelAPI(generics.RetrieveUpdateDestroyAPIView):
  90. queryset = Label.objects.all()
  91. serializer_class = LabelSerializer
  92. def get_queryset(self):
  93. project_id = self.kwargs['project_id']
  94. queryset = self.queryset.filter(project=project_id)
  95. return queryset
  96. def get_object(self):
  97. label_id = self.kwargs['label_id']
  98. queryset = self.filter_queryset(self.get_queryset())
  99. obj = get_object_or_404(queryset, pk=label_id)
  100. self.check_object_permissions(self.request, obj)
  101. return obj
  102. class ProjectDocsAPI(generics.ListCreateAPIView):
  103. queryset = Document.objects.all()
  104. filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter)
  105. search_fields = ('text', )
  106. def get_serializer_class(self):
  107. project_id = self.kwargs['project_id']
  108. project = get_object_or_404(Project, pk=project_id)
  109. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  110. self.serializer_class = DocumentSerializer
  111. elif project.is_type_of(Project.SEQUENCE_LABELING):
  112. self.serializer_class = SequenceSerializer
  113. elif project.is_type_of(Project.Seq2seq):
  114. self.serializer_class = Seq2seqSerializer
  115. return self.serializer_class
  116. def get_queryset(self):
  117. project_id = self.kwargs['project_id']
  118. queryset = self.queryset.filter(project=project_id)
  119. if not self.request.query_params.get('is_checked'):
  120. return queryset
  121. project = get_object_or_404(Project, pk=project_id)
  122. isnull = self.request.query_params.get('is_checked') != 'true'
  123. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  124. queryset = queryset.filter(doc_annotations__isnull=isnull).distinct()
  125. elif project.is_type_of(Project.SEQUENCE_LABELING):
  126. queryset = queryset.filter(seq_annotations__isnull=isnull).distinct()
  127. elif project.is_type_of(Project.Seq2seq):
  128. queryset = queryset.filter(seq2seq_annotations__isnull=isnull).distinct()
  129. else:
  130. queryset = queryset
  131. return queryset
  132. class AnnotationsAPI(generics.ListCreateAPIView):
  133. pagination_class = None
  134. def get_serializer_class(self):
  135. project_id = self.kwargs['project_id']
  136. project = get_object_or_404(Project, pk=project_id)
  137. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  138. self.serializer_class = DocumentAnnotationSerializer
  139. elif project.is_type_of(Project.SEQUENCE_LABELING):
  140. self.serializer_class = SequenceAnnotationSerializer
  141. elif project.is_type_of(Project.Seq2seq):
  142. self.serializer_class = Seq2seqAnnotationSerializer
  143. return self.serializer_class
  144. def get_queryset(self):
  145. doc_id = self.kwargs['doc_id']
  146. project_id = self.kwargs['project_id']
  147. project = get_object_or_404(Project, pk=project_id)
  148. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  149. self.queryset = DocumentAnnotation.objects.all()
  150. elif project.is_type_of(Project.SEQUENCE_LABELING):
  151. self.queryset = SequenceAnnotation.objects.all()
  152. elif project.is_type_of(Project.Seq2seq):
  153. self.queryset = Seq2seqAnnotation.objects.all()
  154. queryset = self.queryset.filter(document=doc_id)
  155. return queryset
  156. def post(self, request, *args, **kwargs):
  157. doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  158. label = get_object_or_404(Label, pk=request.data['label_id'])
  159. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  160. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  161. self.serializer_class = DocumentAnnotationSerializer
  162. annotation = DocumentAnnotation(document=doc, label=label, manual=True,
  163. user=self.request.user)
  164. elif project.is_type_of(Project.SEQUENCE_LABELING):
  165. self.serializer_class = SequenceAnnotationSerializer
  166. annotation = SequenceAnnotation(document=doc, label=label, manual=True,
  167. user=self.request.user,
  168. start_offset=request.data['start_offset'],
  169. end_offset=request.data['end_offset'])
  170. elif project.is_type_of(Project.Seq2seq):
  171. self.serializer_class = Seq2seqAnnotationSerializer
  172. annotation = Seq2seqAnnotation(document=doc, manual=True, user=self.request.user)
  173. annotation.save()
  174. serializer = self.serializer_class(annotation)
  175. return Response(serializer.data)
  176. class AnnotationAPI(generics.RetrieveUpdateDestroyAPIView):
  177. def get_queryset(self):
  178. doc_id = self.kwargs['doc_id']
  179. project_id = self.kwargs['project_id']
  180. project = get_object_or_404(Project, pk=project_id)
  181. if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  182. self.queryset = DocumentAnnotation.objects.all()
  183. elif project.is_type_of(Project.SEQUENCE_LABELING):
  184. self.queryset = SequenceAnnotation.objects.all()
  185. elif project.is_type_of(Project.Seq2seq):
  186. self.queryset = Seq2seqAnnotation.objects.all()
  187. queryset = self.queryset.filter(document=doc_id)
  188. return queryset
  189. def get_object(self):
  190. annotation_id = self.kwargs['annotation_id']
  191. queryset = self.filter_queryset(self.get_queryset())
  192. obj = get_object_or_404(queryset, pk=annotation_id)
  193. self.check_object_permissions(self.request, obj)
  194. return obj