You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

329 lines
10 KiB

6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
3 years ago
6 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.conf import settings
  3. from django.contrib.auth import get_user_model
  4. from django.shortcuts import get_object_or_404
  5. from rest_framework import serializers
  6. from rest_framework.exceptions import ValidationError
  7. from rest_polymorphic.serializers import PolymorphicSerializer
  8. from .models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
  9. SEQUENCE_LABELING, SPEECH2TEXT, AutoLabelingConfig,
  10. Category, Comment, Example, ImageClassificationProject,
  11. Label, Project, Role, RoleMapping, Seq2seqProject,
  12. SequenceLabelingProject, Span, Speech2textProject, Tag,
  13. TextClassificationProject, TextLabel)
  14. class UserSerializer(serializers.ModelSerializer):
  15. class Meta:
  16. model = get_user_model()
  17. fields = ('id', 'username', 'is_superuser')
  18. class LabelSerializer(serializers.ModelSerializer):
  19. def validate(self, attrs):
  20. prefix_key = attrs.get('prefix_key')
  21. suffix_key = attrs.get('suffix_key')
  22. # In the case of user don't set any shortcut key.
  23. if prefix_key is None and suffix_key is None:
  24. return super().validate(attrs)
  25. # Don't allow shortcut key not to have a suffix key.
  26. if prefix_key and not suffix_key:
  27. raise ValidationError('Shortcut key may not have a suffix key.')
  28. # Don't allow to save same shortcut key when prefix_key is null.
  29. try:
  30. context = self.context['request'].parser_context
  31. project_id = context['kwargs']['project_id']
  32. label_id = context['kwargs'].get('label_id')
  33. except (AttributeError, KeyError):
  34. pass # unit tests don't always have the correct context set up
  35. else:
  36. conflicting_labels = Label.objects.filter(
  37. suffix_key=suffix_key,
  38. prefix_key=prefix_key,
  39. project=project_id,
  40. )
  41. if label_id is not None:
  42. conflicting_labels = conflicting_labels.exclude(id=label_id)
  43. if conflicting_labels.exists():
  44. raise ValidationError('Duplicate shortcut key.')
  45. return super().validate(attrs)
  46. class Meta:
  47. model = Label
  48. fields = ('id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color')
  49. class CommentSerializer(serializers.ModelSerializer):
  50. class Meta:
  51. model = Comment
  52. fields = ('id', 'user', 'username', 'example', 'text', 'created_at', )
  53. read_only_fields = ('user', 'example')
  54. class TagSerializer(serializers.ModelSerializer):
  55. class Meta:
  56. model = Tag
  57. fields = ('id', 'project', 'text', )
  58. read_only_fields = ('id', 'project')
  59. class ExampleSerializer(serializers.ModelSerializer):
  60. annotations = serializers.SerializerMethodField()
  61. annotation_approver = serializers.SerializerMethodField()
  62. def get_annotations(self, instance):
  63. request = self.context.get('request')
  64. project = instance.project
  65. model = project.get_annotation_class()
  66. serializer = get_annotation_serializer(task=project.project_type)
  67. annotations = model.objects.filter(example=instance.id)
  68. if request and not project.collaborative_annotation:
  69. annotations = annotations.filter(user=request.user)
  70. serializer = serializer(annotations, many=True)
  71. return serializer.data
  72. @classmethod
  73. def get_annotation_approver(cls, instance):
  74. approver = instance.annotations_approved_by
  75. return approver.username if approver else None
  76. class Meta:
  77. model = Example
  78. fields = [
  79. 'id',
  80. 'filename',
  81. 'annotations',
  82. 'meta',
  83. 'annotation_approver',
  84. 'comment_count',
  85. 'text'
  86. ]
  87. read_only_fields = ['filename']
  88. class ApproverSerializer(ExampleSerializer):
  89. class Meta:
  90. model = Example
  91. fields = ('id', 'annotation_approver')
  92. class ProjectSerializer(serializers.ModelSerializer):
  93. current_users_role = serializers.SerializerMethodField()
  94. tags = TagSerializer(many=True, required=False)
  95. def get_current_users_role(self, instance):
  96. role_abstractor = {
  97. "is_project_admin": settings.ROLE_PROJECT_ADMIN,
  98. "is_annotator": settings.ROLE_ANNOTATOR,
  99. "is_annotation_approver": settings.ROLE_ANNOTATION_APPROVER,
  100. }
  101. queryset = RoleMapping.objects.values("role_id__name")
  102. if queryset:
  103. users_role = get_object_or_404(
  104. queryset, project=instance.id, user=self.context.get("request").user.id
  105. )
  106. for key, val in role_abstractor.items():
  107. role_abstractor[key] = users_role["role_id__name"] == val
  108. return role_abstractor
  109. class Meta:
  110. model = Project
  111. fields = (
  112. 'id',
  113. 'name',
  114. 'description',
  115. 'guideline',
  116. 'users',
  117. 'current_users_role',
  118. 'project_type',
  119. 'updated_at',
  120. 'random_order',
  121. 'collaborative_annotation',
  122. 'single_class_classification',
  123. 'tags'
  124. )
  125. read_only_fields = (
  126. 'updated_at',
  127. 'users',
  128. 'current_users_role',
  129. 'tags'
  130. )
  131. class TextClassificationProjectSerializer(ProjectSerializer):
  132. class Meta(ProjectSerializer.Meta):
  133. model = TextClassificationProject
  134. class SequenceLabelingProjectSerializer(ProjectSerializer):
  135. class Meta(ProjectSerializer.Meta):
  136. model = SequenceLabelingProject
  137. class Seq2seqProjectSerializer(ProjectSerializer):
  138. class Meta(ProjectSerializer.Meta):
  139. model = Seq2seqProject
  140. class Speech2textProjectSerializer(ProjectSerializer):
  141. class Meta(ProjectSerializer.Meta):
  142. model = Speech2textProject
  143. class ImageClassificationProjectSerializer(ProjectSerializer):
  144. class Meta(ProjectSerializer.Meta):
  145. model = ImageClassificationProject
  146. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  147. model_serializer_mapping = {
  148. Project: ProjectSerializer,
  149. **{
  150. cls.Meta.model: cls for cls in ProjectSerializer.__subclasses__()
  151. }
  152. }
  153. class CategorySerializer(serializers.ModelSerializer):
  154. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  155. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  156. class Meta:
  157. model = Category
  158. fields = (
  159. 'id',
  160. 'prob',
  161. 'user',
  162. 'example',
  163. 'created_at',
  164. 'updated_at',
  165. 'label',
  166. )
  167. read_only_fields = ('user',)
  168. class SpanSerializer(serializers.ModelSerializer):
  169. label = serializers.PrimaryKeyRelatedField(queryset=Label.objects.all())
  170. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  171. class Meta:
  172. model = Span
  173. fields = (
  174. 'id',
  175. 'prob',
  176. 'user',
  177. 'example',
  178. 'created_at',
  179. 'updated_at',
  180. 'label',
  181. 'start_offset',
  182. 'end_offset',
  183. )
  184. read_only_fields = ('user',)
  185. class TextLabelSerializer(serializers.ModelSerializer):
  186. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  187. class Meta:
  188. model = TextLabel
  189. fields = (
  190. 'id',
  191. 'prob',
  192. 'user',
  193. 'example',
  194. 'created_at',
  195. 'updated_at',
  196. 'text',
  197. )
  198. read_only_fields = ('user',)
  199. class RoleSerializer(serializers.ModelSerializer):
  200. class Meta:
  201. model = Role
  202. fields = ('id', 'name')
  203. class RoleMappingSerializer(serializers.ModelSerializer):
  204. username = serializers.SerializerMethodField()
  205. rolename = serializers.SerializerMethodField()
  206. @classmethod
  207. def get_username(cls, instance):
  208. user = instance.user
  209. return user.username if user else None
  210. @classmethod
  211. def get_rolename(cls, instance):
  212. role = instance.role
  213. return role.name if role else None
  214. class Meta:
  215. model = RoleMapping
  216. fields = ('id', 'user', 'role', 'username', 'rolename')
  217. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  218. class Meta:
  219. model = AutoLabelingConfig
  220. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  221. read_only_fields = ('created_at', 'updated_at')
  222. def validate_model_name(self, value):
  223. try:
  224. RequestModelFactory.find(value)
  225. except NameError:
  226. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  227. return value
  228. def valid_label_mapping(self, value):
  229. if isinstance(value, dict):
  230. return value
  231. else:
  232. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  233. def validate(self, data):
  234. try:
  235. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  236. except Exception:
  237. model = RequestModelFactory.find(data['model_name'])
  238. schema = model.schema()
  239. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  240. raise serializers.ValidationError(
  241. 'The attributes does not match the model.'
  242. 'You need to correctly specify the required fields: {}'.format(required_fields)
  243. )
  244. return data
  245. def get_annotation_serializer(task: str):
  246. mapping = {
  247. DOCUMENT_CLASSIFICATION: CategorySerializer,
  248. SEQUENCE_LABELING: SpanSerializer,
  249. SEQ2SEQ: TextLabelSerializer,
  250. SPEECH2TEXT: TextLabelSerializer,
  251. IMAGE_CLASSIFICATION: CategorySerializer,
  252. }
  253. try:
  254. return mapping[task]
  255. except KeyError:
  256. raise ValueError(f'{task} is not implemented.')