Browse Source

Add factory class

pull/10/head
Hironsan 6 years ago
parent
commit
e9a52a115b
3 changed files with 82 additions and 68 deletions
  1. BIN
      app/db.sqlite3
  2. 63
      app/server/models.py
  3. 87
      app/server/views.py

BIN
app/db.sqlite3

63
app/server/models.py

@ -107,3 +107,66 @@ class Seq2seqAnnotation(Annotation):
class Meta:
unique_together = ('document', 'user', 'text')
from .serializers import *
# temporary solution
class Factory(object):
@classmethod
def get_template(cls, project):
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
template_name = 'annotation/document_classification.html'
elif project.is_type_of(Project.SEQUENCE_LABELING):
template_name = 'annotation/sequence_labeling.html'
elif project.is_type_of(Project.Seq2seq):
template_name = 'annotation/seq2seq.html'
else:
raise ValueError('Template does not exist')
return template_name
@classmethod
def get_documents(cls, project, is_null=True):
docs = project.documents.all()
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
docs = docs.filter(doc_annotations__isnull=is_null)
elif project.is_type_of(Project.SEQUENCE_LABELING):
docs = docs.filter(seq_annotations__isnull=is_null)
elif project.is_type_of(Project.Seq2seq):
docs = docs.filter(seq2seq_annotations__isnull=is_null)
else:
raise ValueError('Invalid project_type')
return docs
@classmethod
def get_project_serializer(cls, project):
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
return DocumentSerializer
elif project.is_type_of(Project.SEQUENCE_LABELING):
return SequenceSerializer
elif project.is_type_of(Project.Seq2seq):
return Seq2seqSerializer
else:
raise ValueError('Invalid project_type')
@classmethod
def get_annotation_serializer(cls, project):
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
return DocumentAnnotationSerializer
elif project.is_type_of(Project.SEQUENCE_LABELING):
return SequenceAnnotationSerializer
elif project.is_type_of(Project.Seq2seq):
return Seq2seqAnnotationSerializer
@classmethod
def get_annotations_by_doc(cls, document):
if document.project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
return document.doc_annotations.all()
elif document.project.is_type_of(Project.SEQUENCE_LABELING):
return document.seq_annotations.all()
elif document.project.is_type_of(Project.Seq2seq):
return document.seq2seq_annotations.all()

87
app/server/views.py

