You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

371 lines
12 KiB

6 years ago
5 years ago
5 years ago
3 years ago
3 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
4 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.conf import settings
  3. from django.contrib.auth import get_user_model
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import serializers
  6. from rest_framework.exceptions import ValidationError
  7. from rest_polymorphic.serializers import PolymorphicSerializer
  8. from .models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
  9. SEQUENCE_LABELING, SPEECH2TEXT, AnnotationRelations,
  10. AutoLabelingConfig, Category, Comment, Example,
  11. ExampleState, ImageClassificationProject, Label, Project,
  12. RelationTypes, Role, RoleMapping, Seq2seqProject,
  13. SequenceLabelingProject, Span, Speech2textProject, Tag,
  14. TextClassificationProject, TextLabel)
  15. class UserSerializer(serializers.ModelSerializer):
  16. class Meta:
  17. model = get_user_model()
  18. fields = ('id', 'username', 'is_superuser', 'is_staff')
  19. class LabelSerializer(serializers.ModelSerializer):
  20. def validate(self, attrs):
  21. prefix_key = attrs.get('prefix_key')
  22. suffix_key = attrs.get('suffix_key')
  23. # In the case of user don't set any shortcut key.
  24. if prefix_key is None and suffix_key is None:
  25. return super().validate(attrs)
  26. # Don't allow shortcut key not to have a suffix key.
  27. if prefix_key and not suffix_key:
  28. raise ValidationError('Shortcut key may not have a suffix key.')
  29. # Don't allow to save same shortcut key when prefix_key is null.
  30. try:
  31. context = self.context['request'].parser_context
  32. project_id = context['kwargs']['project_id']
  33. label_id = context['kwargs'].get('label_id')
  34. except (AttributeError, KeyError):
  35. pass # unit tests don't always have the correct context set up
  36. else:
  37. conflicting_labels = Label.objects.filter(
  38. suffix_key=suffix_key,
  39. prefix_key=prefix_key,
  40. project=project_id,
  41. )
  42. if label_id is not None:
  43. conflicting_labels = conflicting_labels.exclude(id=label_id)
  44. if conflicting_labels.exists():
  45. raise ValidationError('Duplicate shortcut key.')
  46. return super().validate(attrs)
  47. class Meta:
  48. model = Label
  49. fields = ('id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color')
  50. class CommentSerializer(serializers.ModelSerializer):
  51. class Meta:
  52. model = Comment
  53. fields = ('id', 'user', 'username', 'example', 'text', 'created_at', )
  54. read_only_fields = ('user', 'example')
  55. class TagSerializer(serializers.ModelSerializer):
  56. class Meta:
  57. model = Tag
  58. fields = ('id', 'project', 'text', )
  59. read_only_fields = ('id', 'project')
  60. class ExampleSerializer(serializers.ModelSerializer):
  61. annotations = serializers.SerializerMethodField()
  62. annotation_approver = serializers.SerializerMethodField()
  63. is_confirmed = serializers.SerializerMethodField()
  64. def get_annotations(self, instance):
  65. request = self.context.get('request')
  66. project = instance.project
  67. model = project.get_annotation_class()
  68. serializer = get_annotation_serializer(task=project.project_type)
  69. annotations = model.objects.filter(example=instance.id)
  70. if request and not project.collaborative_annotation:
  71. annotations = annotations.filter(user=request.user)
  72. serializer = serializer(annotations, many=True)
  73. return serializer.data
  74. @classmethod
  75. def get_annotation_approver(cls, instance):
  76. approver = instance.annotations_approved_by
  77. return approver.username if approver else None
  78. def get_is_confirmed(self, instance):
  79. user = self.context.get('request').user
  80. if instance.project.collaborative_annotation:
  81. current_user_role = RoleMapping.objects.get(user_id=user.id, project_id=instance.project.id).role
  82. state_ids = [state.id for state in instance.states.all() if state.confirmed_user_role == current_user_role]
  83. states = instance.states.filter(id__in=state_ids)
  84. else:
  85. states = instance.states.filter(confirmed_by_id=user.id)
  86. return states.count() > 0
  87. class Meta:
  88. model = Example
  89. fields = [
  90. 'id',
  91. 'filename',
  92. 'annotations',
  93. 'meta',
  94. 'annotation_approver',
  95. 'comment_count',
  96. 'text',
  97. 'is_confirmed'
  98. ]
  99. read_only_fields = ['filename', 'is_confirmed']
  100. class ExampleStateSerializer(serializers.ModelSerializer):
  101. class Meta:
  102. model = ExampleState
  103. fields = ('id', 'example', 'confirmed_by')
  104. read_only_fields = ('id', 'example', 'confirmed_by')
  105. class ApproverSerializer(ExampleSerializer):
  106. class Meta:
  107. model = Example
  108. fields = ('id', 'annotation_approver')
  109. class ProjectSerializer(serializers.ModelSerializer):
  110. current_users_role = serializers.SerializerMethodField()
  111. tags = TagSerializer(many=True, required=False)
  112. def get_current_users_role(self, instance):
  113. role_abstractor = {
  114. "is_project_admin": settings.ROLE_PROJECT_ADMIN,
  115. "is_annotator": settings.ROLE_ANNOTATOR,
  116. "is_annotation_approver": settings.ROLE_ANNOTATION_APPROVER,
  117. }
  118. queryset = RoleMapping.objects.values("role_id__name")
  119. if queryset:
  120. users_role = get_object_or_404(
  121. queryset, project=instance.id, user=self.context.get("request").user.id
  122. )
  123. for key, val in role_abstractor.items():
  124. role_abstractor[key] = users_role["role_id__name"] == val
  125. return role_abstractor
  126. class Meta:
  127. model = Project
  128. fields = (
  129. 'id',
  130. 'name',
  131. 'description',
  132. 'guideline',
  133. 'users',
  134. 'current_users_role',
  135. 'project_type',
  136. 'updated_at',
  137. 'random_order',
  138. 'collaborative_annotation',
  139. 'single_class_classification',
  140. 'tags'
  141. )
  142. read_only_fields = (
  143. 'updated_at',
  144. 'users',
  145. 'current_users_role',
  146. 'tags'
  147. )
  148. class TextClassificationProjectSerializer(ProjectSerializer):
  149. class Meta(ProjectSerializer.Meta):
  150. model = TextClassificationProject
  151. class SequenceLabelingProjectSerializer(ProjectSerializer):
  152. class Meta(ProjectSerializer.Meta):
  153. model = SequenceLabelingProject
  154. fields = ProjectSerializer.Meta.fields + ('allow_overlapping', 'grapheme_mode')
  155. class Seq2seqProjectSerializer(ProjectSerializer):
  156. class Meta(ProjectSerializer.Meta):
  157. model = Seq2seqProject
  158. class Speech2textProjectSerializer(ProjectSerializer):
  159. class Meta(ProjectSerializer.Meta):
  160. model = Speech2textProject
  161. class ImageClassificationProjectSerializer(ProjectSerializer):
  162. class Meta(ProjectSerializer.Meta):
  163. model = ImageClassificationProject
  164. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  165. model_serializer_mapping = {
  166. Project: ProjectSerializer,
  167. **{
  168. cls.Meta.model: cls for cls in ProjectSerializer.__subclasses__()
  169. }
  170. }
  171. class CategorySerializer(serializers.ModelSerializer):
  172. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  173. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  174. class Meta:
  175. model = Category
  176. fields = (
  177. 'id',
  178. 'prob',
  179. 'user',
  180. 'example',
  181. 'created_at',
  182. 'updated_at',
  183. 'label',
  184. )
  185. read_only_fields = ('user',)
  186. class SpanSerializer(serializers.ModelSerializer):
  187. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  188. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  189. class Meta:
  190. model = Span
  191. fields = (
  192. 'id',
  193. 'prob',
  194. 'user',
  195. 'example',
  196. 'created_at',
  197. 'updated_at',
  198. 'label',
  199. 'start_offset',
  200. 'end_offset',
  201. )
  202. read_only_fields = ('user',)
  203. class TextLabelSerializer(serializers.ModelSerializer):
  204. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  205. class Meta:
  206. model = TextLabel
  207. fields = (
  208. 'id',
  209. 'prob',
  210. 'user',
  211. 'example',
  212. 'created_at',
  213. 'updated_at',
  214. 'text',
  215. )
  216. read_only_fields = ('user',)
  217. class RoleSerializer(serializers.ModelSerializer):
  218. class Meta:
  219. model = Role
  220. fields = ('id', 'name')
  221. class RoleMappingSerializer(serializers.ModelSerializer):
  222. username = serializers.SerializerMethodField()
  223. rolename = serializers.SerializerMethodField()
  224. @classmethod
  225. def get_username(cls, instance):
  226. user = instance.user
  227. return user.username if user else None
  228. @classmethod
  229. def get_rolename(cls, instance):
  230. role = instance.role
  231. return role.name if role else None
  232. class Meta:
  233. model = RoleMapping
  234. fields = ('id', 'user', 'role', 'username', 'rolename')
  235. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  236. class Meta:
  237. model = AutoLabelingConfig
  238. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  239. read_only_fields = ('created_at', 'updated_at')
  240. def validate_model_name(self, value):
  241. try:
  242. RequestModelFactory.find(value)
  243. except NameError:
  244. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  245. return value
  246. def valid_label_mapping(self, value):
  247. if isinstance(value, dict):
  248. return value
  249. else:
  250. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  251. def validate(self, data):
  252. try:
  253. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  254. except Exception:
  255. model = RequestModelFactory.find(data['model_name'])
  256. schema = model.schema()
  257. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  258. raise serializers.ValidationError(
  259. 'The attributes does not match the model.'
  260. 'You need to correctly specify the required fields: {}'.format(required_fields)
  261. )
  262. return data
  263. def get_annotation_serializer(task: str):
  264. mapping = {
  265. DOCUMENT_CLASSIFICATION: CategorySerializer,
  266. SEQUENCE_LABELING: SpanSerializer,
  267. SEQ2SEQ: TextLabelSerializer,
  268. SPEECH2TEXT: TextLabelSerializer,
  269. IMAGE_CLASSIFICATION: CategorySerializer,
  270. }
  271. try:
  272. return mapping[task]
  273. except KeyError:
  274. raise ValueError(f'{task} is not implemented.')
  275. class RelationTypesSerializer(serializers.ModelSerializer):
  276. def validate(self, attrs):
  277. return super().validate(attrs)
  278. class Meta:
  279. model = RelationTypes
  280. fields = ('id', 'color', 'name')
  281. class AnnotationRelationsSerializer(serializers.ModelSerializer):
  282. def validate(self, attrs):
  283. return super().validate(attrs)
  284. class Meta:
  285. model = AnnotationRelations
  286. fields = ('id', 'annotation_id_1', 'annotation_id_2', 'type', 'user', 'timestamp')