Browse Source

Add file uploader to support plain, csv, json and CoNLL format

Corresponding to issue #11
pull/110/head
Hironsan 5 years ago
committed by Hironsan
parent
commit
5af45811ed
10 changed files with 316 additions and 100 deletions
  1. 207
      app/server/api.py
  2. 12
      app/server/api_urls.py
  3. 18
      app/server/models.py
  4. 6
      app/server/serializers.py
  5. 3
      app/server/tests/data/example.classification.jsonl
  6. 3
      app/server/tests/data/example.labeling.jsonl
  7. 3
      app/server/tests/data/example.seq2seq.jsonl
  8. 6
      app/server/tests/data/example.valid.2.csv
  9. 147
      app/server/tests/test_api.py
  10. 11
      app/server/utils.py

207
app/server/api.py

@ -17,9 +17,11 @@ from rest_framework.parsers import MultiPartParser
from .exceptions import FileParseException
from .models import Project, Label, Document
from .models import SequenceAnnotation
from .models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsMyEntity
from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer
from .serializers import SequenceAnnotationSerializer
from .serializers import SequenceAnnotationSerializer, DocumentAnnotationSerializer, Seq2seqAnnotationSerializer
from .utils import extract_label
class ProjectList(generics.ListCreateAPIView):
@ -146,18 +148,62 @@ class TextUploadAPI(APIView):
def post(self, request, *args, **kwargs):
if 'file' not in request.FILES:
raise ParseError('Empty content')
self.handle_uploaded_file(request.FILES['file'])
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
handler = self.decide_handler(request.data['format'], project.project_type)
handler.handle_uploaded_file(request.FILES['file'], project, self.request.user)
return Response(status=status.HTTP_201_CREATED)
def decide_handler(self, format, project_type):
if format == 'plain':
return PlainTextHandler()
elif format == 'conll' and project_type:
return CoNLLHandler()
elif format == 'csv':
if project_type == DOCUMENT_CLASSIFICATION:
return CSVClassificationHandler()
elif project_type == SEQ2SEQ:
return CSVSeq2seqHandler()
elif format == 'json':
if project_type == DOCUMENT_CLASSIFICATION:
return JsonClassificationHandler()
elif project_type == SEQUENCE_LABELING:
return JsonLabelingHandler()
elif project_type == SEQ2SEQ:
return JsonSeq2seqHandler()
raise ValueError('format {} is invalid.'.format(format))
class FileHandler(object):
annotation_serializer = None
@transaction.atomic
def handle_uploaded_file(self, file):
def handle_uploaded_file(self, file, project, user):
raise NotImplementedError()
def parse(self, file):
raise NotImplementedError()
def save_doc(self, data, project):
serializer = DocumentSerializer(data=data)
serializer.is_valid(raise_exception=True)
doc = serializer.save(project=project)
return doc
def save_label(self, data, project):
label = Label.objects.filter(project=project, **data).first()
serializer = LabelSerializer(label, data=data)
serializer.is_valid(raise_exception=True)
label = serializer.save(project=project)
return label
def save_annotation(self, data, doc, user):
serializer = self.annotation_serializer(data=data)
serializer.is_valid(raise_exception=True)
annotation = serializer.save(document=doc, user=user)
return annotation
class CoNLLFileUploadAPI(TextUploadAPI):
class CoNLLHandler(FileHandler):
"""Uploads CoNLL format file.
The file format is tab-separated values.
@ -179,19 +225,22 @@ class CoNLLFileUploadAPI(TextUploadAPI):
...
```
"""
annotation_serializer = SequenceAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
for words in self.parse(file):
sent = self.words_to_sent(words)
data = {'text': sent}
serializer = DocumentSerializer(data=data)
serializer.is_valid(raise_exception=True)
serializer.save(project=project)
def words_to_sent(self, words):
return ' '.join(words)
def handle_uploaded_file(self, file, project, user):
for words, tags in self.parse(file):
start_offset = 0
sent = ' '.join(words)
doc = self.save_doc({'text': sent}, project)
for word, tag in zip(words, tags):
label = extract_label(tag)
label = self.save_label({'text': label}, project)
data = {'start_offset': start_offset,
'end_offset': start_offset + len(word),
'label': label.id}
start_offset += len(word) + 1
self.save_annotation(data, doc, user)
def parse(self, file):
words, tags = [], []
@ -206,13 +255,13 @@ class CoNLLFileUploadAPI(TextUploadAPI):
words.append(word)
tags.append(tag)
else:
yield words
yield words, tags
words, tags = [], []
if len(words) > 0:
yield words
yield words, tags
class PlainTextUploadAPI(TextUploadAPI):
class PlainTextHandler(FileHandler):
"""Uploads plain text.
The file format is as follows:
@ -223,13 +272,9 @@ class PlainTextUploadAPI(TextUploadAPI):
```
"""
@transaction.atomic
def handle_uploaded_file(self, file):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
def handle_uploaded_file(self, file, project, user):
for text in self.parse(file):
data = {'text': text}
serializer = DocumentSerializer(data=data)
serializer.is_valid(raise_exception=True)
serializer.save(project=project)
self.save_doc({'text': text}, project)
def parse(self, file):
file = io.TextIOWrapper(file, encoding='utf-8')
@ -237,30 +282,20 @@ class PlainTextUploadAPI(TextUploadAPI):
yield line.strip()
class CSVUploadAPI(TextUploadAPI):
class CSVHandler(FileHandler):
"""Uploads csv file.
The file format is comma separated values.
Column names are required at the top of a file.
For example:
```
text, label(optional)
"EU rejects German call to boycott British lamb.",
"President Obama is speaking at the White House.",
"He lives in Newark, Ohio.",
text, label
"EU rejects German call to boycott British lamb.",Politics
"President Obama is speaking at the White House.",Politics
"He lives in Newark, Ohio.",Other
...
```
"""
@transaction.atomic
def handle_uploaded_file(self, file):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
for text, label in self.parse(file):
data = {'text': text}
serializer = DocumentSerializer(data=data)
serializer.is_valid(raise_exception=True)
serializer.save(project=project)
def parse(self, file):
file = io.TextIOWrapper(file, encoding='utf-8')
reader = csv.reader(file)
@ -276,7 +311,28 @@ class CSVUploadAPI(TextUploadAPI):
raise FileParseException(line_num=i, line=row)
class JSONLUploadAPI(TextUploadAPI):
class CSVClassificationHandler(CSVHandler):
annotation_serializer = DocumentAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
for text, label in self.parse(file):
doc = self.save_doc({'text': text}, project)
label = self.save_label({'text': label}, project)
self.save_annotation({'label': label.id}, doc, user)
class CSVSeq2seqHandler(CSVHandler):
annotation_serializer = Seq2seqAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
for text, label in self.parse(file):
doc = self.save_doc({'text': text}, project)
self.save_annotation({'text': label}, doc, user)
class JsonHandler(FileHandler):
"""Uploads jsonl file.
The file format is as follows:
@ -286,15 +342,6 @@ class JSONLUploadAPI(TextUploadAPI):
...
```
"""
@transaction.atomic
def handle_uploaded_file(self, file):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
for data in self.parse(file):
serializer = DocumentSerializer(data=data)
serializer.is_valid(raise_exception=True)
serializer.save(project=project)
def parse(self, file):
for i, line in enumerate(file, start=1):
try:
@ -302,3 +349,65 @@ class JSONLUploadAPI(TextUploadAPI):
yield j
except json.decoder.JSONDecodeError:
raise FileParseException(line_num=i, line=line)
class JsonClassificationHandler(JsonHandler):
"""Upload jsonl for text classification.
The format is as follows:
```
{"text": "Python is awesome!", "labels": ["positive"]}
...
```
"""
annotation_serializer = DocumentAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
for data in self.parse(file):
doc = self.save_doc(data, project)
for label in data['labels']:
label = self.save_label({'text': label}, project)
self.save_annotation({'label': label.id}, doc, user)
class JsonLabelingHandler(JsonHandler):
"""Upload jsonl for sequence labeling.
The format is as follows:
```
{"text": "Python is awesome!", "entities": [[0, 6, "Product"],]}
...
```
"""
annotation_serializer = SequenceAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
for data in self.parse(file):
doc = self.save_doc(data, project)
for start_offset, end_offset, label in data['entities']:
label = self.save_label({'text': label}, project)
data = {'label': label.id,
'start_offset': start_offset,
'end_offset': end_offset}
self.save_annotation(data, doc, user)
class JsonSeq2seqHandler(JsonHandler):
"""Upload jsonl for seq2seq.
The format is as follows:
```
{"text": "Hello, World!", "labels": ["こんにちは、世界!"]}
...
```
"""
annotation_serializer = Seq2seqAnnotationSerializer
@transaction.atomic
def handle_uploaded_file(self, file, project, user):
for data in self.parse(file):
doc = self.save_doc(data, project)
for label in data['labels']:
self.save_annotation({'text': label}, doc, user)

