You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

279 lines
10 KiB

6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
3 years ago
3 years ago
3 years ago
6 years ago
6 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.conf import settings
  3. from django.contrib.auth import get_user_model
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import serializers
  6. from rest_framework.exceptions import ValidationError
  7. from rest_polymorphic.serializers import PolymorphicSerializer
  8. from .models import (AutoLabelingConfig, Comment, Document, DocumentAnnotation,
  9. Label, Project, Role, RoleMapping, Seq2seqAnnotation,
  10. Seq2seqProject, SequenceAnnotation,
  11. SequenceLabelingProject, Speech2textAnnotation,
  12. Speech2textProject, TextClassificationProject)
  13. class UserSerializer(serializers.ModelSerializer):
  14. class Meta:
  15. model = get_user_model()
  16. fields = ('id', 'username', 'is_superuser')
  17. class LabelSerializer(serializers.ModelSerializer):
  18. def validate(self, attrs):
  19. prefix_key = attrs.get('prefix_key')
  20. suffix_key = attrs.get('suffix_key')
  21. # In the case of user don't set any shortcut key.
  22. if prefix_key is None and suffix_key is None:
  23. return super().validate(attrs)
  24. # Don't allow shortcut key not to have a suffix key.
  25. if prefix_key and not suffix_key:
  26. raise ValidationError('Shortcut key may not have a suffix key.')
  27. # Don't allow to save same shortcut key when prefix_key is null.
  28. try:
  29. context = self.context['request'].parser_context
  30. project_id = context['kwargs']['project_id']
  31. label_id = context['kwargs'].get('label_id')
  32. except (AttributeError, KeyError):
  33. pass # unit tests don't always have the correct context set up
  34. else:
  35. conflicting_labels = Label.objects.filter(
  36. suffix_key=suffix_key,
  37. prefix_key=prefix_key,
  38. project=project_id,
  39. )
  40. if label_id is not None:
  41. conflicting_labels = conflicting_labels.exclude(id=label_id)
  42. if conflicting_labels.exists():
  43. raise ValidationError('Duplicate shortcut key.')
  44. return super().validate(attrs)
  45. class Meta:
  46. model = Label
  47. fields = ('id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color')
  48. class CommentSerializer(serializers.ModelSerializer):
  49. class Meta:
  50. model = Comment
  51. fields = ('id', 'user', 'username', 'document', 'document_text', 'text', 'created_at', )
  52. read_only_fields = ('user', 'document')
  53. class DocumentSerializer(serializers.ModelSerializer):
  54. annotations = serializers.SerializerMethodField()
  55. annotation_approver = serializers.SerializerMethodField()
  56. def get_annotations(self, instance):
  57. request = self.context.get('request')
  58. project = instance.project
  59. model = project.get_annotation_class()
  60. serializer = project.get_annotation_serializer()
  61. annotations = model.objects.filter(document=instance.id)
  62. if request and not project.collaborative_annotation:
  63. annotations = annotations.filter(user=request.user)
  64. serializer = serializer(annotations, many=True)
  65. return serializer.data
  66. @classmethod
  67. def get_annotation_approver(cls, instance):
  68. approver = instance.annotations_approved_by
  69. return approver.username if approver else None
  70. class Meta:
  71. model = Document
  72. fields = ('id', 'text', 'annotations', 'meta', 'annotation_approver', 'comment_count')
  73. class ApproverSerializer(DocumentSerializer):
  74. class Meta:
  75. model = Document
  76. fields = ('id', 'annotation_approver')
  77. class ProjectSerializer(serializers.ModelSerializer):
  78. current_users_role = serializers.SerializerMethodField()
  79. def get_current_users_role(self, instance):
  80. role_abstractor = {
  81. "is_project_admin": settings.ROLE_PROJECT_ADMIN,
  82. "is_annotator": settings.ROLE_ANNOTATOR,
  83. "is_annotation_approver": settings.ROLE_ANNOTATION_APPROVER,
  84. }
  85. queryset = RoleMapping.objects.values("role_id__name")
  86. if queryset:
  87. users_role = get_object_or_404(
  88. queryset, project=instance.id, user=self.context.get("request").user.id
  89. )
  90. for key, val in role_abstractor.items():
  91. role_abstractor[key] = users_role["role_id__name"] == val
  92. return role_abstractor
  93. class Meta:
  94. model = Project
  95. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  96. 'updated_at', 'randomize_document_order', 'collaborative_annotation', 'single_class_classification')
  97. read_only_fields = ('updated_at', 'users', 'current_users_role')
  98. class TextClassificationProjectSerializer(ProjectSerializer):
  99. class Meta:
  100. model = TextClassificationProject
  101. fields = ProjectSerializer.Meta.fields
  102. read_only_fields = ProjectSerializer.Meta.read_only_fields
  103. class SequenceLabelingProjectSerializer(ProjectSerializer):
  104. class Meta:
  105. model = SequenceLabelingProject
  106. fields = ProjectSerializer.Meta.fields
  107. read_only_fields = ProjectSerializer.Meta.read_only_fields
  108. class Seq2seqProjectSerializer(ProjectSerializer):
  109. class Meta:
  110. model = Seq2seqProject
  111. fields = ProjectSerializer.Meta.fields
  112. read_only_fields = ProjectSerializer.Meta.read_only_fields
  113. class Speech2textProjectSerializer(ProjectSerializer):
  114. class Meta:
  115. model = Speech2textProject
  116. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  117. 'updated_at', 'randomize_document_order')
  118. read_only_fields = ('updated_at', 'users', 'current_users_role')
  119. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  120. model_serializer_mapping = {
  121. Project: ProjectSerializer,
  122. TextClassificationProject: TextClassificationProjectSerializer,
  123. SequenceLabelingProject: SequenceLabelingProjectSerializer,
  124. Seq2seqProject: Seq2seqProjectSerializer,
  125. Speech2textProject: Speech2textProjectSerializer,
  126. }
  127. class ProjectFilteredPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField):
  128. def get_queryset(self):
  129. view = self.context.get('view', None)
  130. request = self.context.get('request', None)
  131. queryset = super(ProjectFilteredPrimaryKeyRelatedField, self).get_queryset()
  132. if not request or not queryset or not view:
  133. return None
  134. return queryset.filter(project=view.kwargs['project_id'])
  135. class DocumentAnnotationSerializer(serializers.ModelSerializer):
  136. # label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
  137. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  138. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  139. class Meta:
  140. model = DocumentAnnotation
  141. fields = ('id', 'prob', 'label', 'user', 'document', 'created_at', 'updated_at')
  142. read_only_fields = ('user', )
  143. class SequenceAnnotationSerializer(serializers.ModelSerializer):
  144. #label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
  145. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  146. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  147. class Meta:
  148. model = SequenceAnnotation
  149. fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user', 'document', 'created_at', 'updated_at')
  150. read_only_fields = ('user',)
  151. class Seq2seqAnnotationSerializer(serializers.ModelSerializer):
  152. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  153. class Meta:
  154. model = Seq2seqAnnotation
  155. fields = ('id', 'text', 'user', 'document', 'prob', 'created_at', 'updated_at')
  156. read_only_fields = ('user',)
  157. class Speech2textAnnotationSerializer(serializers.ModelSerializer):
  158. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  159. class Meta:
  160. model = Speech2textAnnotation
  161. fields = ('id', 'prob', 'text', 'user', 'document', 'created_at', 'updated_at')
  162. read_only_fields = ('user',)
  163. class RoleSerializer(serializers.ModelSerializer):
  164. class Meta:
  165. model = Role
  166. fields = ('id', 'name')
  167. class RoleMappingSerializer(serializers.ModelSerializer):
  168. username = serializers.SerializerMethodField()
  169. rolename = serializers.SerializerMethodField()
  170. @classmethod
  171. def get_username(cls, instance):
  172. user = instance.user
  173. return user.username if user else None
  174. @classmethod
  175. def get_rolename(cls, instance):
  176. role = instance.role
  177. return role.name if role else None
  178. class Meta:
  179. model = RoleMapping
  180. fields = ('id', 'user', 'role', 'username', 'rolename')
  181. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  182. class Meta:
  183. model = AutoLabelingConfig
  184. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  185. read_only_fields = ('created_at', 'updated_at')
  186. def validate_model_name(self, value):
  187. try:
  188. RequestModelFactory.find(value)
  189. except NameError:
  190. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  191. return value
  192. def valid_label_mapping(self, value):
  193. if isinstance(value, dict):
  194. return value
  195. else:
  196. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  197. def validate(self, data):
  198. try:
  199. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  200. except Exception:
  201. model = RequestModelFactory.find(data['model_name'])
  202. schema = model.schema()
  203. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  204. raise serializers.ValidationError(
  205. 'The attributes does not match the model.'
  206. 'You need to correctly specify the required fields: {}'.format(required_fields)
  207. )
  208. return data