diff --git a/app/server/models.py b/app/server/models.py index 39e7bf97..ce893e2e 100644 --- a/app/server/models.py +++ b/app/server/models.py @@ -1,4 +1,3 @@ -import json from django.core.exceptions import ValidationError from django.db import models from django.urls import reverse @@ -10,16 +9,16 @@ from .utils import get_key_choices class Project(models.Model): DOCUMENT_CLASSIFICATION = 'DocumentClassification' SEQUENCE_LABELING = 'SequenceLabeling' - Seq2seq = 'Seq2seq' + SEQ2SEQ = 'Seq2seq' PROJECT_CHOICES = ( (DOCUMENT_CLASSIFICATION, 'document classification'), (SEQUENCE_LABELING, 'sequence labeling'), - (Seq2seq, 'sequence to sequence'), + (SEQ2SEQ, 'sequence to sequence'), ) name = models.CharField(max_length=100) - description = models.CharField(max_length=500) + description = models.TextField() guideline = models.TextField() created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -29,94 +28,80 @@ class Project(models.Model): def get_absolute_url(self): return reverse('upload', args=[self.id]) - def is_type_of(self, project_type): - return project_type == self.project_type + def __str__(self): + return self.name + + +class TextClassificationProject(Project): - def get_progress(self, user): - docs = self.get_documents(is_null=True, user=user) - total = self.documents.count() - remaining = docs.count() - return {'total': total, 'remaining': remaining} + class Meta: + proxy = True @property def image(self): - if self.is_type_of(self.DOCUMENT_CLASSIFICATION): - url = staticfiles_storage.url('images/cat-1045782_640.jpg') - elif self.is_type_of(self.SEQUENCE_LABELING): - url = staticfiles_storage.url('images/cat-3449999_640.jpg') - elif self.is_type_of(self.Seq2seq): - url = staticfiles_storage.url('images/tiger-768574_640.jpg') - - return url + return staticfiles_storage.url('images/cats/text_classification.jpg') def get_template_name(self): - if self.is_type_of(Project.DOCUMENT_CLASSIFICATION): - template_name = 'annotation/document_classification.html' - elif self.is_type_of(Project.SEQUENCE_LABELING): - template_name = 'annotation/sequence_labeling.html' - elif self.is_type_of(Project.Seq2seq): - template_name = 'annotation/seq2seq.html' - else: - raise ValueError('Template does not exist') - - return template_name - - def get_documents(self, is_null=True, user=None): - docs = self.documents.all() - if self.is_type_of(Project.DOCUMENT_CLASSIFICATION): - if user: - docs = docs.exclude(doc_annotations__user=user) - else: - docs = docs.filter(doc_annotations__isnull=is_null) - elif self.is_type_of(Project.SEQUENCE_LABELING): - if user: - docs = docs.exclude(seq_annotations__user=user) - else: - docs = docs.filter(seq_annotations__isnull=is_null) - elif self.is_type_of(Project.Seq2seq): - if user: - docs = docs.exclude(seq2seq_annotations__user=user) - else: - docs = docs.filter(seq2seq_annotations__isnull=is_null) - else: - raise ValueError('Invalid project_type') - - return docs + return 'annotation/document_classification.html' def get_document_serializer(self): from .serializers import ClassificationDocumentSerializer - from .serializers import SequenceDocumentSerializer - from .serializers import Seq2seqDocumentSerializer - if self.is_type_of(Project.DOCUMENT_CLASSIFICATION): - return ClassificationDocumentSerializer - elif self.is_type_of(Project.SEQUENCE_LABELING): - return SequenceDocumentSerializer - elif self.is_type_of(Project.Seq2seq): - return Seq2seqDocumentSerializer - else: - raise ValueError('Invalid project_type') + return ClassificationDocumentSerializer def get_annotation_serializer(self): from .serializers import DocumentAnnotationSerializer + return DocumentAnnotationSerializer + + def get_annotation_class(self): + return DocumentAnnotation + + +class SequenceLabelingProject(Project): + + class Meta: + proxy = True + + @property + def image(self): + return staticfiles_storage.url('images/cats/sequence_labeling.jpg') + + def get_template_name(self): + return 'annotation/sequence_labeling.html' + + def get_document_serializer(self): + from .serializers import SequenceDocumentSerializer + return SequenceDocumentSerializer + + def get_annotation_serializer(self): from .serializers import SequenceAnnotationSerializer - from .serializers import Seq2seqAnnotationSerializer - if self.is_type_of(Project.DOCUMENT_CLASSIFICATION): - return DocumentAnnotationSerializer - elif self.is_type_of(Project.SEQUENCE_LABELING): - return SequenceAnnotationSerializer - elif self.is_type_of(Project.Seq2seq): - return Seq2seqAnnotationSerializer + return SequenceAnnotationSerializer def get_annotation_class(self): - if self.is_type_of(Project.DOCUMENT_CLASSIFICATION): - return DocumentAnnotation - elif self.is_type_of(Project.SEQUENCE_LABELING): - return SequenceAnnotation - elif self.is_type_of(Project.Seq2seq): - return Seq2seqAnnotation + return SequenceAnnotation - def __str__(self): - return self.name + +class Seq2seqProject(Project): + + class Meta: + proxy = True + + @property + def image(self): + return staticfiles_storage.url('images/cats/seq2seq.jpg') + + def get_template_name(self): + return 'annotation/seq2seq.html' + + def get_document_serializer(self): + from .serializers import Seq2seqDocumentSerializer + return Seq2seqDocumentSerializer + + def get_annotation_serializer(self): + from .serializers import Seq2seqAnnotationSerializer + return Seq2seqAnnotationSerializer + + def get_annotation_class(self): + return Seq2seqAnnotation class Label(models.Model): @@ -144,84 +129,37 @@ class Document(models.Model): project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE) metadata = models.TextField(default='{}') - def get_annotations(self): - if self.project.is_type_of(Project.DOCUMENT_CLASSIFICATION): - return self.doc_annotations.all() - elif self.project.is_type_of(Project.SEQUENCE_LABELING): - return self.seq_annotations.all() - elif self.project.is_type_of(Project.Seq2seq): - return self.seq2seq_annotations.all() - - def to_csv(self): - return self.make_dataset() - - def make_dataset(self): - if self.project.is_type_of(Project.DOCUMENT_CLASSIFICATION): - return self.make_dataset_for_classification() - elif self.project.is_type_of(Project.SEQUENCE_LABELING): - return self.make_dataset_for_sequence_labeling() - elif self.project.is_type_of(Project.Seq2seq): - return self.make_dataset_for_seq2seq() - - def make_dataset_for_classification(self): - annotations = self.get_annotations() - dataset = [[self.id, self.text, a.label.text, a.user.username, self.metadata] - for a in annotations] - return dataset - - def make_dataset_for_sequence_labeling(self): - annotations = self.get_annotations() - dataset = [[self.id, ch, 'O', self.metadata] for ch in self.text] - for a in annotations: - for i in range(a.start_offset, a.end_offset): - if i == a.start_offset: - dataset[i][2] = 'B-{}'.format(a.label.text) - else: - dataset[i][2] = 'I-{}'.format(a.label.text) - return dataset - - def make_dataset_for_seq2seq(self): - annotations = self.get_annotations() - dataset = [[self.id, self.text, a.text, a.user.username, self.metadata] - for a in annotations] - return dataset - - def to_json(self): - return self.make_dataset_json() - - def make_dataset_json(self): - if self.project.is_type_of(Project.DOCUMENT_CLASSIFICATION): - return self.make_dataset_for_classification_json() - elif self.project.is_type_of(Project.SEQUENCE_LABELING): - return self.make_dataset_for_sequence_labeling_json() - elif self.project.is_type_of(Project.Seq2seq): - return self.make_dataset_for_seq2seq_json() - - def make_dataset_for_classification_json(self): - annotations = self.get_annotations() - labels = [a.label.text for a in annotations] - username = annotations[0].user.username - dataset = {'doc_id': self.id, 'text': self.text, 'labels': labels, 'username': username, 'metadata': json.loads(self.metadata)} - return dataset - - def make_dataset_for_sequence_labeling_json(self): - annotations = self.get_annotations() - entities = [(a.start_offset, a.end_offset, a.label.text) for a in annotations] - username = annotations[0].user.username - dataset = {'doc_id': self.id, 'text': self.text, 'entities': entities, 'username': username, 'metadata': json.loads(self.metadata)} - return dataset - - def make_dataset_for_seq2seq_json(self): - annotations = self.get_annotations() - sentences = [a.text for a in annotations] - username = annotations[0].user.username - dataset = {'doc_id': self.id, 'text': self.text, 'sentences': sentences, 'username': username, 'metadata': json.loads(self.metadata)} - return dataset - def __str__(self): return self.text[:50] +class TextClassificationDocument(Document): + + class Meta: + proxy = True + + def get_annotations(self): + return self.doc_annotations.all() + + +class SequenceLabelingDocument(Document): + + class Meta: + proxy = True + + def get_annotations(self): + return self.seq_annotations.all() + + +class Seq2seqDocument(Document): + + class Meta: + proxy = True + + def get_annotations(self): + return self.seq2seq_annotations.all() + + class Annotation(models.Model): prob = models.FloatField(default=0.0) manual = models.BooleanField(default=False) @@ -232,7 +170,7 @@ class Annotation(models.Model): class DocumentAnnotation(Annotation): - document = models.ForeignKey(Document, related_name='doc_annotations', on_delete=models.CASCADE) + document = models.ForeignKey(TextClassificationDocument, related_name='doc_annotations', on_delete=models.CASCADE) label = models.ForeignKey(Label, on_delete=models.CASCADE) class Meta: @@ -240,7 +178,7 @@ class DocumentAnnotation(Annotation): class SequenceAnnotation(Annotation): - document = models.ForeignKey(Document, related_name='seq_annotations', on_delete=models.CASCADE) + document = models.ForeignKey(SequenceLabelingDocument, related_name='seq_annotations', on_delete=models.CASCADE) label = models.ForeignKey(Label, on_delete=models.CASCADE) start_offset = models.IntegerField() end_offset = models.IntegerField() @@ -254,7 +192,7 @@ class SequenceAnnotation(Annotation): class Seq2seqAnnotation(Annotation): - document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE) + document = models.ForeignKey(Seq2seqDocument, related_name='seq2seq_annotations', on_delete=models.CASCADE) text = models.TextField() class Meta: diff --git a/app/server/static/images/tiger-768574_640.jpg b/app/server/static/images/cats/seq2seq.jpg similarity index 100% rename from app/server/static/images/tiger-768574_640.jpg rename to app/server/static/images/cats/seq2seq.jpg diff --git a/app/server/static/images/cat-3449999_640.jpg b/app/server/static/images/cats/sequence_labeling.jpg similarity index 100% rename from app/server/static/images/cat-3449999_640.jpg rename to app/server/static/images/cats/sequence_labeling.jpg diff --git a/app/server/static/images/cat-1045782_640.jpg b/app/server/static/images/cats/text_classification.jpg similarity index 100% rename from app/server/static/images/cat-1045782_640.jpg rename to app/server/static/images/cats/text_classification.jpg diff --git a/app/server/tests/test_models.py b/app/server/tests/test_models.py index 3879ca8b..7c554974 100644 --- a/app/server/tests/test_models.py +++ b/app/server/tests/test_models.py @@ -3,19 +3,90 @@ from django.core.exceptions import ValidationError from django.db.utils import IntegrityError from mixer.backend.django import mixer from ..models import Label, DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation +from ..serializers import ClassificationDocumentSerializer, DocumentAnnotationSerializer +from ..serializers import SequenceDocumentSerializer, SequenceAnnotationSerializer +from ..serializers import Seq2seqDocumentSerializer, Seq2seqAnnotationSerializer -class TestProject(TestCase): +class TestTextClassificationProject(TestCase): - def test_project_type(self): - project = mixer.blend('server.Project') - project.is_type_of(project.project_type) + @classmethod + def setUpTestData(cls): + cls.project = mixer.blend('server.TextClassificationProject') - def test_get_progress(self): - project = mixer.blend('server.Project') - res = project.get_progress(None) - self.assertEqual(res['total'], 0) - self.assertEqual(res['remaining'], 0) + def test_image(self): + image_url = self.project.image + self.assertTrue(image_url.endswith('.jpg')) + + def test_get_template_name(self): + template = self.project.get_template_name() + self.assertEqual(template, 'annotation/document_classification.html') + + def test_get_document_serializer(self): + serializer = self.project.get_document_serializer() + self.assertEqual(serializer, ClassificationDocumentSerializer) + + def test_get_annotation_serializer(self): + serializer = self.project.get_annotation_serializer() + self.assertEqual(serializer, DocumentAnnotationSerializer) + + def test_get_annotation_class(self): + klass = self.project.get_annotation_class() + self.assertEqual(klass, DocumentAnnotation) + + +class TestSequenceLabelingProject(TestCase): + + @classmethod + def setUpTestData(cls): + cls.project = mixer.blend('server.SequenceLabelingProject') + + def test_image(self): + image_url = self.project.image + self.assertTrue(image_url.endswith('.jpg')) + + def test_get_template_name(self): + template = self.project.get_template_name() + self.assertEqual(template, 'annotation/sequence_labeling.html') + + def test_get_document_serializer(self): + serializer = self.project.get_document_serializer() + self.assertEqual(serializer, SequenceDocumentSerializer) + + def test_get_annotation_serializer(self): + serializer = self.project.get_annotation_serializer() + self.assertEqual(serializer, SequenceAnnotationSerializer) + + def test_get_annotation_class(self): + klass = self.project.get_annotation_class() + self.assertEqual(klass, SequenceAnnotation) + + +class TestSeq2seqProject(TestCase): + + @classmethod + def setUpTestData(cls): + cls.project = mixer.blend('server.Seq2seqProject') + + def test_image(self): + image_url = self.project.image + self.assertTrue(image_url.endswith('.jpg')) + + def test_get_template_name(self): + template = self.project.get_template_name() + self.assertEqual(template, 'annotation/seq2seq.html') + + def test_get_document_serializer(self): + serializer = self.project.get_document_serializer() + self.assertEqual(serializer, Seq2seqDocumentSerializer) + + def test_get_annotation_serializer(self): + serializer = self.project.get_annotation_serializer() + self.assertEqual(serializer, Seq2seqAnnotationSerializer) + + def test_get_annotation_class(self): + klass = self.project.get_annotation_class() + self.assertEqual(klass, Seq2seqAnnotation) class TestLabel(TestCase): @@ -37,6 +108,36 @@ class TestLabel(TestCase): Label(project=label.project, text=label.text).save() +class TestTextClassificationDocument(TestCase): + + @classmethod + def setUpTestData(cls): + cls.doc = mixer.blend('server.TextClassificationDocument') + + def test_get_annotations(self): + self.assertEqual(self.doc.get_annotations().count(), 0) + + +class TestSequenceLabelingDocument(TestCase): + + @classmethod + def setUpTestData(cls): + cls.doc = mixer.blend('server.SequenceLabelingDocument') + + def test_get_annotations(self): + self.assertEqual(self.doc.get_annotations().count(), 0) + + +class TestSeq2seqDocument(TestCase): + + @classmethod + def setUpTestData(cls): + cls.doc = mixer.blend('server.Seq2seqDocument') + + def test_get_annotations(self): + self.assertEqual(self.doc.get_annotations().count(), 0) + + class TestDocumentAnnotation(TestCase): def test_uniqueness(self):