You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

369 lines
12 KiB

5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import string
  2. from auto_labeling_pipeline.models import RequestModelFactory
  3. from django.db import models
  4. from django.dispatch import receiver
  5. from django.db.models.signals import post_save, pre_delete
  6. from django.urls import reverse
  7. from django.conf import settings
  8. from django.contrib.auth.models import User
  9. from django.core.exceptions import ValidationError
  10. from polymorphic.models import PolymorphicModel
  11. from .managers import AnnotationManager, Seq2seqAnnotationManager
  12. DOCUMENT_CLASSIFICATION = 'DocumentClassification'
  13. SEQUENCE_LABELING = 'SequenceLabeling'
  14. SEQ2SEQ = 'Seq2seq'
  15. SPEECH2TEXT = 'Speech2text'
  16. PROJECT_CHOICES = (
  17. (DOCUMENT_CLASSIFICATION, 'document classification'),
  18. (SEQUENCE_LABELING, 'sequence labeling'),
  19. (SEQ2SEQ, 'sequence to sequence'),
  20. (SPEECH2TEXT, 'speech to text'),
  21. )
  22. class Project(PolymorphicModel):
  23. name = models.CharField(max_length=100)
  24. description = models.TextField(default='')
  25. guideline = models.TextField(default='')
  26. created_at = models.DateTimeField(auto_now_add=True)
  27. updated_at = models.DateTimeField(auto_now=True)
  28. users = models.ManyToManyField(User, related_name='projects')
  29. project_type = models.CharField(max_length=30, choices=PROJECT_CHOICES)
  30. randomize_document_order = models.BooleanField(default=False)
  31. collaborative_annotation = models.BooleanField(default=False)
  32. single_class_classification = models.BooleanField(default=False)
  33. def get_absolute_url(self):
  34. return reverse('upload', args=[self.id])
  35. def get_bundle_name(self):
  36. raise NotImplementedError()
  37. def get_bundle_name_upload(self):
  38. raise NotImplementedError()
  39. def get_bundle_name_download(self):
  40. raise NotImplementedError()
  41. def get_annotation_serializer(self):
  42. raise NotImplementedError()
  43. def get_annotation_class(self):
  44. raise NotImplementedError()
  45. def get_storage(self, data):
  46. raise NotImplementedError()
  47. def __str__(self):
  48. return self.name
  49. class TextClassificationProject(Project):
  50. def get_bundle_name(self):
  51. return 'document_classification'
  52. def get_bundle_name_upload(self):
  53. return 'upload_text_classification'
  54. def get_bundle_name_download(self):
  55. return 'download_text_classification'
  56. def get_annotation_serializer(self):
  57. from .serializers import DocumentAnnotationSerializer
  58. return DocumentAnnotationSerializer
  59. def get_annotation_class(self):
  60. return DocumentAnnotation
  61. def get_storage(self, data):
  62. from .utils import ClassificationStorage
  63. return ClassificationStorage(data, self)
  64. class SequenceLabelingProject(Project):
  65. def get_bundle_name(self):
  66. return 'sequence_labeling'
  67. def get_bundle_name_upload(self):
  68. return 'upload_sequence_labeling'
  69. def get_bundle_name_download(self):
  70. return 'download_sequence_labeling'
  71. def get_annotation_serializer(self):
  72. from .serializers import SequenceAnnotationSerializer
  73. return SequenceAnnotationSerializer
  74. def get_annotation_class(self):
  75. return SequenceAnnotation
  76. def get_storage(self, data):
  77. from .utils import SequenceLabelingStorage
  78. return SequenceLabelingStorage(data, self)
  79. class Seq2seqProject(Project):
  80. def get_bundle_name(self):
  81. return 'seq2seq'
  82. def get_bundle_name_upload(self):
  83. return 'upload_seq2seq'
  84. def get_bundle_name_download(self):
  85. return 'download_seq2seq'
  86. def get_annotation_serializer(self):
  87. from .serializers import Seq2seqAnnotationSerializer
  88. return Seq2seqAnnotationSerializer
  89. def get_annotation_class(self):
  90. return Seq2seqAnnotation
  91. def get_storage(self, data):
  92. from .utils import Seq2seqStorage
  93. return Seq2seqStorage(data, self)
  94. class Speech2textProject(Project):
  95. def get_bundle_name(self):
  96. return 'speech2text'
  97. def get_bundle_name_upload(self):
  98. return 'upload_speech2text'
  99. def get_bundle_name_download(self):
  100. return 'download_speech2text'
  101. def get_annotation_serializer(self):
  102. from .serializers import Speech2textAnnotationSerializer
  103. return Speech2textAnnotationSerializer
  104. def get_annotation_class(self):
  105. return Speech2textAnnotation
  106. def get_storage(self, data):
  107. from .utils import Speech2textStorage
  108. return Speech2textStorage(data, self)
  109. class Label(models.Model):
  110. PREFIX_KEYS = (
  111. ('ctrl', 'ctrl'),
  112. ('shift', 'shift'),
  113. ('ctrl shift', 'ctrl shift')
  114. )
  115. SUFFIX_KEYS = tuple(
  116. (c, c) for c in string.digits + string.ascii_lowercase
  117. )
  118. text = models.CharField(max_length=100)
  119. prefix_key = models.CharField(max_length=10, blank=True, null=True, choices=PREFIX_KEYS)
  120. suffix_key = models.CharField(max_length=1, blank=True, null=True, choices=SUFFIX_KEYS)
  121. project = models.ForeignKey(Project, related_name='labels', on_delete=models.CASCADE)
  122. background_color = models.CharField(max_length=7, default='#209cee')
  123. text_color = models.CharField(max_length=7, default='#ffffff')
  124. created_at = models.DateTimeField(auto_now_add=True)
  125. updated_at = models.DateTimeField(auto_now=True)
  126. def __str__(self):
  127. return self.text
  128. def clean(self):
  129. # Don't allow shortcut key not to have a suffix key.
  130. if self.prefix_key and not self.suffix_key:
  131. raise ValidationError('Shortcut key may not have a suffix key.')
  132. # each shortcut (prefix key + suffix key) can only be assigned to one label
  133. if self.suffix_key or self.prefix_key:
  134. other_labels = self.project.labels.exclude(id=self.id)
  135. if other_labels.filter(suffix_key=self.suffix_key, prefix_key=self.prefix_key).exists():
  136. raise ValidationError('A label with this shortcut already exists in the project')
  137. super().clean()
  138. class Meta:
  139. unique_together = (
  140. ('project', 'text'),
  141. )
  142. class Document(models.Model):
  143. text = models.TextField()
  144. project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE)
  145. meta = models.TextField(default='{}')
  146. created_at = models.DateTimeField(auto_now_add=True)
  147. updated_at = models.DateTimeField(auto_now=True)
  148. annotations_approved_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True)
  149. def __str__(self):
  150. return self.text[:50]
  151. class Comment(models.Model):
  152. text = models.TextField()
  153. document = models.ForeignKey(Document, related_name='comments', on_delete=models.CASCADE)
  154. user = models.ForeignKey(User, on_delete=models.CASCADE, null=True)
  155. created_at = models.DateTimeField(auto_now_add=True)
  156. updated_at = models.DateTimeField(auto_now=True)
  157. class Annotation(models.Model):
  158. objects = AnnotationManager()
  159. prob = models.FloatField(default=0.0)
  160. manual = models.BooleanField(default=False)
  161. user = models.ForeignKey(User, on_delete=models.CASCADE)
  162. created_at = models.DateTimeField(auto_now_add=True)
  163. updated_at = models.DateTimeField(auto_now=True)
  164. class Meta:
  165. abstract = True
  166. class DocumentAnnotation(Annotation):
  167. document = models.ForeignKey(Document, related_name='doc_annotations', on_delete=models.CASCADE)
  168. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  169. class Meta:
  170. unique_together = ('document', 'user', 'label')
  171. class SequenceAnnotation(Annotation):
  172. document = models.ForeignKey(Document, related_name='seq_annotations', on_delete=models.CASCADE)
  173. label = models.ForeignKey(Label, on_delete=models.CASCADE)
  174. start_offset = models.IntegerField()
  175. end_offset = models.IntegerField()
  176. def clean(self):
  177. if self.start_offset >= self.end_offset:
  178. raise ValidationError('start_offset is after end_offset')
  179. class Meta:
  180. unique_together = ('document', 'user', 'label', 'start_offset', 'end_offset')
  181. class Seq2seqAnnotation(Annotation):
  182. # Override AnnotationManager for custom functionality
  183. objects = Seq2seqAnnotationManager()
  184. document = models.ForeignKey(Document, related_name='seq2seq_annotations', on_delete=models.CASCADE)
  185. text = models.CharField(max_length=500)
  186. class Meta:
  187. unique_together = ('document', 'user', 'text')
  188. class Speech2textAnnotation(Annotation):
  189. document = models.ForeignKey(Document, related_name='speech2text_annotations', on_delete=models.CASCADE)
  190. text = models.TextField()
  191. class Meta:
  192. unique_together = ('document', 'user')
  193. class Role(models.Model):
  194. name = models.CharField(max_length=100, unique=True)
  195. description = models.TextField(default='')
  196. created_at = models.DateTimeField(auto_now_add=True)
  197. updated_at = models.DateTimeField(auto_now=True)
  198. def __str__(self):
  199. return self.name
  200. class RoleMapping(models.Model):
  201. user = models.ForeignKey(User, related_name='role_mappings', on_delete=models.CASCADE)
  202. project = models.ForeignKey(Project, related_name='role_mappings', on_delete=models.CASCADE)
  203. role = models.ForeignKey(Role, on_delete=models.CASCADE)
  204. created_at = models.DateTimeField(auto_now_add=True)
  205. updated_at = models.DateTimeField(auto_now=True)
  206. def clean(self):
  207. other_rolemappings = self.project.role_mappings.exclude(id=self.id)
  208. if other_rolemappings.filter(user=self.user, project=self.project).exists():
  209. raise ValidationError('This user is already assigned to a role in this project.')
  210. class Meta:
  211. unique_together = ("user", "project", "role")
  212. @receiver(post_save, sender=RoleMapping)
  213. def add_linked_project(sender, instance, created, **kwargs):
  214. if not created:
  215. return
  216. userInstance = instance.user
  217. projectInstance = instance.project
  218. if userInstance and projectInstance:
  219. user = User.objects.get(pk=userInstance.pk)
  220. project = Project.objects.get(pk=projectInstance.pk)
  221. user.projects.add(project)
  222. user.save()
  223. @receiver(post_save)
  224. def add_superusers_to_project(sender, instance, created, **kwargs):
  225. if not created:
  226. return
  227. if sender not in Project.__subclasses__():
  228. return
  229. superusers = User.objects.filter(is_superuser=True)
  230. admin_role = Role.objects.filter(name=settings.ROLE_PROJECT_ADMIN).first()
  231. if superusers and admin_role:
  232. RoleMapping.objects.bulk_create(
  233. [RoleMapping(role_id=admin_role.id, user_id=superuser.id, project_id=instance.id)
  234. for superuser in superusers]
  235. )
  236. @receiver(post_save, sender=User)
  237. def add_new_superuser_to_projects(sender, instance, created, **kwargs):
  238. if created and instance.is_superuser:
  239. admin_role = Role.objects.filter(name=settings.ROLE_PROJECT_ADMIN).first()
  240. projects = Project.objects.all()
  241. if admin_role and projects:
  242. RoleMapping.objects.bulk_create(
  243. [RoleMapping(role_id=admin_role.id, user_id=instance.id, project_id=project.id)
  244. for project in projects]
  245. )
  246. @receiver(pre_delete, sender=RoleMapping)
  247. def delete_linked_project(sender, instance, using, **kwargs):
  248. userInstance = instance.user
  249. projectInstance = instance.project
  250. if userInstance and projectInstance:
  251. user = User.objects.get(pk=userInstance.pk)
  252. project = Project.objects.get(pk=projectInstance.pk)
  253. user.projects.remove(project)
  254. user.save()
  255. class AutoLabelingConfig(models.Model):
  256. model_name = models.CharField(max_length=100)
  257. model_attrs = models.JSONField(default=dict)
  258. template = models.TextField(default='')
  259. label_mapping = models.JSONField(default=dict)
  260. project = models.ForeignKey(Project, related_name='auto_labeling_config', on_delete=models.CASCADE)
  261. default = models.BooleanField(default=False)
  262. created_at = models.DateTimeField(auto_now_add=True)
  263. updated_at = models.DateTimeField(auto_now=True)
  264. def __str__(self):
  265. return self.model_name
  266. def clean_fields(self, exclude=None):
  267. super().clean_fields(exclude=exclude)
  268. try:
  269. RequestModelFactory.find(self.model_name)
  270. except NameError:
  271. raise ValidationError(f'The specified model name {self.model_name} does not exist.')
  272. except Exception:
  273. raise ValidationError('The attributes does not match the model.')