Browse Source

Divide project and document models into 3 models

Because of polymorphism
pull/110/head
Hironsan 5 years ago
committed by Hironsan
parent
commit
a43738a13a
5 changed files with 201 additions and 162 deletions
  1. 244
      app/server/models.py
  2. 0
      app/server/static/images/cats/seq2seq.jpg
  3. 0
      app/server/static/images/cats/sequence_labeling.jpg
  4. 0
      app/server/static/images/cats/text_classification.jpg
  5. 119
      app/server/tests/test_models.py

244
app/server/models.py

@ -1,4 +1,3 @@
import json
from django.core.exceptions import ValidationError
from django.db import models
from django.urls import reverse
@ -10,16 +9,16 @@ from .utils import get_key_choices
class Project(models.Model):
DOCUMENT_CLASSIFICATION = 'DocumentClassification'
SEQUENCE_LABELING = 'SequenceLabeling'
Seq2seq = 'Seq2seq'
SEQ2SEQ = 'Seq2seq'
PROJECT_CHOICES = (
(DOCUMENT_CLASSIFICATION, 'document classification'),
(SEQUENCE_LABELING, 'sequence labeling'),
(Seq2seq, 'sequence to sequence'),
(SEQ2SEQ, 'sequence to sequence'),
)
name = models.CharField(max_length=100)
description = models.CharField(max_length=500)
description = models.TextField()
guideline = models.TextField()
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
@ -29,94 +28,80 @@ class Project(models.Model):
def get_absolute_url(self):
return reverse('upload', args=[self.id])
def is_type_of(self, project_type):
return project_type == self.project_type
def __str__(self):
return self.name
class TextClassificationProject(Project):
def get_progress(self, user):
docs = self.get_documents(is_null=True, user=user)
total = self.documents.count()
remaining = docs.count()
return {'total': total, 'remaining': remaining}
class Meta:
proxy = True
@property
def image(self):
if self.is_type_of(self.DOCUMENT_CLASSIFICATION):
url = staticfiles_storage.url('images/cat-1045782_640.jpg')
elif self.is_type_of(self.SEQUENCE_LABELING):
url = staticfiles_storage.url('images/cat-3449999_640.jpg')
elif self.is_type_of(self.Seq2seq):
url = staticfiles_storage.url('images/tiger-768574_640.jpg')
return url
return staticfiles_storage.url('images/cats/text_classification.jpg')
def get_template_name(self):
if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
template_name = 'annotation/document_classification.html'
elif self.is_type_of(Project.SEQUENCE_LABELING):
template_name = 'annotation/sequence_labeling.html'
elif self.is_type_of(Project.Seq2seq):
template_name = 'annotation/seq2seq.html'
else:
raise ValueError('Template does not exist')
return template_name
def get_documents(self, is_null=True, user=None):
docs = self.documents.all()
if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
if user:
docs = docs.exclude(doc_annotations__user=user)
else:
docs = docs.filter(doc_annotations__isnull=is_null)
elif self.is_type_of(Project.SEQUENCE_LABELING):
if user:
docs = docs.exclude(seq_annotations__user=user)
else:
docs = docs.filter(seq_annotations__isnull=is_null)
elif self.is_type_of(Project.Seq2seq):
if user:
docs = docs.exclude(seq2seq_annotations__user=user)
else:
docs = docs.filter(seq2seq_annotations__isnull=is_null)
else:
raise ValueError('Invalid project_type')
return docs
return 'annotation/document_classification.html'
def get_document_serializer(self):
from .serializers import ClassificationDocumentSerializer
from .serializers import SequenceDocumentSerializer
from .serializers import Seq2seqDocumentSerializer
if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
return ClassificationDocumentSerializer
elif self.is_type_of(Project.SEQUENCE_LABELING):
return SequenceDocumentSerializer
elif self.is_type_of(Project.Seq2seq):
return Seq2seqDocumentSerializer
else:
raise ValueError('Invalid project_type')
return ClassificationDocumentSerializer
def get_annotation_serializer(self):
from .serializers import DocumentAnnotationSerializer
return DocumentAnnotationSerializer
def get_annotation_class(self):
return DocumentAnnotation
class SequenceLabelingProject(Project):
class Meta:
proxy = True
@property
def image(self):
return staticfiles_storage.url('images/cats/sequence_labeling.jpg')
def get_template_name(self):
return 'annotation/sequence_labeling.html'
def get_document_serializer(self):
from .serializers import SequenceDocumentSerializer
return SequenceDocumentSerializer
def get_annotation_serializer(self):
from .serializers import SequenceAnnotationSerializer
from .serializers import Seq2seqAnnotationSerializer
if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
return DocumentAnnotationSerializer
elif self.is_type_of(Project.SEQUENCE_LABELING):
return SequenceAnnotationSerializer
elif self.is_type_of(Project.Seq2seq):
return Seq2seqAnnotationSerializer
return SequenceAnnotationSerializer
def get_annotation_class(self):
if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
return DocumentAnnotation
elif self.is_type_of(Project.SEQUENCE_LABELING):
return SequenceAnnotation
elif self.is_type_of(Project.Seq2seq):
return Seq2seqAnnotation
return SequenceAnnotation
def __str__(self):
return self.name
class Seq2seqProject(Project):
class Meta:
proxy = True
@property
def image(self):
return staticfiles_storage.url('images/cats/seq2seq.jpg')
def get_template_name(self):
return 'annotation/seq2seq.html'
def get_document_serializer(self):
from .serializers import Seq2seqDocumentSerializer
return Seq2seqDocumentSerializer
def get_annotation_serializer(self):
from .serializers import Seq2seqAnnotationSerializer
return Seq2seqAnnotationSerializer
def get_annotation_class(self):
return Seq2seqAnnotation
class Label(models.Model):
@ -144,84 +129,37 @@ class Document(models.Model):
project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE)
metadata = models.TextField(default='{}')
def get_annotations(self):
if self.project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
return self.doc_annotations.all()
elif self.project.is_type_of(Project.SEQUENCE_LABELING):
return self.seq_annotations.all()
elif self.project.is_type_of(Project.Seq2seq):
return self.seq2seq_annotations.all()
def to_csv(self):
return self.make_dataset()
def make_dataset(self):
if self.project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
return self.make_dataset_for_classification()
elif self.project.is_type_of(Project.SEQUENCE_LABELING):
return self.make_dataset_for_sequence_labeling()
elif self.project.is_type_of(Project.Seq2seq):
return self.make_dataset_for_seq2seq()
def make_dataset_for_classification(self):
annotations = self.get_annotations()
dataset = [[self.id, self.text, a.label.text, a.user.username, self.metadata]
for a in annotations]
return dataset
def make_dataset_for_sequence_labeling(self):
annotations = self.get_annotations()
dataset = [[self.id, ch, 'O', self.metadata] for ch in self.text]
for a in annotations:
for i in range(a.start_offset, a.end_offset):
if i == a.start_offset:
dataset[i][2] = 'B-{}'.format(a.label.text)
else:
dataset[i][2] = 'I-{}'.format(a.label.text)
return dataset
def make_dataset_for_seq2seq(self):
annotations = self.get_annotations()
dataset = [[self.id, self.text, a.text, a.user.username, self.metadata]
for a in annotations]
return dataset
def to_json(self):
return self.make_dataset_json()
def make_dataset_json(self):
if self.project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
return self.make_dataset_for_classification_json()
elif self.project.is_type_of(Project.SEQUENCE_LABELING):
return self.make_dataset_for_sequence_labeling_json()
elif self.project.is_type_of(Project.Seq2seq):
return self.make_dataset_for_seq2seq_json()
def make_dataset_for_classification_json(self):
annotations = self.get_annotations()
labels = [a.label.text for a in annotations]
username = annotations[0].user.username
dataset = {'doc_id': self.id, 'text': self.text, 'labels': labels, 'username': username, 'metadata': json.loads(self.metadata)}
return dataset
def make_dataset_for_sequence_labeling_json(self):
annotations = self.get_annotations()
entities = [(a.start_offset, a.end_offset, a.label.text) for a in annotations]
username = annotations[0].user.username
dataset = {'doc_id': self.id, 'text': self.text, 'entities': entities, 'username': username, 'metadata': json.loads(self.metadata)}
return dataset
def make_dataset_for_seq2seq_json(self):
annotations = self.get_annotations()
sentences = [a.text for a in annotations]
username = annotations[0].user.username
dataset = {'doc_id': self.id, 'text': self.text, 'sentences': sentences, 'username': username, 'metadata': json.loads(self.metadata)}
return dataset
def __str__(self):
return self.text[:50]
class TextClassificationDocument(Document):
class Meta:
proxy = True
def get_annotations(self):
return self.doc_annotations.all()
class SequenceLabelingDocument(Document):
class Meta:
proxy = True
def get_annotations(self):
return self.seq_annotations.all()
class Seq2seqDocument(Document):
class Meta:
proxy = True
def get_annotations(self):
return self.seq2seq_annotations.all()
class Annotation(models.Model):
prob = models.FloatField(default=0.0)
manual = models.BooleanField(default=False)
@ -232,7 +170,7 @@ class Annotation(models.Model):
class DocumentAnnotation(Annotation):
document = models.ForeignKey(Document, related_name='doc_annotations', on_delete=models.CASCADE)
document = models.ForeignKey(TextClassificationDocument, related_name='doc_annotations', on_delete=models.CASCADE)
label = models.ForeignKey(Label, on_delete=models.CASCADE)
class Meta:
@ -240,7 +178,7 @@ class DocumentAnnotation(Annotation):
class SequenceAnnotation(Annotation):
document = models.ForeignKey(Document, related_name='seq_annotations', on_delete=models.CASCADE)
document = models.ForeignKey(SequenceLabelingDocument, related_name='seq_annotations', on_delete=models.CASCADE)
label = models.ForeignKey(Label, on_delete=models.CASCADE)
start_offset = models.IntegerField()
end_offset = models.IntegerField()
@ -254,7 +192,7 @@ class SequenceAnnotation(Annotation):
class Seq2seqAnnotation(Annotation):
document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE)
document = models.ForeignKey(Seq2seqDocument, related_name='seq2seq_annotations', on_delete=models.CASCADE)
text = models.TextField()
class Meta:

