You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

402 lines
14 KiB

5 years ago
3 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
3 years ago
6 years ago
3 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import string
  2. from auto_labeling_pipeline.models import RequestModelFactory
  3. from django.conf import settings
  4. from django.contrib.auth.models import User
  5. from django.core.exceptions import ValidationError
  6. from django.db import models
  7. from django.db.models.signals import m2m_changed, post_save, pre_delete
  8. from django.dispatch import receiver
  9. from django.urls import reverse
  10. from polymorphic.models import PolymorphicModel
  11. from .managers import (AnnotationManager, RoleMappingManager,
  12. Seq2seqAnnotationManager)
  13. DOCUMENT_CLASSIFICATION = 'DocumentClassification'
  14. SEQUENCE_LABELING = 'SequenceLabeling'
  15. SEQ2SEQ = 'Seq2seq'
  16. SPEECH2TEXT = 'Speech2text'
  17. PROJECT_CHOICES = (
  18. (DOCUMENT_CLASSIFICATION, 'document classification'),
  19. (SEQUENCE_LABELING, 'sequence labeling'),
  20. (SEQ2SEQ, 'sequence to sequence'),
  21. (SPEECH2TEXT, 'speech to text'),
  22. )
  23. class Project(PolymorphicModel):
  24. name = models.CharField(max_length=100)
  25. description = models.TextField(default='')
  26. guideline = models.TextField(default='', blank=True)
  27. created_at = models.DateTimeField(auto_now_add=True)
  28. updated_at = models.DateTimeField(auto_now=True)
  29. users = models.ManyToManyField(User, related_name='projects')
  30. project_type = models.CharField(max_length=30, choices=PROJECT_CHOICES)
  31. randomize_document_order = models.BooleanField(default=False)
  32. collaborative_annotation = models.BooleanField(default=False)
  33. single_class_classification = models.BooleanField(default=False)
  34. def get_absolute_url(self):
  35. return reverse('upload', args=[self.id])
  36. def get_bundle_name(self):
  37. raise NotImplementedError()
  38. def get_bundle_name_upload(self):
  39. raise NotImplementedError()
  40. def get_bundle_name_download(self):
  41. raise NotImplementedError()
  42. def get_annotation_serializer(self):
  43. raise NotImplementedError()
  44. def get_annotation_class(self):
  45. raise NotImplementedError()
  46. def get_storage(self, data):
  47. raise NotImplementedError()
  48. def __str__(self):
  49. return self.name
  50. class TextClassificationProject(Project):
  51. def get_bundle_name(self):
  52. return 'document_classification'
  53. def get_bundle_name_upload(self):
  54. return 'upload_text_classification'
  55. def get_bundle_name_download(self):
  56. return 'download_text_classification'
  57. def get_annotation_serializer(self):
  58. from .serializers import DocumentAnnotationSerializer
  59. return DocumentAnnotationSerializer
  60. def get_annotation_class(self):
  61. return DocumentAnnotation
  62. def get_storage(self, data):
  63. from .utils import ClassificationStorage
  64. return ClassificationStorage(data, self)
  65. class SequenceLabelingProject(Project):
  66. def get_bundle_name(self):
  67. return 'sequence_labeling'
  68. def get_bundle_name_upload(self):
  69. return 'upload_sequence_labeling'
  70. def get_bundle_name_download(self):
  71. return 'download_sequence_labeling'
  72. def get_annotation_serializer(self):
  73. from .serializers import SequenceAnnotationSerializer
  74. return SequenceAnnotationSerializer
  75. def get_annotation_class(self):
  76. return SequenceAnnotation
  77. def get_storage(self, data):
  78. from .utils import SequenceLabelingStorage
  79. return SequenceLabelingStorage(data, self)
  80. class Seq2seqProject(Project):
  81. def get_bundle_name(self):
  82. return 'seq2seq'
  83. def get_bundle_name_upload(self):
  84. return 'upload_seq2seq'
  85. def get_bundle_name_download(self):
  86. return 'download_seq2seq'
  87. def get_annotation_serializer(self):
  88. from .serializers import Seq2seqAnnotationSerializer
  89. return Seq2seqAnnotationSerializer
  90. def get_annotation_class(self):
  91. return Seq2seqAnnotation
  92. def get_storage(self, data):
  93. from .utils import Seq2seqStorage
  94. return Seq2seqStorage(data, self)
  95. class Speech2textProject(Project):
  96. def get_bundle_name(self):
  97. return 'speech2text'
  98. def get_bundle_name_upload(self):
  99. return 'upload_speech2text'
  100. def get_bundle_name_download(self):
  101. return 'download_speech2text'
  102. def get_annotation_serializer(self):
  103. from .serializers import Speech2textAnnotationSerializer
  104. return Speech2textAnnotationSerializer
  105. def get_annotation_class(self):
  106. return Speech2textAnnotation
  107. def get_storage(self, data):
  108. from .utils import Speech2textStorage
  109. return Speech2textStorage(data, self)
  110. class Label(models.Model):
  111. PREFIX_KEYS = (
  112. ('ctrl', 'ctrl'),
  113. ('shift', 'shift'),
  114. ('ctrl shift', 'ctrl shift')
  115. )
  116. SUFFIX_KEYS = tuple(
  117. (c, c) for c in string.digits + string.ascii_lowercase
  118. )
  119. text = models.CharField(max_length=100)
  120. prefix_key = models.CharField(max_length=10, blank=True, null=True, choices=PREFIX_KEYS)
  121. suffix_key = models.CharField(max_length=1, blank=True, null=True, choices=SUFFIX_KEYS)
  122. project = models.ForeignKey(Project, related_name='labels', on_delete=models.CASCADE)
  123. background_color = models.CharField(max_length=7, default='#209cee')
  124. text_color = models.CharField(max_length=7, default='#ffffff')
  125. created_at = models.DateTimeField(auto_now_add=True)
  126. updated_at = models.DateTimeField(auto_now=True)
  127. def __str__(self):
  128. return self.text
  129. def clean(self):
  130. # Don't allow shortcut key not to have a suffix key.
  131. if self.prefix_key and not self.suffix_key:
  132. raise ValidationError('Shortcut key may not have a suffix key.')
  133. # each shortcut (prefix key + suffix key) can only be assigned to one label
  134. if self.suffix_key or self.prefix_key:
  135. other_labels = self.project.labels.exclude(id=self.id)
  136. if other_labels.filter(suffix_key=self.suffix_key, prefix_key=self.prefix_key).exists():
  137. raise ValidationError('A label with this shortcut already exists in the project')
  138. super().clean()
  139. class Meta:
  140. unique_together = (
  141. ('project', 'text'),
  142. )
  143. class Document(models.Model):
  144. text = models.TextField()
  145. project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE)
  146. meta = models.JSONField(default=dict)
  147. filename = models.FilePathField(default='')
  148. created_at = models.DateTimeField(auto_now_add=True)
  149. updated_at = models.DateTimeField(auto_now=True)
  150. annotations_approved_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True, blank=True)
  151. def __str__(self):
  152. return self.text[:50]
  153. @property
  154. def comment_count(self):
  155. return Comment.objects.filter(document=self.id).count()
  156. class Comment(models.Model):
  157. text = models.TextField()
  158. document = models.ForeignKey(Document, related_name='comments', on_delete=models.CASCADE)
  159. user = models.ForeignKey(User, on_delete=models.CASCADE, null=True)
  160. created_at = models.DateTimeField(auto_now_add=True)
  161. updated_at = models.DateTimeField(auto_now=True)
  162. @property
  163. def username(self):
  164. return self.user.username
  165. @property
  166. def document_text(self):
  167. return self.document.text
  168. class Meta:
  169. ordering = ('-created_at', )
  170. class Annotation(models.Model):
  171. objects = AnnotationManager()
  172. prob = models.FloatField(default=0.0)
  173. manual = models.BooleanField(default=False)
  174. user = models.ForeignKey(User, on_delete=models.CASCADE)
  175. created_at = models.DateTimeField(auto_now_add=True)
  176. updated_at = models.DateTimeField(auto_now=True)
  177. class Meta:
  178. abstract = True
  179. class DocumentAnnotation(Annotation):
  180. document = models.ForeignKey(Document, related_name='doc_annotations', on_delete=models.CASCADE)
  181. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  182. class Meta:
  183. unique_together = ('document', 'user', 'label')
  184. class SequenceAnnotation(Annotation):
  185. document = models.ForeignKey(Document, related_name='seq_annotations', on_delete=models.CASCADE)
  186. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  187. start_offset = models.IntegerField()
  188. end_offset = models.IntegerField()
  189. def clean(self):
  190. if self.start_offset >= self.end_offset:
  191. raise ValidationError('start_offset is after end_offset')
  192. class Meta:
  193. unique_together = ('document', 'user', 'label', 'start_offset', 'end_offset')
  194. class Seq2seqAnnotation(Annotation):
  195. # Override AnnotationManager for custom functionality
  196. objects = Seq2seqAnnotationManager()
  197. document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE)
  198. text = models.CharField(max_length=500)
  199. class Meta:
  200. unique_together = ('document', 'user', 'text')
  201. class Speech2textAnnotation(Annotation):
  202. document = models.ForeignKey(Document, related_name='speech2text_annotations', on_delete=models.CASCADE)
  203. text = models.TextField()
  204. class Meta:
  205. unique_together = ('document', 'user')
  206. class Role(models.Model):
  207. name = models.CharField(max_length=100, unique=True)
  208. description = models.TextField(default='')
  209. created_at = models.DateTimeField(auto_now_add=True)
  210. updated_at = models.DateTimeField(auto_now=True)
  211. def __str__(self):
  212. return self.name
  213. class RoleMapping(models.Model):
  214. user = models.ForeignKey(User, related_name='role_mappings', on_delete=models.CASCADE)
  215. project = models.ForeignKey(Project, related_name='role_mappings', on_delete=models.CASCADE)
  216. role = models.ForeignKey(Role, on_delete=models.CASCADE)
  217. created_at = models.DateTimeField(auto_now_add=True)
  218. updated_at = models.DateTimeField(auto_now=True)
  219. objects = RoleMappingManager()
  220. def clean(self):
  221. other_rolemappings = self.project.role_mappings.exclude(id=self.id)
  222. if other_rolemappings.filter(user=self.user, project=self.project).exists():
  223. raise ValidationError('This user is already assigned to a role in this project.')
  224. class Meta:
  225. unique_together = ("user", "project")
  226. @receiver(post_save, sender=RoleMapping)
  227. def add_linked_project(sender, instance, created, **kwargs):
  228. if not created:
  229. return
  230. userInstance = instance.user
  231. projectInstance = instance.project
  232. if userInstance and projectInstance:
  233. user = User.objects.get(pk=userInstance.pk)
  234. project = Project.objects.get(pk=projectInstance.pk)
  235. user.projects.add(project)
  236. user.save()
  237. # @receiver(post_save)
  238. # def add_superusers_to_project(sender, instance, created, **kwargs):
  239. # if not created:
  240. # return
  241. # if sender not in Project.__subclasses__():
  242. # return
  243. # superusers = User.objects.filter(is_superuser=True)
  244. # admin_role = Role.objects.filter(name=settings.ROLE_PROJECT_ADMIN).first()
  245. # if superusers and admin_role:
  246. # RoleMapping.objects.bulk_create(
  247. # [RoleMapping(role_id=admin_role.id, user_id=superuser.id, project_id=instance.id)
  248. # for superuser in superusers]
  249. # )
  250. #
  251. #
  252. # @receiver(post_save, sender=User)
  253. # def add_new_superuser_to_projects(sender, instance, created, **kwargs):
  254. # if created and instance.is_superuser:
  255. # admin_role = Role.objects.filter(name=settings.ROLE_PROJECT_ADMIN).first()
  256. # projects = Project.objects.all()
  257. # if admin_role and projects:
  258. # RoleMapping.objects.bulk_create(
  259. # [RoleMapping(role_id=admin_role.id, user_id=instance.id, project_id=project.id)
  260. # for project in projects]
  261. # )
  262. @receiver(m2m_changed, sender=Project.users.through)
  263. def remove_mapping_on_remove_user_from_project(sender, instance, action, reverse, **kwargs):
  264. # if reverse is True, pk_set is project_ids and instance is user.
  265. # else, pk_set is user_ids and instance is project.
  266. user_ids = kwargs['pk_set']
  267. if action.startswith('post_remove') and not reverse:
  268. RoleMapping.objects.filter(user__in=user_ids, project=instance).delete()
  269. elif action.startswith('post_add') and not reverse:
  270. admin_role = Role.objects.get(name=settings.ROLE_PROJECT_ADMIN)
  271. RoleMapping.objects.bulk_create(
  272. [RoleMapping(role=admin_role, project=instance, user_id=user)
  273. for user in user_ids
  274. if not RoleMapping.objects.filter(project=instance, user_id=user).exists()]
  275. )
  276. @receiver(pre_delete, sender=RoleMapping)
  277. def delete_linked_project(sender, instance, using, **kwargs):
  278. userInstance = instance.user
  279. projectInstance = instance.project
  280. if userInstance and projectInstance:
  281. user = User.objects.get(pk=userInstance.pk)
  282. project = Project.objects.get(pk=projectInstance.pk)
  283. user.projects.remove(project)
  284. user.save()
  285. class AutoLabelingConfig(models.Model):
  286. model_name = models.CharField(max_length=100)
  287. model_attrs = models.JSONField(default=dict)
  288. template = models.TextField(default='')
  289. label_mapping = models.JSONField(default=dict)
  290. project = models.ForeignKey(Project, related_name='auto_labeling_config', on_delete=models.CASCADE)
  291. default = models.BooleanField(default=False)
  292. created_at = models.DateTimeField(auto_now_add=True)
  293. updated_at = models.DateTimeField(auto_now=True)
  294. def __str__(self):
  295. return self.model_name
  296. def clean_fields(self, exclude=None):
  297. super().clean_fields(exclude=exclude)
  298. try:
  299. RequestModelFactory.find(self.model_name)
  300. except NameError:
  301. raise ValidationError(f'The specified model name {self.model_name} does not exist.')
  302. except Exception:
  303. raise ValidationError('The attributes does not match the model.')