You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

272 lines
9.8 KiB

6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.conf import settings
  3. from django.contrib.auth import get_user_model
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import serializers
  6. from rest_polymorphic.serializers import PolymorphicSerializer
  7. from rest_framework.exceptions import ValidationError
  8. from .models import Label, Project, Document, RoleMapping, Role, Comment, AutoLabelingConfig
  9. from .models import TextClassificationProject, SequenceLabelingProject, Seq2seqProject, Speech2textProject
  10. from .models import DocumentAnnotation, SequenceAnnotation, Seq2seqAnnotation, Speech2textAnnotation
  11. class UserSerializer(serializers.ModelSerializer):
  12. class Meta:
  13. model = get_user_model()
  14. fields = ('id', 'username', 'first_name', 'last_name', 'email', 'is_superuser')
  15. class LabelSerializer(serializers.ModelSerializer):
  16. def validate(self, attrs):
  17. prefix_key = attrs.get('prefix_key')
  18. suffix_key = attrs.get('suffix_key')
  19. # In the case of user don't set any shortcut key.
  20. if prefix_key is None and suffix_key is None:
  21. return super().validate(attrs)
  22. # Don't allow shortcut key not to have a suffix key.
  23. if prefix_key and not suffix_key:
  24. raise ValidationError('Shortcut key may not have a suffix key.')
  25. # Don't allow to save same shortcut key when prefix_key is null.
  26. try:
  27. context = self.context['request'].parser_context
  28. project_id = context['kwargs']['project_id']
  29. label_id = context['kwargs'].get('label_id')
  30. except (AttributeError, KeyError):
  31. pass # unit tests don't always have the correct context set up
  32. else:
  33. conflicting_labels = Label.objects.filter(
  34. suffix_key=suffix_key,
  35. prefix_key=prefix_key,
  36. project=project_id,
  37. )
  38. if label_id is not None:
  39. conflicting_labels = conflicting_labels.exclude(id=label_id)
  40. if conflicting_labels.exists():
  41. raise ValidationError('Duplicate shortcut key.')
  42. return super().validate(attrs)
  43. class Meta:
  44. model = Label
  45. fields = ('id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color')
  46. class CommentSerializer(serializers.ModelSerializer):
  47. class Meta:
  48. model = Comment
  49. fields = ('id', 'text', 'user')
  50. read_only_fields = ('user', 'document')
  51. class DocumentSerializer(serializers.ModelSerializer):
  52. annotations = serializers.SerializerMethodField()
  53. annotation_approver = serializers.SerializerMethodField()
  54. comments = CommentSerializer(many=True, read_only=True)
  55. def get_annotations(self, instance):
  56. request = self.context.get('request')
  57. project = instance.project
  58. model = project.get_annotation_class()
  59. serializer = project.get_annotation_serializer()
  60. annotations = model.objects.filter(document=instance.id)
  61. if request and not project.collaborative_annotation:
  62. annotations = annotations.filter(user=request.user)
  63. serializer = serializer(annotations, many=True)
  64. return serializer.data
  65. @classmethod
  66. def get_annotation_approver(cls, instance):
  67. approver = instance.annotations_approved_by
  68. return approver.username if approver else None
  69. class Meta:
  70. model = Document
  71. fields = ('id', 'text', 'annotations', 'meta', 'annotation_approver', 'comments')
  72. class ApproverSerializer(DocumentSerializer):
  73. class Meta:
  74. model = Document
  75. fields = ('id', 'annotation_approver')
  76. class ProjectSerializer(serializers.ModelSerializer):
  77. current_users_role = serializers.SerializerMethodField()
  78. def get_current_users_role(self, instance):
  79. role_abstractor = {
  80. "is_project_admin": settings.ROLE_PROJECT_ADMIN,
  81. "is_annotator": settings.ROLE_ANNOTATOR,
  82. "is_annotation_approver": settings.ROLE_ANNOTATION_APPROVER,
  83. }
  84. queryset = RoleMapping.objects.values("role_id__name")
  85. if queryset:
  86. users_role = get_object_or_404(
  87. queryset, project=instance.id, user=self.context.get("request").user.id
  88. )
  89. for key, val in role_abstractor.items():
  90. role_abstractor[key] = users_role["role_id__name"] == val
  91. return role_abstractor
  92. class Meta:
  93. model = Project
  94. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  95. 'updated_at', 'randomize_document_order', 'collaborative_annotation', 'single_class_classification')
  96. read_only_fields = ('updated_at', 'users', 'current_users_role')
  97. class TextClassificationProjectSerializer(ProjectSerializer):
  98. class Meta:
  99. model = TextClassificationProject
  100. fields = ProjectSerializer.Meta.fields
  101. read_only_fields = ProjectSerializer.Meta.read_only_fields
  102. class SequenceLabelingProjectSerializer(ProjectSerializer):
  103. class Meta:
  104. model = SequenceLabelingProject
  105. fields = ProjectSerializer.Meta.fields
  106. read_only_fields = ProjectSerializer.Meta.read_only_fields
  107. class Seq2seqProjectSerializer(ProjectSerializer):
  108. class Meta:
  109. model = Seq2seqProject
  110. fields = ProjectSerializer.Meta.fields
  111. read_only_fields = ProjectSerializer.Meta.read_only_fields
  112. class Speech2textProjectSerializer(ProjectSerializer):
  113. class Meta:
  114. model = Speech2textProject
  115. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  116. 'updated_at', 'randomize_document_order')
  117. read_only_fields = ('updated_at', 'users', 'current_users_role')
  118. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  119. model_serializer_mapping = {
  120. Project: ProjectSerializer,
  121. TextClassificationProject: TextClassificationProjectSerializer,
  122. SequenceLabelingProject: SequenceLabelingProjectSerializer,
  123. Seq2seqProject: Seq2seqProjectSerializer,
  124. Speech2textProject: Speech2textProjectSerializer,
  125. }
  126. class ProjectFilteredPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField):
  127. def get_queryset(self):
  128. view = self.context.get('view', None)
  129. request = self.context.get('request', None)
  130. queryset = super(ProjectFilteredPrimaryKeyRelatedField, self).get_queryset()
  131. if not request or not queryset or not view:
  132. return None
  133. return queryset.filter(project=view.kwargs['project_id'])
  134. class DocumentAnnotationSerializer(serializers.ModelSerializer):
  135. # label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
  136. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  137. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  138. class Meta:
  139. model = DocumentAnnotation
  140. fields = ('id', 'prob', 'label', 'user', 'document', 'created_at', 'updated_at')
  141. read_only_fields = ('user', )
  142. class SequenceAnnotationSerializer(serializers.ModelSerializer):
  143. #label = ProjectFilteredPrimaryKeyRelatedField(queryset=Label.objects.all())
  144. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  145. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  146. class Meta:
  147. model = SequenceAnnotation
  148. fields = ('id', 'prob', 'label', 'start_offset', 'end_offset', 'user', 'document', 'created_at', 'updated_at')
  149. read_only_fields = ('user',)
  150. class Seq2seqAnnotationSerializer(serializers.ModelSerializer):
  151. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  152. class Meta:
  153. model = Seq2seqAnnotation
  154. fields = ('id', 'text', 'user', 'document', 'prob', 'created_at', 'updated_at')
  155. read_only_fields = ('user',)
  156. class Speech2textAnnotationSerializer(serializers.ModelSerializer):
  157. document = serializers.PrimaryKeyRelatedField(queryset=Document.objects.all())
  158. class Meta:
  159. model = Speech2textAnnotation
  160. fields = ('id', 'prob', 'text', 'user', 'document', 'created_at', 'updated_at')
  161. read_only_fields = ('user',)
  162. class RoleSerializer(serializers.ModelSerializer):
  163. class Meta:
  164. model = Role
  165. fields = ('id', 'name')
  166. class RoleMappingSerializer(serializers.ModelSerializer):
  167. username = serializers.SerializerMethodField()
  168. rolename = serializers.SerializerMethodField()
  169. @classmethod
  170. def get_username(cls, instance):
  171. user = instance.user
  172. return user.username if user else None
  173. @classmethod
  174. def get_rolename(cls, instance):
  175. role = instance.role
  176. return role.name if role else None
  177. class Meta:
  178. model = RoleMapping
  179. fields = ('id', 'user', 'role', 'username', 'rolename')
  180. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  181. class Meta:
  182. model = AutoLabelingConfig
  183. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  184. read_only_fields = ('created_at', 'updated_at')
  185. def validate_model_name(self, value):
  186. try:
  187. RequestModelFactory.find(value)
  188. except NameError:
  189. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  190. return value
  191. def valid_label_mapping(self, value):
  192. if isinstance(value, dict):
  193. return value
  194. else:
  195. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  196. def validate(self, data):
  197. try:
  198. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  199. except Exception:
  200. raise serializers.ValidationError('The attributes does not match the model.')
  201. return data