You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

384 lines
12 KiB

5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
3 years ago
6 years ago
3 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import string
  2. from auto_labeling_pipeline.models import RequestModelFactory
  3. from django.conf import settings
  4. from django.contrib.auth.models import User
  5. from django.core.exceptions import ValidationError
  6. from django.db import models
  7. from django.db.models.signals import post_save, pre_delete
  8. from django.dispatch import receiver
  9. from django.urls import reverse
  10. from polymorphic.models import PolymorphicModel
  11. from .managers import AnnotationManager, Seq2seqAnnotationManager
  12. DOCUMENT_CLASSIFICATION = 'DocumentClassification'
  13. SEQUENCE_LABELING = 'SequenceLabeling'
  14. SEQ2SEQ = 'Seq2seq'
  15. SPEECH2TEXT = 'Speech2text'
  16. PROJECT_CHOICES = (
  17. (DOCUMENT_CLASSIFICATION, 'document classification'),
  18. (SEQUENCE_LABELING, 'sequence labeling'),
  19. (SEQ2SEQ, 'sequence to sequence'),
  20. (SPEECH2TEXT, 'speech to text'),
  21. )
  22. class Project(PolymorphicModel):
  23. name = models.CharField(max_length=100)
  24. description = models.TextField(default='')
  25. guideline = models.TextField(default='')
  26. created_at = models.DateTimeField(auto_now_add=True)
  27. updated_at = models.DateTimeField(auto_now=True)
  28. users = models.ManyToManyField(User, related_name='projects')
  29. project_type = models.CharField(max_length=30, choices=PROJECT_CHOICES)
  30. randomize_document_order = models.BooleanField(default=False)
  31. collaborative_annotation = models.BooleanField(default=False)
  32. single_class_classification = models.BooleanField(default=False)
  33. def get_absolute_url(self):
  34. return reverse('upload', args=[self.id])
  35. def get_bundle_name(self):
  36. raise NotImplementedError()
  37. def get_bundle_name_upload(self):
  38. raise NotImplementedError()
  39. def get_bundle_name_download(self):
  40. raise NotImplementedError()
  41. def get_annotation_serializer(self):
  42. raise NotImplementedError()
  43. def get_annotation_class(self):
  44. raise NotImplementedError()
  45. def get_storage(self, data):
  46. raise NotImplementedError()
  47. def __str__(self):
  48. return self.name
  49. class TextClassificationProject(Project):
  50. def get_bundle_name(self):
  51. return 'document_classification'
  52. def get_bundle_name_upload(self):
  53. return 'upload_text_classification'
  54. def get_bundle_name_download(self):
  55. return 'download_text_classification'
  56. def get_annotation_serializer(self):
  57. from .serializers import DocumentAnnotationSerializer
  58. return DocumentAnnotationSerializer
  59. def get_annotation_class(self):
  60. return DocumentAnnotation
  61. def get_storage(self, data):
  62. from .utils import ClassificationStorage
  63. return ClassificationStorage(data, self)
  64. class SequenceLabelingProject(Project):
  65. def get_bundle_name(self):
  66. return 'sequence_labeling'
  67. def get_bundle_name_upload(self):
  68. return 'upload_sequence_labeling'
  69. def get_bundle_name_download(self):
  70. return 'download_sequence_labeling'
  71. def get_annotation_serializer(self):
  72. from .serializers import SequenceAnnotationSerializer
  73. return SequenceAnnotationSerializer
  74. def get_annotation_class(self):
  75. return SequenceAnnotation
  76. def get_storage(self, data):
  77. from .utils import SequenceLabelingStorage
  78. return SequenceLabelingStorage(data, self)
  79. class Seq2seqProject(Project):
  80. def get_bundle_name(self):
  81. return 'seq2seq'
  82. def get_bundle_name_upload(self):
  83. return 'upload_seq2seq'
  84. def get_bundle_name_download(self):
  85. return 'download_seq2seq'
  86. def get_annotation_serializer(self):
  87. from .serializers import Seq2seqAnnotationSerializer
  88. return Seq2seqAnnotationSerializer
  89. def get_annotation_class(self):
  90. return Seq2seqAnnotation
  91. def get_storage(self, data):
  92. from .utils import Seq2seqStorage
  93. return Seq2seqStorage(data, self)
  94. class Speech2textProject(Project):
  95. def get_bundle_name(self):
  96. return 'speech2text'
  97. def get_bundle_name_upload(self):
  98. return 'upload_speech2text'
  99. def get_bundle_name_download(self):
  100. return 'download_speech2text'
  101. def get_annotation_serializer(self):
  102. from .serializers import Speech2textAnnotationSerializer
  103. return Speech2textAnnotationSerializer
  104. def get_annotation_class(self):
  105. return Speech2textAnnotation
  106. def get_storage(self, data):
  107. from .utils import Speech2textStorage
  108. return Speech2textStorage(data, self)
  109. class Label(models.Model):
  110. PREFIX_KEYS = (
  111. ('ctrl', 'ctrl'),
  112. ('shift', 'shift'),
  113. ('ctrl shift', 'ctrl shift')
  114. )
  115. SUFFIX_KEYS = tuple(
  116. (c, c) for c in string.digits + string.ascii_lowercase
  117. )
  118. text = models.CharField(max_length=100)
  119. prefix_key = models.CharField(max_length=10, blank=True, null=True, choices=PREFIX_KEYS)
  120. suffix_key = models.CharField(max_length=1, blank=True, null=True, choices=SUFFIX_KEYS)
  121. project = models.ForeignKey(Project, related_name='labels', on_delete=models.CASCADE)
  122. background_color = models.CharField(max_length=7, default='#209cee')
  123. text_color = models.CharField(max_length=7, default='#ffffff')
  124. created_at = models.DateTimeField(auto_now_add=True)
  125. updated_at = models.DateTimeField(auto_now=True)
  126. def __str__(self):
  127. return self.text
  128. def clean(self):
  129. # Don't allow shortcut key not to have a suffix key.
  130. if self.prefix_key and not self.suffix_key:
  131. raise ValidationError('Shortcut key may not have a suffix key.')
  132. # each shortcut (prefix key + suffix key) can only be assigned to one label
  133. if self.suffix_key or self.prefix_key:
  134. other_labels = self.project.labels.exclude(id=self.id)
  135. if other_labels.filter(suffix_key=self.suffix_key, prefix_key=self.prefix_key).exists():
  136. raise ValidationError('A label with this shortcut already exists in the project')
  137. super().clean()
  138. class Meta:
  139. unique_together = (
  140. ('project', 'text'),
  141. )
  142. class Document(models.Model):
  143. text = models.TextField()
  144. project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE)
  145. meta = models.TextField(default='{}')
  146. created_at = models.DateTimeField(auto_now_add=True)
  147. updated_at = models.DateTimeField(auto_now=True)
  148. annotations_approved_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True)
  149. def __str__(self):
  150. return self.text[:50]
  151. @property
  152. def comment_count(self):
  153. return Comment.objects.filter(document=self.id).count()
  154. class Comment(models.Model):
  155. text = models.TextField()
  156. document = models.ForeignKey(Document, related_name='comments', on_delete=models.CASCADE)
  157. user = models.ForeignKey(User, on_delete=models.CASCADE, null=True)
  158. created_at = models.DateTimeField(auto_now_add=True)
  159. updated_at = models.DateTimeField(auto_now=True)
  160. @property
  161. def username(self):
  162. return self.user.username
  163. @property
  164. def document_text(self):
  165. return self.document.text
  166. class Meta:
  167. ordering = ('-created_at', )
  168. class Annotation(models.Model):
  169. objects = AnnotationManager()
  170. prob = models.FloatField(default=0.0)
  171. manual = models.BooleanField(default=False)
  172. user = models.ForeignKey(User, on_delete=models.CASCADE)
  173. created_at = models.DateTimeField(auto_now_add=True)
  174. updated_at = models.DateTimeField(auto_now=True)
  175. class Meta:
  176. abstract = True
  177. class DocumentAnnotation(Annotation):
  178. document = models.ForeignKey(Document, related_name='doc_annotations', on_delete=models.CASCADE)
  179. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  180. class Meta:
  181. unique_together = ('document', 'user', 'label')
  182. class SequenceAnnotation(Annotation):
  183. document = models.ForeignKey(Document, related_name='seq_annotations', on_delete=models.CASCADE)
  184. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  185. start_offset = models.IntegerField()
  186. end_offset = models.IntegerField()
  187. def clean(self):
  188. if self.start_offset >= self.end_offset:
  189. raise ValidationError('start_offset is after end_offset')
  190. class Meta:
  191. unique_together = ('document', 'user', 'label', 'start_offset', 'end_offset')
  192. class Seq2seqAnnotation(Annotation):
  193. # Override AnnotationManager for custom functionality
  194. objects = Seq2seqAnnotationManager()
  195. document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE)
  196. text = models.CharField(max_length=500)
  197. class Meta:
  198. unique_together = ('document', 'user', 'text')
  199. class Speech2textAnnotation(Annotation):
  200. document = models.ForeignKey(Document, related_name='speech2text_annotations', on_delete=models.CASCADE)
  201. text = models.TextField()
  202. class Meta:
  203. unique_together = ('document', 'user')
  204. class Role(models.Model):
  205. name = models.CharField(max_length=100, unique=True)
  206. description = models.TextField(default='')
  207. created_at = models.DateTimeField(auto_now_add=True)
  208. updated_at = models.DateTimeField(auto_now=True)
  209. def __str__(self):
  210. return self.name
  211. class RoleMapping(models.Model):
  212. user = models.ForeignKey(User, related_name='role_mappings', on_delete=models.CASCADE)
  213. project = models.ForeignKey(Project, related_name='role_mappings', on_delete=models.CASCADE)
  214. role = models.ForeignKey(Role, on_delete=models.CASCADE)
  215. created_at = models.DateTimeField(auto_now_add=True)
  216. updated_at = models.DateTimeField(auto_now=True)
  217. def clean(self):
  218. other_rolemappings = self.project.role_mappings.exclude(id=self.id)
  219. if other_rolemappings.filter(user=self.user, project=self.project).exists():
  220. raise ValidationError('This user is already assigned to a role in this project.')
  221. class Meta:
  222. unique_together = ("user", "project", "role")
  223. @receiver(post_save, sender=RoleMapping)
  224. def add_linked_project(sender, instance, created, **kwargs):
  225. if not created:
  226. return
  227. userInstance = instance.user
  228. projectInstance = instance.project
  229. if userInstance and projectInstance:
  230. user = User.objects.get(pk=userInstance.pk)
  231. project = Project.objects.get(pk=projectInstance.pk)
  232. user.projects.add(project)
  233. user.save()
  234. @receiver(post_save)
  235. def add_superusers_to_project(sender, instance, created, **kwargs):
  236. if not created:
  237. return
  238. if sender not in Project.__subclasses__():
  239. return
  240. superusers = User.objects.filter(is_superuser=True)
  241. admin_role = Role.objects.filter(name=settings.ROLE_PROJECT_ADMIN).first()
  242. if superusers and admin_role:
  243. RoleMapping.objects.bulk_create(
  244. [RoleMapping(role_id=admin_role.id, user_id=superuser.id, project_id=instance.id)
  245. for superuser in superusers]
  246. )
  247. @receiver(post_save, sender=User)
  248. def add_new_superuser_to_projects(sender, instance, created, **kwargs):
  249. if created and instance.is_superuser:
  250. admin_role = Role.objects.filter(name=settings.ROLE_PROJECT_ADMIN).first()
  251. projects = Project.objects.all()
  252. if admin_role and projects:
  253. RoleMapping.objects.bulk_create(
  254. [RoleMapping(role_id=admin_role.id, user_id=instance.id, project_id=project.id)
  255. for project in projects]
  256. )
  257. @receiver(pre_delete, sender=RoleMapping)
  258. def delete_linked_project(sender, instance, using, **kwargs):
  259. userInstance = instance.user
  260. projectInstance = instance.project
  261. if userInstance and projectInstance:
  262. user = User.objects.get(pk=userInstance.pk)
  263. project = Project.objects.get(pk=projectInstance.pk)
  264. user.projects.remove(project)
  265. user.save()
  266. class AutoLabelingConfig(models.Model):
  267. model_name = models.CharField(max_length=100)
  268. model_attrs = models.JSONField(default=dict)
  269. template = models.TextField(default='')
  270. label_mapping = models.JSONField(default=dict)
  271. project = models.ForeignKey(Project, related_name='auto_labeling_config', on_delete=models.CASCADE)
  272. default = models.BooleanField(default=False)
  273. created_at = models.DateTimeField(auto_now_add=True)
  274. updated_at = models.DateTimeField(auto_now=True)
  275. def __str__(self):
  276. return self.model_name
  277. def clean_fields(self, exclude=None):
  278. super().clean_fields(exclude=exclude)
  279. try:
  280. RequestModelFactory.find(self.model_name)
  281. except NameError:
  282. raise ValidationError(f'The specified model name {self.model_name} does not exist.')
  283. except Exception:
  284. raise ValidationError('The attributes does not match the model.')