You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

55 lines
2.0 KiB

from auto_labeling_pipeline.models import RequestModelFactory
from rest_framework import serializers
from api.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ, SPEECH2TEXT, \
IMAGE_CLASSIFICATION
from api.serializers import CategorySerializer, SpanSerializer, TextLabelSerializer
from .models import AutoLabelingConfig
class AutoLabelingConfigSerializer(serializers.ModelSerializer):
class Meta:
model = AutoLabelingConfig
fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default', 'task_type')
read_only_fields = ('created_at', 'updated_at')
def validate_model_name(self, value):
try:
RequestModelFactory.find(value)
except NameError:
raise serializers.ValidationError(f'The specified model name {value} does not exist.')
return value
def valid_label_mapping(self, value):
if isinstance(value, dict):
return value
else:
raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
def validate(self, data):
try:
RequestModelFactory.create(data['model_name'], data['model_attrs'])
except Exception:
model = RequestModelFactory.find(data['model_name'])
schema = model.schema()
required_fields = ', '.join(schema['required']) if 'required' in schema else ''
raise serializers.ValidationError(
'The attributes does not match the model.'
'You need to correctly specify the required fields: {}'.format(required_fields)
)
return data
def get_annotation_serializer(task: str):
mapping = {
DOCUMENT_CLASSIFICATION: CategorySerializer,
SEQUENCE_LABELING: SpanSerializer,
SEQ2SEQ: TextLabelSerializer,
SPEECH2TEXT: TextLabelSerializer,
IMAGE_CLASSIFICATION: CategorySerializer,
}
try:
return mapping[task]
except KeyError:
raise ValueError(f'{task} is not implemented.')