From 6c33281f14621e6c0bc090c930b8d5c623896369 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 6 May 2021 16:16:48 +0900 Subject: [PATCH] Remove circular dependency --- backend/api/models.py | 19 ------------------- backend/api/serializers.py | 22 ++++++++++++++++++---- backend/api/views/annotation.py | 6 +++--- backend/api/views/auto_labeling.py | 5 +++-- 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/backend/api/models.py b/backend/api/models.py index 30c8e130..dd83dc08 100644 --- a/backend/api/models.py +++ b/backend/api/models.py @@ -33,9 +33,6 @@ class Project(PolymorphicModel): collaborative_annotation = models.BooleanField(default=False) single_class_classification = models.BooleanField(default=False) - def get_annotation_serializer(self): - raise NotImplementedError() - def get_annotation_class(self): raise NotImplementedError() @@ -45,40 +42,24 @@ class Project(PolymorphicModel): class TextClassificationProject(Project): - def get_annotation_serializer(self): - from .serializers import DocumentAnnotationSerializer - return DocumentAnnotationSerializer - def get_annotation_class(self): return DocumentAnnotation class SequenceLabelingProject(Project): - def get_annotation_serializer(self): - from .serializers import SequenceAnnotationSerializer - return SequenceAnnotationSerializer - def get_annotation_class(self): return SequenceAnnotation class Seq2seqProject(Project): - def get_annotation_serializer(self): - from .serializers import Seq2seqAnnotationSerializer - return Seq2seqAnnotationSerializer - def get_annotation_class(self): return Seq2seqAnnotation class Speech2textProject(Project): - def get_annotation_serializer(self): - from .serializers import Speech2textAnnotationSerializer - return Speech2textAnnotationSerializer - def get_annotation_class(self): return Speech2textAnnotation diff --git a/backend/api/serializers.py b/backend/api/serializers.py index 4ef09792..f1ba2318 100644 --- a/backend/api/serializers.py +++ b/backend/api/serializers.py @@ -6,9 +6,10 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_polymorphic.serializers import PolymorphicSerializer -from .models import (AutoLabelingConfig, Comment, Document, DocumentAnnotation, - Label, Project, Role, RoleMapping, Seq2seqAnnotation, - Seq2seqProject, SequenceAnnotation, +from .models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, + SPEECH2TEXT, AutoLabelingConfig, Comment, Document, + DocumentAnnotation, Label, Project, Role, RoleMapping, + Seq2seqAnnotation, Seq2seqProject, SequenceAnnotation, SequenceLabelingProject, Speech2textAnnotation, Speech2textProject, Tag, TextClassificationProject) @@ -85,7 +86,7 @@ class DocumentSerializer(serializers.ModelSerializer): request = self.context.get('request') project = instance.project model = project.get_annotation_class() - serializer = project.get_annotation_serializer() + serializer = get_annotation_serializer(task=project.project_type) annotations = model.objects.filter(document=instance.id) if request and not project.collaborative_annotation: annotations = annotations.filter(user=request.user) @@ -286,3 +287,16 @@ class AutoLabelingConfigSerializer(serializers.ModelSerializer): 'You need to correctly specify the required fields: {}'.format(required_fields) ) return data + + +def get_annotation_serializer(task: str): + mapping = { + DOCUMENT_CLASSIFICATION: DocumentAnnotationSerializer, + SEQUENCE_LABELING: SequenceAnnotationSerializer, + SEQ2SEQ: Seq2seqAnnotationSerializer, + SPEECH2TEXT: Speech2textAnnotationSerializer + } + try: + return mapping[task] + except KeyError: + raise ValueError(f'{task} is not implemented.') diff --git a/backend/api/views/annotation.py b/backend/api/views/annotation.py index 32e6af58..d770a252 100644 --- a/backend/api/views/annotation.py +++ b/backend/api/views/annotation.py @@ -7,7 +7,7 @@ from rest_framework.views import APIView from ..models import Document, Project from ..permissions import (IsAnnotationApprover, IsInProjectOrAdmin, IsOwnAnnotation, IsProjectAdmin) -from ..serializers import ApproverSerializer +from ..serializers import ApproverSerializer, get_annotation_serializer class AnnotationList(generics.ListCreateAPIView): @@ -20,7 +20,7 @@ class AnnotationList(generics.ListCreateAPIView): return get_object_or_404(Project, pk=self.kwargs['project_id']) def get_serializer_class(self): - self.serializer_class = self.project.get_annotation_serializer() + self.serializer_class = get_annotation_serializer(task=self.project.project_type) return self.serializer_class def get_queryset(self): @@ -59,7 +59,7 @@ class AnnotationDetail(generics.RetrieveUpdateDestroyAPIView): def get_serializer_class(self): project = get_object_or_404(Project, pk=self.kwargs['project_id']) - self.serializer_class = project.get_annotation_serializer() + self.serializer_class = get_annotation_serializer(task=project.project_type) return self.serializer_class def get_queryset(self): diff --git a/backend/api/views/auto_labeling.py b/backend/api/views/auto_labeling.py index 3919909d..46b5cb7d 100644 --- a/backend/api/views/auto_labeling.py +++ b/backend/api/views/auto_labeling.py @@ -18,7 +18,8 @@ from ..exceptions import (AutoLabeliingPermissionDenied, AutoLabelingException, URLConnectionError) from ..models import AutoLabelingConfig, Document, Project from ..permissions import IsInProjectOrAdmin, IsProjectAdmin -from ..serializers import AutoLabelingConfigSerializer +from ..serializers import (AutoLabelingConfigSerializer, + get_annotation_serializer) class AutoLabelingTemplateListAPI(APIView): @@ -170,7 +171,7 @@ class AutoLabelingAnnotation(generics.CreateAPIView): def get_serializer_class(self): project = get_object_or_404(Project, pk=self.kwargs['project_id']) - self.serializer_class = project.get_annotation_serializer() + self.serializer_class = get_annotation_serializer(task=project.project_type) return self.serializer_class def get_queryset(self):