You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

261 lines
9.9 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import json
  2. import string
  3. from django.core.exceptions import ValidationError
  4. from django.db import models
  5. from django.urls import reverse
  6. from django.contrib.auth.models import User
  7. from django.contrib.staticfiles.storage import staticfiles_storage
  8. class Project(models.Model):
  9. DOCUMENT_CLASSIFICATION = 'DocumentClassification'
  10. SEQUENCE_LABELING = 'SequenceLabeling'
  11. Seq2seq = 'Seq2seq'
  12. PROJECT_CHOICES = (
  13. (DOCUMENT_CLASSIFICATION, 'document classification'),
  14. (SEQUENCE_LABELING, 'sequence labeling'),
  15. (Seq2seq, 'sequence to sequence'),
  16. )
  17. name = models.CharField(max_length=100)
  18. description = models.CharField(max_length=500)
  19. guideline = models.TextField()
  20. created_at = models.DateTimeField(auto_now_add=True)
  21. updated_at = models.DateTimeField(auto_now=True)
  22. users = models.ManyToManyField(User, related_name='projects')
  23. project_type = models.CharField(max_length=30, choices=PROJECT_CHOICES)
  24. def get_absolute_url(self):
  25. return reverse('upload', args=[self.id])
  26. def is_type_of(self, project_type):
  27. return project_type == self.project_type
  28. def get_progress(self, user):
  29. docs = self.get_documents(is_null=True, user=user)
  30. total = self.documents.count()
  31. remaining = docs.count()
  32. return {'total': total, 'remaining': remaining}
  33. @property
  34. def image(self):
  35. if self.is_type_of(self.DOCUMENT_CLASSIFICATION):
  36. url = staticfiles_storage.url('images/cat-1045782_640.jpg')
  37. elif self.is_type_of(self.SEQUENCE_LABELING):
  38. url = staticfiles_storage.url('images/cat-3449999_640.jpg')
  39. elif self.is_type_of(self.Seq2seq):
  40. url = staticfiles_storage.url('images/tiger-768574_640.jpg')
  41. return url
  42. def get_template_name(self):
  43. if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  44. template_name = 'annotation/document_classification.html'
  45. elif self.is_type_of(Project.SEQUENCE_LABELING):
  46. template_name = 'annotation/sequence_labeling.html'
  47. elif self.is_type_of(Project.Seq2seq):
  48. template_name = 'annotation/seq2seq.html'
  49. else:
  50. raise ValueError('Template does not exist')
  51. return template_name
  52. def get_documents(self, is_null=True, user=None):
  53. docs = self.documents.all()
  54. if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  55. if user:
  56. docs = docs.exclude(doc_annotations__user=user)
  57. else:
  58. docs = docs.filter(doc_annotations__isnull=is_null)
  59. elif self.is_type_of(Project.SEQUENCE_LABELING):
  60. if user:
  61. docs = docs.exclude(seq_annotations__user=user)
  62. else:
  63. docs = docs.filter(seq_annotations__isnull=is_null)
  64. elif self.is_type_of(Project.Seq2seq):
  65. if user:
  66. docs = docs.exclude(seq2seq_annotations__user=user)
  67. else:
  68. docs = docs.filter(seq2seq_annotations__isnull=is_null)
  69. else:
  70. raise ValueError('Invalid project_type')
  71. return docs
  72. def get_document_serializer(self):
  73. from .serializers import ClassificationDocumentSerializer
  74. from .serializers import SequenceDocumentSerializer
  75. from .serializers import Seq2seqDocumentSerializer
  76. if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  77. return ClassificationDocumentSerializer
  78. elif self.is_type_of(Project.SEQUENCE_LABELING):
  79. return SequenceDocumentSerializer
  80. elif self.is_type_of(Project.Seq2seq):
  81. return Seq2seqDocumentSerializer
  82. else:
  83. raise ValueError('Invalid project_type')
  84. def get_annotation_serializer(self):
  85. from .serializers import DocumentAnnotationSerializer
  86. from .serializers import SequenceAnnotationSerializer
  87. from .serializers import Seq2seqAnnotationSerializer
  88. if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  89. return DocumentAnnotationSerializer
  90. elif self.is_type_of(Project.SEQUENCE_LABELING):
  91. return SequenceAnnotationSerializer
  92. elif self.is_type_of(Project.Seq2seq):
  93. return Seq2seqAnnotationSerializer
  94. def get_annotation_class(self):
  95. if self.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  96. return DocumentAnnotation
  97. elif self.is_type_of(Project.SEQUENCE_LABELING):
  98. return SequenceAnnotation
  99. elif self.is_type_of(Project.Seq2seq):
  100. return Seq2seqAnnotation
  101. def __str__(self):
  102. return self.name
  103. class Label(models.Model):
  104. KEY_CHOICES = ((u, c) for u, c in zip(string.ascii_lowercase, string.ascii_lowercase))
  105. COLOR_CHOICES = ()
  106. text = models.CharField(max_length=100)
  107. shortcut = models.CharField(max_length=10, choices=KEY_CHOICES)
  108. project = models.ForeignKey(Project, related_name='labels', on_delete=models.CASCADE)
  109. background_color = models.CharField(max_length=7, default='#209cee')
  110. text_color = models.CharField(max_length=7, default='#ffffff')
  111. def __str__(self):
  112. return self.text
  113. class Meta:
  114. unique_together = (
  115. ('project', 'text'),
  116. ('project', 'shortcut')
  117. )
  118. class Document(models.Model):
  119. text = models.TextField()
  120. project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE)
  121. metadata = models.TextField(default='{}')
  122. def get_annotations(self):
  123. if self.project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  124. return self.doc_annotations.all()
  125. elif self.project.is_type_of(Project.SEQUENCE_LABELING):
  126. return self.seq_annotations.all()
  127. elif self.project.is_type_of(Project.Seq2seq):
  128. return self.seq2seq_annotations.all()
  129. def to_csv(self):
  130. return self.make_dataset()
  131. def make_dataset(self):
  132. if self.project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  133. return self.make_dataset_for_classification()
  134. elif self.project.is_type_of(Project.SEQUENCE_LABELING):
  135. return self.make_dataset_for_sequence_labeling()
  136. elif self.project.is_type_of(Project.Seq2seq):
  137. return self.make_dataset_for_seq2seq()
  138. def make_dataset_for_classification(self):
  139. annotations = self.get_annotations()
  140. dataset = [[self.id, self.text, a.label.text, a.user.username, self.metadata]
  141. for a in annotations]
  142. return dataset
  143. def make_dataset_for_sequence_labeling(self):
  144. annotations = self.get_annotations()
  145. dataset = [[self.id, ch, 'O', self.metadata] for ch in self.text]
  146. for a in annotations:
  147. for i in range(a.start_offset, a.end_offset):
  148. if i == a.start_offset:
  149. dataset[i][2] = 'B-{}'.format(a.label.text)
  150. else:
  151. dataset[i][2] = 'I-{}'.format(a.label.text)
  152. return dataset
  153. def make_dataset_for_seq2seq(self):
  154. annotations = self.get_annotations()
  155. dataset = [[self.id, self.text, a.text, a.user.username, self.metadata]
  156. for a in annotations]
  157. return dataset
  158. def to_json(self):
  159. return self.make_dataset_json()
  160. def make_dataset_json(self):
  161. if self.project.is_type_of(Project.DOCUMENT_CLASSIFICATION):
  162. return self.make_dataset_for_classification_json()
  163. elif self.project.is_type_of(Project.SEQUENCE_LABELING):
  164. return self.make_dataset_for_sequence_labeling_json()
  165. elif self.project.is_type_of(Project.Seq2seq):
  166. return self.make_dataset_for_seq2seq_json()
  167. def make_dataset_for_classification_json(self):
  168. annotations = self.get_annotations()
  169. labels = [a.label.text for a in annotations]
  170. username = annotations[0].user.username
  171. dataset = {'doc_id': self.id, 'text': self.text, 'labels': labels, 'username': username, 'metadata': json.loads(self.metadata)}
  172. return dataset
  173. def make_dataset_for_sequence_labeling_json(self):
  174. annotations = self.get_annotations()
  175. entities = [(a.start_offset, a.end_offset, a.label.text) for a in annotations]
  176. username = annotations[0].user.username
  177. dataset = {'doc_id': self.id, 'text': self.text, 'entities': entities, 'username': username, 'metadata': json.loads(self.metadata)}
  178. return dataset
  179. def make_dataset_for_seq2seq_json(self):
  180. annotations = self.get_annotations()
  181. sentences = [a.text for a in annotations]
  182. username = annotations[0].user.username
  183. dataset = {'doc_id': self.id, 'text': self.text, 'sentences': sentences, 'username': username, 'metadata': json.loads(self.metadata)}
  184. return dataset
  185. def __str__(self):
  186. return self.text[:50]
  187. class Annotation(models.Model):
  188. prob = models.FloatField(default=0.0)
  189. manual = models.BooleanField(default=False)
  190. user = models.ForeignKey(User, on_delete=models.CASCADE)
  191. class Meta:
  192. abstract = True
  193. class DocumentAnnotation(Annotation):
  194. document = models.ForeignKey(Document, related_name='doc_annotations', on_delete=models.CASCADE)
  195. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  196. class Meta:
  197. unique_together = ('document', 'user', 'label')
  198. class SequenceAnnotation(Annotation):
  199. document = models.ForeignKey(Document, related_name='seq_annotations', on_delete=models.CASCADE)
  200. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  201. start_offset = models.IntegerField()
  202. end_offset = models.IntegerField()
  203. def clean(self):
  204. if self.start_offset >= self.end_offset:
  205. raise ValidationError('start_offset is after end_offset')
  206. class Meta:
  207. unique_together = ('document', 'user', 'label', 'start_offset', 'end_offset')
  208. class Seq2seqAnnotation(Annotation):
  209. document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE)
  210. text = models.TextField()
  211. class Meta:
  212. unique_together = ('document', 'user', 'text')