Browse Source

Add downloader to support csv and json format

pull/110/head
Hironsan 5 years ago
parent
commit
14bcb0a703
6 changed files with 207 additions and 62 deletions
  1. 4
      app/server/admin.py
  2. 120
      app/server/api.py
  3. 6
      app/server/api_urls.py
  4. 27
      app/server/models.py
  5. 25
      app/server/serializers.py
  6. 87
      app/server/tests/test_api.py

4
app/server/admin.py

@ -2,6 +2,7 @@ from django.contrib import admin
from .models import Label, Document, Project
from .models import DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation
from .models import TextClassificationProject, SequenceLabelingProject, Seq2seqProject
admin.site.register(DocumentAnnotation)
admin.site.register(SequenceAnnotation)
@ -9,3 +10,6 @@ admin.site.register(Seq2seqAnnotation)
admin.site.register(Label)
admin.site.register(Document)
admin.site.register(Project)
admin.site.register(TextClassificationProject)
admin.site.register(SequenceLabelingProject)
admin.site.register(Seq2seqProject)

120
app/server/api.py

@ -5,10 +5,11 @@ from collections import Counter
from itertools import chain
from django.db import transaction
from django.http import HttpResponse
from django.shortcuts import get_object_or_404
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import generics, filters, status
from rest_framework.exceptions import ParseError
from rest_framework.exceptions import ParseError, ValidationError
from rest_framework.permissions import IsAuthenticated, IsAdminUser
from rest_framework.response import Response
from rest_framework.views import APIView
@ -186,31 +187,50 @@ class TextUploadAPI(APIView):
raise ParseError('Empty content')
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
handler = project.get_upload_handler(request.data['format'])
handler.handle_uploaded_file(request.FILES['file'], project, self.request.user)
handler.handle_uploaded_file(request.FILES['file'], self.request.user)
return Response(status=status.HTTP_201_CREATED)
class TextDownloadAPI(APIView):
"""API for text download."""
permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUser)
def get(self, request, *args, **kwargs):
project_id = self.kwargs['project_id']
format = request.query_params.get('q')
project = get_object_or_404(Project, pk=project_id)
handler = project.get_upload_handler(format)
response = handler.render()
return response
class FileHandler(object):
annotation_serializer = None
def __init__(self, project):
self.project = project
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
def handle_uploaded_file(self, file, user):
raise NotImplementedError()
def parse(self, file):
raise NotImplementedError()
def save_doc(self, data, project):
def render(self):
raise NotImplementedError()
def save_doc(self, data):
serializer = DocumentSerializer(data=data)
serializer.is_valid(raise_exception=True)
doc = serializer.save(project=project)
doc = serializer.save(project=self.project)
return doc
def save_label(self, data, project):
label = Label.objects.filter(project=project, **data).first()
def save_label(self, data):
label = Label.objects.filter(project=self.project, **data).first()
serializer = LabelSerializer(label, data=data)
serializer.is_valid(raise_exception=True)
label = serializer.save(project=project)
label = serializer.save(project=self.project)
return label
def save_annotation(self, data, doc, user):
@ -245,14 +265,14 @@ class CoNLLHandler(FileHandler):
annotation_serializer = SequenceAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
def handle_uploaded_file(self, file, user):
for words, tags in self.parse(file):
start_offset = 0
sent = ' '.join(words)
doc = self.save_doc({'text': sent}, project)
doc = self.save_doc({'text': sent})
for word, tag in zip(words, tags):
label = extract_label(tag)
label = self.save_label({'text': label}, project)
label = self.save_label({'text': label})
data = {'start_offset': start_offset,
'end_offset': start_offset + len(word),
'label': label.id}
@ -277,6 +297,9 @@ class CoNLLHandler(FileHandler):
if len(words) > 0:
yield words, tags
def render(self):
raise ValidationError("This project type doesn't support CoNLL format.")
class PlainTextHandler(FileHandler):
"""Uploads plain text.
@ -289,15 +312,18 @@ class PlainTextHandler(FileHandler):
```
"""
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
def handle_uploaded_file(self, file, user):
for text in self.parse(file):
self.save_doc({'text': text}, project)
self.save_doc({'text': text})
def parse(self, file):
file = io.TextIOWrapper(file, encoding='utf-8')
for i, line in enumerate(file, start=1):
yield line.strip()
def render(self):
raise ValidationError("You cannot download plain text. Please specify csv or json.")
class CSVHandler(FileHandler):
"""Uploads csv file.
@ -327,27 +353,58 @@ class CSVHandler(FileHandler):
else:
raise FileParseException(line_num=i, line=row)
def render(self):
raise NotImplementedError()
class CSVClassificationHandler(CSVHandler):
annotation_serializer = DocumentAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
def handle_uploaded_file(self, file, user):
for text, label in self.parse(file):
doc = self.save_doc({'text': text}, project)
label = self.save_label({'text': label}, project)
doc = self.save_doc({'text': text})
label = self.save_label({'text': label})
self.save_annotation({'label': label.id}, doc, user)
def render(self):
queryset = self.project.documents.all()
serializer = DocumentSerializer(queryset, many=True)
filename = '_'.join(self.project.name.lower().split())
response = HttpResponse(content_type='text/csv')
response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename)
writer = csv.writer(response)
writer.writerow(['id', 'text', 'label', 'user'])
for d in serializer.data:
for a in d['annotations']:
row = [d['id'], d['text'], a['label'], a['user']]
writer.writerow(row)
return response
class CSVSeq2seqHandler(CSVHandler):
annotation_serializer = Seq2seqAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
def handle_uploaded_file(self, file, user):
for text, label in self.parse(file):
doc = self.save_doc({'text': text}, project)
doc = self.save_doc({'text': text})
self.save_annotation({'text': label}, doc, user)
def render(self):
queryset = self.project.documents.all()
serializer = DocumentSerializer(queryset, many=True)
filename = '_'.join(self.project.name.lower().split())
response = HttpResponse(content_type='text/csv')
response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename)
writer = csv.writer(response)
writer.writerow(['id', 'text', 'label', 'user'])
for d in serializer.data:
for a in d['annotations']:
row = [d['id'], d['text'], a['text'], a['user']]
writer.writerow(row)
return response
class JsonHandler(FileHandler):
"""Uploads jsonl file.
@ -367,6 +424,17 @@ class JsonHandler(FileHandler):
except json.decoder.JSONDecodeError:
raise FileParseException(line_num=i, line=line)
def render(self):
queryset = self.project.documents.all()
serializer = DocumentSerializer(queryset, many=True)
filename = '_'.join(self.project.name.lower().split())
response = HttpResponse(content_type='application/json')
response['Content-Disposition'] = 'attachment; filename="{}.jsonl"'.format(filename)
for d in serializer.data:
dump = json.dumps(d, ensure_ascii=False)
response.write(dump + '\n')
return response
class JsonClassificationHandler(JsonHandler):
"""Upload jsonl for text classification.
@ -380,11 +448,11 @@ class JsonClassificationHandler(JsonHandler):
annotation_serializer = DocumentAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
def handle_uploaded_file(self, file, user):
for data in self.parse(file):
doc = self.save_doc(data, project)
doc = self.save_doc(data)
for label in data['labels']:
label = self.save_label({'text': label}, project)
label = self.save_label({'text': label})
self.save_annotation({'label': label.id}, doc, user)
@ -400,11 +468,11 @@ class JsonLabelingHandler(JsonHandler):
annotation_serializer = SequenceAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
def handle_uploaded_file(self, file, user):
for data in self.parse(file):
doc = self.save_doc(data, project)
doc = self.save_doc(data)
for start_offset, end_offset, label in data['entities']:
label = self.save_label({'text': label}, project)
label = self.save_label({'text': label})
data = {'label': label.id,
'start_offset': start_offset,
'end_offset': end_offset}
@ -423,8 +491,8 @@ class JsonSeq2seqHandler(JsonHandler):
annotation_serializer = Seq2seqAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
def handle_uploaded_file(self, file, user):
for data in self.parse(file):
doc = self.save_doc(data, project)
doc = self.save_doc(data)
for label in data['labels']:
self.save_annotation({'text': label}, doc, user)

6
app/server/api_urls.py

@ -6,7 +6,7 @@ from .api import LabelList, LabelDetail
from .api import DocumentList, DocumentDetail
from .api import EntityList, EntityDetail
from .api import AnnotationList, AnnotationDetail
from .api import TextUploadAPI
from .api import TextUploadAPI, TextDownloadAPI
from .api import StatisticsAPI
@ -32,7 +32,9 @@ urlpatterns = [
path('projects/<int:project_id>/docs/<int:doc_id>/annotations/<int:annotation_id>',
AnnotationDetail.as_view(), name='annotation_detail'),
path('projects/<int:project_id>/docs/upload',
TextUploadAPI.as_view(), name='doc_uploader')
TextUploadAPI.as_view(), name='doc_uploader'),
path('projects/<int:project_id>/docs/download',
TextDownloadAPI.as_view(), name='doc_downloader')
]
urlpatterns = format_suffix_patterns(urlpatterns, allowed=['json', 'xml'])

27
app/server/models.py

@ -3,6 +3,7 @@ from django.db import models
from django.urls import reverse
from django.contrib.auth.models import User
from django.contrib.staticfiles.storage import staticfiles_storage
from rest_framework.exceptions import ValidationError
from polymorphic.models import PolymorphicModel
from .utils import get_key_choices
@ -43,7 +44,7 @@ class Project(PolymorphicModel):
raise NotImplementedError()
def get_upload_handler(self, format):
raise NotImplementedError
raise NotImplementedError()
def __str__(self):
return self.name
@ -68,12 +69,12 @@ class TextClassificationProject(Project):
def get_upload_handler(self, format):
from .api import PlainTextHandler, CSVClassificationHandler, JsonClassificationHandler
if format == 'plain':
return PlainTextHandler()
return PlainTextHandler(self)
elif format == 'csv':
return CSVClassificationHandler()
return CSVClassificationHandler(self)
elif format == 'json':
return JsonClassificationHandler()
raise ValueError('format {} is invalid.'.format(format))
return JsonClassificationHandler(self)
raise ValidationError('format {} is invalid.'.format(format))
class SequenceLabelingProject(Project):
@ -95,12 +96,12 @@ class SequenceLabelingProject(Project):
def get_upload_handler(self, format):
from .api import PlainTextHandler, CoNLLHandler, JsonLabelingHandler
if format == 'plain':
return PlainTextHandler()
return PlainTextHandler(self)
elif format == 'conll':
return CoNLLHandler()
return CoNLLHandler(self)
elif format == 'json':
return JsonLabelingHandler()
raise ValueError('format {} is invalid.'.format(format))
return JsonLabelingHandler(self)
raise ValidationError('format {} is invalid.'.format(format))
class Seq2seqProject(Project):
@ -122,12 +123,12 @@ class Seq2seqProject(Project):
def get_upload_handler(self, format):
from .api import PlainTextHandler, CSVSeq2seqHandler, JsonSeq2seqHandler
if format == 'plain':
return PlainTextHandler()
return PlainTextHandler(self)
elif format == 'csv':
return CSVSeq2seqHandler()
return CSVSeq2seqHandler(self)
elif format == 'json':
return JsonSeq2seqHandler()
raise ValueError('format {} is invalid.'.format(format))
return JsonSeq2seqHandler(self)
raise ValidationError('format {} is invalid.'.format(format))
class Label(models.Model):

25
app/server/serializers.py

@ -19,14 +19,14 @@ class DocumentSerializer(serializers.ModelSerializer):
def get_annotations(self, instance):
request = self.context.get('request')
view = self.context.get('view', None)
if request and view:
project = get_object_or_404(Project, pk=view.kwargs['project_id'])
model = project.get_annotation_class()
serializer = project.get_annotation_serializer()
annotations = model.objects.filter(user=request.user, document=instance.id)
serializer = serializer(annotations, many=True)
return serializer.data
project = instance.project
model = project.get_annotation_class()
serializer = project.get_annotation_serializer()
annotations = model.objects.filter(document=instance.id)
if request:
annotations = annotations.filter(user=request.user)
serializer = serializer(annotations, many=True)
return serializer.data
class Meta:
model = Document
@ -91,7 +91,8 @@ class DocumentAnnotationSerializer(serializers.ModelSerializer):
class Meta:
model = DocumentAnnotation
fields = ('id', 'prob', 'label')
fields = ('id', 'prob', 'label', 'user')
read_only_fields = ('user', )
def create(self, validated_data):
annotation = DocumentAnnotation.objects.create(**validated_data)
@ -104,7 +105,8 @@ class SequenceAnnotationSerializer(serializers.ModelSerializer):
class Meta:
model = SequenceAnnotation
fields = ('id', 'prob', 'label', 'start_offset', 'end_offset')
fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user')
read_only_fields = ('user',)
def create(self, validated_data):
annotation = SequenceAnnotation.objects.create(**validated_data)
@ -115,4 +117,5 @@ class Seq2seqAnnotationSerializer(serializers.ModelSerializer):
class Meta:
model = Seq2seqAnnotation
fields = ('id', 'text')
fields = ('id', 'text', 'user')
read_only_fields = ('user',)

87
app/server/tests/test_api.py

@ -3,7 +3,6 @@ import os
from rest_framework import status
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from mixer.backend.django import mixer
from model_mommy import mommy
from ..models import User, SequenceAnnotation, Document, Label, Seq2seqAnnotation, DocumentAnnotation
from ..models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
@ -826,7 +825,7 @@ class TestUploader(APITestCase):
expected_status=status.HTTP_201_CREATED)
def test_can_upload_seq2seq_csv(self):
self.upload_test_helper(url=self.classification_url,
self.upload_test_helper(url=self.seq2seq_url,
filename='example.valid.2.csv',
format='csv',
expected_status=status.HTTP_201_CREATED)
@ -882,11 +881,11 @@ class TestFileHandler(APITestCase):
def handler_test_helper(self, filename, handler):
with open(os.path.join(DATA_DIR, filename), mode='rb') as f:
handler.handle_uploaded_file(f, self.project, self.super_user)
handler.handle_uploaded_file(f, self.super_user)
def test_conll_handler(self):
self.handler_test_helper(filename='example.valid.conll',
handler=CoNLLHandler())
handler=CoNLLHandler(self.project))
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Label.objects.count(), 3) # LOC, PER, O
self.assertEqual(SequenceAnnotation.objects.count(), 20) # num of annotation line
@ -894,40 +893,108 @@ class TestFileHandler(APITestCase):
def test_conll_invalid_handler(self):
with self.assertRaises(FileParseException):
self.handler_test_helper(filename='example.invalid.conll',
handler=CoNLLHandler())
handler=CoNLLHandler(self.project))
self.assertEqual(Document.objects.count(), 0)
self.assertEqual(Label.objects.count(), 0)
self.assertEqual(SequenceAnnotation.objects.count(), 0)
def test_csv_classification_handler(self):
self.handler_test_helper(filename='example.valid.2.csv',
handler=CSVClassificationHandler())
handler=CSVClassificationHandler(self.project))
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Label.objects.count(), 2)
self.assertEqual(DocumentAnnotation.objects.count(), 3)
def test_csv_seq2seq_handler(self):
self.handler_test_helper(filename='example.valid.2.csv',
handler=CSVSeq2seqHandler())
handler=CSVSeq2seqHandler(self.project))
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Seq2seqAnnotation.objects.count(), 3)
def test_json_classification_handler(self):
self.handler_test_helper(filename='example.classification.jsonl',
handler=JsonClassificationHandler())
handler=JsonClassificationHandler(self.project))
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Label.objects.count(), 2)
self.assertEqual(DocumentAnnotation.objects.count(), 4)
def test_json_labeling_handler(self):
self.handler_test_helper(filename='example.labeling.jsonl',
handler=JsonLabelingHandler())
handler=JsonLabelingHandler(self.project))
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Label.objects.count(), 3)
self.assertEqual(SequenceAnnotation.objects.count(), 4)
def test_json_seq2seq_handler(self):
self.handler_test_helper(filename='example.seq2seq.jsonl',
handler=JsonSeq2seqHandler())
handler=JsonSeq2seqHandler(self.project))
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Seq2seqAnnotation.objects.count(), 4)
class TestDownloader(APITestCase):
@classmethod
def setUpTestData(cls):
cls.super_user_name = 'super_user_name'
cls.super_user_pass = 'super_user_pass'
# Todo: change super_user to project_admin.
super_user = User.objects.create_superuser(username=cls.super_user_name,
password=cls.super_user_pass,
email='fizz@buzz.com')
cls.classification_project = mommy.make('server.TextClassificationProject',
users=[super_user], project_type=DOCUMENT_CLASSIFICATION)
cls.labeling_project = mommy.make('server.SequenceLabelingProject',
users=[super_user], project_type=SEQUENCE_LABELING)
cls.seq2seq_project = mommy.make('server.Seq2seqProject', users=[super_user], project_type=SEQ2SEQ)
cls.classification_url = reverse(viewname='doc_downloader', args=[cls.classification_project.id])
cls.labeling_url = reverse(viewname='doc_downloader', args=[cls.labeling_project.id])
cls.seq2seq_url = reverse(viewname='doc_downloader', args=[cls.seq2seq_project.id])
def setUp(self):
self.client.login(username=self.super_user_name,
password=self.super_user_pass)
def download_test_helper(self, url, format, expected_status):
response = self.client.get(url, data={'q': format})
self.assertEqual(response.status_code, expected_status)
def test_can_upload_conll_format_file(self):
self.download_test_helper(url=self.labeling_url,
format='conll',
expected_status=status.HTTP_400_BAD_REQUEST)
def test_can_download_classification_csv(self):
self.download_test_helper(url=self.classification_url,
format='csv',
expected_status=status.HTTP_200_OK)
def test_cannot_download_labeling_csv(self):
self.download_test_helper(url=self.labeling_url,
format='csv',
expected_status=status.HTTP_400_BAD_REQUEST)
def test_can_download_seq2seq_csv(self):
self.download_test_helper(url=self.seq2seq_url,
format='csv',
expected_status=status.HTTP_200_OK)
def test_can_download_classification_jsonl(self):
self.download_test_helper(url=self.classification_url,
format='json',
expected_status=status.HTTP_200_OK)
def test_can_download_labeling_jsonl(self):
self.download_test_helper(url=self.labeling_url,
format='json',
expected_status=status.HTTP_200_OK)
def test_can_download_seq2seq_jsonl(self):
self.download_test_helper(url=self.seq2seq_url,
format='json',
expected_status=status.HTTP_200_OK)
def test_can_download_plain_text(self):
self.download_test_helper(url=self.classification_url,
format='plain',
expected_status=status.HTTP_400_BAD_REQUEST)
Loading…
Cancel
Save