diff --git a/app/server/api.py b/app/server/api.py
index 5c0a13fe..9fadb468 100644
--- a/app/server/api.py
+++ b/app/server/api.py
@@ -2,18 +2,23 @@ from collections import Counter
from django.shortcuts import get_object_or_404
from django_filters.rest_framework import DjangoFilterBackend
+from django.db.models import Count
from rest_framework import generics, filters, status
-from rest_framework.exceptions import ParseError
+from rest_framework.exceptions import ParseError, ValidationError
from rest_framework.permissions import IsAuthenticated, IsAdminUser
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.parsers import MultiPartParser
+from rest_framework_csv.renderers import CSVRenderer
from .filters import DocumentFilter
from .models import Project, Label, Document
from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsOwnAnnotation
from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer
from .serializers import ProjectPolymorphicSerializer
+from .utils import CSVParser, JSONParser, PlainTextParser, CoNLLParser
+from .utils import JSONLRenderer
+from .utils import JSONPainter, CSVPainter
class ProjectList(generics.ListCreateAPIView):
@@ -51,24 +56,23 @@ class StatisticsAPI(APIView):
return Response(response)
def progress(self, project):
- total = project.documents.count()
- remaining = 0
+ docs = project.documents
annotation_class = project.get_annotation_class()
- for d in project.documents.all():
- count = annotation_class.objects.filter(document=d).count()
- if count == 0:
- remaining += 1
+ total = docs.count()
+ done = annotation_class.objects.filter(document_id__in=docs.all()).\
+ aggregate(Count('document', distinct=True))['document__count']
+ remaining = total - done
return {'total': total, 'remaining': remaining}
def label_per_data(self, project):
label_count = Counter()
user_count = Counter()
annotation_class = project.get_annotation_class()
- for doc in project.documents.all():
- annotations = annotation_class.objects.filter(document=doc.id)
- for a in annotations:
- label_count[a.label.text] += 1
- user_count[a.user.username] += 1
+ docs = project.documents.all()
+ annotations = annotation_class.objects.filter(document_id__in=docs.all())
+ for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')):
+ label_count[d['label__text']] += d['label__count']
+ user_count[d['user__username']] += d['user__count']
return label_count, user_count
@@ -132,9 +136,14 @@ class AnnotationList(generics.ListCreateAPIView):
def get_queryset(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
model = project.get_annotation_class()
- self.queryset = model.objects.filter(document=self.kwargs['doc_id'], user=self.request.user)
+ self.queryset = model.objects.filter(document=self.kwargs['doc_id'],
+ user=self.request.user)
return self.queryset
+ def create(self, request, *args, **kwargs):
+ request.data['document'] = self.kwargs['doc_id']
+ return super().create(request, args, kwargs)
+
def perform_create(self, serializer):
doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
serializer.save(document=doc, user=self.request.user)
@@ -164,18 +173,41 @@ class TextUploadAPI(APIView):
if 'file' not in request.data:
raise ParseError('Empty content')
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
- handler = project.get_file_handler(request.data['format'])
- handler.handle_uploaded_file(request.data['file'], self.request.user)
+ parser = self.select_parser(request.data['format'])
+ data = parser.parse(request.data['file'])
+ storage = project.get_storage(data)
+ storage.save(self.request.user)
return Response(status=status.HTTP_201_CREATED)
+ def select_parser(self, format):
+ if format == 'plain':
+ return PlainTextParser()
+ elif format == 'csv':
+ return CSVParser()
+ elif format == 'json':
+ return JSONParser()
+ elif format == 'conll':
+ return CoNLLParser()
+ else:
+ raise ValidationError('format {} is invalid.'.format(format))
+
class TextDownloadAPI(APIView):
permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUser)
+ renderer_classes = (CSVRenderer, JSONLRenderer)
def get(self, request, *args, **kwargs):
- project_id = self.kwargs['project_id']
format = request.query_params.get('q')
- project = get_object_or_404(Project, pk=project_id)
- handler = project.get_file_handler(format)
- response = handler.render()
- return response
+ project = get_object_or_404(Project, pk=self.kwargs['project_id'])
+ documents = project.documents.all()
+ painter = self.select_painter(format)
+ data = painter.paint(documents)
+ return Response(data)
+
+ def select_painter(self, format):
+ if format == 'csv':
+ return CSVPainter()
+ elif format == 'json':
+ return JSONPainter()
+ else:
+ raise ValidationError('format {} is invalid.'.format(format))
diff --git a/app/server/models.py b/app/server/models.py
index 93f9010f..441b05cb 100644
--- a/app/server/models.py
+++ b/app/server/models.py
@@ -58,7 +58,7 @@ class Project(PolymorphicModel):
def get_annotation_class(self):
raise NotImplementedError()
- def get_file_handler(self, format):
+ def get_storage(self, data):
raise NotImplementedError()
def __str__(self):
@@ -87,17 +87,9 @@ class TextClassificationProject(Project):
def get_annotation_class(self):
return DocumentAnnotation
- def get_file_handler(self, format):
- from .utils import JsonClassificationHandler
- from .utils import CSVClassificationHandler
- from .utils import PlainTextHandler
- if format == 'plain':
- return PlainTextHandler(self)
- elif format == 'csv':
- return CSVClassificationHandler(self)
- elif format == 'json':
- return JsonClassificationHandler(self)
- raise ValidationError('format {} is invalid.'.format(format))
+ def get_storage(self, data):
+ from .utils import ClassificationStorage
+ return ClassificationStorage(data, self)
class SequenceLabelingProject(Project):
@@ -122,17 +114,9 @@ class SequenceLabelingProject(Project):
def get_annotation_class(self):
return SequenceAnnotation
- def get_file_handler(self, format):
- from .utils import JsonLabelingHandler
- from .utils import PlainTextHandler
- from .utils import CoNLLHandler
- if format == 'plain':
- return PlainTextHandler(self)
- elif format == 'conll':
- return CoNLLHandler(self)
- elif format == 'json':
- return JsonLabelingHandler(self)
- raise ValidationError('format {} is invalid.'.format(format))
+ def get_storage(self, data):
+ from .utils import SequenceLabelingStorage
+ return SequenceLabelingStorage(data, self)
class Seq2seqProject(Project):
@@ -157,17 +141,9 @@ class Seq2seqProject(Project):
def get_annotation_class(self):
return Seq2seqAnnotation
- def get_file_handler(self, format):
- from .utils import JsonSeq2seqHandler
- from .utils import CSVSeq2seqHandler
- from .utils import PlainTextHandler
- if format == 'plain':
- return PlainTextHandler(self)
- elif format == 'csv':
- return CSVSeq2seqHandler(self)
- elif format == 'json':
- return JsonSeq2seqHandler(self)
- raise ValidationError('format {} is invalid.'.format(format))
+ def get_storage(self, data):
+ from .utils import Seq2seqStorage
+ return Seq2seqStorage(data, self)
class Label(models.Model):
diff --git a/app/server/serializers.py b/app/server/serializers.py
index 48b9c83f..889286da 100644
--- a/app/server/serializers.py
+++ b/app/server/serializers.py
@@ -87,34 +87,29 @@ class ProjectFilteredPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField):
class DocumentAnnotationSerializer(serializers.ModelSerializer):
# label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
+ document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
class Meta:
model = DocumentAnnotation
- fields = ('id', 'prob', 'label', 'user')
+ fields = ('id', 'prob', 'label', 'user', 'document')
read_only_fields = ('user', )
- def create(self, validated_data):
- annotation = DocumentAnnotation.objects.create(**validated_data)
- return annotation
-
class SequenceAnnotationSerializer(serializers.ModelSerializer):
#label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
+ document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
class Meta:
model = SequenceAnnotation
- fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user')
+ fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user', 'document')
read_only_fields = ('user',)
- def create(self, validated_data):
- annotation = SequenceAnnotation.objects.create(**validated_data)
- return annotation
-
class Seq2seqAnnotationSerializer(serializers.ModelSerializer):
+ document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
class Meta:
model = Seq2seqAnnotation
- fields = ('id', 'text', 'user')
+ fields = ('id', 'text', 'user', 'document')
read_only_fields = ('user',)
diff --git a/app/server/static/js/upload.js b/app/server/static/js/upload.js
index be46e5d8..508801e6 100644
--- a/app/server/static/js/upload.js
+++ b/app/server/static/js/upload.js
@@ -8,11 +8,13 @@ const vm = new Vue({
file: '',
messages: [],
format: 'json',
+ isLoading: false,
},
methods: {
upload() {
+ this.isLoading = true;
this.file = this.$refs.file.files[0];
let formData = new FormData();
formData.append('file', this.file);
@@ -27,8 +29,10 @@ const vm = new Vue({
.then((response) => {
console.log(response);
this.messages = [];
+ window.location = window.location.pathname.split('/').slice(0, -1).join('/');
})
.catch((error) => {
+ this.isLoading = false;
if ('detail' in error.response.data) {
this.messages.push(error.response.data.detail);
} else if ('text' in error.response.data) {
@@ -38,6 +42,14 @@ const vm = new Vue({
},
download() {
+ let headers = {};
+ if (this.format === 'csv') {
+ headers.Accept = 'text/csv; charset=utf-8';
+ headers['Content-Type'] = 'text/csv; charset=utf-8';
+ } else {
+ headers.Accept = 'application/json';
+ headers['Content-Type'] = 'application/json';
+ }
HTTP({
url: 'docs/download',
method: 'GET',
@@ -45,6 +57,7 @@ const vm = new Vue({
params: {
q: this.format,
},
+ headers,
}).then((response) => {
const url = window.URL.createObjectURL(new Blob([response.data]));
const link = document.createElement('a');
diff --git a/app/server/templates/admin/upload/base.html b/app/server/templates/admin/upload/base.html
index 53f77591..e9586278 100644
--- a/app/server/templates/admin/upload/base.html
+++ b/app/server/templates/admin/upload/base.html
@@ -27,7 +27,7 @@