Browse Source

Update project serializer mapping

pull/110/head
Hironsan 5 years ago
parent
commit
fa388ac6ec
4 changed files with 82 additions and 34 deletions
  1. 22
      app/server/api.py
  2. 48
      app/server/models.py
  3. 30
      app/server/serializers.py
  4. 16
      app/server/tests/test_api.py

22
app/server/api.py

@ -17,7 +17,6 @@ from rest_framework.parsers import MultiPartParser
from .exceptions import FileParseException
from .models import Project, Label, Document
from .models import SequenceAnnotation
from .models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsMyEntity
from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer
from .serializers import SequenceAnnotationSerializer, DocumentAnnotationSerializer, Seq2seqAnnotationSerializer
@ -150,29 +149,10 @@ class TextUploadAPI(APIView):
if 'file' not in request.FILES:
raise ParseError('Empty content')
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
handler = self.decide_handler(request.data['format'], project.project_type)
handler = project.get_upload_handler(request.data['format'])
handler.handle_uploaded_file(request.FILES['file'], project, self.request.user)
return Response(status=status.HTTP_201_CREATED)
def decide_handler(self, format, project_type):
if format == 'plain':
return PlainTextHandler()
elif format == 'conll' and project_type:
return CoNLLHandler()
elif format == 'csv':
if project_type == DOCUMENT_CLASSIFICATION:
return CSVClassificationHandler()
elif project_type == SEQ2SEQ:
return CSVSeq2seqHandler()
elif format == 'json':
if project_type == DOCUMENT_CLASSIFICATION:
return JsonClassificationHandler()
elif project_type == SEQUENCE_LABELING:
return JsonLabelingHandler()
elif project_type == SEQ2SEQ:
return JsonSeq2seqHandler()
raise ValueError('format {} is invalid.'.format(format))
class FileHandler(object):
annotation_serializer = None

48
app/server/models.py

