diff --git a/app/api/managers.py b/app/api/managers.py new file mode 100644 index 00000000..7aa8dd48 --- /dev/null +++ b/app/api/managers.py @@ -0,0 +1,33 @@ +from collections import Counter + +from django.db.models import Manager, Count + + +class AnnotationManager(Manager): + + def get_label_per_data(self, project): + label_count = Counter() + user_count = Counter() + docs = project.documents.all() + annotations = self.filter(document_id__in=docs.all()) + + for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')): + label_count[d['label__text']] += d['label__count'] + user_count[d['user__username']] += d['user__count'] + + return label_count, user_count + + +class Seq2seqAnnotationManager(Manager): + + def get_label_per_data(self, project): + label_count = Counter() + user_count = Counter() + docs = project.documents.all() + annotations = self.filter(document_id__in=docs.all()) + + for d in annotations.values('text', 'user__username').annotate(Count('text'), Count('user')): + label_count[d['text']] += d['text__count'] + user_count[d['user__username']] += d['user__count'] + + return label_count, user_count diff --git a/app/api/models.py b/app/api/models.py index 573d7661..93307487 100644 --- a/app/api/models.py +++ b/app/api/models.py @@ -7,6 +7,8 @@ from django.contrib.staticfiles.storage import staticfiles_storage from django.core.exceptions import ValidationError from polymorphic.models import PolymorphicModel +from .managers import AnnotationManager, Seq2seqAnnotationManager + DOCUMENT_CLASSIFICATION = 'DocumentClassification' SEQUENCE_LABELING = 'SequenceLabeling' SEQ2SEQ = 'Seq2seq' @@ -192,6 +194,8 @@ class Document(models.Model): class Annotation(models.Model): + objects = AnnotationManager() + prob = models.FloatField(default=0.0) manual = models.BooleanField(default=False) user = models.ForeignKey(User, on_delete=models.CASCADE) @@ -225,6 +229,9 @@ class SequenceAnnotation(Annotation): class Seq2seqAnnotation(Annotation): + # Override AnnotationManager for custom functionality + objects = Seq2seqAnnotationManager() + document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE) text = models.TextField() diff --git a/app/api/views.py b/app/api/views.py index 015b8e17..90bb4f25 100644 --- a/app/api/views.py +++ b/app/api/views.py @@ -1,5 +1,3 @@ -from collections import Counter - from django.conf import settings from django.shortcuts import get_object_or_404, redirect from django_filters.rest_framework import DjangoFilterBackend @@ -15,7 +13,7 @@ from rest_framework.parsers import MultiPartParser from rest_framework_csv.renderers import CSVRenderer from .filters import DocumentFilter -from .models import Project, Label, Document, Seq2seqAnnotation +from .models import Project, Label, Document from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsOwnAnnotation from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer, UserSerializer from .serializers import ProjectPolymorphicSerializer @@ -85,20 +83,8 @@ class StatisticsAPI(APIView): return {'total': total, 'remaining': remaining} def label_per_data(self, project): - label_count = Counter() - user_count = Counter() annotation_class = project.get_annotation_class() - docs = project.documents.all() - annotations = annotation_class.objects.filter(document_id__in=docs.all()) - if annotation_class == Seq2seqAnnotation: - for d in annotations.values('text', 'user__username').annotate(Count('text'), Count('user')): - label_count[d['text']] += d['text__count'] - user_count[d['user__username']] += d['user__count'] - else: - for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')): - label_count[d['label__text']] += d['label__count'] - user_count[d['user__username']] += d['user__count'] - return label_count, user_count + return annotation_class.objects.get_label_per_data(project=project) class ApproveLabelsAPI(APIView):