|
|
@ -8,19 +8,15 @@ from django.views.generic import TemplateView |
|
|
|
from django.views.generic.list import ListView |
|
|
|
from django.views.generic.detail import DetailView |
|
|
|
from django.contrib.auth.mixins import LoginRequiredMixin |
|
|
|
from django.core.paginator import Paginator |
|
|
|
from rest_framework import viewsets, filters, generics |
|
|
|
from rest_framework.decorators import action |
|
|
|
from rest_framework.response import Response |
|
|
|
from rest_framework.permissions import IsAdminUser |
|
|
|
from django.db.models.query import QuerySet |
|
|
|
|
|
|
|
|
|
|
|
from .models import Label, Document, Project |
|
|
|
from .models import Label, Document, Project, Factory |
|
|
|
from .models import DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation |
|
|
|
from .serializers import LabelSerializer, ProjectSerializer, DocumentSerializer, DocumentAnnotationSerializer |
|
|
|
from .serializers import SequenceSerializer, SequenceAnnotationSerializer |
|
|
|
from .serializers import Seq2seqSerializer, Seq2seqAnnotationSerializer |
|
|
|
from .serializers import LabelSerializer, ProjectSerializer |
|
|
|
|
|
|
|
|
|
|
|
class IndexView(TemplateView): |
|
|
@ -34,14 +30,7 @@ class ProjectView(LoginRequiredMixin, TemplateView): |
|
|
|
context = super().get_context_data(**kwargs) |
|
|
|
project_id = kwargs.get('project_id') |
|
|
|
project = get_object_or_404(Project, pk=project_id) |
|
|
|
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION): |
|
|
|
self.template_name = 'annotation/document_classification.html' |
|
|
|
elif project.is_type_of(Project.SEQUENCE_LABELING): |
|
|
|
self.template_name = 'annotation/sequence_labeling.html' |
|
|
|
elif project.is_type_of(Project.Seq2seq): |
|
|
|
self.template_name = 'annotation/seq2seq.html' |
|
|
|
else: |
|
|
|
pass |
|
|
|
self.template_name = Factory.get_template(project) |
|
|
|
|
|
|
|
return context |
|
|
|
|
|
|
@ -88,16 +77,11 @@ class ProjectViewSet(viewsets.ModelViewSet): |
|
|
|
@action(methods=['get'], detail=True) |
|
|
|
def progress(self, request, pk=None): |
|
|
|
project = self.get_object() |
|
|
|
docs = project.documents.all() |
|
|
|
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION): |
|
|
|
remaining = docs.filter(doc_annotations__isnull=True).count() |
|
|
|
elif project.is_type_of(Project.SEQUENCE_LABELING): |
|
|
|
remaining = docs.filter(seq_annotations__isnull=True).count() |
|
|
|
elif project.is_type_of(Project.Seq2seq): |
|
|
|
remaining = docs.filter(seq2seq_annotations__isnull=True).count() |
|
|
|
else: |
|
|
|
remaining = 0 |
|
|
|
return Response({'total': docs.count(), 'remaining': remaining}) |
|
|
|
docs = Factory.get_documents(project, is_null=True) |
|
|
|
total = project.documents.count() |
|
|
|
remaining = docs.count() |
|
|
|
|
|
|
|
return Response({'total': total, 'remaining': remaining}) |
|
|
|
|
|
|
|
|
|
|
|
class ProjectLabelsAPI(generics.ListCreateAPIView): |
|
|
@ -144,12 +128,7 @@ class ProjectDocsAPI(generics.ListCreateAPIView): |
|
|
|
def get_serializer_class(self): |
|
|
|
project_id = self.kwargs['project_id'] |
|
|
|
project = get_object_or_404(Project, pk=project_id) |
|
|
|
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION): |
|
|
|
self.serializer_class = DocumentSerializer |
|
|
|
elif project.is_type_of(Project.SEQUENCE_LABELING): |
|
|
|
self.serializer_class = SequenceSerializer |
|
|
|
elif project.is_type_of(Project.Seq2seq): |
|
|
|
self.serializer_class = Seq2seqSerializer |
|
|
|
self.serializer_class = Factory.get_project_serializer(project) |
|
|
|
|
|
|
|
return self.serializer_class |
|
|
|
|
|
|
@ -160,15 +139,8 @@ class ProjectDocsAPI(generics.ListCreateAPIView): |
|
|
|
return queryset |
|
|
|
|
|
|
|
project = get_object_or_404(Project, pk=project_id) |
|
|
|
isnull = self.request.query_params.get('is_checked') == 'true' |
|
|
|
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION): |
|
|
|
queryset = queryset.filter(doc_annotations__isnull=isnull).distinct() |
|
|
|
elif project.is_type_of(Project.SEQUENCE_LABELING): |
|
|
|
queryset = queryset.filter(seq_annotations__isnull=isnull).distinct() |
|
|
|
elif project.is_type_of(Project.Seq2seq): |
|
|
|
queryset = queryset.filter(seq2seq_annotations__isnull=isnull).distinct() |
|
|
|
else: |
|
|
|
queryset = queryset |
|
|
|
is_null = self.request.query_params.get('is_checked') == 'true' |
|
|
|
queryset = Factory.get_documents(project, is_null).distinct() |
|
|
|
|
|
|
|
return queryset |
|
|
|
|
|
|
@ -179,45 +151,31 @@ class AnnotationsAPI(generics.ListCreateAPIView): |
|
|
|
def get_serializer_class(self): |
|
|
|
project_id = self.kwargs['project_id'] |
|
|
|
project = get_object_or_404(Project, pk=project_id) |
|
|
|
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION): |
|
|
|
self.serializer_class = DocumentAnnotationSerializer |
|
|
|
elif project.is_type_of(Project.SEQUENCE_LABELING): |
|
|
|
self.serializer_class = SequenceAnnotationSerializer |
|
|
|
elif project.is_type_of(Project.Seq2seq): |
|
|
|
self.serializer_class = Seq2seqAnnotationSerializer |
|
|
|
self.serializer_class = Factory.get_annotation_serializer(project) |
|
|
|
|
|
|
|
return self.serializer_class |
|
|
|
|
|
|
|
def get_queryset(self): |
|
|
|
doc_id = self.kwargs['doc_id'] |
|
|
|
project_id = self.kwargs['project_id'] |
|
|
|
project = get_object_or_404(Project, pk=project_id) |
|
|
|
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION): |
|
|
|
self.queryset = DocumentAnnotation.objects.all() |
|
|
|
elif project.is_type_of(Project.SEQUENCE_LABELING): |
|
|
|
self.queryset = SequenceAnnotation.objects.all() |
|
|
|
elif project.is_type_of(Project.Seq2seq): |
|
|
|
self.queryset = Seq2seqAnnotation.objects.all() |
|
|
|
queryset = self.queryset.filter(document=doc_id) |
|
|
|
document = get_object_or_404(Document, pk=doc_id) |
|
|
|
self.queryset = Factory.get_annotations_by_doc(document) |
|
|
|
|
|
|
|
return queryset |
|
|
|
return self.queryset |
|
|
|
|
|
|
|
def post(self, request, *args, **kwargs): |
|
|
|
doc = get_object_or_404(Document, pk=self.kwargs['doc_id']) |
|
|
|
label = get_object_or_404(Label, pk=request.data['label_id']) |
|
|
|
project = get_object_or_404(Project, pk=self.kwargs['project_id']) |
|
|
|
self.serializer_class = Factory.get_annotation_serializer(project) |
|
|
|
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION): |
|
|
|
self.serializer_class = DocumentAnnotationSerializer |
|
|
|
annotation = DocumentAnnotation(document=doc, label=label, manual=True, |
|
|
|
user=self.request.user) |
|
|
|
elif project.is_type_of(Project.SEQUENCE_LABELING): |
|
|
|
self.serializer_class = SequenceAnnotationSerializer |
|
|
|
annotation = SequenceAnnotation(document=doc, label=label, manual=True, |
|
|
|
user=self.request.user, |
|
|
|
start_offset=request.data['start_offset'], |
|
|
|
end_offset=request.data['end_offset']) |
|
|
|
elif project.is_type_of(Project.Seq2seq): |
|
|
|
self.serializer_class = Seq2seqAnnotationSerializer |
|
|
|
annotation = Seq2seqAnnotation(document=doc, manual=True, user=self.request.user) |
|
|
|
annotation.save() |
|
|
|
serializer = self.serializer_class(annotation) |
|
|
@ -229,17 +187,10 @@ class AnnotationAPI(generics.RetrieveUpdateDestroyAPIView): |
|
|
|
|
|
|
|
def get_queryset(self): |
|
|
|
doc_id = self.kwargs['doc_id'] |
|
|
|
project_id = self.kwargs['project_id'] |
|
|
|
project = get_object_or_404(Project, pk=project_id) |
|
|
|
if project.is_type_of(Project.DOCUMENT_CLASSIFICATION): |
|
|
|
self.queryset = DocumentAnnotation.objects.all() |
|
|
|
elif project.is_type_of(Project.SEQUENCE_LABELING): |
|
|
|
self.queryset = SequenceAnnotation.objects.all() |
|
|
|
elif project.is_type_of(Project.Seq2seq): |
|
|
|
self.queryset = Seq2seqAnnotation.objects.all() |
|
|
|
queryset = self.queryset.filter(document=doc_id) |
|
|
|
document = get_object_or_404(Document, pk=doc_id) |
|
|
|
self.queryset = Factory.get_annotations_by_doc(document) |
|
|
|
|
|
|
|
return queryset |
|
|
|
return self.queryset |
|
|
|
|
|
|
|
def get_object(self): |
|
|
|
annotation_id = self.kwargs['annotation_id'] |
|
|
|