from auto_labeling_pipeline.models import RequestModelFactory from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_polymorphic.serializers import PolymorphicSerializer from .models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, SPEECH2TEXT, AnnotationRelations, AutoLabelingConfig, Category, CategoryType, Comment, Example, ExampleState, ImageClassificationProject, IntentDetectionAndSlotFillingProject, Label, Project, RelationTypes, Seq2seqProject, SequenceLabelingProject, Span, SpanType, Speech2textProject, Tag, TextClassificationProject, TextLabel) class LabelSerializer(serializers.ModelSerializer): def validate(self, attrs): prefix_key = attrs.get('prefix_key') suffix_key = attrs.get('suffix_key') # In the case of user don't set any shortcut key. if prefix_key is None and suffix_key is None: return super().validate(attrs) # Don't allow shortcut key not to have a suffix key. if prefix_key and not suffix_key: raise ValidationError('Shortcut key may not have a suffix key.') # Don't allow to save same shortcut key when prefix_key is null. try: context = self.context['request'].parser_context project_id = context['kwargs']['project_id'] label_id = context['kwargs'].get('label_id') except (AttributeError, KeyError): pass # unit tests don't always have the correct context set up else: conflicting_labels = self.Meta.model.objects.filter( suffix_key=suffix_key, prefix_key=prefix_key, project=project_id, ) if label_id is not None: conflicting_labels = conflicting_labels.exclude(id=label_id) if conflicting_labels.exists(): raise ValidationError('Duplicate shortcut key.') return super().validate(attrs) class Meta: model = Label fields = ( 'id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color', ) class CategoryTypeSerializer(LabelSerializer): class Meta: model = CategoryType fields = ( 'id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color', ) class SpanTypeSerializer(LabelSerializer): class Meta: model = SpanType fields = ( 'id', 'text', 'prefix_key', 'suffix_key', 'background_color', 'text_color', ) class CommentSerializer(serializers.ModelSerializer): class Meta: model = Comment fields = ('id', 'user', 'username', 'example', 'text', 'created_at', ) read_only_fields = ('user', 'example') class TagSerializer(serializers.ModelSerializer): class Meta: model = Tag fields = ('id', 'project', 'text', ) read_only_fields = ('id', 'project') class ExampleSerializer(serializers.ModelSerializer): annotation_approver = serializers.SerializerMethodField() is_confirmed = serializers.SerializerMethodField() @classmethod def get_annotation_approver(cls, instance): approver = instance.annotations_approved_by return approver.username if approver else None def get_is_confirmed(self, instance): user = self.context.get('request').user if instance.project.collaborative_annotation: states = instance.states.all() else: states = instance.states.filter(confirmed_by_id=user.id) return states.count() > 0 class Meta: model = Example fields = [ 'id', 'filename', 'meta', 'annotation_approver', 'comment_count', 'text', 'is_confirmed' ] read_only_fields = ['filename', 'is_confirmed'] class ExampleStateSerializer(serializers.ModelSerializer): class Meta: model = ExampleState fields = ('id', 'example', 'confirmed_by') read_only_fields = ('id', 'example', 'confirmed_by') class ApproverSerializer(ExampleSerializer): class Meta: model = Example fields = ('id', 'annotation_approver') class ProjectSerializer(serializers.ModelSerializer): tags = TagSerializer(many=True, required=False) class Meta: model = Project fields = ( 'id', 'name', 'description', 'guideline', 'users', 'project_type', 'updated_at', 'random_order', 'collaborative_annotation', 'single_class_classification', 'is_text_project', 'can_define_label', 'can_define_relation', 'can_define_category', 'can_define_span', 'tags' ) read_only_fields = ( 'updated_at', 'users', 'is_text_project', 'can_define_label', 'can_define_relation', 'can_define_category', 'can_define_span', 'tags' ) class TextClassificationProjectSerializer(ProjectSerializer): class Meta(ProjectSerializer.Meta): model = TextClassificationProject class SequenceLabelingProjectSerializer(ProjectSerializer): class Meta(ProjectSerializer.Meta): model = SequenceLabelingProject fields = ProjectSerializer.Meta.fields + ('allow_overlapping', 'grapheme_mode') class Seq2seqProjectSerializer(ProjectSerializer): class Meta(ProjectSerializer.Meta): model = Seq2seqProject class IntentDetectionAndSlotFillingProjectSerializer(ProjectSerializer): class Meta(ProjectSerializer.Meta): model = IntentDetectionAndSlotFillingProject class Speech2textProjectSerializer(ProjectSerializer): class Meta(ProjectSerializer.Meta): model = Speech2textProject class ImageClassificationProjectSerializer(ProjectSerializer): class Meta(ProjectSerializer.Meta): model = ImageClassificationProject class ProjectPolymorphicSerializer(PolymorphicSerializer): model_serializer_mapping = { Project: ProjectSerializer, **{ cls.Meta.model: cls for cls in ProjectSerializer.__subclasses__() } } class CategorySerializer(serializers.ModelSerializer): label = serializers.PrimaryKeyRelatedField(queryset=CategoryType.objects.all()) example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all()) class Meta: model = Category fields = ( 'id', 'prob', 'user', 'example', 'created_at', 'updated_at', 'label', ) read_only_fields = ('user',) class SpanSerializer(serializers.ModelSerializer): label = serializers.PrimaryKeyRelatedField(queryset=SpanType.objects.all()) example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all()) class Meta: model = Span fields = ( 'id', 'prob', 'user', 'example', 'created_at', 'updated_at', 'label', 'start_offset', 'end_offset', ) read_only_fields = ('user',) class TextLabelSerializer(serializers.ModelSerializer): example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all()) class Meta: model = TextLabel fields = ( 'id', 'prob', 'user', 'example', 'created_at', 'updated_at', 'text', ) read_only_fields = ('user',) class AutoLabelingConfigSerializer(serializers.ModelSerializer): class Meta: model = AutoLabelingConfig fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default') read_only_fields = ('created_at', 'updated_at') def validate_model_name(self, value): try: RequestModelFactory.find(value) except NameError: raise serializers.ValidationError(f'The specified model name {value} does not exist.') return value def valid_label_mapping(self, value): if isinstance(value, dict): return value else: raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.') def validate(self, data): try: RequestModelFactory.create(data['model_name'], data['model_attrs']) except Exception: model = RequestModelFactory.find(data['model_name']) schema = model.schema() required_fields = ', '.join(schema['required']) if 'required' in schema else '' raise serializers.ValidationError( 'The attributes does not match the model.' 'You need to correctly specify the required fields: {}'.format(required_fields) ) return data def get_annotation_serializer(task: str): mapping = { DOCUMENT_CLASSIFICATION: CategorySerializer, SEQUENCE_LABELING: SpanSerializer, SEQ2SEQ: TextLabelSerializer, SPEECH2TEXT: TextLabelSerializer, IMAGE_CLASSIFICATION: CategorySerializer, } try: return mapping[task] except KeyError: raise ValueError(f'{task} is not implemented.') class RelationTypesSerializer(serializers.ModelSerializer): def validate(self, attrs): return super().validate(attrs) class Meta: model = RelationTypes fields = ('id', 'color', 'name') class AnnotationRelationsSerializer(serializers.ModelSerializer): def validate(self, attrs): return super().validate(attrs) class Meta: model = AnnotationRelations fields = ('id', 'annotation_id_1', 'annotation_id_2', 'type', 'user', 'timestamp')