@ -8,19 +8,15 @@ from django.views.generic import TemplateView
from django.views.generic.list import ListView
from django.views.generic.detail import DetailView
from django.contrib.auth.mixins import LoginRequiredMixin
from django.core.paginator import Paginator
from rest_framework import viewsets, filters, generics
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.permissions import IsAdminUser
from django.db.models.query import QuerySet
from .models import Label, Document, Project
from .models import Label, Document, Project, Factory
from .models import DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation
from .serializers import LabelSerializer, ProjectSerializer, DocumentSerializer, DocumentAnnotationSerializer
from .serializers import SequenceSerializer, SequenceAnnotationSerializer
from .serializers import Seq2seqSerializer, Seq2seqAnnotationSerializer
from .serializers import LabelSerializer, ProjectSerializer
class IndexView(TemplateView):
@ -34,14 +30,7 @@ class ProjectView(LoginRequiredMixin, TemplateView):
context = super().get_context_data(**kwargs)
project_id = kwargs.get('project_id')
project = get_object_or_404(Project, pk=project_id)
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
self.template_name = 'annotation/document_classification.html'
elif project.is_type_of(Project.SEQUENCE_LABELING):
self.template_name = 'annotation/sequence_labeling.html'
elif project.is_type_of(Project.Seq2seq):
self.template_name = 'annotation/seq2seq.html'
else:
pass
self.template_name = Factory.get_template(project)
return context
@ -88,16 +77,11 @@ class ProjectViewSet(viewsets.ModelViewSet):
@action(methods=['get'], detail=True)
def progress(self, request, pk=None):
project = self.get_object()
docs = project.documents.all()
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
remaining = docs.filter(doc_annotations__isnull=True).count()
elif project.is_type_of(Project.SEQUENCE_LABELING):
remaining = docs.filter(seq_annotations__isnull=True).count()
elif project.is_type_of(Project.Seq2seq):
remaining = docs.filter(seq2seq_annotations__isnull=True).count()
else:
remaining = 0
return Response({'total': docs.count(), 'remaining': remaining})
docs = Factory.get_documents(project, is_null=True)
total = project.documents.count()
remaining = docs.count()
return Response({'total': total, 'remaining': remaining})
class ProjectLabelsAPI(generics.ListCreateAPIView):
@ -144,12 +128,7 @@ class ProjectDocsAPI(generics.ListCreateAPIView):
def get_serializer_class(self):
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
self.serializer_class = DocumentSerializer
elif project.is_type_of(Project.SEQUENCE_LABELING):
self.serializer_class = SequenceSerializer
elif project.is_type_of(Project.Seq2seq):
self.serializer_class = Seq2seqSerializer
self.serializer_class = Factory.get_project_serializer(project)
return self.serializer_class
@ -160,15 +139,8 @@ class ProjectDocsAPI(generics.ListCreateAPIView):
return queryset
project = get_object_or_404(Project, pk=project_id)
isnull = self.request.query_params.get('is_checked') == 'true'
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
queryset = queryset.filter(doc_annotations__isnull=isnull).distinct()
elif project.is_type_of(Project.SEQUENCE_LABELING):
queryset = queryset.filter(seq_annotations__isnull=isnull).distinct()
elif project.is_type_of(Project.Seq2seq):
queryset = queryset.filter(seq2seq_annotations__isnull=isnull).distinct()
else:
queryset = queryset
is_null = self.request.query_params.get('is_checked') == 'true'
queryset = Factory.get_documents(project, is_null).distinct()
return queryset
@ -179,45 +151,31 @@ class AnnotationsAPI(generics.ListCreateAPIView):
def get_serializer_class(self):
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
self.serializer_class = DocumentAnnotationSerializer
elif project.is_type_of(Project.SEQUENCE_LABELING):
self.serializer_class = SequenceAnnotationSerializer
elif project.is_type_of(Project.Seq2seq):
self.serializer_class = Seq2seqAnnotationSerializer
self.serializer_class = Factory.get_annotation_serializer(project)
return self.serializer_class
def get_queryset(self):
doc_id = self.kwargs['doc_id']
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
self.queryset = DocumentAnnotation.objects.all()
elif project.is_type_of(Project.SEQUENCE_LABELING):
self.queryset = SequenceAnnotation.objects.all()
elif project.is_type_of(Project.Seq2seq):
self.queryset = Seq2seqAnnotation.objects.all()
queryset = self.queryset.filter(document=doc_id)
document = get_object_or_404(Document, pk=doc_id)
self.queryset = Factory.get_annotations_by_doc(document)
return queryset
return self.queryset
def post(self, request, *args, **kwargs):
doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
label = get_object_or_404(Label, pk=request.data['label_id'])
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
self.serializer_class = Factory.get_annotation_serializer(project)
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
self.serializer_class = DocumentAnnotationSerializer
annotation = DocumentAnnotation(document=doc, label=label, manual=True,
user=self.request.user)
elif project.is_type_of(Project.SEQUENCE_LABELING):
self.serializer_class = SequenceAnnotationSerializer
annotation = SequenceAnnotation(document=doc, label=label, manual=True,
user=self.request.user,
start_offset=request.data['start_offset'],
end_offset=request.data['end_offset'])
elif project.is_type_of(Project.Seq2seq):
self.serializer_class = Seq2seqAnnotationSerializer
annotation = Seq2seqAnnotation(document=doc, manual=True, user=self.request.user)
annotation.save()
serializer = self.serializer_class(annotation)
@ -229,17 +187,10 @@ class AnnotationAPI(generics.RetrieveUpdateDestroyAPIView):
def get_queryset(self):
doc_id = self.kwargs['doc_id']
project_id = self.kwargs['project_id']
project = get_object_or_404(Project, pk=project_id)
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
self.queryset = DocumentAnnotation.objects.all()
elif project.is_type_of(Project.SEQUENCE_LABELING):
self.queryset = SequenceAnnotation.objects.all()
elif project.is_type_of(Project.Seq2seq):
self.queryset = Seq2seqAnnotation.objects.all()
queryset = self.queryset.filter(document=doc_id)
document = get_object_or_404(Document, pk=doc_id)
self.queryset = Factory.get_annotations_by_doc(document)
return queryset
return self.queryset
def get_object(self):
annotation_id = self.kwargs['annotation_id']

Loading…
Cancel
Save