You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

350 lines
10 KiB

6 years ago
5 years ago
5 years ago
3 years ago
3 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
3 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.conf import settings
  3. from django.contrib.auth import get_user_model
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import serializers
  6. from rest_framework.exceptions import ValidationError
  7. from rest_polymorphic.serializers import PolymorphicSerializer
  8. from .models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
  9. SEQUENCE_LABELING, SPEECH2TEXT, AnnotationRelations,
  10. AutoLabelingConfig, Category, Comment, Example,
  11. ImageClassificationProject, Label, Project, RelationTypes,
  12. Role, RoleMapping, Seq2seqProject,
  13. SequenceLabelingProject, Span, Speech2textProject, Tag,
  14. TextClassificationProject, TextLabel)
  15. class UserSerializer(serializers.ModelSerializer):
  16. class Meta:
  17. model = get_user_model()
  18. fields = ('id', 'username', 'is_superuser')
  19. class LabelSerializer(serializers.ModelSerializer):
  20. def validate(self, attrs):
  21. prefix_key = attrs.get('prefix_key')
  22. suffix_key = attrs.get('suffix_key')
  23. # In the case of user don't set any shortcut key.
  24. if prefix_key is None and suffix_key is None:
  25. return super().validate(attrs)
  26. # Don't allow shortcut key not to have a suffix key.
  27. if prefix_key and not suffix_key:
  28. raise ValidationError('Shortcut key may not have a suffix key.')
  29. # Don't allow to save same shortcut key when prefix_key is null.
  30. try:
  31. context = self.context['request'].parser_context
  32. project_id = context['kwargs']['project_id']
  33. label_id = context['kwargs'].get('label_id')
  34. except (AttributeError, KeyError):
  35. pass # unit tests don't always have the correct context set up
  36. else:
  37. conflicting_labels = Label.objects.filter(
  38. suffix_key=suffix_key,
  39. prefix_key=prefix_key,
  40. project=project_id,
  41. )
  42. if label_id is not None:
  43. conflicting_labels = conflicting_labels.exclude(id=label_id)
  44. if conflicting_labels.exists():
  45. raise ValidationError('Duplicate shortcut key.')
  46. return super().validate(attrs)
  47. class Meta:
  48. model = Label
  49. fields = ('id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color')
  50. class CommentSerializer(serializers.ModelSerializer):
  51. class Meta:
  52. model = Comment
  53. fields = ('id', 'user', 'username', 'example', 'text', 'created_at', )
  54. read_only_fields = ('user', 'example')
  55. class TagSerializer(serializers.ModelSerializer):
  56. class Meta:
  57. model = Tag
  58. fields = ('id', 'project', 'text', )
  59. read_only_fields = ('id', 'project')
  60. class ExampleSerializer(serializers.ModelSerializer):
  61. annotations = serializers.SerializerMethodField()
  62. annotation_approver = serializers.SerializerMethodField()
  63. def get_annotations(self, instance):
  64. request = self.context.get('request')
  65. project = instance.project
  66. model = project.get_annotation_class()
  67. serializer = get_annotation_serializer(task=project.project_type)
  68. annotations = model.objects.filter(example=instance.id)
  69. if request and not project.collaborative_annotation:
  70. annotations = annotations.filter(user=request.user)
  71. serializer = serializer(annotations, many=True)
  72. return serializer.data
  73. @classmethod
  74. def get_annotation_approver(cls, instance):
  75. approver = instance.annotations_approved_by
  76. return approver.username if approver else None
  77. class Meta:
  78. model = Example
  79. fields = [
  80. 'id',
  81. 'filename',
  82. 'annotations',
  83. 'meta',
  84. 'annotation_approver',
  85. 'comment_count',
  86. 'text'
  87. ]
  88. read_only_fields = ['filename']
  89. class ApproverSerializer(ExampleSerializer):
  90. class Meta:
  91. model = Example
  92. fields = ('id', 'annotation_approver')
  93. class ProjectSerializer(serializers.ModelSerializer):
  94. current_users_role = serializers.SerializerMethodField()
  95. tags = TagSerializer(many=True, required=False)
  96. def get_current_users_role(self, instance):
  97. role_abstractor = {
  98. "is_project_admin": settings.ROLE_PROJECT_ADMIN,
  99. "is_annotator": settings.ROLE_ANNOTATOR,
  100. "is_annotation_approver": settings.ROLE_ANNOTATION_APPROVER,
  101. }
  102. queryset = RoleMapping.objects.values("role_id__name")
  103. if queryset:
  104. users_role = get_object_or_404(
  105. queryset, project=instance.id, user=self.context.get("request").user.id
  106. )
  107. for key, val in role_abstractor.items():
  108. role_abstractor[key] = users_role["role_id__name"] == val
  109. return role_abstractor
  110. class Meta:
  111. model = Project
  112. fields = (
  113. 'id',
  114. 'name',
  115. 'description',
  116. 'guideline',
  117. 'users',
  118. 'current_users_role',
  119. 'project_type',
  120. 'updated_at',
  121. 'random_order',
  122. 'collaborative_annotation',
  123. 'single_class_classification',
  124. 'tags'
  125. )
  126. read_only_fields = (
  127. 'updated_at',
  128. 'users',
  129. 'current_users_role',
  130. 'tags'
  131. )
  132. class TextClassificationProjectSerializer(ProjectSerializer):
  133. class Meta(ProjectSerializer.Meta):
  134. model = TextClassificationProject
  135. class SequenceLabelingProjectSerializer(ProjectSerializer):
  136. class Meta(ProjectSerializer.Meta):
  137. model = SequenceLabelingProject
  138. class Seq2seqProjectSerializer(ProjectSerializer):
  139. class Meta(ProjectSerializer.Meta):
  140. model = Seq2seqProject
  141. class Speech2textProjectSerializer(ProjectSerializer):
  142. class Meta(ProjectSerializer.Meta):
  143. model = Speech2textProject
  144. class ImageClassificationProjectSerializer(ProjectSerializer):
  145. class Meta(ProjectSerializer.Meta):
  146. model = ImageClassificationProject
  147. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  148. model_serializer_mapping = {
  149. Project: ProjectSerializer,
  150. **{
  151. cls.Meta.model: cls for cls in ProjectSerializer.__subclasses__()
  152. }
  153. }
  154. class CategorySerializer(serializers.ModelSerializer):
  155. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  156. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  157. class Meta:
  158. model = Category
  159. fields = (
  160. 'id',
  161. 'prob',
  162. 'user',
  163. 'example',
  164. 'created_at',
  165. 'updated_at',
  166. 'label',
  167. )
  168. read_only_fields = ('user',)
  169. class SpanSerializer(serializers.ModelSerializer):
  170. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  171. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  172. class Meta:
  173. model = Span
  174. fields = (
  175. 'id',
  176. 'prob',
  177. 'user',
  178. 'example',
  179. 'created_at',
  180. 'updated_at',
  181. 'label',
  182. 'start_offset',
  183. 'end_offset',
  184. )
  185. read_only_fields = ('user',)
  186. class TextLabelSerializer(serializers.ModelSerializer):
  187. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  188. class Meta:
  189. model = TextLabel
  190. fields = (
  191. 'id',
  192. 'prob',
  193. 'user',
  194. 'example',
  195. 'created_at',
  196. 'updated_at',
  197. 'text',
  198. )
  199. read_only_fields = ('user',)
  200. class RoleSerializer(serializers.ModelSerializer):
  201. class Meta:
  202. model = Role
  203. fields = ('id', 'name')
  204. class RoleMappingSerializer(serializers.ModelSerializer):
  205. username = serializers.SerializerMethodField()
  206. rolename = serializers.SerializerMethodField()
  207. @classmethod
  208. def get_username(cls, instance):
  209. user = instance.user
  210. return user.username if user else None
  211. @classmethod
  212. def get_rolename(cls, instance):
  213. role = instance.role
  214. return role.name if role else None
  215. class Meta:
  216. model = RoleMapping
  217. fields = ('id', 'user', 'role', 'username', 'rolename')
  218. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  219. class Meta:
  220. model = AutoLabelingConfig
  221. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  222. read_only_fields = ('created_at', 'updated_at')
  223. def validate_model_name(self, value):
  224. try:
  225. RequestModelFactory.find(value)
  226. except NameError:
  227. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  228. return value
  229. def valid_label_mapping(self, value):
  230. if isinstance(value, dict):
  231. return value
  232. else:
  233. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  234. def validate(self, data):
  235. try:
  236. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  237. except Exception:
  238. model = RequestModelFactory.find(data['model_name'])
  239. schema = model.schema()
  240. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  241. raise serializers.ValidationError(
  242. 'The attributes does not match the model.'
  243. 'You need to correctly specify the required fields: {}'.format(required_fields)
  244. )
  245. return data
  246. def get_annotation_serializer(task: str):
  247. mapping = {
  248. DOCUMENT_CLASSIFICATION: CategorySerializer,
  249. SEQUENCE_LABELING: SpanSerializer,
  250. SEQ2SEQ: TextLabelSerializer,
  251. SPEECH2TEXT: TextLabelSerializer,
  252. IMAGE_CLASSIFICATION: CategorySerializer,
  253. }
  254. try:
  255. return mapping[task]
  256. except KeyError:
  257. raise ValueError(f'{task} is not implemented.')
  258. class RelationTypesSerializer(serializers.ModelSerializer):
  259. def validate(self, attrs):
  260. return super().validate(attrs)
  261. class Meta:
  262. model = RelationTypes
  263. fields = ('id', 'color', 'name')
  264. class AnnotationRelationsSerializer(serializers.ModelSerializer):
  265. def validate(self, attrs):
  266. return super().validate(attrs)
  267. class Meta:
  268. model = AnnotationRelations
  269. fields = ('id', 'annotation_id_1', 'annotation_id_2', 'type', 'user', 'timestamp')