You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

361 lines
10 KiB

6 years ago
5 years ago
5 years ago
3 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
3 years ago
  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from django.contrib.auth import get_user_model
  3. from rest_framework import serializers
  4. from rest_framework.exceptions import ValidationError
  5. from rest_polymorphic.serializers import PolymorphicSerializer
  6. from .models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
  7. SEQUENCE_LABELING, SPEECH2TEXT, AnnotationRelations,
  8. AutoLabelingConfig, Category, CategoryType, Comment,
  9. Example, ExampleState, ImageClassificationProject,
  10. IntentDetectionAndSlotFillingProject, Label, Project,
  11. RelationTypes, Seq2seqProject, SequenceLabelingProject,
  12. Span, SpanType, Speech2textProject, Tag,
  13. TextClassificationProject, TextLabel)
  14. class UserSerializer(serializers.ModelSerializer):
  15. class Meta:
  16. model = get_user_model()
  17. fields = ('id', 'username', 'is_superuser', 'is_staff')
  18. class LabelSerializer(serializers.ModelSerializer):
  19. def validate(self, attrs):
  20. prefix_key = attrs.get('prefix_key')
  21. suffix_key = attrs.get('suffix_key')
  22. # In the case of user don't set any shortcut key.
  23. if prefix_key is None and suffix_key is None:
  24. return super().validate(attrs)
  25. # Don't allow shortcut key not to have a suffix key.
  26. if prefix_key and not suffix_key:
  27. raise ValidationError('Shortcut key may not have a suffix key.')
  28. # Don't allow to save same shortcut key when prefix_key is null.
  29. try:
  30. context = self.context['request'].parser_context
  31. project_id = context['kwargs']['project_id']
  32. label_id = context['kwargs'].get('label_id')
  33. except (AttributeError, KeyError):
  34. pass # unit tests don't always have the correct context set up
  35. else:
  36. conflicting_labels = self.Meta.model.objects.filter(
  37. suffix_key=suffix_key,
  38. prefix_key=prefix_key,
  39. project=project_id,
  40. )
  41. if label_id is not None:
  42. conflicting_labels = conflicting_labels.exclude(id=label_id)
  43. if conflicting_labels.exists():
  44. raise ValidationError('Duplicate shortcut key.')
  45. return super().validate(attrs)
  46. class Meta:
  47. model = Label
  48. fields = (
  49. 'id',
  50. 'text',
  51. 'prefix_key',
  52. 'suffix_key',
  53. 'background_color',
  54. 'text_color',
  55. )
  56. class CategoryTypeSerializer(LabelSerializer):
  57. class Meta:
  58. model = CategoryType
  59. fields = (
  60. 'id',
  61. 'text',
  62. 'prefix_key',
  63. 'suffix_key',
  64. 'background_color',
  65. 'text_color',
  66. )
  67. class SpanTypeSerializer(LabelSerializer):
  68. class Meta:
  69. model = SpanType
  70. fields = (
  71. 'id',
  72. 'text',
  73. 'prefix_key',
  74. 'suffix_key',
  75. 'background_color',
  76. 'text_color',
  77. )
  78. class CommentSerializer(serializers.ModelSerializer):
  79. class Meta:
  80. model = Comment
  81. fields = ('id', 'user', 'username', 'example', 'text', 'created_at', )
  82. read_only_fields = ('user', 'example')
  83. class TagSerializer(serializers.ModelSerializer):
  84. class Meta:
  85. model = Tag
  86. fields = ('id', 'project', 'text', )
  87. read_only_fields = ('id', 'project')
  88. class ExampleSerializer(serializers.ModelSerializer):
  89. annotation_approver = serializers.SerializerMethodField()
  90. is_confirmed = serializers.SerializerMethodField()
  91. @classmethod
  92. def get_annotation_approver(cls, instance):
  93. approver = instance.annotations_approved_by
  94. return approver.username if approver else None
  95. def get_is_confirmed(self, instance):
  96. user = self.context.get('request').user
  97. if instance.project.collaborative_annotation:
  98. states = instance.states.all()
  99. else:
  100. states = instance.states.filter(confirmed_by_id=user.id)
  101. return states.count() > 0
  102. class Meta:
  103. model = Example
  104. fields = [
  105. 'id',
  106. 'filename',
  107. 'meta',
  108. 'annotation_approver',
  109. 'comment_count',
  110. 'text',
  111. 'is_confirmed'
  112. ]
  113. read_only_fields = ['filename', 'is_confirmed']
  114. class ExampleStateSerializer(serializers.ModelSerializer):
  115. class Meta:
  116. model = ExampleState
  117. fields = ('id', 'example', 'confirmed_by')
  118. read_only_fields = ('id', 'example', 'confirmed_by')
  119. class ApproverSerializer(ExampleSerializer):
  120. class Meta:
  121. model = Example
  122. fields = ('id', 'annotation_approver')
  123. class ProjectSerializer(serializers.ModelSerializer):
  124. tags = TagSerializer(many=True, required=False)
  125. class Meta:
  126. model = Project
  127. fields = (
  128. 'id',
  129. 'name',
  130. 'description',
  131. 'guideline',
  132. 'users',
  133. 'project_type',
  134. 'updated_at',
  135. 'random_order',
  136. 'collaborative_annotation',
  137. 'single_class_classification',
  138. 'is_text_project',
  139. 'can_define_label',
  140. 'can_define_relation',
  141. 'can_define_category',
  142. 'can_define_span',
  143. 'tags'
  144. )
  145. read_only_fields = (
  146. 'updated_at',
  147. 'users',
  148. 'is_text_project',
  149. 'can_define_label',
  150. 'can_define_relation',
  151. 'can_define_category',
  152. 'can_define_span',
  153. 'tags'
  154. )
  155. class TextClassificationProjectSerializer(ProjectSerializer):
  156. class Meta(ProjectSerializer.Meta):
  157. model = TextClassificationProject
  158. class SequenceLabelingProjectSerializer(ProjectSerializer):
  159. class Meta(ProjectSerializer.Meta):
  160. model = SequenceLabelingProject
  161. fields = ProjectSerializer.Meta.fields + ('allow_overlapping', 'grapheme_mode')
  162. class Seq2seqProjectSerializer(ProjectSerializer):
  163. class Meta(ProjectSerializer.Meta):
  164. model = Seq2seqProject
  165. class IntentDetectionAndSlotFillingProjectSerializer(ProjectSerializer):
  166. class Meta(ProjectSerializer.Meta):
  167. model = IntentDetectionAndSlotFillingProject
  168. class Speech2textProjectSerializer(ProjectSerializer):
  169. class Meta(ProjectSerializer.Meta):
  170. model = Speech2textProject
  171. class ImageClassificationProjectSerializer(ProjectSerializer):
  172. class Meta(ProjectSerializer.Meta):
  173. model = ImageClassificationProject
  174. class ProjectPolymorphicSerializer(PolymorphicSerializer):
  175. model_serializer_mapping = {
  176. Project: ProjectSerializer,
  177. **{
  178. cls.Meta.model: cls for cls in ProjectSerializer.__subclasses__()
  179. }
  180. }
  181. class CategorySerializer(serializers.ModelSerializer):
  182. label = serializers.PrimaryKeyRelatedField(queryset=CategoryType.objects.all())
  183. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  184. class Meta:
  185. model = Category
  186. fields = (
  187. 'id',
  188. 'prob',
  189. 'user',
  190. 'example',
  191. 'created_at',
  192. 'updated_at',
  193. 'label',
  194. )
  195. read_only_fields = ('user',)
  196. class SpanSerializer(serializers.ModelSerializer):
  197. label = serializers.PrimaryKeyRelatedField(queryset=SpanType.objects.all())
  198. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  199. class Meta:
  200. model = Span
  201. fields = (
  202. 'id',
  203. 'prob',
  204. 'user',
  205. 'example',
  206. 'created_at',
  207. 'updated_at',
  208. 'label',
  209. 'start_offset',
  210. 'end_offset',
  211. )
  212. read_only_fields = ('user',)
  213. class TextLabelSerializer(serializers.ModelSerializer):
  214. example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
  215. class Meta:
  216. model = TextLabel
  217. fields = (
  218. 'id',
  219. 'prob',
  220. 'user',
  221. 'example',
  222. 'created_at',
  223. 'updated_at',
  224. 'text',
  225. )
  226. read_only_fields = ('user',)
  227. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  228. class Meta:
  229. model = AutoLabelingConfig
  230. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
  231. read_only_fields = ('created_at', 'updated_at')
  232. def validate_model_name(self, value):
  233. try:
  234. RequestModelFactory.find(value)
  235. except NameError:
  236. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  237. return value
  238. def valid_label_mapping(self, value):
  239. if isinstance(value, dict):
  240. return value
  241. else:
  242. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  243. def validate(self, data):
  244. try:
  245. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  246. except Exception:
  247. model = RequestModelFactory.find(data['model_name'])
  248. schema = model.schema()
  249. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  250. raise serializers.ValidationError(
  251. 'The attributes does not match the model.'
  252. 'You need to correctly specify the required fields: {}'.format(required_fields)
  253. )
  254. return data
  255. def get_annotation_serializer(task: str):
  256. mapping = {
  257. DOCUMENT_CLASSIFICATION: CategorySerializer,
  258. SEQUENCE_LABELING: SpanSerializer,
  259. SEQ2SEQ: TextLabelSerializer,
  260. SPEECH2TEXT: TextLabelSerializer,
  261. IMAGE_CLASSIFICATION: CategorySerializer,
  262. }
  263. try:
  264. return mapping[task]
  265. except KeyError:
  266. raise ValueError(f'{task} is not implemented.')
  267. class RelationTypesSerializer(serializers.ModelSerializer):
  268. def validate(self, attrs):
  269. return super().validate(attrs)
  270. class Meta:
  271. model = RelationTypes
  272. fields = ('id', 'color', 'name')
  273. class AnnotationRelationsSerializer(serializers.ModelSerializer):
  274. def validate(self, attrs):
  275. return super().validate(attrs)
  276. class Meta:
  277. model = AnnotationRelations
  278. fields = ('id', 'annotation_id_1', 'annotation_id_2', 'type', 'user', 'timestamp')