diff --git a/app/server/api.py b/app/server/api.py index 5c0a13fe..9fadb468 100644 --- a/app/server/api.py +++ b/app/server/api.py @@ -2,18 +2,23 @@ from collections import Counter from django.shortcuts import get_object_or_404 from django_filters.rest_framework import DjangoFilterBackend +from django.db.models import Count from rest_framework import generics, filters, status -from rest_framework.exceptions import ParseError +from rest_framework.exceptions import ParseError, ValidationError from rest_framework.permissions import IsAuthenticated, IsAdminUser from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.parsers import MultiPartParser +from rest_framework_csv.renderers import CSVRenderer from .filters import DocumentFilter from .models import Project, Label, Document from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsOwnAnnotation from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer from .serializers import ProjectPolymorphicSerializer +from .utils import CSVParser, JSONParser, PlainTextParser, CoNLLParser +from .utils import JSONLRenderer +from .utils import JSONPainter, CSVPainter class ProjectList(generics.ListCreateAPIView): @@ -51,24 +56,23 @@ class StatisticsAPI(APIView): return Response(response) def progress(self, project): - total = project.documents.count() - remaining = 0 + docs = project.documents annotation_class = project.get_annotation_class() - for d in project.documents.all(): - count = annotation_class.objects.filter(document=d).count() - if count == 0: - remaining += 1 + total = docs.count() + done = annotation_class.objects.filter(document_id__in=docs.all()).\ + aggregate(Count('document', distinct=True))['document__count'] + remaining = total - done return {'total': total, 'remaining': remaining} def label_per_data(self, project): label_count = Counter() user_count = Counter() annotation_class = project.get_annotation_class() - for doc in project.documents.all(): - annotations = annotation_class.objects.filter(document=doc.id) - for a in annotations: - label_count[a.label.text] += 1 - user_count[a.user.username] += 1 + docs = project.documents.all() + annotations = annotation_class.objects.filter(document_id__in=docs.all()) + for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')): + label_count[d['label__text']] += d['label__count'] + user_count[d['user__username']] += d['user__count'] return label_count, user_count @@ -132,9 +136,14 @@ class AnnotationList(generics.ListCreateAPIView): def get_queryset(self): project = get_object_or_404(Project, pk=self.kwargs['project_id']) model = project.get_annotation_class() - self.queryset = model.objects.filter(document=self.kwargs['doc_id'], user=self.request.user) + self.queryset = model.objects.filter(document=self.kwargs['doc_id'], + user=self.request.user) return self.queryset + def create(self, request, *args, **kwargs): + request.data['document'] = self.kwargs['doc_id'] + return super().create(request, args, kwargs) + def perform_create(self, serializer): doc = get_object_or_404(Document, pk=self.kwargs['doc_id']) serializer.save(document=doc, user=self.request.user) @@ -164,18 +173,41 @@ class TextUploadAPI(APIView): if 'file' not in request.data: raise ParseError('Empty content') project = get_object_or_404(Project, pk=self.kwargs['project_id']) - handler = project.get_file_handler(request.data['format']) - handler.handle_uploaded_file(request.data['file'], self.request.user) + parser = self.select_parser(request.data['format']) + data = parser.parse(request.data['file']) + storage = project.get_storage(data) + storage.save(self.request.user) return Response(status=status.HTTP_201_CREATED) + def select_parser(self, format): + if format == 'plain': + return PlainTextParser() + elif format == 'csv': + return CSVParser() + elif format == 'json': + return JSONParser() + elif format == 'conll': + return CoNLLParser() + else: + raise ValidationError('format {} is invalid.'.format(format)) + class TextDownloadAPI(APIView): permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUser) + renderer_classes = (CSVRenderer, JSONLRenderer) def get(self, request, *args, **kwargs): - project_id = self.kwargs['project_id'] format = request.query_params.get('q') - project = get_object_or_404(Project, pk=project_id) - handler = project.get_file_handler(format) - response = handler.render() - return response + project = get_object_or_404(Project, pk=self.kwargs['project_id']) + documents = project.documents.all() + painter = self.select_painter(format) + data = painter.paint(documents) + return Response(data) + + def select_painter(self, format): + if format == 'csv': + return CSVPainter() + elif format == 'json': + return JSONPainter() + else: + raise ValidationError('format {} is invalid.'.format(format)) diff --git a/app/server/models.py b/app/server/models.py index 93f9010f..441b05cb 100644 --- a/app/server/models.py +++ b/app/server/models.py @@ -58,7 +58,7 @@ class Project(PolymorphicModel): def get_annotation_class(self): raise NotImplementedError() - def get_file_handler(self, format): + def get_storage(self, data): raise NotImplementedError() def __str__(self): @@ -87,17 +87,9 @@ class TextClassificationProject(Project): def get_annotation_class(self): return DocumentAnnotation - def get_file_handler(self, format): - from .utils import JsonClassificationHandler - from .utils import CSVClassificationHandler - from .utils import PlainTextHandler - if format == 'plain': - return PlainTextHandler(self) - elif format == 'csv': - return CSVClassificationHandler(self) - elif format == 'json': - return JsonClassificationHandler(self) - raise ValidationError('format {} is invalid.'.format(format)) + def get_storage(self, data): + from .utils import ClassificationStorage + return ClassificationStorage(data, self) class SequenceLabelingProject(Project): @@ -122,17 +114,9 @@ class SequenceLabelingProject(Project): def get_annotation_class(self): return SequenceAnnotation - def get_file_handler(self, format): - from .utils import JsonLabelingHandler - from .utils import PlainTextHandler - from .utils import CoNLLHandler - if format == 'plain': - return PlainTextHandler(self) - elif format == 'conll': - return CoNLLHandler(self) - elif format == 'json': - return JsonLabelingHandler(self) - raise ValidationError('format {} is invalid.'.format(format)) + def get_storage(self, data): + from .utils import SequenceLabelingStorage + return SequenceLabelingStorage(data, self) class Seq2seqProject(Project): @@ -157,17 +141,9 @@ class Seq2seqProject(Project): def get_annotation_class(self): return Seq2seqAnnotation - def get_file_handler(self, format): - from .utils import JsonSeq2seqHandler - from .utils import CSVSeq2seqHandler - from .utils import PlainTextHandler - if format == 'plain': - return PlainTextHandler(self) - elif format == 'csv': - return CSVSeq2seqHandler(self) - elif format == 'json': - return JsonSeq2seqHandler(self) - raise ValidationError('format {} is invalid.'.format(format)) + def get_storage(self, data): + from .utils import Seq2seqStorage + return Seq2seqStorage(data, self) class Label(models.Model): diff --git a/app/server/serializers.py b/app/server/serializers.py index 48b9c83f..889286da 100644 --- a/app/server/serializers.py +++ b/app/server/serializers.py @@ -87,34 +87,29 @@ class ProjectFilteredPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField): class DocumentAnnotationSerializer(serializers.ModelSerializer): # label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all()) label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all()) + document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all()) class Meta: model = DocumentAnnotation - fields = ('id', 'prob', 'label', 'user') + fields = ('id', 'prob', 'label', 'user', 'document') read_only_fields = ('user', ) - def create(self, validated_data): - annotation = DocumentAnnotation.objects.create(**validated_data) - return annotation - class SequenceAnnotationSerializer(serializers.ModelSerializer): #label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all()) label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all()) + document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all()) class Meta: model = SequenceAnnotation - fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user') + fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user', 'document') read_only_fields = ('user',) - def create(self, validated_data): - annotation = SequenceAnnotation.objects.create(**validated_data) - return annotation - class Seq2seqAnnotationSerializer(serializers.ModelSerializer): + document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all()) class Meta: model = Seq2seqAnnotation - fields = ('id', 'text', 'user') + fields = ('id', 'text', 'user', 'document') read_only_fields = ('user',) diff --git a/app/server/static/js/upload.js b/app/server/static/js/upload.js index be46e5d8..508801e6 100644 --- a/app/server/static/js/upload.js +++ b/app/server/static/js/upload.js @@ -8,11 +8,13 @@ const vm = new Vue({ file: '', messages: [], format: 'json', + isLoading: false, }, methods: { upload() { + this.isLoading = true; this.file = this.$refs.file.files[0]; let formData = new FormData(); formData.append('file', this.file); @@ -27,8 +29,10 @@ const vm = new Vue({ .then((response) => { console.log(response); this.messages = []; + window.location = window.location.pathname.split('/').slice(0, -1).join('/'); }) .catch((error) => { + this.isLoading = false; if ('detail' in error.response.data) { this.messages.push(error.response.data.detail); } else if ('text' in error.response.data) { @@ -38,6 +42,14 @@ const vm = new Vue({ }, download() { + let headers = {}; + if (this.format === 'csv') { + headers.Accept = 'text/csv; charset=utf-8'; + headers['Content-Type'] = 'text/csv; charset=utf-8'; + } else { + headers.Accept = 'application/json'; + headers['Content-Type'] = 'application/json'; + } HTTP({ url: 'docs/download', method: 'GET', @@ -45,6 +57,7 @@ const vm = new Vue({ params: { q: this.format, }, + headers, }).then((response) => { const url = window.URL.createObjectURL(new Blob([response.data])); const link = document.createElement('a'); diff --git a/app/server/templates/admin/upload/base.html b/app/server/templates/admin/upload/base.html index 53f77591..e9586278 100644 --- a/app/server/templates/admin/upload/base.html +++ b/app/server/templates/admin/upload/base.html @@ -27,7 +27,7 @@