You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

350 lines
11 KiB

6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
3 years ago
6 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.conf import settings
  3. from django.contrib.auth import get_user_model
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import serializers
  6. from rest_framework.exceptions import ValidationError
  7. from rest_polymorphic.serializers import PolymorphicSerializer
  8. from .models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
  9. SPEECH2TEXT, AutoLabelingConfig, Category, Comment,
  10. Document, Example, Image, ImageClassificationProject,
  11. Label, Project, Role, RoleMapping, Seq2seqProject,
  12. SequenceLabelingProject, Span, Speech2textProject, Tag,
  13. TextClassificationProject, TextLabel)
  14. class UserSerializer(serializers.ModelSerializer):
  15. class Meta:
  16. model = get_user_model()
  17. fields = ('id', 'username', 'is_superuser')
  18. class LabelSerializer(serializers.ModelSerializer):
  19. def validate(self, attrs):
  20. prefix_key = attrs.get('prefix_key')
  21. suffix_key = attrs.get('suffix_key')
  22. # In the case of user don't set any shortcut key.
  23. if prefix_key is None and suffix_key is None:
  24. return super().validate(attrs)
  25. # Don't allow shortcut key not to have a suffix key.
  26. if prefix_key and not suffix_key:
  27. raise ValidationError('Shortcut key may not have a suffix key.')
  28. # Don't allow to save same shortcut key when prefix_key is null.
  29. try:
  30. context = self.context['request'].parser_context
  31. project_id = context['kwargs']['project_id']
  32. label_id = context['kwargs'].get('label_id')
  33. except (AttributeError, KeyError):
  34. pass # unit tests don't always have the correct context set up
  35. else:
  36. conflicting_labels = Label.objects.filter(
  37. suffix_key=suffix_key,
  38. prefix_key=prefix_key,
  39. project=project_id,
  40. )
  41. if label_id is not None:
  42. conflicting_labels = conflicting_labels.exclude(id=label_id)
  43. if conflicting_labels.exists():
  44. raise ValidationError('Duplicate shortcut key.')
  45. return super().validate(attrs)
  46. class Meta:
  47. model = Label
  48. fields = ('id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color')
  49. class CommentSerializer(serializers.ModelSerializer):
  50. class Meta:
  51. model = Comment
  52. fields = ('id', 'user', 'username', 'example', 'text', 'created_at', )
  53. read_only_fields = ('user', 'example')
  54. class TagSerializer(serializers.ModelSerializer):
  55. class Meta:
  56. model = Tag
  57. fields = ('id', 'project', 'text', )
  58. read_only_fields = ('id', 'project')
  59. class BaseDataSerializer(serializers.ModelSerializer):
  60. annotations = serializers.SerializerMethodField()
  61. annotation_approver = serializers.SerializerMethodField()
  62. def get_annotations(self, instance):
  63. request = self.context.get('request')
  64. project = instance.project
  65. model = project.get_annotation_class()
  66. serializer = get_annotation_serializer(task=project.project_type)
  67. annotations = model.objects.filter(example=instance.id)
  68. if request and not project.collaborative_annotation:
  69. annotations = annotations.filter(user=request.user)
  70. serializer = serializer(annotations, many=True)
  71. return serializer.data
  72. @classmethod
  73. def get_annotation_approver(cls, instance):
  74. approver = instance.annotations_approved_by
  75. return approver.username if approver else None
  76. class Meta:
  77. model = Example
  78. fields = ['id', 'filename', 'annotations', 'meta', 'annotation_approver', 'comment_count']
  79. read_only_fields = ['filename']
  80. class DocumentSerializer(BaseDataSerializer):
  81. class Meta:
  82. model = Document
  83. fields = BaseDataSerializer.Meta.fields + ['text']
  84. class ImageSerializer(BaseDataSerializer):
  85. class Meta:
  86. model = Image
  87. fields = BaseDataSerializer.Meta.fields
  88. class ExampleSerializer(PolymorphicSerializer):
  89. model_serializer_mapping = {
  90. Example: BaseDataSerializer,
  91. Document: DocumentSerializer,
  92. Image: ImageSerializer
  93. }
  94. class ApproverSerializer(DocumentSerializer):
  95. class Meta:
  96. model = Document
  97. fields = ('id', 'annotation_approver')
  98. class ProjectSerializer(serializers.ModelSerializer):
  99. current_users_role = serializers.SerializerMethodField()
  100. tags = TagSerializer(many=True, required=False)
  101. def get_current_users_role(self, instance):
  102. role_abstractor = {
  103. "is_project_admin": settings.ROLE_PROJECT_ADMIN,
  104. "is_annotator": settings.ROLE_ANNOTATOR,
  105. "is_annotation_approver": settings.ROLE_ANNOTATION_APPROVER,
  106. }
  107. queryset = RoleMapping.objects.values("role_id__name")
  108. if queryset:
  109. users_role = get_object_or_404(
  110. queryset, project=instance.id, user=self.context.get("request").user.id
  111. )
  112. for key, val in role_abstractor.items():
  113. role_abstractor[key] = users_role["role_id__name"] == val
  114. return role_abstractor
  115. class Meta:
  116. model = Project
  117. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  118. 'updated_at', 'random_order', 'collaborative_annotation', 'single_class_classification',
  119. 'tags')
  120. read_only_fields = ('updated_at', 'users', 'current_users_role', 'tags')
  121. class TextClassificationProjectSerializer(ProjectSerializer):
  122. class Meta:
  123. model = TextClassificationProject
  124. fields = ProjectSerializer.Meta.fields
  125. read_only_fields = ProjectSerializer.Meta.read_only_fields
  126. class SequenceLabelingProjectSerializer(ProjectSerializer):
  127. class Meta:
  128. model = SequenceLabelingProject
  129. fields = ProjectSerializer.Meta.fields
  130. read_only_fields = ProjectSerializer.Meta.read_only_fields
  131. class Seq2seqProjectSerializer(ProjectSerializer):
  132. class Meta:
  133. model = Seq2seqProject
  134. fields = ProjectSerializer.Meta.fields
  135. read_only_fields = ProjectSerializer.Meta.read_only_fields
  136. class Speech2textProjectSerializer(ProjectSerializer):
  137. class Meta:
  138. model = Speech2textProject
  139. fields = ('id', 'name', 'description', 'guideline', 'users', 'current_users_role', 'project_type',
  140. 'updated_at', 'random_order')
  141. read_only_fields = ('updated_at', 'users', 'current_users_role')
  142. class ImageClassificationProjectSerializer(ProjectSerializer):
  143. class Meta:
  144. model = ImageClassificationProject
  145. fields = ProjectSerializer.Meta.fields
  146. read_only_fields = ProjectSerializer.Meta.read_only_fields
  147. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  148. model_serializer_mapping = {
  149. Project: ProjectSerializer,
  150. TextClassificationProject: TextClassificationProjectSerializer,
  151. SequenceLabelingProject: SequenceLabelingProjectSerializer,
  152. Seq2seqProject: Seq2seqProjectSerializer,
  153. Speech2textProject: Speech2textProjectSerializer,
  154. ImageClassificationProject: ImageClassificationProjectSerializer
  155. }
  156. class ProjectFilteredPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField):
  157. def get_queryset(self):
  158. view = self.context.get('view', None)
  159. request = self.context.get('request', None)
  160. queryset = super(ProjectFilteredPrimaryKeyRelatedField, self).get_queryset()
  161. if not request or not queryset or not view:
  162. return None
  163. return queryset.filter(project=view.kwargs['project_id'])
  164. class CategorySerializer(serializers.ModelSerializer):
  165. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  166. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  167. class Meta:
  168. model = Category
  169. fields = (
  170. 'id',
  171. 'prob',
  172. 'user',
  173. 'example',
  174. 'created_at',
  175. 'updated_at',
  176. 'label',
  177. )
  178. read_only_fields = ('user',)
  179. class SpanSerializer(serializers.ModelSerializer):
  180. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  181. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  182. class Meta:
  183. model = Span
  184. fields = (
  185. 'id',
  186. 'prob',
  187. 'user',
  188. 'example',
  189. 'created_at',
  190. 'updated_at',
  191. 'label',
  192. 'start_offset',
  193. 'end_offset',
  194. )
  195. read_only_fields = ('user',)
  196. class TextLabelSerializer(serializers.ModelSerializer):
  197. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  198. class Meta:
  199. model = TextLabel
  200. fields = (
  201. 'id',
  202. 'prob',
  203. 'user',
  204. 'example',
  205. 'created_at',
  206. 'updated_at',
  207. 'text',
  208. )
  209. read_only_fields = ('user',)
  210. class RoleSerializer(serializers.ModelSerializer):
  211. class Meta:
  212. model = Role
  213. fields = ('id', 'name')
  214. class RoleMappingSerializer(serializers.ModelSerializer):
  215. username = serializers.SerializerMethodField()
  216. rolename = serializers.SerializerMethodField()
  217. @classmethod
  218. def get_username(cls, instance):
  219. user = instance.user
  220. return user.username if user else None
  221. @classmethod
  222. def get_rolename(cls, instance):
  223. role = instance.role
  224. return role.name if role else None
  225. class Meta:
  226. model = RoleMapping
  227. fields = ('id', 'user', 'role', 'username', 'rolename')
  228. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  229. class Meta:
  230. model = AutoLabelingConfig
  231. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  232. read_only_fields = ('created_at', 'updated_at')
  233. def validate_model_name(self, value):
  234. try:
  235. RequestModelFactory.find(value)
  236. except NameError:
  237. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  238. return value
  239. def valid_label_mapping(self, value):
  240. if isinstance(value, dict):
  241. return value
  242. else:
  243. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  244. def validate(self, data):
  245. try:
  246. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  247. except Exception:
  248. model = RequestModelFactory.find(data['model_name'])
  249. schema = model.schema()
  250. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  251. raise serializers.ValidationError(
  252. 'The attributes does not match the model.'
  253. 'You need to correctly specify the required fields: {}'.format(required_fields)
  254. )
  255. return data
  256. def get_annotation_serializer(task: str):
  257. mapping = {
  258. DOCUMENT_CLASSIFICATION: CategorySerializer,
  259. SEQUENCE_LABELING: SpanSerializer,
  260. SEQ2SEQ: TextLabelSerializer,
  261. SPEECH2TEXT: TextLabelSerializer
  262. }
  263. try:
  264. return mapping[task]
  265. except KeyError:
  266. raise ValueError(f'{task} is not implemented.')