12
app/server/api_urls.py

@ -5,7 +5,7 @@ from .api import ProjectList, ProjectDetail
from .api import LabelList, LabelDetail
from .api import DocumentList, DocumentDetail
from .api import EntityList, EntityDetail
from .api import CoNLLFileUploadAPI, CSVUploadAPI, JSONLUploadAPI, PlainTextUploadAPI
from .api import TextUploadAPI
from .api import StatisticsAPI
@ -26,14 +26,8 @@ urlpatterns = [
EntityList.as_view(), name='entity_list'),
path('projects/<int:project_id>/docs/<int:doc_id>/entities/<int:entity_id>',
EntityDetail.as_view(), name='entity_detail'),
path('projects/<int:project_id>/plain_uploader',
PlainTextUploadAPI.as_view(), name='plain_uploader'),
path('projects/<int:project_id>/conll_uploader',
CoNLLFileUploadAPI.as_view(), name='conll_uploader'),
path('projects/<int:project_id>/csv_uploader',
CSVUploadAPI.as_view(), name='csv_uploader'),
path('projects/<int:project_id>/json_uploader',
JSONLUploadAPI.as_view(), name='json_uploader'),
path('projects/<int:project_id>/docs/upload',
TextUploadAPI.as_view(), name='doc_uploader')
]
urlpatterns = format_suffix_patterns(urlpatterns, allowed=['json', 'xml'])

