You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

401 lines
13 KiB

5 years ago
3 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
3 years ago
6 years ago
3 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import string
  2. from auto_labeling_pipeline.models import RequestModelFactory
  3. from django.conf import settings
  4. from django.contrib.auth.models import User
  5. from django.core.exceptions import ValidationError
  6. from django.db import models
  7. from django.db.models.signals import m2m_changed, post_save, pre_delete
  8. from django.dispatch import receiver
  9. from django.urls import reverse
  10. from polymorphic.models import PolymorphicModel
  11. from .managers import (AnnotationManager, RoleMappingManager,
  12. Seq2seqAnnotationManager)
  13. DOCUMENT_CLASSIFICATION = 'DocumentClassification'
  14. SEQUENCE_LABELING = 'SequenceLabeling'
  15. SEQ2SEQ = 'Seq2seq'
  16. SPEECH2TEXT = 'Speech2text'
  17. PROJECT_CHOICES = (
  18. (DOCUMENT_CLASSIFICATION, 'document classification'),
  19. (SEQUENCE_LABELING, 'sequence labeling'),
  20. (SEQ2SEQ, 'sequence to sequence'),
  21. (SPEECH2TEXT, 'speech to text'),
  22. )
  23. class Project(PolymorphicModel):
  24. name = models.CharField(max_length=100)
  25. description = models.TextField(default='')
  26. guideline = models.TextField(default='', blank=True)
  27. created_at = models.DateTimeField(auto_now_add=True)
  28. updated_at = models.DateTimeField(auto_now=True)
  29. users = models.ManyToManyField(User, related_name='projects')
  30. project_type = models.CharField(max_length=30, choices=PROJECT_CHOICES)
  31. randomize_document_order = models.BooleanField(default=False)
  32. collaborative_annotation = models.BooleanField(default=False)
  33. single_class_classification = models.BooleanField(default=False)
  34. def get_absolute_url(self):
  35. return reverse('upload', args=[self.id])
  36. def get_bundle_name(self):
  37. raise NotImplementedError()
  38. def get_bundle_name_upload(self):
  39. raise NotImplementedError()
  40. def get_bundle_name_download(self):
  41. raise NotImplementedError()
  42. def get_annotation_serializer(self):
  43. raise NotImplementedError()
  44. def get_annotation_class(self):
  45. raise NotImplementedError()
  46. def get_storage(self, data):
  47. raise NotImplementedError()
  48. def __str__(self):
  49. return self.name
  50. class TextClassificationProject(Project):
  51. def get_bundle_name(self):
  52. return 'document_classification'
  53. def get_bundle_name_upload(self):
  54. return 'upload_text_classification'
  55. def get_bundle_name_download(self):
  56. return 'download_text_classification'
  57. def get_annotation_serializer(self):
  58. from .serializers import DocumentAnnotationSerializer
  59. return DocumentAnnotationSerializer
  60. def get_annotation_class(self):
  61. return DocumentAnnotation
  62. def get_storage(self, data):
  63. from .utils import ClassificationStorage
  64. return ClassificationStorage(data, self)
  65. class SequenceLabelingProject(Project):
  66. def get_bundle_name(self):
  67. return 'sequence_labeling'
  68. def get_bundle_name_upload(self):
  69. return 'upload_sequence_labeling'
  70. def get_bundle_name_download(self):
  71. return 'download_sequence_labeling'
  72. def get_annotation_serializer(self):
  73. from .serializers import SequenceAnnotationSerializer
  74. return SequenceAnnotationSerializer
  75. def get_annotation_class(self):
  76. return SequenceAnnotation
  77. def get_storage(self, data):
  78. from .utils import SequenceLabelingStorage
  79. return SequenceLabelingStorage(data, self)
  80. class Seq2seqProject(Project):
  81. def get_bundle_name(self):
  82. return 'seq2seq'
  83. def get_bundle_name_upload(self):
  84. return 'upload_seq2seq'
  85. def get_bundle_name_download(self):
  86. return 'download_seq2seq'
  87. def get_annotation_serializer(self):
  88. from .serializers import Seq2seqAnnotationSerializer
  89. return Seq2seqAnnotationSerializer
  90. def get_annotation_class(self):
  91. return Seq2seqAnnotation
  92. def get_storage(self, data):
  93. from .utils import Seq2seqStorage
  94. return Seq2seqStorage(data, self)
  95. class Speech2textProject(Project):
  96. def get_bundle_name(self):
  97. return 'speech2text'
  98. def get_bundle_name_upload(self):
  99. return 'upload_speech2text'
  100. def get_bundle_name_download(self):
  101. return 'download_speech2text'
  102. def get_annotation_serializer(self):
  103. from .serializers import Speech2textAnnotationSerializer
  104. return Speech2textAnnotationSerializer
  105. def get_annotation_class(self):
  106. return Speech2textAnnotation
  107. def get_storage(self, data):
  108. from .utils import Speech2textStorage
  109. return Speech2textStorage(data, self)
  110. class Label(models.Model):
  111. PREFIX_KEYS = (
  112. ('ctrl', 'ctrl'),
  113. ('shift', 'shift'),
  114. ('ctrl shift', 'ctrl shift')
  115. )
  116. SUFFIX_KEYS = tuple(
  117. (c, c) for c in string.digits + string.ascii_lowercase
  118. )
  119. text = models.CharField(max_length=100)
  120. prefix_key = models.CharField(max_length=10, blank=True, null=True, choices=PREFIX_KEYS)
  121. suffix_key = models.CharField(max_length=1, blank=True, null=True, choices=SUFFIX_KEYS)
  122. project = models.ForeignKey(Project, related_name='labels', on_delete=models.CASCADE)
  123. background_color = models.CharField(max_length=7, default='#209cee')
  124. text_color = models.CharField(max_length=7, default='#ffffff')
  125. created_at = models.DateTimeField(auto_now_add=True)
  126. updated_at = models.DateTimeField(auto_now=True)
  127. def __str__(self):
  128. return self.text
  129. def clean(self):
  130. # Don't allow shortcut key not to have a suffix key.
  131. if self.prefix_key and not self.suffix_key:
  132. raise ValidationError('Shortcut key may not have a suffix key.')
  133. # each shortcut (prefix key + suffix key) can only be assigned to one label
  134. if self.suffix_key or self.prefix_key:
  135. other_labels = self.project.labels.exclude(id=self.id)
  136. if other_labels.filter(suffix_key=self.suffix_key, prefix_key=self.prefix_key).exists():
  137. raise ValidationError('A label with this shortcut already exists in the project')
  138. super().clean()
  139. class Meta:
  140. unique_together = (
  141. ('project', 'text'),
  142. )
  143. class Document(models.Model):
  144. text = models.TextField()
  145. project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE)
  146. meta = models.TextField(default='{}')
  147. created_at = models.DateTimeField(auto_now_add=True)
  148. updated_at = models.DateTimeField(auto_now=True)
  149. annotations_approved_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True)
  150. def __str__(self):
  151. return self.text[:50]
  152. @property
  153. def comment_count(self):
  154. return Comment.objects.filter(document=self.id).count()
  155. class Comment(models.Model):
  156. text = models.TextField()
  157. document = models.ForeignKey(Document, related_name='comments', on_delete=models.CASCADE)
  158. user = models.ForeignKey(User, on_delete=models.CASCADE, null=True)
  159. created_at = models.DateTimeField(auto_now_add=True)
  160. updated_at = models.DateTimeField(auto_now=True)
  161. @property
  162. def username(self):
  163. return self.user.username
  164. @property
  165. def document_text(self):
  166. return self.document.text
  167. class Meta:
  168. ordering = ('-created_at', )
  169. class Annotation(models.Model):
  170. objects = AnnotationManager()
  171. prob = models.FloatField(default=0.0)
  172. manual = models.BooleanField(default=False)
  173. user = models.ForeignKey(User, on_delete=models.CASCADE)
  174. created_at = models.DateTimeField(auto_now_add=True)
  175. updated_at = models.DateTimeField(auto_now=True)
  176. class Meta:
  177. abstract = True
  178. class DocumentAnnotation(Annotation):
  179. document = models.ForeignKey(Document, related_name='doc_annotations', on_delete=models.CASCADE)
  180. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  181. class Meta:
  182. unique_together = ('document', 'user', 'label')
  183. class SequenceAnnotation(Annotation):
  184. document = models.ForeignKey(Document, related_name='seq_annotations', on_delete=models.CASCADE)
  185. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  186. start_offset = models.IntegerField()
  187. end_offset = models.IntegerField()
  188. def clean(self):
  189. if self.start_offset >= self.end_offset:
  190. raise ValidationError('start_offset is after end_offset')
  191. class Meta:
  192. unique_together = ('document', 'user', 'label', 'start_offset', 'end_offset')
  193. class Seq2seqAnnotation(Annotation):
  194. # Override AnnotationManager for custom functionality
  195. objects = Seq2seqAnnotationManager()
  196. document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE)
  197. text = models.CharField(max_length=500)
  198. class Meta:
  199. unique_together = ('document', 'user', 'text')
  200. class Speech2textAnnotation(Annotation):
  201. document = models.ForeignKey(Document, related_name='speech2text_annotations', on_delete=models.CASCADE)
  202. text = models.TextField()
  203. class Meta:
  204. unique_together = ('document', 'user')
  205. class Role(models.Model):
  206. name = models.CharField(max_length=100, unique=True)
  207. description = models.TextField(default='')
  208. created_at = models.DateTimeField(auto_now_add=True)
  209. updated_at = models.DateTimeField(auto_now=True)
  210. def __str__(self):
  211. return self.name
  212. class RoleMapping(models.Model):
  213. user = models.ForeignKey(User, related_name='role_mappings', on_delete=models.CASCADE)
  214. project = models.ForeignKey(Project, related_name='role_mappings', on_delete=models.CASCADE)
  215. role = models.ForeignKey(Role, on_delete=models.CASCADE)
  216. created_at = models.DateTimeField(auto_now_add=True)
  217. updated_at = models.DateTimeField(auto_now=True)
  218. objects = RoleMappingManager()
  219. def clean(self):
  220. other_rolemappings = self.project.role_mappings.exclude(id=self.id)
  221. if other_rolemappings.filter(user=self.user, project=self.project).exists():
  222. raise ValidationError('This user is already assigned to a role in this project.')
  223. class Meta:
  224. unique_together = ("user", "project")
  225. @receiver(post_save, sender=RoleMapping)
  226. def add_linked_project(sender, instance, created, **kwargs):
  227. if not created:
  228. return
  229. userInstance = instance.user
  230. projectInstance = instance.project
  231. if userInstance and projectInstance:
  232. user = User.objects.get(pk=userInstance.pk)
  233. project = Project.objects.get(pk=projectInstance.pk)
  234. user.projects.add(project)
  235. user.save()
  236. # @receiver(post_save)
  237. # def add_superusers_to_project(sender, instance, created, **kwargs):
  238. # if not created:
  239. # return
  240. # if sender not in Project.__subclasses__():
  241. # return
  242. # superusers = User.objects.filter(is_superuser=True)
  243. # admin_role = Role.objects.filter(name=settings.ROLE_PROJECT_ADMIN).first()
  244. # if superusers and admin_role:
  245. # RoleMapping.objects.bulk_create(
  246. # [RoleMapping(role_id=admin_role.id, user_id=superuser.id, project_id=instance.id)
  247. # for superuser in superusers]
  248. # )
  249. #
  250. #
  251. # @receiver(post_save, sender=User)
  252. # def add_new_superuser_to_projects(sender, instance, created, **kwargs):
  253. # if created and instance.is_superuser:
  254. # admin_role = Role.objects.filter(name=settings.ROLE_PROJECT_ADMIN).first()
  255. # projects = Project.objects.all()
  256. # if admin_role and projects:
  257. # RoleMapping.objects.bulk_create(
  258. # [RoleMapping(role_id=admin_role.id, user_id=instance.id, project_id=project.id)
  259. # for project in projects]
  260. # )
  261. @receiver(m2m_changed, sender=Project.users.through)
  262. def remove_mapping_on_remove_user_from_project(sender, instance, action, reverse, **kwargs):
  263. # if reverse is True, pk_set is project_ids and instance is user.
  264. # else, pk_set is user_ids and instance is project.
  265. user_ids = kwargs['pk_set']
  266. if action.startswith('post_remove') and not reverse:
  267. RoleMapping.objects.filter(user__in=user_ids, project=instance).delete()
  268. elif action.startswith('post_add') and not reverse:
  269. admin_role = Role.objects.get(name=settings.ROLE_PROJECT_ADMIN)
  270. RoleMapping.objects.bulk_create(
  271. [RoleMapping(role=admin_role, project=instance, user_id=user)
  272. for user in user_ids
  273. if not RoleMapping.objects.filter(project=instance, user_id=user).exists()]
  274. )
  275. @receiver(pre_delete, sender=RoleMapping)
  276. def delete_linked_project(sender, instance, using, **kwargs):
  277. userInstance = instance.user
  278. projectInstance = instance.project
  279. if userInstance and projectInstance:
  280. user = User.objects.get(pk=userInstance.pk)
  281. project = Project.objects.get(pk=projectInstance.pk)
  282. user.projects.remove(project)
  283. user.save()
  284. class AutoLabelingConfig(models.Model):
  285. model_name = models.CharField(max_length=100)
  286. model_attrs = models.JSONField(default=dict)
  287. template = models.TextField(default='')
  288. label_mapping = models.JSONField(default=dict)
  289. project = models.ForeignKey(Project, related_name='auto_labeling_config', on_delete=models.CASCADE)
  290. default = models.BooleanField(default=False)
  291. created_at = models.DateTimeField(auto_now_add=True)
  292. updated_at = models.DateTimeField(auto_now=True)
  293. def __str__(self):
  294. return self.model_name
  295. def clean_fields(self, exclude=None):
  296. super().clean_fields(exclude=exclude)
  297. try:
  298. RequestModelFactory.find(self.model_name)
  299. except NameError:
  300. raise ValidationError(f'The specified model name {self.model_name} does not exist.')
  301. except Exception:
  302. raise ValidationError('The attributes does not match the model.')