@ -19,8 +19,8 @@ PROJECT_CHOICES = (
class Project(PolymorphicModel):
name = models.CharField(max_length=100)
description = models.TextField()
guideline = models.TextField()
description = models.TextField(default='')
guideline = models.TextField(default='')
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
users = models.ManyToManyField(User, related_name='projects')
@ -31,7 +31,19 @@ class Project(PolymorphicModel):
@property
def image(self):
return staticfiles_storage.url('images/cats/text_classification.jpg')
raise NotImplementedError()
def get_template_name(self):
raise NotImplementedError()
def get_annotation_serializer(self):
raise NotImplementedError()
def get_annotation_class(self):
raise NotImplementedError()
def get_upload_handler(self, format):
raise NotImplementedError
def __str__(self):
return self.name
@ -53,6 +65,16 @@ class TextClassificationProject(Project):
def get_annotation_class(self):
return DocumentAnnotation
def get_upload_handler(self, format):
from .api import PlainTextHandler, CSVClassificationHandler, JsonClassificationHandler
if format == 'plain':
return PlainTextHandler()
elif format == 'csv':
return CSVClassificationHandler()
elif format == 'json':
return JsonClassificationHandler()
raise ValueError('format {} is invalid.'.format(format))
class SequenceLabelingProject(Project):
@ -70,6 +92,16 @@ class SequenceLabelingProject(Project):
def get_annotation_class(self):
return SequenceAnnotation
def get_upload_handler(self, format):
from .api import PlainTextHandler, CoNLLHandler, JsonLabelingHandler
if format == 'plain':
return PlainTextHandler()
elif format == 'conll':
return CoNLLHandler()
elif format == 'json':
return JsonLabelingHandler()
raise ValueError('format {} is invalid.'.format(format))
class Seq2seqProject(Project):
@ -87,6 +119,16 @@ class Seq2seqProject(Project):
def get_annotation_class(self):
return Seq2seqAnnotation
def get_upload_handler(self, format):
from .api import PlainTextHandler, CSVSeq2seqHandler, JsonSeq2seqHandler
if format == 'plain':
return PlainTextHandler()
elif format == 'csv':
return CSVSeq2seqHandler()
elif format == 'json':
return JsonSeq2seqHandler()
raise ValueError('format {} is invalid.'.format(format))
class Label(models.Model):
KEY_CHOICES = get_key_choices()

30
app/server/serializers.py

@ -41,12 +41,36 @@ class ProjectSerializer(serializers.ModelSerializer):
read_only_fields = ('image', 'updated_at')
class TextClassificationProjectSerializer(serializers.ModelSerializer):
class Meta:
model = TextClassificationProject
fields = '__all__'
read_only_fields = ('image', 'updated_at', 'users')
class SequenceLabelingProjectSerializer(serializers.ModelSerializer):
class Meta:
model = SequenceLabelingProject
fields = '__all__'
read_only_fields = ('image', 'updated_at', 'users')
class Seq2seqProjectSerializer(serializers.ModelSerializer):
class Meta:
model = Seq2seqProject
fields = '__all__'
read_only_fields = ('image', 'updated_at', 'users')
class ProjectPolymorphicSerializer(PolymorphicSerializer):
model_serializer_mapping = {
Project: ProjectSerializer,
TextClassificationProject: ProjectSerializer,
SequenceLabelingProject: ProjectSerializer,
Seq2seqProject: ProjectSerializer
TextClassificationProject: TextClassificationProjectSerializer,
SequenceLabelingProject: SequenceLabelingProjectSerializer,
Seq2seqProject: Seq2seqProjectSerializer
}

16
app/server/tests/test_api.py

@ -32,8 +32,8 @@ class TestProjectListAPI(APITestCase):
password=cls.super_user_pass,
email='fizz@buzz.com')
cls.main_project = mommy.make('server.Project', users=[main_project_member])
cls.sub_project = mommy.make('server.Project', users=[sub_project_member])
cls.main_project = mommy.make('server.TextClassificationProject', users=[main_project_member])
cls.sub_project = mommy.make('server.TextClassificationProject', users=[sub_project_member])
cls.url = reverse(viewname='project_list')
cls.data = {'name': 'example', 'project_type': 'DocumentClassification',
@ -90,8 +90,8 @@ class TestProjectDetailAPI(APITestCase):
super_user = User.objects.create_superuser(username=cls.super_user_name,
password=cls.super_user_pass,
email='fizz@buzz.com')
cls.main_project = mommy.make('server.Project', users=[cls.project_member, super_user])
sub_project = mommy.make('server.Project', users=[non_project_member])
cls.main_project = mommy.make('server.TextClassificationProject', users=[cls.project_member, super_user])
sub_project = mommy.make('server.TextClassificationProject', users=[non_project_member])
cls.url = reverse(viewname='project_detail', args=[cls.main_project.id])
cls.data = {'description': 'lorem'}
@ -634,9 +634,11 @@ class TestUploader(APITestCase):
super_user = User.objects.create_superuser(username=cls.super_user_name,
password=cls.super_user_pass,
email='fizz@buzz.com')
cls.classification_project = mommy.make('server.Project', users=[super_user], project_type=DOCUMENT_CLASSIFICATION)
cls.labeling_project = mommy.make('server.Project', users=[super_user], project_type=SEQUENCE_LABELING)
cls.seq2seq_project = mommy.make('server.Project', users=[super_user], project_type=SEQ2SEQ)
cls.classification_project = mommy.make('server.TextClassificationProject',
users=[super_user], project_type=DOCUMENT_CLASSIFICATION)
cls.labeling_project = mommy.make('server.SequenceLabelingProject',
users=[super_user], project_type=SEQUENCE_LABELING)
cls.seq2seq_project = mommy.make('server.Seq2seqProject', users=[super_user], project_type=SEQ2SEQ)
cls.classification_url = reverse(viewname='doc_uploader', args=[cls.classification_project.id])
cls.labeling_url = reverse(viewname='doc_uploader', args=[cls.labeling_project.id])
cls.seq2seq_url = reverse(viewname='doc_uploader', args=[cls.seq2seq_project.id])

Loading…
Cancel
Save