From fa388ac6ecd18e1de7b47a01dee1188f3dab41e6 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Mon, 4 Mar 2019 21:14:59 +0900 Subject: [PATCH] Update project serializer mapping --- app/server/api.py | 22 +---------------- app/server/models.py | 48 +++++++++++++++++++++++++++++++++--- app/server/serializers.py | 30 +++++++++++++++++++--- app/server/tests/test_api.py | 16 ++++++------ 4 files changed, 82 insertions(+), 34 deletions(-) diff --git a/app/server/api.py b/app/server/api.py index d7dcb120..a827b8e1 100644 --- a/app/server/api.py +++ b/app/server/api.py @@ -17,7 +17,6 @@ from rest_framework.parsers import MultiPartParser from .exceptions import FileParseException from .models import Project, Label, Document from .models import SequenceAnnotation -from .models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsMyEntity from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer from .serializers import SequenceAnnotationSerializer, DocumentAnnotationSerializer, Seq2seqAnnotationSerializer @@ -150,29 +149,10 @@ class TextUploadAPI(APIView): if 'file' not in request.FILES: raise ParseError('Empty content') project = get_object_or_404(Project, pk=self.kwargs['project_id']) - handler = self.decide_handler(request.data['format'], project.project_type) + handler = project.get_upload_handler(request.data['format']) handler.handle_uploaded_file(request.FILES['file'], project, self.request.user) return Response(status=status.HTTP_201_CREATED) - def decide_handler(self, format, project_type): - if format == 'plain': - return PlainTextHandler() - elif format == 'conll' and project_type: - return CoNLLHandler() - elif format == 'csv': - if project_type == DOCUMENT_CLASSIFICATION: - return CSVClassificationHandler() - elif project_type == SEQ2SEQ: - return CSVSeq2seqHandler() - elif format == 'json': - if project_type == DOCUMENT_CLASSIFICATION: - return JsonClassificationHandler() - elif project_type == SEQUENCE_LABELING: - return JsonLabelingHandler() - elif project_type == SEQ2SEQ: - return JsonSeq2seqHandler() - raise ValueError('format {} is invalid.'.format(format)) - class FileHandler(object): annotation_serializer = None diff --git a/app/server/models.py b/app/server/models.py index 6d61c856..6cd0ae52 100644 --- a/app/server/models.py +++ b/app/server/models.py @@ -19,8 +19,8 @@ PROJECT_CHOICES = ( class Project(PolymorphicModel): name = models.CharField(max_length=100) - description = models.TextField() - guideline = models.TextField() + description = models.TextField(default='') + guideline = models.TextField(default='') created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) users = models.ManyToManyField(User, related_name='projects') @@ -31,7 +31,19 @@ class Project(PolymorphicModel): @property def image(self): - return staticfiles_storage.url('images/cats/text_classification.jpg') + raise NotImplementedError() + + def get_template_name(self): + raise NotImplementedError() + + def get_annotation_serializer(self): + raise NotImplementedError() + + def get_annotation_class(self): + raise NotImplementedError() + + def get_upload_handler(self, format): + raise NotImplementedError def __str__(self): return self.name @@ -53,6 +65,16 @@ class TextClassificationProject(Project): def get_annotation_class(self): return DocumentAnnotation + def get_upload_handler(self, format): + from .api import PlainTextHandler, CSVClassificationHandler, JsonClassificationHandler + if format == 'plain': + return PlainTextHandler() + elif format == 'csv': + return CSVClassificationHandler() + elif format == 'json': + return JsonClassificationHandler() + raise ValueError('format {} is invalid.'.format(format)) + class SequenceLabelingProject(Project): @@ -70,6 +92,16 @@ class SequenceLabelingProject(Project): def get_annotation_class(self): return SequenceAnnotation + def get_upload_handler(self, format): + from .api import PlainTextHandler, CoNLLHandler, JsonLabelingHandler + if format == 'plain': + return PlainTextHandler() + elif format == 'conll': + return CoNLLHandler() + elif format == 'json': + return JsonLabelingHandler() + raise ValueError('format {} is invalid.'.format(format)) + class Seq2seqProject(Project): @@ -87,6 +119,16 @@ class Seq2seqProject(Project): def get_annotation_class(self): return Seq2seqAnnotation + def get_upload_handler(self, format): + from .api import PlainTextHandler, CSVSeq2seqHandler, JsonSeq2seqHandler + if format == 'plain': + return PlainTextHandler() + elif format == 'csv': + return CSVSeq2seqHandler() + elif format == 'json': + return JsonSeq2seqHandler() + raise ValueError('format {} is invalid.'.format(format)) + class Label(models.Model): KEY_CHOICES = get_key_choices() diff --git a/app/server/serializers.py b/app/server/serializers.py index f952c0f1..d1b29a7a 100644 --- a/app/server/serializers.py +++ b/app/server/serializers.py @@ -41,12 +41,36 @@ class ProjectSerializer(serializers.ModelSerializer): read_only_fields = ('image', 'updated_at') +class TextClassificationProjectSerializer(serializers.ModelSerializer): + + class Meta: + model = TextClassificationProject + fields = '__all__' + read_only_fields = ('image', 'updated_at', 'users') + + +class SequenceLabelingProjectSerializer(serializers.ModelSerializer): + + class Meta: + model = SequenceLabelingProject + fields = '__all__' + read_only_fields = ('image', 'updated_at', 'users') + + +class Seq2seqProjectSerializer(serializers.ModelSerializer): + + class Meta: + model = Seq2seqProject + fields = '__all__' + read_only_fields = ('image', 'updated_at', 'users') + + class ProjectPolymorphicSerializer(PolymorphicSerializer): model_serializer_mapping = { Project: ProjectSerializer, - TextClassificationProject: ProjectSerializer, - SequenceLabelingProject: ProjectSerializer, - Seq2seqProject: ProjectSerializer + TextClassificationProject: TextClassificationProjectSerializer, + SequenceLabelingProject: SequenceLabelingProjectSerializer, + Seq2seqProject: Seq2seqProjectSerializer } diff --git a/app/server/tests/test_api.py b/app/server/tests/test_api.py index 907809fb..79656ba4 100644 --- a/app/server/tests/test_api.py +++ b/app/server/tests/test_api.py @@ -32,8 +32,8 @@ class TestProjectListAPI(APITestCase): password=cls.super_user_pass, email='fizz@buzz.com') - cls.main_project = mommy.make('server.Project', users=[main_project_member]) - cls.sub_project = mommy.make('server.Project', users=[sub_project_member]) + cls.main_project = mommy.make('server.TextClassificationProject', users=[main_project_member]) + cls.sub_project = mommy.make('server.TextClassificationProject', users=[sub_project_member]) cls.url = reverse(viewname='project_list') cls.data = {'name': 'example', 'project_type': 'DocumentClassification', @@ -90,8 +90,8 @@ class TestProjectDetailAPI(APITestCase): super_user = User.objects.create_superuser(username=cls.super_user_name, password=cls.super_user_pass, email='fizz@buzz.com') - cls.main_project = mommy.make('server.Project', users=[cls.project_member, super_user]) - sub_project = mommy.make('server.Project', users=[non_project_member]) + cls.main_project = mommy.make('server.TextClassificationProject', users=[cls.project_member, super_user]) + sub_project = mommy.make('server.TextClassificationProject', users=[non_project_member]) cls.url = reverse(viewname='project_detail', args=[cls.main_project.id]) cls.data = {'description': 'lorem'} @@ -634,9 +634,11 @@ class TestUploader(APITestCase): super_user = User.objects.create_superuser(username=cls.super_user_name, password=cls.super_user_pass, email='fizz@buzz.com') - cls.classification_project = mommy.make('server.Project', users=[super_user], project_type=DOCUMENT_CLASSIFICATION) - cls.labeling_project = mommy.make('server.Project', users=[super_user], project_type=SEQUENCE_LABELING) - cls.seq2seq_project = mommy.make('server.Project', users=[super_user], project_type=SEQ2SEQ) + cls.classification_project = mommy.make('server.TextClassificationProject', + users=[super_user], project_type=DOCUMENT_CLASSIFICATION) + cls.labeling_project = mommy.make('server.SequenceLabelingProject', + users=[super_user], project_type=SEQUENCE_LABELING) + cls.seq2seq_project = mommy.make('server.Seq2seqProject', users=[super_user], project_type=SEQ2SEQ) cls.classification_url = reverse(viewname='doc_uploader', args=[cls.classification_project.id]) cls.labeling_url = reverse(viewname='doc_uploader', args=[cls.labeling_project.id]) cls.seq2seq_url = reverse(viewname='doc_uploader', args=[cls.seq2seq_project.id])