You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

347 lines
12 KiB

import json
import csv
from itertools import chain
from collections import Counter
from io import TextIOWrapper
from django import forms
from django.contrib.auth.mixins import UserPassesTestMixin
from django.urls import reverse
from django_filters.rest_framework import DjangoFilterBackend
from django.http import JsonResponse, HttpResponse, HttpResponseRedirect
from django.shortcuts import render, get_object_or_404
from django.views import View
from django.views.generic import TemplateView
from django.views.generic.list import ListView
from django.views.generic.detail import DetailView
from django.contrib.auth.mixins import LoginRequiredMixin
from rest_framework import viewsets, filters, generics
from rest_framework.views import APIView
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.permissions import SAFE_METHODS, BasePermission, IsAdminUser, IsAuthenticated
from .models import Label, Document, Project
from .models import DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation
from .serializers import LabelSerializer, ProjectSerializer
class SuperUserMixin(UserPassesTestMixin):
def test_func(self):
return self.request.user.is_superuser
class IndexView(TemplateView):
template_name = 'index.html'
class ProjectView(LoginRequiredMixin, TemplateView):
template_name = 'annotation.html'
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
project_id = kwargs.get('project_id')
project = get_object_or_404(Project, pk=project_id)
self.template_name = project.get_template()
return context
class ProjectForm(forms.ModelForm):
class Meta:
model = Project
fields = ('name', 'description', 'project_type', 'users')
class ProjectsView(LoginRequiredMixin, TemplateView):
model = Project
paginate_by = 100
template_name = 'projects.html'
def get(self, request, *args, **kwargs):
form = ProjectForm()
return render(request, self.template_name, {'form': form})
def post(self, request, *args, **kwargs):
form = ProjectForm(request.POST)
if form.is_valid():
project = form.save()
return HttpResponseRedirect(reverse('upload', args=[project.id]))
else:
return render(request, self.template_name, {'form': form})
class DatasetView(SuperUserMixin, LoginRequiredMixin, ListView):
template_name = 'admin/dataset.html'
context_object_name = 'documents'
paginate_by = 5
def get_queryset(self):
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
return project.documents.all()
class LabelView(SuperUserMixin, LoginRequiredMixin, TemplateView):
template_name = 'admin/label.html'
class StatsView(SuperUserMixin, LoginRequiredMixin, TemplateView):
template_name = 'admin/stats.html'
class DatasetUpload(SuperUserMixin, LoginRequiredMixin, TemplateView):
model = Project
template_name = 'admin/dataset_upload.html'
def post(self, request, *args, **kwargs):
project = get_object_or_404(Project, pk=kwargs.get('project_id'))
try:
form_data = TextIOWrapper(request.FILES['csv_file'].file, encoding='utf-8')
reader = csv.reader(form_data)
for line in reader:
text = line[0]
Document(text=text, project=project).save()
return HttpResponseRedirect(reverse('dataset', args=[project.id]))
except:
print("failed")
return HttpResponseRedirect(reverse('dataset-upload', args=[project.id]))
class DataDownload(SuperUserMixin, View):
def get(self, request, *args, **kwargs):
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
docs = project.get_documents(is_null=False).distinct()
filename = '_'.join(project.name.lower().split())
response = HttpResponse(content_type='text/csv')
response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename)
writer = csv.writer(response)
for d in docs:
writer.writerows(d.make_dataset())
return response
class IsProjectUser(BasePermission):
def has_permission(self, request, view):
user = request.user
project_id = view.kwargs.get('project_id')
project = get_object_or_404(Project, pk=project_id)
return user in project.users.all()
class IsAdminUserAndWriteOnly(BasePermission):
def has_permission(self, request, view):
if request.method in SAFE_METHODS:
return True
return IsAdminUser().has_permission(request, view)
class IsOwnAnnotation(BasePermission):
def has_permission(self, request, view):
user = request.user
project_id = view.kwargs.get('project_id')
annotation_id = view.kwargs.get('annotation_id')
project = get_object_or_404(Project, pk=project_id)
Annotation = project.get_annotation_class()
annotation = Annotation.objects.get(id=annotation_id)
return annotation.user == user
class ProjectViewSet(viewsets.ModelViewSet):
queryset = Project.objects.all()
serializer_class = ProjectSerializer
pagination_class = None
permission_classes = (IsAuthenticated, IsAdminUserAndWriteOnly)
def get_queryset(self):
user = self.request.user
queryset = self.queryset.filter(users__id__contains=user.id)
return queryset
@action(methods=['get'], detail=True)
def progress(self, request, pk=None):
project = self.get_object()
docs = project.get_documents(is_null=True)
total = project.documents.count()
remaining = docs.count()
return Response({'total': total, 'remaining': remaining})
class ProjectLabelsAPI(generics.ListCreateAPIView):
queryset = Label.objects.all()
serializer_class = LabelSerializer
pagination_class = None
permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
def get_queryset(self):
project_id = self.kwargs['project_id']
queryset = self.queryset.filter(project=project_id)
return queryset
def perform_create(self, serializer):
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
serializer.save(project=project)
class ProjectStatsAPI(APIView):
pagination_class = None
permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
def get(self, request, *args, **kwargs):
p = get_object_or_404(Project, pk=self.kwargs['project_id'])
labels = [label.text for label in p.labels.all()]
users = [user.username for user in p.users.all()]
docs = [doc for doc in p.documents.all()]
nested_labels = [[a.label.text for a in doc.get_annotations()] for doc in docs]
nested_users = [[a.user.username for a in doc.get_annotations()] for doc in docs]
label_count = Counter(chain(*nested_labels))
label_data = [label_count[name] for name in labels]
user_count = Counter(chain(*nested_users))
user_data = [user_count[name] for name in users]
response = {'label': {'labels': labels, 'data': label_data},
'user': {'users': users, 'data': user_data}}
return Response(response)
class ProjectLabelAPI(generics.RetrieveUpdateDestroyAPIView):
queryset = Label.objects.all()
serializer_class = LabelSerializer
permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUser)
def get_queryset(self):
project_id = self.kwargs['project_id']
queryset = self.queryset.filter(project=project_id)
return queryset
def get_object(self):
label_id = self.kwargs['label_id']
queryset = self.filter_queryset(self.get_queryset())
obj = get_object_or_404(queryset, pk=label_id)
self.check_object_permissions(self.request, obj)
return obj
class ProjectDocsAPI(generics.ListCreateAPIView):
queryset = Document.objects.all()
filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter)
search_fields = ('text', )
permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
def get_serializer_class(self):
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
self.serializer_class = project.get_project_serializer()
return self.serializer_class
def get_queryset(self):
project_id = self.kwargs['project_id']
queryset = self.queryset.filter(project=project_id)
if not self.request.query_params.get('is_checked'):
return queryset
project = get_object_or_404(Project, pk=project_id)
is_null = self.request.query_params.get('is_checked') == 'true'
queryset = project.get_documents(is_null).distinct()
return queryset
class AnnotationsAPI(generics.ListCreateAPIView):
pagination_class = None
permission_classes = (IsAuthenticated, IsProjectUser)
def get_serializer_class(self):
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
self.serializer_class = project.get_annotation_serializer()
return self.serializer_class
def get_queryset(self):
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
doc_id = self.kwargs['doc_id']
document = get_object_or_404(Document, pk=doc_id, project=project)
self.queryset = document.get_annotations()
return self.queryset
def post(self, request, *args, **kwargs):
doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
self.serializer_class = project.get_annotation_serializer()
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
label = get_object_or_404(Label, pk=request.data['label_id'])
annotation = DocumentAnnotation(document=doc, label=label, manual=True,
user=self.request.user)
elif project.is_type_of(Project.SEQUENCE_LABELING):
label = get_object_or_404(Label, pk=request.data['label_id'])
annotation = SequenceAnnotation(document=doc, label=label, manual=True,
user=self.request.user,
start_offset=request.data['start_offset'],
end_offset=request.data['end_offset'])
elif project.is_type_of(Project.Seq2seq):
text = request.data['text']
annotation = Seq2seqAnnotation(document=doc,
text=text,
manual=True,
user=self.request.user)
annotation.save()
serializer = self.serializer_class(annotation)
return Response(serializer.data)
class AnnotationAPI(generics.RetrieveUpdateDestroyAPIView):
permission_classes = (IsAuthenticated, IsProjectUser, IsOwnAnnotation)
def get_queryset(self):
doc_id = self.kwargs['doc_id']
document = get_object_or_404(Document, pk=doc_id)
self.queryset = document.get_annotations()
return self.queryset
def get_object(self):
annotation_id = self.kwargs['annotation_id']
queryset = self.filter_queryset(self.get_queryset())
obj = get_object_or_404(queryset, pk=annotation_id)
self.check_object_permissions(self.request, obj)
return obj
def put(self, request, *args, **kwargs):
doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
self.serializer_class = project.get_annotation_serializer()
if project.is_type_of(Project.Seq2seq):
text = request.data['text']
annotation = get_object_or_404(Seq2seqAnnotation, pk=request.data['id'])
annotation.text = text
annotation.save()
serializer = self.serializer_class(annotation)
return Response(serializer.data)