Browse Source
Merge pull request #287 from CatalystCode/bugfix/seq2seq_label_download
Bugfix/seq2seq label download
pull/312/head
Hiroki Nakayama
5 years ago
committed by
GitHub
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with
44 additions and
11 deletions
-
app/api/managers.py
-
app/api/models.py
-
app/api/serializers.py
-
app/api/utils.py
-
app/api/views.py
|
|
@ -0,0 +1,33 @@ |
|
|
|
from collections import Counter |
|
|
|
|
|
|
|
from django.db.models import Manager, Count |
|
|
|
|
|
|
|
|
|
|
|
class AnnotationManager(Manager): |
|
|
|
|
|
|
|
def get_label_per_data(self, project): |
|
|
|
label_count = Counter() |
|
|
|
user_count = Counter() |
|
|
|
docs = project.documents.all() |
|
|
|
annotations = self.filter(document_id__in=docs.all()) |
|
|
|
|
|
|
|
for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')): |
|
|
|
label_count[d['label__text']] += d['label__count'] |
|
|
|
user_count[d['user__username']] += d['user__count'] |
|
|
|
|
|
|
|
return label_count, user_count |
|
|
|
|
|
|
|
|
|
|
|
class Seq2seqAnnotationManager(Manager): |
|
|
|
|
|
|
|
def get_label_per_data(self, project): |
|
|
|
label_count = Counter() |
|
|
|
user_count = Counter() |
|
|
|
docs = project.documents.all() |
|
|
|
annotations = self.filter(document_id__in=docs.all()) |
|
|
|
|
|
|
|
for d in annotations.values('text', 'user__username').annotate(Count('text'), Count('user')): |
|
|
|
label_count[d['text']] += d['text__count'] |
|
|
|
user_count[d['user__username']] += d['user__count'] |
|
|
|
|
|
|
|
return label_count, user_count |
|
|
@ -7,6 +7,8 @@ from django.contrib.staticfiles.storage import staticfiles_storage |
|
|
|
from django.core.exceptions import ValidationError |
|
|
|
from polymorphic.models import PolymorphicModel |
|
|
|
|
|
|
|
from .managers import AnnotationManager, Seq2seqAnnotationManager |
|
|
|
|
|
|
|
DOCUMENT_CLASSIFICATION = 'DocumentClassification' |
|
|
|
SEQUENCE_LABELING = 'SequenceLabeling' |
|
|
|
SEQ2SEQ = 'Seq2seq' |
|
|
@ -191,6 +193,8 @@ class Document(models.Model): |
|
|
|
|
|
|
|
|
|
|
|
class Annotation(models.Model): |
|
|
|
objects = AnnotationManager() |
|
|
|
|
|
|
|
prob = models.FloatField(default=0.0) |
|
|
|
manual = models.BooleanField(default=False) |
|
|
|
user = models.ForeignKey(User, on_delete=models.CASCADE) |
|
|
@ -224,6 +228,9 @@ class SequenceAnnotation(Annotation): |
|
|
|
|
|
|
|
|
|
|
|
class Seq2seqAnnotation(Annotation): |
|
|
|
# Override AnnotationManager for custom functionality |
|
|
|
objects = Seq2seqAnnotationManager() |
|
|
|
|
|
|
|
document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE) |
|
|
|
text = models.CharField(max_length=500) |
|
|
|
|
|
|
|
|
|
@ -160,5 +160,5 @@ class Seq2seqAnnotationSerializer(serializers.ModelSerializer): |
|
|
|
|
|
|
|
class Meta: |
|
|
|
model = Seq2seqAnnotation |
|
|
|
fields = ('id', 'text', 'user', 'document') |
|
|
|
fields = ('id', 'text', 'user', 'document', 'prob') |
|
|
|
read_only_fields = ('user',) |
|
|
@ -373,6 +373,7 @@ class JSONLRenderer(JSONRenderer): |
|
|
|
ensure_ascii=self.ensure_ascii, |
|
|
|
allow_nan=not self.strict) + '\n' |
|
|
|
|
|
|
|
|
|
|
|
class JSONPainter(object): |
|
|
|
|
|
|
|
def paint(self, documents): |
|
|
@ -406,6 +407,7 @@ class JSONPainter(object): |
|
|
|
data.append(d) |
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
class CSVPainter(JSONPainter): |
|
|
|
|
|
|
|
def paint(self, documents): |
|
|
|
|
|
@ -1,5 +1,3 @@ |
|
|
|
from collections import Counter |
|
|
|
|
|
|
|
from django.conf import settings |
|
|
|
from django.shortcuts import get_object_or_404, redirect |
|
|
|
from django_filters.rest_framework import DjangoFilterBackend |
|
|
@ -85,15 +83,8 @@ class StatisticsAPI(APIView): |
|
|
|
return {'total': total, 'remaining': remaining} |
|
|
|
|
|
|
|
def label_per_data(self, project): |
|
|
|
label_count = Counter() |
|
|
|
user_count = Counter() |
|
|
|
annotation_class = project.get_annotation_class() |
|
|
|
docs = project.documents.all() |
|
|
|
annotations = annotation_class.objects.filter(document_id__in=docs.all()) |
|
|
|
for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')): |
|
|
|
label_count[d['label__text']] += d['label__count'] |
|
|
|
user_count[d['user__username']] += d['user__count'] |
|
|
|
return label_count, user_count |
|
|
|
return annotation_class.objects.get_label_per_data(project=project) |
|
|
|
|
|
|
|
|
|
|
|
class ApproveLabelsAPI(APIView): |
|
|
|