You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

288 lines
10 KiB

6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
3 years ago
3 years ago
3 years ago
6 years ago
6 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.conf import settings
  3. from django.contrib.auth import get_user_model
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import serializers
  6. from rest_framework.exceptions import ValidationError
  7. from rest_polymorphic.serializers import PolymorphicSerializer
  8. from .models import (AutoLabelingConfig, Comment, Document, DocumentAnnotation,
  9. Label, Project, Role, RoleMapping, Seq2seqAnnotation,
  10. Seq2seqProject, SequenceAnnotation,
  11. SequenceLabelingProject, Speech2textAnnotation,
  12. Speech2textProject, Tag, TextClassificationProject)
  13. class UserSerializer(serializers.ModelSerializer):
  14. class Meta:
  15. model = get_user_model()
  16. fields = ('id', 'username', 'is_superuser')
  17. class LabelSerializer(serializers.ModelSerializer):
  18. def validate(self, attrs):
  19. prefix_key = attrs.get('prefix_key')
  20. suffix_key = attrs.get('suffix_key')
  21. # In the case of user don't set any shortcut key.
  22. if prefix_key is None and suffix_key is None:
  23. return super().validate(attrs)
  24. # Don't allow shortcut key not to have a suffix key.
  25. if prefix_key and not suffix_key:
  26. raise ValidationError('Shortcut key may not have a suffix key.')
  27. # Don't allow to save same shortcut key when prefix_key is null.
  28. try:
  29. context = self.context['request'].parser_context
  30. project_id = context['kwargs']['project_id']
  31. label_id = context['kwargs'].get('label_id')
  32. except (AttributeError, KeyError):
  33. pass # unit tests don't always have the correct context set up
  34. else:
  35. conflicting_labels = Label.objects.filter(
  36. suffix_key=suffix_key,
  37. prefix_key=prefix_key,
  38. project=project_id,
  39. )
  40. if label_id is not None:
  41. conflicting_labels = conflicting_labels.exclude(id=label_id)
  42. if conflicting_labels.exists():
  43. raise ValidationError('Duplicate shortcut key.')
  44. return super().validate(attrs)
  45. class Meta:
  46. model = Label
  47. fields = ('id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color')
  48. class CommentSerializer(serializers.ModelSerializer):
  49. class Meta:
  50. model = Comment
  51. fields = ('id', 'user', 'username', 'document', 'document_text', 'text', 'created_at', )
  52. read_only_fields = ('user', 'document')
  53. class TagSerializer(serializers.ModelSerializer):
  54. class Meta:
  55. model = Tag
  56. fields = ('id', 'project', 'text', )
  57. read_only_fields = ('id', 'project')
  58. class DocumentSerializer(serializers.ModelSerializer):
  59. annotations = serializers.SerializerMethodField()
  60. annotation_approver = serializers.SerializerMethodField()
  61. def get_annotations(self, instance):
  62. request = self.context.get('request')
  63. project = instance.project
  64. model = project.get_annotation_class()
  65. serializer = project.get_annotation_serializer()
  66. annotations = model.objects.filter(document=instance.id)
  67. if request and not project.collaborative_annotation:
  68. annotations = annotations.filter(user=request.user)
  69. serializer = serializer(annotations, many=True)
  70. return serializer.data
  71. @classmethod
  72. def get_annotation_approver(cls, instance):
  73. approver = instance.annotations_approved_by
  74. return approver.username if approver else None
  75. class Meta:
  76. model = Document
  77. fields = ('id', 'text', 'annotations', 'meta', 'annotation_approver', 'comment_count')
  78. class ApproverSerializer(DocumentSerializer):
  79. class Meta:
  80. model = Document
  81. fields = ('id', 'annotation_approver')
  82. class ProjectSerializer(serializers.ModelSerializer):
  83. current_users_role = serializers.SerializerMethodField()
  84. tags = TagSerializer(many=True, required=False)
  85. def get_current_users_role(self, instance):
  86. role_abstractor = {
  87. "is_project_admin": settings.ROLE_PROJECT_ADMIN,
  88. "is_annotator": settings.ROLE_ANNOTATOR,
  89. "is_annotation_approver": settings.ROLE_ANNOTATION_APPROVER,
  90. }
  91. queryset = RoleMapping.objects.values("role_id__name")
  92. if queryset:
  93. users_role = get_object_or_404(
  94. queryset, project=instance.id, user=self.context.get("request").user.id
  95. )
  96. for key, val in role_abstractor.items():
  97. role_abstractor[key] = users_role["role_id__name"] == val
  98. return role_abstractor
  99. class Meta:
  100. model = Project
  101. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  102. 'updated_at', 'randomize_document_order', 'collaborative_annotation', 'single_class_classification', 'tags')
  103. read_only_fields = ('updated_at', 'users', 'current_users_role', 'tags')
  104. class TextClassificationProjectSerializer(ProjectSerializer):
  105. class Meta:
  106. model = TextClassificationProject
  107. fields = ProjectSerializer.Meta.fields
  108. read_only_fields = ProjectSerializer.Meta.read_only_fields
  109. class SequenceLabelingProjectSerializer(ProjectSerializer):
  110. class Meta:
  111. model = SequenceLabelingProject
  112. fields = ProjectSerializer.Meta.fields
  113. read_only_fields = ProjectSerializer.Meta.read_only_fields
  114. class Seq2seqProjectSerializer(ProjectSerializer):
  115. class Meta:
  116. model = Seq2seqProject
  117. fields = ProjectSerializer.Meta.fields
  118. read_only_fields = ProjectSerializer.Meta.read_only_fields
  119. class Speech2textProjectSerializer(ProjectSerializer):
  120. class Meta:
  121. model = Speech2textProject
  122. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  123. 'updated_at', 'randomize_document_order')
  124. read_only_fields = ('updated_at', 'users', 'current_users_role')
  125. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  126. model_serializer_mapping = {
  127. Project: ProjectSerializer,
  128. TextClassificationProject: TextClassificationProjectSerializer,
  129. SequenceLabelingProject: SequenceLabelingProjectSerializer,
  130. Seq2seqProject: Seq2seqProjectSerializer,
  131. Speech2textProject: Speech2textProjectSerializer,
  132. }
  133. class ProjectFilteredPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField):
  134. def get_queryset(self):
  135. view = self.context.get('view', None)
  136. request = self.context.get('request', None)
  137. queryset = super(ProjectFilteredPrimaryKeyRelatedField, self).get_queryset()
  138. if not request or not queryset or not view:
  139. return None
  140. return queryset.filter(project=view.kwargs['project_id'])
  141. class DocumentAnnotationSerializer(serializers.ModelSerializer):
  142. # label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
  143. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  144. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  145. class Meta:
  146. model = DocumentAnnotation
  147. fields = ('id', 'prob', 'label', 'user', 'document', 'created_at', 'updated_at')
  148. read_only_fields = ('user', )
  149. class SequenceAnnotationSerializer(serializers.ModelSerializer):
  150. #label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
  151. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  152. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  153. class Meta:
  154. model = SequenceAnnotation
  155. fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user', 'document', 'created_at', 'updated_at')
  156. read_only_fields = ('user',)
  157. class Seq2seqAnnotationSerializer(serializers.ModelSerializer):
  158. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  159. class Meta:
  160. model = Seq2seqAnnotation
  161. fields = ('id', 'text', 'user', 'document', 'prob', 'created_at', 'updated_at')
  162. read_only_fields = ('user',)
  163. class Speech2textAnnotationSerializer(serializers.ModelSerializer):
  164. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  165. class Meta:
  166. model = Speech2textAnnotation
  167. fields = ('id', 'prob', 'text', 'user', 'document', 'created_at', 'updated_at')
  168. read_only_fields = ('user',)
  169. class RoleSerializer(serializers.ModelSerializer):
  170. class Meta:
  171. model = Role
  172. fields = ('id', 'name')
  173. class RoleMappingSerializer(serializers.ModelSerializer):
  174. username = serializers.SerializerMethodField()
  175. rolename = serializers.SerializerMethodField()
  176. @classmethod
  177. def get_username(cls, instance):
  178. user = instance.user
  179. return user.username if user else None
  180. @classmethod
  181. def get_rolename(cls, instance):
  182. role = instance.role
  183. return role.name if role else None
  184. class Meta:
  185. model = RoleMapping
  186. fields = ('id', 'user', 'role', 'username', 'rolename')
  187. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  188. class Meta:
  189. model = AutoLabelingConfig
  190. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  191. read_only_fields = ('created_at', 'updated_at')
  192. def validate_model_name(self, value):
  193. try:
  194. RequestModelFactory.find(value)
  195. except NameError:
  196. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  197. return value
  198. def valid_label_mapping(self, value):
  199. if isinstance(value, dict):
  200. return value
  201. else:
  202. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  203. def validate(self, data):
  204. try:
  205. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  206. except Exception:
  207. model = RequestModelFactory.find(data['model_name'])
  208. schema = model.schema()
  209. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  210. raise serializers.ValidationError(
  211. 'The attributes does not match the model.'
  212. 'You need to correctly specify the required fields: {}'.format(required_fields)
  213. )
  214. return data