app/server/static/images/tiger-768574_640.jpg → app/server/static/images/cats/seq2seq.jpg

app/server/static/images/cat-3449999_640.jpg → app/server/static/images/cats/sequence_labeling.jpg

app/server/static/images/cat-1045782_640.jpg → app/server/static/images/cats/text_classification.jpg

119
app/server/tests/test_models.py

@ -3,19 +3,90 @@ from django.core.exceptions import ValidationError
from django.db.utils import IntegrityError
from mixer.backend.django import mixer
from ..models import Label, DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation
from ..serializers import ClassificationDocumentSerializer, DocumentAnnotationSerializer
from ..serializers import SequenceDocumentSerializer, SequenceAnnotationSerializer
from ..serializers import Seq2seqDocumentSerializer, Seq2seqAnnotationSerializer
class TestProject(TestCase):
class TestTextClassificationProject(TestCase):
def test_project_type(self):
project = mixer.blend('server.Project')
project.is_type_of(project.project_type)
@classmethod
def setUpTestData(cls):
cls.project = mixer.blend('server.TextClassificationProject')
def test_get_progress(self):
project = mixer.blend('server.Project')
res = project.get_progress(None)
self.assertEqual(res['total'], 0)
self.assertEqual(res['remaining'], 0)
def test_image(self):
image_url = self.project.image
self.assertTrue(image_url.endswith('.jpg'))
def test_get_template_name(self):
template = self.project.get_template_name()
self.assertEqual(template, 'annotation/document_classification.html')
def test_get_document_serializer(self):
serializer = self.project.get_document_serializer()
self.assertEqual(serializer, ClassificationDocumentSerializer)
def test_get_annotation_serializer(self):
serializer = self.project.get_annotation_serializer()
self.assertEqual(serializer, DocumentAnnotationSerializer)
def test_get_annotation_class(self):
klass = self.project.get_annotation_class()
self.assertEqual(klass, DocumentAnnotation)
class TestSequenceLabelingProject(TestCase):
@classmethod
def setUpTestData(cls):
cls.project = mixer.blend('server.SequenceLabelingProject')
def test_image(self):
image_url = self.project.image
self.assertTrue(image_url.endswith('.jpg'))
def test_get_template_name(self):
template = self.project.get_template_name()
self.assertEqual(template, 'annotation/sequence_labeling.html')
def test_get_document_serializer(self):
serializer = self.project.get_document_serializer()
self.assertEqual(serializer, SequenceDocumentSerializer)
def test_get_annotation_serializer(self):
serializer = self.project.get_annotation_serializer()
self.assertEqual(serializer, SequenceAnnotationSerializer)
def test_get_annotation_class(self):
klass = self.project.get_annotation_class()
self.assertEqual(klass, SequenceAnnotation)
class TestSeq2seqProject(TestCase):
@classmethod
def setUpTestData(cls):
cls.project = mixer.blend('server.Seq2seqProject')
def test_image(self):
image_url = self.project.image
self.assertTrue(image_url.endswith('.jpg'))
def test_get_template_name(self):
template = self.project.get_template_name()
self.assertEqual(template, 'annotation/seq2seq.html')
def test_get_document_serializer(self):
serializer = self.project.get_document_serializer()
self.assertEqual(serializer, Seq2seqDocumentSerializer)
def test_get_annotation_serializer(self):
serializer = self.project.get_annotation_serializer()
self.assertEqual(serializer, Seq2seqAnnotationSerializer)
def test_get_annotation_class(self):
klass = self.project.get_annotation_class()
self.assertEqual(klass, Seq2seqAnnotation)
class TestLabel(TestCase):
@ -37,6 +108,36 @@ class TestLabel(TestCase):
Label(project=label.project, text=label.text).save()
class TestTextClassificationDocument(TestCase):
@classmethod
def setUpTestData(cls):
cls.doc = mixer.blend('server.TextClassificationDocument')
def test_get_annotations(self):
self.assertEqual(self.doc.get_annotations().count(), 0)
class TestSequenceLabelingDocument(TestCase):
@classmethod
def setUpTestData(cls):
cls.doc = mixer.blend('server.SequenceLabelingDocument')
def test_get_annotations(self):
self.assertEqual(self.doc.get_annotations().count(), 0)
class TestSeq2seqDocument(TestCase):
@classmethod
def setUpTestData(cls):
cls.doc = mixer.blend('server.Seq2seqDocument')
def test_get_annotations(self):
self.assertEqual(self.doc.get_annotations().count(), 0)
class TestDocumentAnnotation(TestCase):
def test_uniqueness(self):

Loading…
Cancel
Save