18
app/server/models.py

@ -6,17 +6,17 @@ from django.contrib.staticfiles.storage import staticfiles_storage
from .utils import get_key_choices
class Project(models.Model):
DOCUMENT_CLASSIFICATION = 'DocumentClassification'
SEQUENCE_LABELING = 'SequenceLabeling'
SEQ2SEQ = 'Seq2seq'
DOCUMENT_CLASSIFICATION = 'DocumentClassification'
SEQUENCE_LABELING = 'SequenceLabeling'
SEQ2SEQ = 'Seq2seq'
PROJECT_CHOICES = (
(DOCUMENT_CLASSIFICATION, 'document classification'),
(SEQUENCE_LABELING, 'sequence labeling'),
(SEQ2SEQ, 'sequence to sequence'),
)
PROJECT_CHOICES = (
(DOCUMENT_CLASSIFICATION, 'document classification'),
(SEQUENCE_LABELING, 'sequence labeling'),
(SEQ2SEQ, 'sequence to sequence'),
)
class Project(models.Model):
name = models.CharField(max_length=100)
description = models.TextField()
guideline = models.TextField()

6
app/server/serializers.py

@ -38,7 +38,8 @@ class ProjectFilteredPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField):
class DocumentAnnotationSerializer(serializers.ModelSerializer):
label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
# label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
class Meta:
model = DocumentAnnotation
@ -50,7 +51,8 @@ class DocumentAnnotationSerializer(serializers.ModelSerializer):
class SequenceAnnotationSerializer(serializers.ModelSerializer):
label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
#label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
class Meta:
model = SequenceAnnotation

3
app/server/tests/data/example.classification.jsonl

@ -0,0 +1,3 @@
{"text": "example", "labels": ["positive"]}
{"text": "example", "labels": ["positive", "negative"]}
{"text": "example", "labels": ["negative"]}

3
app/server/tests/data/example.labeling.jsonl

@ -0,0 +1,3 @@
{"text": "example", "entities": [[0, 1, "LOC"], [0, 2, "ORG"]]}
{"text": "example", "entities": [[0, 1, "LOC"]]}
{"text": "example", "entities": [[0, 1, "PER"]]}

3
app/server/tests/data/example.seq2seq.jsonl

@ -0,0 +1,3 @@
{"text": "example", "labels": ["example1", "example2"]}
{"text": "example", "labels": ["example"]}
{"text": "example", "labels": ["example"]}

6
app/server/tests/data/example.valid.2.csv

