Browse Source

Remove circular dependency

pull/1349/head
Hironsan 3 years ago
parent
commit
6c33281f14
4 changed files with 24 additions and 28 deletions
  1. 19
      backend/api/models.py
  2. 22
      backend/api/serializers.py
  3. 6
      backend/api/views/annotation.py
  4. 5
      backend/api/views/auto_labeling.py

19
backend/api/models.py

@ -33,9 +33,6 @@ class Project(PolymorphicModel):
collaborative_annotation = models.BooleanField(default=False)
single_class_classification = models.BooleanField(default=False)
def get_annotation_serializer(self):
raise NotImplementedError()
def get_annotation_class(self):
raise NotImplementedError()
@ -45,40 +42,24 @@ class Project(PolymorphicModel):
class TextClassificationProject(Project):
def get_annotation_serializer(self):
from .serializers import DocumentAnnotationSerializer
return DocumentAnnotationSerializer
def get_annotation_class(self):
return DocumentAnnotation
class SequenceLabelingProject(Project):
def get_annotation_serializer(self):
from .serializers import SequenceAnnotationSerializer
return SequenceAnnotationSerializer
def get_annotation_class(self):
return SequenceAnnotation
class Seq2seqProject(Project):
def get_annotation_serializer(self):
from .serializers import Seq2seqAnnotationSerializer
return Seq2seqAnnotationSerializer
def get_annotation_class(self):
return Seq2seqAnnotation
class Speech2textProject(Project):
def get_annotation_serializer(self):
from .serializers import Speech2textAnnotationSerializer
return Speech2textAnnotationSerializer
def get_annotation_class(self):
return Speech2textAnnotation

22
backend/api/serializers.py

@ -6,9 +6,10 @@ from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_polymorphic.serializers import PolymorphicSerializer
from .models import (AutoLabelingConfig, Comment, Document, DocumentAnnotation,
Label, Project, Role, RoleMapping, Seq2seqAnnotation,
Seq2seqProject, SequenceAnnotation,
from .models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
SPEECH2TEXT, AutoLabelingConfig, Comment, Document,
DocumentAnnotation, Label, Project, Role, RoleMapping,
Seq2seqAnnotation, Seq2seqProject, SequenceAnnotation,
SequenceLabelingProject, Speech2textAnnotation,
Speech2textProject, Tag, TextClassificationProject)
@ -85,7 +86,7 @@ class DocumentSerializer(serializers.ModelSerializer):
request = self.context.get('request')
project = instance.project
model = project.get_annotation_class()
serializer = project.get_annotation_serializer()
serializer = get_annotation_serializer(task=project.project_type)
annotations = model.objects.filter(document=instance.id)
if request and not project.collaborative_annotation:
annotations = annotations.filter(user=request.user)
@ -286,3 +287,16 @@ class AutoLabelingConfigSerializer(serializers.ModelSerializer):
'You need to correctly specify the required fields: {}'.format(required_fields)
)
return data
def get_annotation_serializer(task: str):
mapping = {
DOCUMENT_CLASSIFICATION: DocumentAnnotationSerializer,
SEQUENCE_LABELING: SequenceAnnotationSerializer,
SEQ2SEQ: Seq2seqAnnotationSerializer,
SPEECH2TEXT: Speech2textAnnotationSerializer
}
try:
return mapping[task]
except KeyError:
raise ValueError(f'{task} is not implemented.')

6
backend/api/views/annotation.py

@ -7,7 +7,7 @@ from rest_framework.views import APIView
from ..models import Document, Project
from ..permissions import (IsAnnotationApprover, IsInProjectOrAdmin,
IsOwnAnnotation, IsProjectAdmin)
from ..serializers import ApproverSerializer
from ..serializers import ApproverSerializer, get_annotation_serializer
class AnnotationList(generics.ListCreateAPIView):
@ -20,7 +20,7 @@ class AnnotationList(generics.ListCreateAPIView):
return get_object_or_404(Project, pk=self.kwargs['project_id'])
def get_serializer_class(self):
self.serializer_class = self.project.get_annotation_serializer()
self.serializer_class = get_annotation_serializer(task=self.project.project_type)
return self.serializer_class
def get_queryset(self):
@ -59,7 +59,7 @@ class AnnotationDetail(generics.RetrieveUpdateDestroyAPIView):
def get_serializer_class(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
self.serializer_class = project.get_annotation_serializer()
self.serializer_class = get_annotation_serializer(task=project.project_type)
return self.serializer_class
def get_queryset(self):

5
backend/api/views/auto_labeling.py

@ -18,7 +18,8 @@ from ..exceptions import (AutoLabeliingPermissionDenied, AutoLabelingException,
URLConnectionError)
from ..models import AutoLabelingConfig, Document, Project
from ..permissions import IsInProjectOrAdmin, IsProjectAdmin
from ..serializers import AutoLabelingConfigSerializer
from ..serializers import (AutoLabelingConfigSerializer,
get_annotation_serializer)
class AutoLabelingTemplateListAPI(APIView):
@ -170,7 +171,7 @@ class AutoLabelingAnnotation(generics.CreateAPIView):
def get_serializer_class(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
self.serializer_class = project.get_annotation_serializer()
self.serializer_class = get_annotation_serializer(task=project.project_type)
return self.serializer_class
def get_queryset(self):

Loading…
Cancel
Save