diff --git a/app/server/admin.py b/app/server/admin.py index 2806c245..bf921b8b 100644 --- a/app/server/admin.py +++ b/app/server/admin.py @@ -2,6 +2,7 @@ from django.contrib import admin from .models import Label, Document, Project from .models import DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation +from .models import TextClassificationProject, SequenceLabelingProject, Seq2seqProject admin.site.register(DocumentAnnotation) admin.site.register(SequenceAnnotation) @@ -9,3 +10,6 @@ admin.site.register(Seq2seqAnnotation) admin.site.register(Label) admin.site.register(Document) admin.site.register(Project) +admin.site.register(TextClassificationProject) +admin.site.register(SequenceLabelingProject) +admin.site.register(Seq2seqProject) diff --git a/app/server/api.py b/app/server/api.py index 42f10bfa..34e8b742 100644 --- a/app/server/api.py +++ b/app/server/api.py @@ -5,10 +5,11 @@ from collections import Counter from itertools import chain from django.db import transaction +from django.http import HttpResponse from django.shortcuts import get_object_or_404 from django_filters.rest_framework import DjangoFilterBackend from rest_framework import generics, filters, status -from rest_framework.exceptions import ParseError +from rest_framework.exceptions import ParseError, ValidationError from rest_framework.permissions import IsAuthenticated, IsAdminUser from rest_framework.response import Response from rest_framework.views import APIView @@ -186,31 +187,50 @@ class TextUploadAPI(APIView): raise ParseError('Empty content') project = get_object_or_404(Project, pk=self.kwargs['project_id']) handler = project.get_upload_handler(request.data['format']) - handler.handle_uploaded_file(request.FILES['file'], project, self.request.user) + handler.handle_uploaded_file(request.FILES['file'], self.request.user) return Response(status=status.HTTP_201_CREATED) +class TextDownloadAPI(APIView): + """API for text download.""" + permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUser) + + def get(self, request, *args, **kwargs): + project_id = self.kwargs['project_id'] + format = request.query_params.get('q') + project = get_object_or_404(Project, pk=project_id) + handler = project.get_upload_handler(format) + response = handler.render() + return response + + class FileHandler(object): annotation_serializer = None + def __init__(self, project): + self.project = project + @transaction.atomic - def handle_uploaded_file(self, file, project, user): + def handle_uploaded_file(self, file, user): raise NotImplementedError() def parse(self, file): raise NotImplementedError() - def save_doc(self, data, project): + def render(self): + raise NotImplementedError() + + def save_doc(self, data): serializer = DocumentSerializer(data=data) serializer.is_valid(raise_exception=True) - doc = serializer.save(project=project) + doc = serializer.save(project=self.project) return doc - def save_label(self, data, project): - label = Label.objects.filter(project=project, **data).first() + def save_label(self, data): + label = Label.objects.filter(project=self.project, **data).first() serializer = LabelSerializer(label, data=data) serializer.is_valid(raise_exception=True) - label = serializer.save(project=project) + label = serializer.save(project=self.project) return label def save_annotation(self, data, doc, user): @@ -245,14 +265,14 @@ class CoNLLHandler(FileHandler): annotation_serializer = SequenceAnnotationSerializer @transaction.atomic - def handle_uploaded_file(self, file, project, user): + def handle_uploaded_file(self, file, user): for words, tags in self.parse(file): start_offset = 0 sent = ' '.join(words) - doc = self.save_doc({'text': sent}, project) + doc = self.save_doc({'text': sent}) for word, tag in zip(words, tags): label = extract_label(tag) - label = self.save_label({'text': label}, project) + label = self.save_label({'text': label}) data = {'start_offset': start_offset, 'end_offset': start_offset + len(word), 'label': label.id} @@ -277,6 +297,9 @@ class CoNLLHandler(FileHandler): if len(words) > 0: yield words, tags + def render(self): + raise ValidationError("This project type doesn't support CoNLL format.") + class PlainTextHandler(FileHandler): """Uploads plain text. @@ -289,15 +312,18 @@ class PlainTextHandler(FileHandler): ``` """ @transaction.atomic - def handle_uploaded_file(self, file, project, user): + def handle_uploaded_file(self, file, user): for text in self.parse(file): - self.save_doc({'text': text}, project) + self.save_doc({'text': text}) def parse(self, file): file = io.TextIOWrapper(file, encoding='utf-8') for i, line in enumerate(file, start=1): yield line.strip() + def render(self): + raise ValidationError("You cannot download plain text. Please specify csv or json.") + class CSVHandler(FileHandler): """Uploads csv file. @@ -327,27 +353,58 @@ class CSVHandler(FileHandler): else: raise FileParseException(line_num=i, line=row) + def render(self): + raise NotImplementedError() + class CSVClassificationHandler(CSVHandler): annotation_serializer = DocumentAnnotationSerializer @transaction.atomic - def handle_uploaded_file(self, file, project, user): + def handle_uploaded_file(self, file, user): for text, label in self.parse(file): - doc = self.save_doc({'text': text}, project) - label = self.save_label({'text': label}, project) + doc = self.save_doc({'text': text}) + label = self.save_label({'text': label}) self.save_annotation({'label': label.id}, doc, user) + def render(self): + queryset = self.project.documents.all() + serializer = DocumentSerializer(queryset, many=True) + filename = '_'.join(self.project.name.lower().split()) + response = HttpResponse(content_type='text/csv') + response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename) + writer = csv.writer(response) + writer.writerow(['id', 'text', 'label', 'user']) + for d in serializer.data: + for a in d['annotations']: + row = [d['id'], d['text'], a['label'], a['user']] + writer.writerow(row) + return response + class CSVSeq2seqHandler(CSVHandler): annotation_serializer = Seq2seqAnnotationSerializer @transaction.atomic - def handle_uploaded_file(self, file, project, user): + def handle_uploaded_file(self, file, user): for text, label in self.parse(file): - doc = self.save_doc({'text': text}, project) + doc = self.save_doc({'text': text}) self.save_annotation({'text': label}, doc, user) + def render(self): + queryset = self.project.documents.all() + serializer = DocumentSerializer(queryset, many=True) + filename = '_'.join(self.project.name.lower().split()) + response = HttpResponse(content_type='text/csv') + response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename) + writer = csv.writer(response) + writer.writerow(['id', 'text', 'label', 'user']) + for d in serializer.data: + for a in d['annotations']: + row = [d['id'], d['text'], a['text'], a['user']] + writer.writerow(row) + return response + class JsonHandler(FileHandler): """Uploads jsonl file. @@ -367,6 +424,17 @@ class JsonHandler(FileHandler): except json.decoder.JSONDecodeError: raise FileParseException(line_num=i, line=line) + def render(self): + queryset = self.project.documents.all() + serializer = DocumentSerializer(queryset, many=True) + filename = '_'.join(self.project.name.lower().split()) + response = HttpResponse(content_type='application/json') + response['Content-Disposition'] = 'attachment; filename="{}.jsonl"'.format(filename) + for d in serializer.data: + dump = json.dumps(d, ensure_ascii=False) + response.write(dump + '\n') + return response + class JsonClassificationHandler(JsonHandler): """Upload jsonl for text classification. @@ -380,11 +448,11 @@ class JsonClassificationHandler(JsonHandler): annotation_serializer = DocumentAnnotationSerializer @transaction.atomic - def handle_uploaded_file(self, file, project, user): + def handle_uploaded_file(self, file, user): for data in self.parse(file): - doc = self.save_doc(data, project) + doc = self.save_doc(data) for label in data['labels']: - label = self.save_label({'text': label}, project) + label = self.save_label({'text': label}) self.save_annotation({'label': label.id}, doc, user) @@ -400,11 +468,11 @@ class JsonLabelingHandler(JsonHandler): annotation_serializer = SequenceAnnotationSerializer @transaction.atomic - def handle_uploaded_file(self, file, project, user): + def handle_uploaded_file(self, file, user): for data in self.parse(file): - doc = self.save_doc(data, project) + doc = self.save_doc(data) for start_offset, end_offset, label in data['entities']: - label = self.save_label({'text': label}, project) + label = self.save_label({'text': label}) data = {'label': label.id, 'start_offset': start_offset, 'end_offset': end_offset} @@ -423,8 +491,8 @@ class JsonSeq2seqHandler(JsonHandler): annotation_serializer = Seq2seqAnnotationSerializer @transaction.atomic - def handle_uploaded_file(self, file, project, user): + def handle_uploaded_file(self, file, user): for data in self.parse(file): - doc = self.save_doc(data, project) + doc = self.save_doc(data) for label in data['labels']: self.save_annotation({'text': label}, doc, user) diff --git a/app/server/api_urls.py b/app/server/api_urls.py index 721199e2..c4a74087 100644 --- a/app/server/api_urls.py +++ b/app/server/api_urls.py @@ -6,7 +6,7 @@ from .api import LabelList, LabelDetail from .api import DocumentList, DocumentDetail from .api import EntityList, EntityDetail from .api import AnnotationList, AnnotationDetail -from .api import TextUploadAPI +from .api import TextUploadAPI, TextDownloadAPI from .api import StatisticsAPI @@ -32,7 +32,9 @@ urlpatterns = [ path('projects//docs//annotations/', AnnotationDetail.as_view(), name='annotation_detail'), path('projects//docs/upload', - TextUploadAPI.as_view(), name='doc_uploader') + TextUploadAPI.as_view(), name='doc_uploader'), + path('projects//docs/download', + TextDownloadAPI.as_view(), name='doc_downloader') ] urlpatterns = format_suffix_patterns(urlpatterns, allowed=['json', 'xml']) diff --git a/app/server/models.py b/app/server/models.py index 6cd0ae52..59d5af65 100644 --- a/app/server/models.py +++ b/app/server/models.py @@ -3,6 +3,7 @@ from django.db import models from django.urls import reverse from django.contrib.auth.models import User from django.contrib.staticfiles.storage import staticfiles_storage +from rest_framework.exceptions import ValidationError from polymorphic.models import PolymorphicModel from .utils import get_key_choices @@ -43,7 +44,7 @@ class Project(PolymorphicModel): raise NotImplementedError() def get_upload_handler(self, format): - raise NotImplementedError + raise NotImplementedError() def __str__(self): return self.name @@ -68,12 +69,12 @@ class TextClassificationProject(Project): def get_upload_handler(self, format): from .api import PlainTextHandler, CSVClassificationHandler, JsonClassificationHandler if format == 'plain': - return PlainTextHandler() + return PlainTextHandler(self) elif format == 'csv': - return CSVClassificationHandler() + return CSVClassificationHandler(self) elif format == 'json': - return JsonClassificationHandler() - raise ValueError('format {} is invalid.'.format(format)) + return JsonClassificationHandler(self) + raise ValidationError('format {} is invalid.'.format(format)) class SequenceLabelingProject(Project): @@ -95,12 +96,12 @@ class SequenceLabelingProject(Project): def get_upload_handler(self, format): from .api import PlainTextHandler, CoNLLHandler, JsonLabelingHandler if format == 'plain': - return PlainTextHandler() + return PlainTextHandler(self) elif format == 'conll': - return CoNLLHandler() + return CoNLLHandler(self) elif format == 'json': - return JsonLabelingHandler() - raise ValueError('format {} is invalid.'.format(format)) + return JsonLabelingHandler(self) + raise ValidationError('format {} is invalid.'.format(format)) class Seq2seqProject(Project): @@ -122,12 +123,12 @@ class Seq2seqProject(Project): def get_upload_handler(self, format): from .api import PlainTextHandler, CSVSeq2seqHandler, JsonSeq2seqHandler if format == 'plain': - return PlainTextHandler() + return PlainTextHandler(self) elif format == 'csv': - return CSVSeq2seqHandler() + return CSVSeq2seqHandler(self) elif format == 'json': - return JsonSeq2seqHandler() - raise ValueError('format {} is invalid.'.format(format)) + return JsonSeq2seqHandler(self) + raise ValidationError('format {} is invalid.'.format(format)) class Label(models.Model): diff --git a/app/server/serializers.py b/app/server/serializers.py index d1b29a7a..c8540633 100644 --- a/app/server/serializers.py +++ b/app/server/serializers.py @@ -19,14 +19,14 @@ class DocumentSerializer(serializers.ModelSerializer): def get_annotations(self, instance): request = self.context.get('request') - view = self.context.get('view', None) - if request and view: - project = get_object_or_404(Project, pk=view.kwargs['project_id']) - model = project.get_annotation_class() - serializer = project.get_annotation_serializer() - annotations = model.objects.filter(user=request.user, document=instance.id) - serializer = serializer(annotations, many=True) - return serializer.data + project = instance.project + model = project.get_annotation_class() + serializer = project.get_annotation_serializer() + annotations = model.objects.filter(document=instance.id) + if request: + annotations = annotations.filter(user=request.user) + serializer = serializer(annotations, many=True) + return serializer.data class Meta: model = Document @@ -91,7 +91,8 @@ class DocumentAnnotationSerializer(serializers.ModelSerializer): class Meta: model = DocumentAnnotation - fields = ('id', 'prob', 'label') + fields = ('id', 'prob', 'label', 'user') + read_only_fields = ('user', ) def create(self, validated_data): annotation = DocumentAnnotation.objects.create(**validated_data) @@ -104,7 +105,8 @@ class SequenceAnnotationSerializer(serializers.ModelSerializer): class Meta: model = SequenceAnnotation - fields = ('id', 'prob', 'label', 'start_offset', 'end_offset') + fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user') + read_only_fields = ('user',) def create(self, validated_data): annotation = SequenceAnnotation.objects.create(**validated_data) @@ -115,4 +117,5 @@ class Seq2seqAnnotationSerializer(serializers.ModelSerializer): class Meta: model = Seq2seqAnnotation - fields = ('id', 'text') + fields = ('id', 'text', 'user') + read_only_fields = ('user',) diff --git a/app/server/tests/test_api.py b/app/server/tests/test_api.py index 41cc3690..bd7da7b2 100644 --- a/app/server/tests/test_api.py +++ b/app/server/tests/test_api.py @@ -3,7 +3,6 @@ import os from rest_framework import status from rest_framework.reverse import reverse from rest_framework.test import APITestCase -from mixer.backend.django import mixer from model_mommy import mommy from ..models import User, SequenceAnnotation, Document, Label, Seq2seqAnnotation, DocumentAnnotation from ..models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ @@ -826,7 +825,7 @@ class TestUploader(APITestCase): expected_status=status.HTTP_201_CREATED) def test_can_upload_seq2seq_csv(self): - self.upload_test_helper(url=self.classification_url, + self.upload_test_helper(url=self.seq2seq_url, filename='example.valid.2.csv', format='csv', expected_status=status.HTTP_201_CREATED) @@ -882,11 +881,11 @@ class TestFileHandler(APITestCase): def handler_test_helper(self, filename, handler): with open(os.path.join(DATA_DIR, filename), mode='rb') as f: - handler.handle_uploaded_file(f, self.project, self.super_user) + handler.handle_uploaded_file(f, self.super_user) def test_conll_handler(self): self.handler_test_helper(filename='example.valid.conll', - handler=CoNLLHandler()) + handler=CoNLLHandler(self.project)) self.assertEqual(Document.objects.count(), 3) self.assertEqual(Label.objects.count(), 3) # LOC, PER, O self.assertEqual(SequenceAnnotation.objects.count(), 20) # num of annotation line @@ -894,40 +893,108 @@ class TestFileHandler(APITestCase): def test_conll_invalid_handler(self): with self.assertRaises(FileParseException): self.handler_test_helper(filename='example.invalid.conll', - handler=CoNLLHandler()) + handler=CoNLLHandler(self.project)) self.assertEqual(Document.objects.count(), 0) self.assertEqual(Label.objects.count(), 0) self.assertEqual(SequenceAnnotation.objects.count(), 0) def test_csv_classification_handler(self): self.handler_test_helper(filename='example.valid.2.csv', - handler=CSVClassificationHandler()) + handler=CSVClassificationHandler(self.project)) self.assertEqual(Document.objects.count(), 3) self.assertEqual(Label.objects.count(), 2) self.assertEqual(DocumentAnnotation.objects.count(), 3) def test_csv_seq2seq_handler(self): self.handler_test_helper(filename='example.valid.2.csv', - handler=CSVSeq2seqHandler()) + handler=CSVSeq2seqHandler(self.project)) self.assertEqual(Document.objects.count(), 3) self.assertEqual(Seq2seqAnnotation.objects.count(), 3) def test_json_classification_handler(self): self.handler_test_helper(filename='example.classification.jsonl', - handler=JsonClassificationHandler()) + handler=JsonClassificationHandler(self.project)) self.assertEqual(Document.objects.count(), 3) self.assertEqual(Label.objects.count(), 2) self.assertEqual(DocumentAnnotation.objects.count(), 4) def test_json_labeling_handler(self): self.handler_test_helper(filename='example.labeling.jsonl', - handler=JsonLabelingHandler()) + handler=JsonLabelingHandler(self.project)) self.assertEqual(Document.objects.count(), 3) self.assertEqual(Label.objects.count(), 3) self.assertEqual(SequenceAnnotation.objects.count(), 4) def test_json_seq2seq_handler(self): self.handler_test_helper(filename='example.seq2seq.jsonl', - handler=JsonSeq2seqHandler()) + handler=JsonSeq2seqHandler(self.project)) self.assertEqual(Document.objects.count(), 3) self.assertEqual(Seq2seqAnnotation.objects.count(), 4) + + +class TestDownloader(APITestCase): + + @classmethod + def setUpTestData(cls): + cls.super_user_name = 'super_user_name' + cls.super_user_pass = 'super_user_pass' + # Todo: change super_user to project_admin. + super_user = User.objects.create_superuser(username=cls.super_user_name, + password=cls.super_user_pass, + email='fizz@buzz.com') + cls.classification_project = mommy.make('server.TextClassificationProject', + users=[super_user], project_type=DOCUMENT_CLASSIFICATION) + cls.labeling_project = mommy.make('server.SequenceLabelingProject', + users=[super_user], project_type=SEQUENCE_LABELING) + cls.seq2seq_project = mommy.make('server.Seq2seqProject', users=[super_user], project_type=SEQ2SEQ) + cls.classification_url = reverse(viewname='doc_downloader', args=[cls.classification_project.id]) + cls.labeling_url = reverse(viewname='doc_downloader', args=[cls.labeling_project.id]) + cls.seq2seq_url = reverse(viewname='doc_downloader', args=[cls.seq2seq_project.id]) + + def setUp(self): + self.client.login(username=self.super_user_name, + password=self.super_user_pass) + + def download_test_helper(self, url, format, expected_status): + response = self.client.get(url, data={'q': format}) + self.assertEqual(response.status_code, expected_status) + + def test_can_upload_conll_format_file(self): + self.download_test_helper(url=self.labeling_url, + format='conll', + expected_status=status.HTTP_400_BAD_REQUEST) + + def test_can_download_classification_csv(self): + self.download_test_helper(url=self.classification_url, + format='csv', + expected_status=status.HTTP_200_OK) + + def test_cannot_download_labeling_csv(self): + self.download_test_helper(url=self.labeling_url, + format='csv', + expected_status=status.HTTP_400_BAD_REQUEST) + + def test_can_download_seq2seq_csv(self): + self.download_test_helper(url=self.seq2seq_url, + format='csv', + expected_status=status.HTTP_200_OK) + + def test_can_download_classification_jsonl(self): + self.download_test_helper(url=self.classification_url, + format='json', + expected_status=status.HTTP_200_OK) + + def test_can_download_labeling_jsonl(self): + self.download_test_helper(url=self.labeling_url, + format='json', + expected_status=status.HTTP_200_OK) + + def test_can_download_seq2seq_jsonl(self): + self.download_test_helper(url=self.seq2seq_url, + format='json', + expected_status=status.HTTP_200_OK) + + def test_can_download_plain_text(self): + self.download_test_helper(url=self.classification_url, + format='plain', + expected_status=status.HTTP_400_BAD_REQUEST)