You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

301 lines
11 KiB

6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
3 years ago
3 years ago
3 years ago
6 years ago
6 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.conf import settings
  3. from django.contrib.auth import get_user_model
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import serializers
  6. from rest_framework.exceptions import ValidationError
  7. from rest_polymorphic.serializers import PolymorphicSerializer
  8. from .models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
  9. SPEECH2TEXT, AutoLabelingConfig, Comment, Document,
  10. DocumentAnnotation, Label, Project, Role, RoleMapping,
  11. Seq2seqAnnotation, Seq2seqProject, SequenceAnnotation,
  12. SequenceLabelingProject, Speech2textAnnotation,
  13. Speech2textProject, Tag, TextClassificationProject)
  14. class UserSerializer(serializers.ModelSerializer):
  15. class Meta:
  16. model = get_user_model()
  17. fields = ('id', 'username', 'is_superuser')
  18. class LabelSerializer(serializers.ModelSerializer):
  19. def validate(self, attrs):
  20. prefix_key = attrs.get('prefix_key')
  21. suffix_key = attrs.get('suffix_key')
  22. # In the case of user don't set any shortcut key.
  23. if prefix_key is None and suffix_key is None:
  24. return super().validate(attrs)
  25. # Don't allow shortcut key not to have a suffix key.
  26. if prefix_key and not suffix_key:
  27. raise ValidationError('Shortcut key may not have a suffix key.')
  28. # Don't allow to save same shortcut key when prefix_key is null.
  29. try:
  30. context = self.context['request'].parser_context
  31. project_id = context['kwargs']['project_id']
  32. label_id = context['kwargs'].get('label_id')
  33. except (AttributeError, KeyError):
  34. pass # unit tests don't always have the correct context set up
  35. else:
  36. conflicting_labels = Label.objects.filter(
  37. suffix_key=suffix_key,
  38. prefix_key=prefix_key,
  39. project=project_id,
  40. )
  41. if label_id is not None:
  42. conflicting_labels = conflicting_labels.exclude(id=label_id)
  43. if conflicting_labels.exists():
  44. raise ValidationError('Duplicate shortcut key.')
  45. return super().validate(attrs)
  46. class Meta:
  47. model = Label
  48. fields = ('id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color')
  49. class CommentSerializer(serializers.ModelSerializer):
  50. class Meta:
  51. model = Comment
  52. fields = ('id', 'user', 'username', 'document', 'document_text', 'text', 'created_at', )
  53. read_only_fields = ('user', 'document')
  54. class TagSerializer(serializers.ModelSerializer):
  55. class Meta:
  56. model = Tag
  57. fields = ('id', 'project', 'text', )
  58. read_only_fields = ('id', 'project')
  59. class DocumentSerializer(serializers.ModelSerializer):
  60. annotations = serializers.SerializerMethodField()
  61. annotation_approver = serializers.SerializerMethodField()
  62. def get_annotations(self, instance):
  63. request = self.context.get('request')
  64. project = instance.project
  65. model = project.get_annotation_class()
  66. serializer = get_annotation_serializer(task=project.project_type)
  67. annotations = model.objects.filter(document=instance.id)
  68. if request and not project.collaborative_annotation:
  69. annotations = annotations.filter(user=request.user)
  70. serializer = serializer(annotations, many=True)
  71. return serializer.data
  72. @classmethod
  73. def get_annotation_approver(cls, instance):
  74. approver = instance.annotations_approved_by
  75. return approver.username if approver else None
  76. class Meta:
  77. model = Document
  78. fields = ('id', 'text', 'annotations', 'meta', 'annotation_approver', 'comment_count')
  79. class ApproverSerializer(DocumentSerializer):
  80. class Meta:
  81. model = Document
  82. fields = ('id', 'annotation_approver')
  83. class ProjectSerializer(serializers.ModelSerializer):
  84. current_users_role = serializers.SerializerMethodField()
  85. tags = TagSerializer(many=True, required=False)
  86. def get_current_users_role(self, instance):
  87. role_abstractor = {
  88. "is_project_admin": settings.ROLE_PROJECT_ADMIN,
  89. "is_annotator": settings.ROLE_ANNOTATOR,
  90. "is_annotation_approver": settings.ROLE_ANNOTATION_APPROVER,
  91. }
  92. queryset = RoleMapping.objects.values("role_id__name")
  93. if queryset:
  94. users_role = get_object_or_404(
  95. queryset, project=instance.id, user=self.context.get("request").user.id
  96. )
  97. for key, val in role_abstractor.items():
  98. role_abstractor[key] = users_role["role_id__name"] == val
  99. return role_abstractor
  100. class Meta:
  101. model = Project
  102. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  103. 'updated_at', 'randomize_document_order', 'collaborative_annotation', 'single_class_classification',
  104. 'tags')
  105. read_only_fields = ('updated_at', 'users', 'current_users_role', 'tags')
  106. class TextClassificationProjectSerializer(ProjectSerializer):
  107. class Meta:
  108. model = TextClassificationProject
  109. fields = ProjectSerializer.Meta.fields
  110. read_only_fields = ProjectSerializer.Meta.read_only_fields
  111. class SequenceLabelingProjectSerializer(ProjectSerializer):
  112. class Meta:
  113. model = SequenceLabelingProject
  114. fields = ProjectSerializer.Meta.fields
  115. read_only_fields = ProjectSerializer.Meta.read_only_fields
  116. class Seq2seqProjectSerializer(ProjectSerializer):
  117. class Meta:
  118. model = Seq2seqProject
  119. fields = ProjectSerializer.Meta.fields
  120. read_only_fields = ProjectSerializer.Meta.read_only_fields
  121. class Speech2textProjectSerializer(ProjectSerializer):
  122. class Meta:
  123. model = Speech2textProject
  124. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  125. 'updated_at', 'randomize_document_order')
  126. read_only_fields = ('updated_at', 'users', 'current_users_role')
  127. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  128. model_serializer_mapping = {
  129. Project: ProjectSerializer,
  130. TextClassificationProject: TextClassificationProjectSerializer,
  131. SequenceLabelingProject: SequenceLabelingProjectSerializer,
  132. Seq2seqProject: Seq2seqProjectSerializer,
  133. Speech2textProject: Speech2textProjectSerializer,
  134. }
  135. class ProjectFilteredPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField):
  136. def get_queryset(self):
  137. view = self.context.get('view', None)
  138. request = self.context.get('request', None)
  139. queryset = super(ProjectFilteredPrimaryKeyRelatedField, self).get_queryset()
  140. if not request or not queryset or not view:
  141. return None
  142. return queryset.filter(project=view.kwargs['project_id'])
  143. class DocumentAnnotationSerializer(serializers.ModelSerializer):
  144. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  145. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  146. class Meta:
  147. model = DocumentAnnotation
  148. fields = ('id', 'prob', 'label', 'user', 'document', 'created_at', 'updated_at')
  149. read_only_fields = ('user', )
  150. class SequenceAnnotationSerializer(serializers.ModelSerializer):
  151. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  152. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  153. class Meta:
  154. model = SequenceAnnotation
  155. fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user', 'document', 'created_at', 'updated_at')
  156. read_only_fields = ('user',)
  157. class Seq2seqAnnotationSerializer(serializers.ModelSerializer):
  158. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  159. class Meta:
  160. model = Seq2seqAnnotation
  161. fields = ('id', 'text', 'user', 'document', 'prob', 'created_at', 'updated_at')
  162. read_only_fields = ('user',)
  163. class Speech2textAnnotationSerializer(serializers.ModelSerializer):
  164. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  165. class Meta:
  166. model = Speech2textAnnotation
  167. fields = ('id', 'prob', 'text', 'user', 'document', 'created_at', 'updated_at')
  168. read_only_fields = ('user',)
  169. class RoleSerializer(serializers.ModelSerializer):
  170. class Meta:
  171. model = Role
  172. fields = ('id', 'name')
  173. class RoleMappingSerializer(serializers.ModelSerializer):
  174. username = serializers.SerializerMethodField()
  175. rolename = serializers.SerializerMethodField()
  176. @classmethod
  177. def get_username(cls, instance):
  178. user = instance.user
  179. return user.username if user else None
  180. @classmethod
  181. def get_rolename(cls, instance):
  182. role = instance.role
  183. return role.name if role else None
  184. class Meta:
  185. model = RoleMapping
  186. fields = ('id', 'user', 'role', 'username', 'rolename')
  187. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  188. class Meta:
  189. model = AutoLabelingConfig
  190. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  191. read_only_fields = ('created_at', 'updated_at')
  192. def validate_model_name(self, value):
  193. try:
  194. RequestModelFactory.find(value)
  195. except NameError:
  196. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  197. return value
  198. def valid_label_mapping(self, value):
  199. if isinstance(value, dict):
  200. return value
  201. else:
  202. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  203. def validate(self, data):
  204. try:
  205. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  206. except Exception:
  207. model = RequestModelFactory.find(data['model_name'])
  208. schema = model.schema()
  209. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  210. raise serializers.ValidationError(
  211. 'The attributes does not match the model.'
  212. 'You need to correctly specify the required fields: {}'.format(required_fields)
  213. )
  214. return data
  215. def get_annotation_serializer(task: str):
  216. mapping = {
  217. DOCUMENT_CLASSIFICATION: DocumentAnnotationSerializer,
  218. SEQUENCE_LABELING: SequenceAnnotationSerializer,
  219. SEQ2SEQ: Seq2seqAnnotationSerializer,
  220. SPEECH2TEXT: Speech2textAnnotationSerializer
  221. }
  222. try:
  223. return mapping[task]
  224. except KeyError:
  225. raise ValueError(f'{task} is not implemented.')