diff --git a/app/server/api.py b/app/server/api.py index c3dbfc6e..639c8eb6 100644 --- a/app/server/api.py +++ b/app/server/api.py @@ -17,9 +17,11 @@ from rest_framework.parsers import MultiPartParser from .exceptions import FileParseException from .models import Project, Label, Document from .models import SequenceAnnotation +from .models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsMyEntity from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer -from .serializers import SequenceAnnotationSerializer +from .serializers import SequenceAnnotationSerializer, DocumentAnnotationSerializer, Seq2seqAnnotationSerializer +from .utils import extract_label class ProjectList(generics.ListCreateAPIView): @@ -146,18 +148,62 @@ class TextUploadAPI(APIView): def post(self, request, *args, **kwargs): if 'file' not in request.FILES: raise ParseError('Empty content') - self.handle_uploaded_file(request.FILES['file']) + project = get_object_or_404(Project, pk=self.kwargs['project_id']) + handler = self.decide_handler(request.data['format'], project.project_type) + handler.handle_uploaded_file(request.FILES['file'], project, self.request.user) return Response(status=status.HTTP_201_CREATED) + def decide_handler(self, format, project_type): + if format == 'plain': + return PlainTextHandler() + elif format == 'conll' and project_type: + return CoNLLHandler() + elif format == 'csv': + if project_type == DOCUMENT_CLASSIFICATION: + return CSVClassificationHandler() + elif project_type == SEQ2SEQ: + return CSVSeq2seqHandler() + elif format == 'json': + if project_type == DOCUMENT_CLASSIFICATION: + return JsonClassificationHandler() + elif project_type == SEQUENCE_LABELING: + return JsonLabelingHandler() + elif project_type == SEQ2SEQ: + return JsonSeq2seqHandler() + raise ValueError('format {} is invalid.'.format(format)) + + +class FileHandler(object): + annotation_serializer = None + @transaction.atomic - def handle_uploaded_file(self, file): + def handle_uploaded_file(self, file, project, user): raise NotImplementedError() def parse(self, file): raise NotImplementedError() + def save_doc(self, data, project): + serializer = DocumentSerializer(data=data) + serializer.is_valid(raise_exception=True) + doc = serializer.save(project=project) + return doc + + def save_label(self, data, project): + label = Label.objects.filter(project=project, **data).first() + serializer = LabelSerializer(label, data=data) + serializer.is_valid(raise_exception=True) + label = serializer.save(project=project) + return label + + def save_annotation(self, data, doc, user): + serializer = self.annotation_serializer(data=data) + serializer.is_valid(raise_exception=True) + annotation = serializer.save(document=doc, user=user) + return annotation -class CoNLLFileUploadAPI(TextUploadAPI): + +class CoNLLHandler(FileHandler): """Uploads CoNLL format file. The file format is tab-separated values. @@ -179,19 +225,22 @@ class CoNLLFileUploadAPI(TextUploadAPI): ... ``` """ + annotation_serializer = SequenceAnnotationSerializer @transaction.atomic - def handle_uploaded_file(self, file): - project = get_object_or_404(Project, pk=self.kwargs['project_id']) - for words in self.parse(file): - sent = self.words_to_sent(words) - data = {'text': sent} - serializer = DocumentSerializer(data=data) - serializer.is_valid(raise_exception=True) - serializer.save(project=project) - - def words_to_sent(self, words): - return ' '.join(words) + def handle_uploaded_file(self, file, project, user): + for words, tags in self.parse(file): + start_offset = 0 + sent = ' '.join(words) + doc = self.save_doc({'text': sent}, project) + for word, tag in zip(words, tags): + label = extract_label(tag) + label = self.save_label({'text': label}, project) + data = {'start_offset': start_offset, + 'end_offset': start_offset + len(word), + 'label': label.id} + start_offset += len(word) + 1 + self.save_annotation(data, doc, user) def parse(self, file): words, tags = [], [] @@ -206,13 +255,13 @@ class CoNLLFileUploadAPI(TextUploadAPI): words.append(word) tags.append(tag) else: - yield words + yield words, tags words, tags = [], [] if len(words) > 0: - yield words + yield words, tags -class PlainTextUploadAPI(TextUploadAPI): +class PlainTextHandler(FileHandler): """Uploads plain text. The file format is as follows: @@ -223,13 +272,9 @@ class PlainTextUploadAPI(TextUploadAPI): ``` """ @transaction.atomic - def handle_uploaded_file(self, file): - project = get_object_or_404(Project, pk=self.kwargs['project_id']) + def handle_uploaded_file(self, file, project, user): for text in self.parse(file): - data = {'text': text} - serializer = DocumentSerializer(data=data) - serializer.is_valid(raise_exception=True) - serializer.save(project=project) + self.save_doc({'text': text}, project) def parse(self, file): file = io.TextIOWrapper(file, encoding='utf-8') @@ -237,30 +282,20 @@ class PlainTextUploadAPI(TextUploadAPI): yield line.strip() -class CSVUploadAPI(TextUploadAPI): +class CSVHandler(FileHandler): """Uploads csv file. The file format is comma separated values. Column names are required at the top of a file. For example: ``` - text, label(optional) - "EU rejects German call to boycott British lamb.", - "President Obama is speaking at the White House.", - "He lives in Newark, Ohio.", + text, label + "EU rejects German call to boycott British lamb.",Politics + "President Obama is speaking at the White House.",Politics + "He lives in Newark, Ohio.",Other ... ``` """ - - @transaction.atomic - def handle_uploaded_file(self, file): - project = get_object_or_404(Project, pk=self.kwargs['project_id']) - for text, label in self.parse(file): - data = {'text': text} - serializer = DocumentSerializer(data=data) - serializer.is_valid(raise_exception=True) - serializer.save(project=project) - def parse(self, file): file = io.TextIOWrapper(file, encoding='utf-8') reader = csv.reader(file) @@ -276,7 +311,28 @@ class CSVUploadAPI(TextUploadAPI): raise FileParseException(line_num=i, line=row) -class JSONLUploadAPI(TextUploadAPI): +class CSVClassificationHandler(CSVHandler): + annotation_serializer = DocumentAnnotationSerializer + + @transaction.atomic + def handle_uploaded_file(self, file, project, user): + for text, label in self.parse(file): + doc = self.save_doc({'text': text}, project) + label = self.save_label({'text': label}, project) + self.save_annotation({'label': label.id}, doc, user) + + +class CSVSeq2seqHandler(CSVHandler): + annotation_serializer = Seq2seqAnnotationSerializer + + @transaction.atomic + def handle_uploaded_file(self, file, project, user): + for text, label in self.parse(file): + doc = self.save_doc({'text': text}, project) + self.save_annotation({'text': label}, doc, user) + + +class JsonHandler(FileHandler): """Uploads jsonl file. The file format is as follows: @@ -286,15 +342,6 @@ class JSONLUploadAPI(TextUploadAPI): ... ``` """ - - @transaction.atomic - def handle_uploaded_file(self, file): - project = get_object_or_404(Project, pk=self.kwargs['project_id']) - for data in self.parse(file): - serializer = DocumentSerializer(data=data) - serializer.is_valid(raise_exception=True) - serializer.save(project=project) - def parse(self, file): for i, line in enumerate(file, start=1): try: @@ -302,3 +349,65 @@ class JSONLUploadAPI(TextUploadAPI): yield j except json.decoder.JSONDecodeError: raise FileParseException(line_num=i, line=line) + + +class JsonClassificationHandler(JsonHandler): + """Upload jsonl for text classification. + + The format is as follows: + ``` + {"text": "Python is awesome!", "labels": ["positive"]} + ... + ``` + """ + annotation_serializer = DocumentAnnotationSerializer + + @transaction.atomic + def handle_uploaded_file(self, file, project, user): + for data in self.parse(file): + doc = self.save_doc(data, project) + for label in data['labels']: + label = self.save_label({'text': label}, project) + self.save_annotation({'label': label.id}, doc, user) + + +class JsonLabelingHandler(JsonHandler): + """Upload jsonl for sequence labeling. + + The format is as follows: + ``` + {"text": "Python is awesome!", "entities": [[0, 6, "Product"],]} + ... + ``` + """ + annotation_serializer = SequenceAnnotationSerializer + + @transaction.atomic + def handle_uploaded_file(self, file, project, user): + for data in self.parse(file): + doc = self.save_doc(data, project) + for start_offset, end_offset, label in data['entities']: + label = self.save_label({'text': label}, project) + data = {'label': label.id, + 'start_offset': start_offset, + 'end_offset': end_offset} + self.save_annotation(data, doc, user) + + +class JsonSeq2seqHandler(JsonHandler): + """Upload jsonl for seq2seq. + + The format is as follows: + ``` + {"text": "Hello, World!", "labels": ["こんにちは、世界!"]} + ... + ``` + """ + annotation_serializer = Seq2seqAnnotationSerializer + + @transaction.atomic + def handle_uploaded_file(self, file, project, user): + for data in self.parse(file): + doc = self.save_doc(data, project) + for label in data['labels']: + self.save_annotation({'text': label}, doc, user) diff --git a/app/server/api_urls.py b/app/server/api_urls.py index 1f403188..dc3aafa5 100644 --- a/app/server/api_urls.py +++ b/app/server/api_urls.py @@ -5,7 +5,7 @@ from .api import ProjectList, ProjectDetail from .api import LabelList, LabelDetail from .api import DocumentList, DocumentDetail from .api import EntityList, EntityDetail -from .api import CoNLLFileUploadAPI, CSVUploadAPI, JSONLUploadAPI, PlainTextUploadAPI +from .api import TextUploadAPI from .api import StatisticsAPI @@ -26,14 +26,8 @@ urlpatterns = [ EntityList.as_view(), name='entity_list'), path('projects//docs//entities/', EntityDetail.as_view(), name='entity_detail'), - path('projects//plain_uploader', - PlainTextUploadAPI.as_view(), name='plain_uploader'), - path('projects//conll_uploader', - CoNLLFileUploadAPI.as_view(), name='conll_uploader'), - path('projects//csv_uploader', - CSVUploadAPI.as_view(), name='csv_uploader'), - path('projects//json_uploader', - JSONLUploadAPI.as_view(), name='json_uploader'), + path('projects//docs/upload', + TextUploadAPI.as_view(), name='doc_uploader') ] urlpatterns = format_suffix_patterns(urlpatterns, allowed=['json', 'xml']) diff --git a/app/server/models.py b/app/server/models.py index b4a25dc1..2cfa2862 100644 --- a/app/server/models.py +++ b/app/server/models.py @@ -6,17 +6,17 @@ from django.contrib.staticfiles.storage import staticfiles_storage from .utils import get_key_choices -class Project(models.Model): - DOCUMENT_CLASSIFICATION = 'DocumentClassification' - SEQUENCE_LABELING = 'SequenceLabeling' - SEQ2SEQ = 'Seq2seq' +DOCUMENT_CLASSIFICATION = 'DocumentClassification' +SEQUENCE_LABELING = 'SequenceLabeling' +SEQ2SEQ = 'Seq2seq' +PROJECT_CHOICES = ( + (DOCUMENT_CLASSIFICATION, 'document classification'), + (SEQUENCE_LABELING, 'sequence labeling'), + (SEQ2SEQ, 'sequence to sequence'), +) - PROJECT_CHOICES = ( - (DOCUMENT_CLASSIFICATION, 'document classification'), - (SEQUENCE_LABELING, 'sequence labeling'), - (SEQ2SEQ, 'sequence to sequence'), - ) +class Project(models.Model): name = models.CharField(max_length=100) description = models.TextField() guideline = models.TextField() diff --git a/app/server/serializers.py b/app/server/serializers.py index eb47dc5f..57e9a0ea 100644 --- a/app/server/serializers.py +++ b/app/server/serializers.py @@ -38,7 +38,8 @@ class ProjectFilteredPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField): class DocumentAnnotationSerializer(serializers.ModelSerializer): - label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all()) + # label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all()) + label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all()) class Meta: model = DocumentAnnotation @@ -50,7 +51,8 @@ class DocumentAnnotationSerializer(serializers.ModelSerializer): class SequenceAnnotationSerializer(serializers.ModelSerializer): - label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all()) + #label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all()) + label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all()) class Meta: model = SequenceAnnotation diff --git a/app/server/tests/data/example.classification.jsonl b/app/server/tests/data/example.classification.jsonl new file mode 100644 index 00000000..b5ab4a56 --- /dev/null +++ b/app/server/tests/data/example.classification.jsonl @@ -0,0 +1,3 @@ +{"text": "example", "labels": ["positive"]} +{"text": "example", "labels": ["positive", "negative"]} +{"text": "example", "labels": ["negative"]} diff --git a/app/server/tests/data/example.labeling.jsonl b/app/server/tests/data/example.labeling.jsonl new file mode 100644 index 00000000..6591c7bf --- /dev/null +++ b/app/server/tests/data/example.labeling.jsonl @@ -0,0 +1,3 @@ +{"text": "example", "entities": [[0, 1, "LOC"], [0, 2, "ORG"]]} +{"text": "example", "entities": [[0, 1, "LOC"]]} +{"text": "example", "entities": [[0, 1, "PER"]]} diff --git a/app/server/tests/data/example.seq2seq.jsonl b/app/server/tests/data/example.seq2seq.jsonl new file mode 100644 index 00000000..6138b6af --- /dev/null +++ b/app/server/tests/data/example.seq2seq.jsonl @@ -0,0 +1,3 @@ +{"text": "example", "labels": ["example1", "example2"]} +{"text": "example", "labels": ["example"]} +{"text": "example", "labels": ["example"]} diff --git a/app/server/tests/data/example.valid.2.csv b/app/server/tests/data/example.valid.2.csv index 89d78846..206730c3 100644 --- a/app/server/tests/data/example.valid.2.csv +++ b/app/server/tests/data/example.valid.2.csv @@ -1,4 +1,4 @@ text, label -AAA, Positive -BBB, Positive -CCC, Negative \ No newline at end of file +AAA,Positive +BBB,Positive +CCC,Negative \ No newline at end of file diff --git a/app/server/tests/test_api.py b/app/server/tests/test_api.py index b5008559..d4092369 100644 --- a/app/server/tests/test_api.py +++ b/app/server/tests/test_api.py @@ -4,7 +4,11 @@ from rest_framework import status from rest_framework.reverse import reverse from rest_framework.test import APITestCase from mixer.backend.django import mixer -from ..models import User, SequenceAnnotation, Document +from ..models import User, SequenceAnnotation, Document, Label, Seq2seqAnnotation, DocumentAnnotation +from ..models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ +from ..api import CoNLLHandler, CSVClassificationHandler, CSVSeq2seqHandler +from ..api import JsonClassificationHandler, JsonLabelingHandler, JsonSeq2seqHandler +from ..exceptions import FileParseException DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') @@ -622,62 +626,149 @@ class TestUploader(APITestCase): @classmethod def setUpTestData(cls): - cls.project_member_name = 'project_member_name' - cls.project_member_pass = 'project_member_pass' - project_member = User.objects.create_user(username=cls.project_member_name, - password=cls.project_member_pass) cls.super_user_name = 'super_user_name' cls.super_user_pass = 'super_user_pass' # Todo: change super_user to project_admin. super_user = User.objects.create_superuser(username=cls.super_user_name, password=cls.super_user_pass, email='fizz@buzz.com') - cls.main_project = mixer.blend('server.Project', users=[project_member, super_user]) - cls.conll_url = reverse(viewname='conll_uploader', args=[cls.main_project.id]) - cls.csv_url = reverse(viewname='csv_uploader', args=[cls.main_project.id]) - cls.json_url = reverse(viewname='json_uploader', args=[cls.main_project.id]) - cls.plain_url = reverse(viewname='plain_uploader', args=[cls.main_project.id]) + cls.classification_project = mixer.blend('server.Project', users=[super_user], project_type=DOCUMENT_CLASSIFICATION) + cls.labeling_project = mixer.blend('server.Project', users=[super_user], project_type=SEQUENCE_LABELING) + cls.seq2seq_project = mixer.blend('server.Project', users=[super_user], project_type=SEQ2SEQ) + cls.classification_url = reverse(viewname='doc_uploader', args=[cls.classification_project.id]) + cls.labeling_url = reverse(viewname='doc_uploader', args=[cls.labeling_project.id]) + cls.seq2seq_url = reverse(viewname='doc_uploader', args=[cls.seq2seq_project.id]) def setUp(self): self.client.login(username=self.super_user_name, password=self.super_user_pass) - def upload_test_helper(self, filename, url, expected_status): + def upload_test_helper(self, url, filename, format, expected_status): with open(os.path.join(DATA_DIR, filename)) as f: - response = self.client.post(url, data={'file': f}) + response = self.client.post(url, data={'file': f, 'format': format}) self.assertEqual(response.status_code, expected_status) def test_can_upload_conll_format_file(self): - self.upload_test_helper(filename='example.valid.conll', - url=self.conll_url, + self.upload_test_helper(url=self.labeling_url, + filename='example.valid.conll', + format='conll', expected_status=status.HTTP_201_CREATED) def test_cannot_upload_wrong_conll_format_file(self): - self.upload_test_helper(filename='example.invalid.conll', - url=self.conll_url, + self.upload_test_helper(url=self.labeling_url, + filename='example.invalid.conll', + format='conll', expected_status=status.HTTP_400_BAD_REQUEST) - def test_can_upload_csv_with_label(self): - self.upload_test_helper(filename='example.valid.2.csv', - url=self.csv_url, + def test_can_upload_classification_csv(self): + self.upload_test_helper(url=self.classification_url, + filename='example.valid.2.csv', + format='csv', + expected_status=status.HTTP_201_CREATED) + + def test_can_upload_seq2seq_csv(self): + self.upload_test_helper(url=self.classification_url, + filename='example.valid.2.csv', + format='csv', expected_status=status.HTTP_201_CREATED) def test_cannot_upload_csv_file_does_not_match_column_and_row(self): - self.upload_test_helper(filename='example.invalid.1.csv', - url=self.csv_url, + self.upload_test_helper(url=self.classification_url, + filename='example.invalid.1.csv', + format='csv', expected_status=status.HTTP_400_BAD_REQUEST) def test_cannot_upload_csv_file_has_too_many_columns(self): - self.upload_test_helper(filename='example.invalid.2.csv', - url=self.csv_url, + self.upload_test_helper(url=self.classification_url, + filename='example.invalid.2.csv', + format='csv', expected_status=status.HTTP_400_BAD_REQUEST) - def test_can_upload_jsonl(self): - self.upload_test_helper(filename='example.jsonl', - url=self.json_url, + def test_can_upload_classification_jsonl(self): + self.upload_test_helper(url=self.classification_url, + filename='example.classification.jsonl', + format='json', + expected_status=status.HTTP_201_CREATED) + + def test_can_upload_labeling_jsonl(self): + self.upload_test_helper(url=self.labeling_url, + filename='example.labeling.jsonl', + format='json', + expected_status=status.HTTP_201_CREATED) + + def test_can_upload_seq2seq_jsonl(self): + self.upload_test_helper(url=self.seq2seq_url, + filename='example.seq2seq.jsonl', + format='json', expected_status=status.HTTP_201_CREATED) def test_can_upload_plain_text(self): - self.upload_test_helper(filename='example.txt', - url=self.plain_url, + self.upload_test_helper(url=self.classification_url, + filename='example.txt', + format='plain', expected_status=status.HTTP_201_CREATED) + + +class TestFileHandler(APITestCase): + + @classmethod + def setUpTestData(cls): + cls.super_user_name = 'super_user_name' + cls.super_user_pass = 'super_user_pass' + # Todo: change super_user to project_admin. + cls.super_user = User.objects.create_superuser(username=cls.super_user_name, + password=cls.super_user_pass, + email='fizz@buzz.com') + cls.project = mixer.blend('server.Project', users=[cls.super_user]) + + def handler_test_helper(self, filename, handler): + with open(os.path.join(DATA_DIR, filename), mode='rb') as f: + handler.handle_uploaded_file(f, self.project, self.super_user) + + def test_conll_handler(self): + self.handler_test_helper(filename='example.valid.conll', + handler=CoNLLHandler()) + self.assertEqual(Document.objects.count(), 3) + self.assertEqual(Label.objects.count(), 3) # LOC, PER, O + self.assertEqual(SequenceAnnotation.objects.count(), 20) # num of annotation line + + def test_conll_invalid_handler(self): + with self.assertRaises(FileParseException): + self.handler_test_helper(filename='example.invalid.conll', + handler=CoNLLHandler()) + self.assertEqual(Document.objects.count(), 0) + self.assertEqual(Label.objects.count(), 0) + self.assertEqual(SequenceAnnotation.objects.count(), 0) + + def test_csv_classification_handler(self): + self.handler_test_helper(filename='example.valid.2.csv', + handler=CSVClassificationHandler()) + self.assertEqual(Document.objects.count(), 3) + self.assertEqual(Label.objects.count(), 2) + self.assertEqual(DocumentAnnotation.objects.count(), 3) + + def test_csv_seq2seq_handler(self): + self.handler_test_helper(filename='example.valid.2.csv', + handler=CSVSeq2seqHandler()) + self.assertEqual(Document.objects.count(), 3) + self.assertEqual(Seq2seqAnnotation.objects.count(), 3) + + def test_json_classification_handler(self): + self.handler_test_helper(filename='example.classification.jsonl', + handler=JsonClassificationHandler()) + self.assertEqual(Document.objects.count(), 3) + self.assertEqual(Label.objects.count(), 2) + self.assertEqual(DocumentAnnotation.objects.count(), 4) + + def test_json_labeling_handler(self): + self.handler_test_helper(filename='example.labeling.jsonl', + handler=JsonLabelingHandler()) + self.assertEqual(Document.objects.count(), 3) + self.assertEqual(Label.objects.count(), 3) + self.assertEqual(SequenceAnnotation.objects.count(), 4) + + def test_json_seq2seq_handler(self): + self.handler_test_helper(filename='example.seq2seq.jsonl', + handler=JsonSeq2seqHandler()) + self.assertEqual(Document.objects.count(), 3) + self.assertEqual(Seq2seqAnnotation.objects.count(), 4) diff --git a/app/server/utils.py b/app/server/utils.py index fd229a0b..9c5896b8 100644 --- a/app/server/utils.py +++ b/app/server/utils.py @@ -1,5 +1,7 @@ +import re import string + def get_key_choices(): selectKey, shortKey = [c for c in string.ascii_lowercase], [c for c in string.ascii_lowercase] checkKey = 'ctrl shift' @@ -8,3 +10,12 @@ def get_key_choices(): shortKey += [''] KEY_CHOICES = ((u, c) for u, c in zip(shortKey, shortKey)) return KEY_CHOICES + + +def extract_label(tag): + ptn = re.compile(r'(B|I|E|S)-(.+)') + m = ptn.match(tag) + if m: + return m.groups()[1] + else: + return tag