You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

361 lines
10 KiB

from auto_labeling_pipeline.models import RequestModelFactory
from django.contrib.auth import get_user_model
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_polymorphic.serializers import PolymorphicSerializer
from .models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT, AnnotationRelations,
AutoLabelingConfig, Category, CategoryType, Comment,
Example, ExampleState, ImageClassificationProject,
IntentDetectionAndSlotFillingProject, Label, Project,
RelationTypes, Seq2seqProject, SequenceLabelingProject,
Span, SpanType, Speech2textProject, Tag,
TextClassificationProject, TextLabel)
class UserSerializer(serializers.ModelSerializer):
class Meta:
model = get_user_model()
fields = ('id', 'username', 'is_superuser', 'is_staff')
class LabelSerializer(serializers.ModelSerializer):
def validate(self, attrs):
prefix_key = attrs.get('prefix_key')
suffix_key = attrs.get('suffix_key')
# In the case of user don't set any shortcut key.
if prefix_key is None and suffix_key is None:
return super().validate(attrs)
# Don't allow shortcut key not to have a suffix key.
if prefix_key and not suffix_key:
raise ValidationError('Shortcut key may not have a suffix key.')
# Don't allow to save same shortcut key when prefix_key is null.
try:
context = self.context['request'].parser_context
project_id = context['kwargs']['project_id']
label_id = context['kwargs'].get('label_id')
except (AttributeError, KeyError):
pass # unit tests don't always have the correct context set up
else:
conflicting_labels = self.Meta.model.objects.filter(
suffix_key=suffix_key,
prefix_key=prefix_key,
project=project_id,
)
if label_id is not None:
conflicting_labels = conflicting_labels.exclude(id=label_id)
if conflicting_labels.exists():
raise ValidationError('Duplicate shortcut key.')
return super().validate(attrs)
class Meta:
model = Label
fields = (
'id',
'text',
'prefix_key',
'suffix_key',
'background_color',
'text_color',
)
class CategoryTypeSerializer(LabelSerializer):
class Meta:
model = CategoryType
fields = (
'id',
'text',
'prefix_key',
'suffix_key',
'background_color',
'text_color',
)
class SpanTypeSerializer(LabelSerializer):
class Meta:
model = SpanType
fields = (
'id',
'text',
'prefix_key',
'suffix_key',
'background_color',
'text_color',
)
class CommentSerializer(serializers.ModelSerializer):
class Meta:
model = Comment
fields = ('id', 'user', 'username', 'example', 'text', 'created_at', )
read_only_fields = ('user', 'example')
class TagSerializer(serializers.ModelSerializer):
class Meta:
model = Tag
fields = ('id', 'project', 'text', )
read_only_fields = ('id', 'project')
class ExampleSerializer(serializers.ModelSerializer):
annotation_approver = serializers.SerializerMethodField()
is_confirmed = serializers.SerializerMethodField()
@classmethod
def get_annotation_approver(cls, instance):
approver = instance.annotations_approved_by
return approver.username if approver else None
def get_is_confirmed(self, instance):
user = self.context.get('request').user
if instance.project.collaborative_annotation:
states = instance.states.all()
else:
states = instance.states.filter(confirmed_by_id=user.id)
return states.count() > 0
class Meta:
model = Example
fields = [
'id',
'filename',
'meta',
'annotation_approver',
'comment_count',
'text',
'is_confirmed'
]
read_only_fields = ['filename', 'is_confirmed']
class ExampleStateSerializer(serializers.ModelSerializer):
class Meta:
model = ExampleState
fields = ('id', 'example', 'confirmed_by')
read_only_fields = ('id', 'example', 'confirmed_by')
class ApproverSerializer(ExampleSerializer):
class Meta:
model = Example
fields = ('id', 'annotation_approver')
class ProjectSerializer(serializers.ModelSerializer):
tags = TagSerializer(many=True, required=False)
class Meta:
model = Project
fields = (
'id',
'name',
'description',
'guideline',
'users',
'project_type',
'updated_at',
'random_order',
'collaborative_annotation',
'single_class_classification',
'is_text_project',
'can_define_label',
'can_define_relation',
'can_define_category',
'can_define_span',
'tags'
)
read_only_fields = (
'updated_at',
'users',
'is_text_project',
'can_define_label',
'can_define_relation',
'can_define_category',
'can_define_span',
'tags'
)
class TextClassificationProjectSerializer(ProjectSerializer):
class Meta(ProjectSerializer.Meta):
model = TextClassificationProject
class SequenceLabelingProjectSerializer(ProjectSerializer):
class Meta(ProjectSerializer.Meta):
model = SequenceLabelingProject
fields = ProjectSerializer.Meta.fields + ('allow_overlapping', 'grapheme_mode')
class Seq2seqProjectSerializer(ProjectSerializer):
class Meta(ProjectSerializer.Meta):
model = Seq2seqProject
class IntentDetectionAndSlotFillingProjectSerializer(ProjectSerializer):
class Meta(ProjectSerializer.Meta):
model = IntentDetectionAndSlotFillingProject
class Speech2textProjectSerializer(ProjectSerializer):
class Meta(ProjectSerializer.Meta):
model = Speech2textProject
class ImageClassificationProjectSerializer(ProjectSerializer):
class Meta(ProjectSerializer.Meta):
model = ImageClassificationProject
class ProjectPolymorphicSerializer(PolymorphicSerializer):
model_serializer_mapping = {
Project: ProjectSerializer,
**{
cls.Meta.model: cls for cls in ProjectSerializer.__subclasses__()
}
}
class CategorySerializer(serializers.ModelSerializer):
label = serializers.PrimaryKeyRelatedField(queryset=CategoryType.objects.all())
example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
class Meta:
model = Category
fields = (
'id',
'prob',
'user',
'example',
'created_at',
'updated_at',
'label',
)
read_only_fields = ('user',)
class SpanSerializer(serializers.ModelSerializer):
label = serializers.PrimaryKeyRelatedField(queryset=SpanType.objects.all())
example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
class Meta:
model = Span
fields = (
'id',
'prob',
'user',
'example',
'created_at',
'updated_at',
'label',
'start_offset',
'end_offset',
)
read_only_fields = ('user',)
class TextLabelSerializer(serializers.ModelSerializer):
example = serializers.PrimaryKeyRelatedField(queryset=Example.objects.all())
class Meta:
model = TextLabel
fields = (
'id',
'prob',
'user',
'example',
'created_at',
'updated_at',
'text',
)
read_only_fields = ('user',)
class AutoLabelingConfigSerializer(serializers.ModelSerializer):
class Meta:
model = AutoLabelingConfig
fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
read_only_fields = ('created_at', 'updated_at')
def validate_model_name(self, value):
try:
RequestModelFactory.find(value)
except NameError:
raise serializers.ValidationError(f'The specified model name {value} does not exist.')
return value
def valid_label_mapping(self, value):
if isinstance(value, dict):
return value
else:
raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
def validate(self, data):
try:
RequestModelFactory.create(data['model_name'], data['model_attrs'])
except Exception:
model = RequestModelFactory.find(data['model_name'])
schema = model.schema()
required_fields = ', '.join(schema['required']) if 'required' in schema else ''
raise serializers.ValidationError(
'The attributes does not match the model.'
'You need to correctly specify the required fields: {}'.format(required_fields)
)
return data
def get_annotation_serializer(task: str):
mapping = {
DOCUMENT_CLASSIFICATION: CategorySerializer,
SEQUENCE_LABELING: SpanSerializer,
SEQ2SEQ: TextLabelSerializer,
SPEECH2TEXT: TextLabelSerializer,
IMAGE_CLASSIFICATION: CategorySerializer,
}
try:
return mapping[task]
except KeyError:
raise ValueError(f'{task} is not implemented.')
class RelationTypesSerializer(serializers.ModelSerializer):
def validate(self, attrs):
return super().validate(attrs)
class Meta:
model = RelationTypes
fields = ('id', 'color', 'name')
class AnnotationRelationsSerializer(serializers.ModelSerializer):
def validate(self, attrs):
return super().validate(attrs)
class Meta:
model = AnnotationRelations
fields = ('id', 'annotation_id_1', 'annotation_id_2', 'type', 'user', 'timestamp')