diff --git a/app/db.sqlite3 b/app/db.sqlite3 index 59e5e779..3ce39775 100644 Binary files a/app/db.sqlite3 and b/app/db.sqlite3 differ diff --git a/app/server/api.py b/app/server/api.py index 17b8ca54..0117121f 100644 --- a/app/server/api.py +++ b/app/server/api.py @@ -89,13 +89,16 @@ class LabelDetail(generics.RetrieveUpdateDestroyAPIView): class DocumentList(generics.ListCreateAPIView): queryset = Document.objects.all() - from .serializers import SequenceDocumentSerializer - # serializer_class = DocumentSerializer - serializer_class = SequenceDocumentSerializer filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) search_fields = ('text', ) permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly) + def get_serializer_class(self): + project = get_object_or_404(Project, pk=self.kwargs['project_id']) + self.serializer_class = project.get_document_serializer() + + return self.serializer_class + def get_queryset(self): queryset = self.queryset.filter(project=self.kwargs['project_id']) if not self.request.query_params.get('is_checked'): diff --git a/app/server/models.py b/app/server/models.py index 2515932b..11caa6a8 100644 --- a/app/server/models.py +++ b/app/server/models.py @@ -82,6 +82,19 @@ class Project(models.Model): return docs + def get_document_serializer(self): + from .serializers import ClassificationDocumentSerializer + from .serializers import SequenceDocumentSerializer + from .serializers import Seq2seqDocumentSerializer + if self.is_type_of(Project.DOCUMENT_CLASSIFICATION): + return ClassificationDocumentSerializer + elif self.is_type_of(Project.SEQUENCE_LABELING): + return SequenceDocumentSerializer + elif self.is_type_of(Project.Seq2seq): + return Seq2seqDocumentSerializer + else: + raise ValueError('Invalid project_type') + def get_annotation_serializer(self): from .serializers import DocumentAnnotationSerializer from .serializers import SequenceAnnotationSerializer diff --git a/app/server/serializers.py b/app/server/serializers.py index 6e50f392..0d6b0e9b 100644 --- a/app/server/serializers.py +++ b/app/server/serializers.py @@ -67,17 +67,46 @@ class Seq2seqAnnotationSerializer(serializers.ModelSerializer): fields = ('id', 'text') +class ClassificationDocumentSerializer(serializers.ModelSerializer): + annotations = serializers.SerializerMethodField() + + def get_annotations(self, instance): + request = self.context.get('request') + if request: + annotations = instance.doc_annotations.filter(user=request.user) + serializer = DocumentAnnotationSerializer(annotations, many=True) + return serializer.data + + class Meta: + model = Document + fields = ('id', 'text', 'annotations') + + class SequenceDocumentSerializer(serializers.ModelSerializer): - labels = SequenceAnnotationSerializer(source='seq_annotations', many=True) annotations = serializers.SerializerMethodField() - def get_annotations(self, obj): + def get_annotations(self, instance): + request = self.context.get('request') + if request: + annotations = instance.seq_annotations.filter(user=request.user) + serializer = SequenceAnnotationSerializer(annotations, many=True) + return serializer.data + + class Meta: + model = Document + fields = ('id', 'text', 'annotations') + + +class Seq2seqDocumentSerializer(serializers.ModelSerializer): + annotations = serializers.SerializerMethodField() + + def get_annotations(self, instance): request = self.context.get('request') if request: - annotations = obj.seq_annotations.filter(user=request.user) - serializer = SequenceAnnotationSerializer(annotations.all(), many=True) + annotations = instance.seq2seq_annotations.filter(user=request.user) + serializer = Seq2seqAnnotationSerializer(annotations, many=True) return serializer.data class Meta: model = Document - fields = ('id', 'text', 'labels', 'annotations') + fields = ('id', 'text', 'annotations')