You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

370 lines
11 KiB

6 years ago
5 years ago
5 years ago
3 years ago
3 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
4 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.conf import settings
  3. from django.contrib.auth import get_user_model
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import serializers
  6. from rest_framework.exceptions import ValidationError
  7. from rest_polymorphic.serializers import PolymorphicSerializer
  8. from .models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
  9. SEQUENCE_LABELING, SPEECH2TEXT, AnnotationRelations,
  10. AutoLabelingConfig, Category, Comment, Example,
  11. ExampleState, ImageClassificationProject, Label, Project,
  12. RelationTypes, Role, RoleMapping, Seq2seqProject,
  13. SequenceLabelingProject, Span, Speech2textProject, Tag,
  14. TextClassificationProject, TextLabel)
  15. class UserSerializer(serializers.ModelSerializer):
  16. class Meta:
  17. model = get_user_model()
  18. fields = ('id', 'username', 'is_superuser', 'is_staff')
  19. class LabelSerializer(serializers.ModelSerializer):
  20. def validate(self, attrs):
  21. prefix_key = attrs.get('prefix_key')
  22. suffix_key = attrs.get('suffix_key')
  23. # In the case of user don't set any shortcut key.
  24. if prefix_key is None and suffix_key is None:
  25. return super().validate(attrs)
  26. # Don't allow shortcut key not to have a suffix key.
  27. if prefix_key and not suffix_key:
  28. raise ValidationError('Shortcut key may not have a suffix key.')
  29. # Don't allow to save same shortcut key when prefix_key is null.
  30. try:
  31. context = self.context['request'].parser_context
  32. project_id = context['kwargs']['project_id']
  33. label_id = context['kwargs'].get('label_id')
  34. except (AttributeError, KeyError):
  35. pass # unit tests don't always have the correct context set up
  36. else:
  37. conflicting_labels = Label.objects.filter(
  38. suffix_key=suffix_key,
  39. prefix_key=prefix_key,
  40. project=project_id,
  41. )
  42. if label_id is not None:
  43. conflicting_labels = conflicting_labels.exclude(id=label_id)
  44. if conflicting_labels.exists():
  45. raise ValidationError('Duplicate shortcut key.')
  46. return super().validate(attrs)
  47. class Meta:
  48. model = Label
  49. fields = ('id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color')
  50. class CommentSerializer(serializers.ModelSerializer):
  51. class Meta:
  52. model = Comment
  53. fields = ('id', 'user', 'username', 'example', 'text', 'created_at', )
  54. read_only_fields = ('user', 'example')
  55. class TagSerializer(serializers.ModelSerializer):
  56. class Meta:
  57. model = Tag
  58. fields = ('id', 'project', 'text', )
  59. read_only_fields = ('id', 'project')
  60. class ExampleSerializer(serializers.ModelSerializer):
  61. annotations = serializers.SerializerMethodField()
  62. annotation_approver = serializers.SerializerMethodField()
  63. is_confirmed = serializers.SerializerMethodField()
  64. def get_annotations(self, instance):
  65. request = self.context.get('request')
  66. project = instance.project
  67. model = project.get_annotation_class()
  68. serializer = get_annotation_serializer(task=project.project_type)
  69. annotations = model.objects.filter(example=instance.id)
  70. if request and not project.collaborative_annotation:
  71. annotations = annotations.filter(user=request.user)
  72. serializer = serializer(annotations, many=True)
  73. return serializer.data
  74. @classmethod
  75. def get_annotation_approver(cls, instance):
  76. approver = instance.annotations_approved_by
  77. return approver.username if approver else None
  78. def get_is_confirmed(self, instance):
  79. user = self.context.get('request').user
  80. if instance.project.collaborative_annotation:
  81. current_user_role = RoleMapping.objects.get(user_id=user.id, project_id=instance.project.id).role
  82. state_ids = [state.id for state in instance.states.all() if state.confirmed_user_role == current_user_role]
  83. states = instance.states.filter(id__in=state_ids)
  84. else:
  85. states = instance.states.filter(confirmed_by_id=user.id)
  86. return states.count() > 0
  87. class Meta:
  88. model = Example
  89. fields = [
  90. 'id',
  91. 'filename',
  92. 'annotations',
  93. 'meta',
  94. 'annotation_approver',
  95. 'comment_count',
  96. 'text',
  97. 'is_confirmed'
  98. ]
  99. read_only_fields = ['filename', 'is_confirmed']
  100. class ExampleStateSerializer(serializers.ModelSerializer):
  101. class Meta:
  102. model = ExampleState
  103. fields = ('id', 'example', 'confirmed_by')
  104. read_only_fields = ('id', 'example', 'confirmed_by')
  105. class ApproverSerializer(ExampleSerializer):
  106. class Meta:
  107. model = Example
  108. fields = ('id', 'annotation_approver')
  109. class ProjectSerializer(serializers.ModelSerializer):
  110. current_users_role = serializers.SerializerMethodField()
  111. tags = TagSerializer(many=True, required=False)
  112. def get_current_users_role(self, instance):
  113. role_abstractor = {
  114. "is_project_admin": settings.ROLE_PROJECT_ADMIN,
  115. "is_annotator": settings.ROLE_ANNOTATOR,
  116. "is_annotation_approver": settings.ROLE_ANNOTATION_APPROVER,
  117. }
  118. queryset = RoleMapping.objects.values("role_id__name")
  119. if queryset:
  120. users_role = get_object_or_404(
  121. queryset, project=instance.id, user=self.context.get("request").user.id
  122. )
  123. for key, val in role_abstractor.items():
  124. role_abstractor[key] = users_role["role_id__name"] == val
  125. return role_abstractor
  126. class Meta:
  127. model = Project
  128. fields = (
  129. 'id',
  130. 'name',
  131. 'description',
  132. 'guideline',
  133. 'users',
  134. 'current_users_role',
  135. 'project_type',
  136. 'updated_at',
  137. 'random_order',
  138. 'collaborative_annotation',
  139. 'single_class_classification',
  140. 'tags'
  141. )
  142. read_only_fields = (
  143. 'updated_at',
  144. 'users',
  145. 'current_users_role',
  146. 'tags'
  147. )
  148. class TextClassificationProjectSerializer(ProjectSerializer):
  149. class Meta(ProjectSerializer.Meta):
  150. model = TextClassificationProject
  151. class SequenceLabelingProjectSerializer(ProjectSerializer):
  152. class Meta(ProjectSerializer.Meta):
  153. model = SequenceLabelingProject
  154. class Seq2seqProjectSerializer(ProjectSerializer):
  155. class Meta(ProjectSerializer.Meta):
  156. model = Seq2seqProject
  157. class Speech2textProjectSerializer(ProjectSerializer):
  158. class Meta(ProjectSerializer.Meta):
  159. model = Speech2textProject
  160. class ImageClassificationProjectSerializer(ProjectSerializer):
  161. class Meta(ProjectSerializer.Meta):
  162. model = ImageClassificationProject
  163. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  164. model_serializer_mapping = {
  165. Project: ProjectSerializer,
  166. **{
  167. cls.Meta.model: cls for cls in ProjectSerializer.__subclasses__()
  168. }
  169. }
  170. class CategorySerializer(serializers.ModelSerializer):
  171. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  172. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  173. class Meta:
  174. model = Category
  175. fields = (
  176. 'id',
  177. 'prob',
  178. 'user',
  179. 'example',
  180. 'created_at',
  181. 'updated_at',
  182. 'label',
  183. )
  184. read_only_fields = ('user',)
  185. class SpanSerializer(serializers.ModelSerializer):
  186. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  187. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  188. class Meta:
  189. model = Span
  190. fields = (
  191. 'id',
  192. 'prob',
  193. 'user',
  194. 'example',
  195. 'created_at',
  196. 'updated_at',
  197. 'label',
  198. 'start_offset',
  199. 'end_offset',
  200. )
  201. read_only_fields = ('user',)
  202. class TextLabelSerializer(serializers.ModelSerializer):
  203. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  204. class Meta:
  205. model = TextLabel
  206. fields = (
  207. 'id',
  208. 'prob',
  209. 'user',
  210. 'example',
  211. 'created_at',
  212. 'updated_at',
  213. 'text',
  214. )
  215. read_only_fields = ('user',)
  216. class RoleSerializer(serializers.ModelSerializer):
  217. class Meta:
  218. model = Role
  219. fields = ('id', 'name')
  220. class RoleMappingSerializer(serializers.ModelSerializer):
  221. username = serializers.SerializerMethodField()
  222. rolename = serializers.SerializerMethodField()
  223. @classmethod
  224. def get_username(cls, instance):
  225. user = instance.user
  226. return user.username if user else None
  227. @classmethod
  228. def get_rolename(cls, instance):
  229. role = instance.role
  230. return role.name if role else None
  231. class Meta:
  232. model = RoleMapping
  233. fields = ('id', 'user', 'role', 'username', 'rolename')
  234. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  235. class Meta:
  236. model = AutoLabelingConfig
  237. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  238. read_only_fields = ('created_at', 'updated_at')
  239. def validate_model_name(self, value):
  240. try:
  241. RequestModelFactory.find(value)
  242. except NameError:
  243. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  244. return value
  245. def valid_label_mapping(self, value):
  246. if isinstance(value, dict):
  247. return value
  248. else:
  249. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  250. def validate(self, data):
  251. try:
  252. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  253. except Exception:
  254. model = RequestModelFactory.find(data['model_name'])
  255. schema = model.schema()
  256. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  257. raise serializers.ValidationError(
  258. 'The attributes does not match the model.'
  259. 'You need to correctly specify the required fields: {}'.format(required_fields)
  260. )
  261. return data
  262. def get_annotation_serializer(task: str):
  263. mapping = {
  264. DOCUMENT_CLASSIFICATION: CategorySerializer,
  265. SEQUENCE_LABELING: SpanSerializer,
  266. SEQ2SEQ: TextLabelSerializer,
  267. SPEECH2TEXT: TextLabelSerializer,
  268. IMAGE_CLASSIFICATION: CategorySerializer,
  269. }
  270. try:
  271. return mapping[task]
  272. except KeyError:
  273. raise ValueError(f'{task} is not implemented.')
  274. class RelationTypesSerializer(serializers.ModelSerializer):
  275. def validate(self, attrs):
  276. return super().validate(attrs)
  277. class Meta:
  278. model = RelationTypes
  279. fields = ('id', 'color', 'name')
  280. class AnnotationRelationsSerializer(serializers.ModelSerializer):
  281. def validate(self, attrs):
  282. return super().validate(attrs)
  283. class Meta:
  284. model = AnnotationRelations
  285. fields = ('id', 'annotation_id_1', 'annotation_id_2', 'type', 'user', 'timestamp')