@ -1,4 +1,4 @@
text, label
AAA, Positive
BBB, Positive
CCC, Negative
AAA,Positive
BBB,Positive
CCC,Negative

147
app/server/tests/test_api.py

@ -4,7 +4,11 @@ from rest_framework import status
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from mixer.backend.django import mixer
from ..models import User, SequenceAnnotation, Document
from ..models import User, SequenceAnnotation, Document, Label, Seq2seqAnnotation, DocumentAnnotation
from ..models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
from ..api import CoNLLHandler, CSVClassificationHandler, CSVSeq2seqHandler
from ..api import JsonClassificationHandler, JsonLabelingHandler, JsonSeq2seqHandler
from ..exceptions import FileParseException
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
@ -622,62 +626,149 @@ class TestUploader(APITestCase):
@classmethod
def setUpTestData(cls):
cls.project_member_name = 'project_member_name'
cls.project_member_pass = 'project_member_pass'
project_member = User.objects.create_user(username=cls.project_member_name,
password=cls.project_member_pass)
cls.super_user_name = 'super_user_name'
cls.super_user_pass = 'super_user_pass'
# Todo: change super_user to project_admin.
super_user = User.objects.create_superuser(username=cls.super_user_name,
password=cls.super_user_pass,
email='fizz@buzz.com')
cls.main_project = mixer.blend('server.Project', users=[project_member, super_user])
cls.conll_url = reverse(viewname='conll_uploader', args=[cls.main_project.id])
cls.csv_url = reverse(viewname='csv_uploader', args=[cls.main_project.id])
cls.json_url = reverse(viewname='json_uploader', args=[cls.main_project.id])
cls.plain_url = reverse(viewname='plain_uploader', args=[cls.main_project.id])
cls.classification_project = mixer.blend('server.Project', users=[super_user], project_type=DOCUMENT_CLASSIFICATION)
cls.labeling_project = mixer.blend('server.Project', users=[super_user], project_type=SEQUENCE_LABELING)
cls.seq2seq_project = mixer.blend('server.Project', users=[super_user], project_type=SEQ2SEQ)
cls.classification_url = reverse(viewname='doc_uploader', args=[cls.classification_project.id])
cls.labeling_url = reverse(viewname='doc_uploader', args=[cls.labeling_project.id])
cls.seq2seq_url = reverse(viewname='doc_uploader', args=[cls.seq2seq_project.id])
def setUp(self):
self.client.login(username=self.super_user_name,
password=self.super_user_pass)
def upload_test_helper(self, filename, url, expected_status):
def upload_test_helper(self, url, filename, format, expected_status):
with open(os.path.join(DATA_DIR, filename)) as f:
response = self.client.post(url, data={'file': f})
response = self.client.post(url, data={'file': f, 'format': format})
self.assertEqual(response.status_code, expected_status)
def test_can_upload_conll_format_file(self):
self.upload_test_helper(filename='example.valid.conll',
url=self.conll_url,
self.upload_test_helper(url=self.labeling_url,
filename='example.valid.conll',
format='conll',
expected_status=status.HTTP_201_CREATED)
def test_cannot_upload_wrong_conll_format_file(self):
self.upload_test_helper(filename='example.invalid.conll',
url=self.conll_url,
self.upload_test_helper(url=self.labeling_url,
filename='example.invalid.conll',
format='conll',
expected_status=status.HTTP_400_BAD_REQUEST)
def test_can_upload_csv_with_label(self):
self.upload_test_helper(filename='example.valid.2.csv',
url=self.csv_url,
def test_can_upload_classification_csv(self):
self.upload_test_helper(url=self.classification_url,
filename='example.valid.2.csv',
format='csv',
expected_status=status.HTTP_201_CREATED)
def test_can_upload_seq2seq_csv(self):
self.upload_test_helper(url=self.classification_url,
filename='example.valid.2.csv',
format='csv',
expected_status=status.HTTP_201_CREATED)
def test_cannot_upload_csv_file_does_not_match_column_and_row(self):
self.upload_test_helper(filename='example.invalid.1.csv',
url=self.csv_url,
self.upload_test_helper(url=self.classification_url,
filename='example.invalid.1.csv',
format='csv',
expected_status=status.HTTP_400_BAD_REQUEST)
def test_cannot_upload_csv_file_has_too_many_columns(self):
self.upload_test_helper(filename='example.invalid.2.csv',
url=self.csv_url,
self.upload_test_helper(url=self.classification_url,
filename='example.invalid.2.csv',
format='csv',
expected_status=status.HTTP_400_BAD_REQUEST)
def test_can_upload_jsonl(self):
self.upload_test_helper(filename='example.jsonl',
url=self.json_url,
def test_can_upload_classification_jsonl(self):
self.upload_test_helper(url=self.classification_url,
filename='example.classification.jsonl',
format='json',
expected_status=status.HTTP_201_CREATED)
def test_can_upload_labeling_jsonl(self):
self.upload_test_helper(url=self.labeling_url,
filename='example.labeling.jsonl',
format='json',
expected_status=status.HTTP_201_CREATED)
def test_can_upload_seq2seq_jsonl(self):
self.upload_test_helper(url=self.seq2seq_url,
filename='example.seq2seq.jsonl',
format='json',
expected_status=status.HTTP_201_CREATED)
def test_can_upload_plain_text(self):
self.upload_test_helper(filename='example.txt',
url=self.plain_url,
self.upload_test_helper(url=self.classification_url,
filename='example.txt',
format='plain',
expected_status=status.HTTP_201_CREATED)
class TestFileHandler(APITestCase):
@classmethod
def setUpTestData(cls):
cls.super_user_name = 'super_user_name'
cls.super_user_pass = 'super_user_pass'
# Todo: change super_user to project_admin.
cls.super_user = User.objects.create_superuser(username=cls.super_user_name,
password=cls.super_user_pass,
email='fizz@buzz.com')
cls.project = mixer.blend('server.Project', users=[cls.super_user])
def handler_test_helper(self, filename, handler):
with open(os.path.join(DATA_DIR, filename), mode='rb') as f:
handler.handle_uploaded_file(f, self.project, self.super_user)
def test_conll_handler(self):
self.handler_test_helper(filename='example.valid.conll',
handler=CoNLLHandler())
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Label.objects.count(), 3) # LOC, PER, O
self.assertEqual(SequenceAnnotation.objects.count(), 20) # num of annotation line
def test_conll_invalid_handler(self):
with self.assertRaises(FileParseException):
self.handler_test_helper(filename='example.invalid.conll',
handler=CoNLLHandler())
self.assertEqual(Document.objects.count(), 0)
self.assertEqual(Label.objects.count(), 0)
self.assertEqual(SequenceAnnotation.objects.count(), 0)
def test_csv_classification_handler(self):
self.handler_test_helper(filename='example.valid.2.csv',
handler=CSVClassificationHandler())
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Label.objects.count(), 2)
self.assertEqual(DocumentAnnotation.objects.count(), 3)
def test_csv_seq2seq_handler(self):
self.handler_test_helper(filename='example.valid.2.csv',
handler=CSVSeq2seqHandler())
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Seq2seqAnnotation.objects.count(), 3)
def test_json_classification_handler(self):
self.handler_test_helper(filename='example.classification.jsonl',
handler=JsonClassificationHandler())
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Label.objects.count(), 2)
self.assertEqual(DocumentAnnotation.objects.count(), 4)
def test_json_labeling_handler(self):
self.handler_test_helper(filename='example.labeling.jsonl',
handler=JsonLabelingHandler())
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Label.objects.count(), 3)
self.assertEqual(SequenceAnnotation.objects.count(), 4)
def test_json_seq2seq_handler(self):
self.handler_test_helper(filename='example.seq2seq.jsonl',
handler=JsonSeq2seqHandler())
self.assertEqual(Document.objects.count(), 3)
self.assertEqual(Seq2seqAnnotation.objects.count(), 4)

11
app/server/utils.py

@ -1,5 +1,7 @@
import re
import string
def get_key_choices():
selectKey, shortKey = [c for c in string.ascii_lowercase], [c for c in string.ascii_lowercase]
checkKey = 'ctrl shift'
@ -8,3 +10,12 @@ def get_key_choices():
shortKey += ['']
KEY_CHOICES = ((u, c) for u, c in zip(shortKey, shortKey))
return KEY_CHOICES
def extract_label(tag):
ptn = re.compile(r'(B|I|E|S)-(.+)')
m = ptn.match(tag)
if m:
return m.groups()[1]
else:
return tag
Loading…
Cancel
Save