You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

55 lines
2.0 KiB

  1. from auto_labeling_pipeline.models import RequestModelFactory
  2. from rest_framework import serializers
  3. from api.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ, SPEECH2TEXT, \
  4. IMAGE_CLASSIFICATION
  5. from api.serializers import CategorySerializer, SpanSerializer, TextLabelSerializer
  6. from .models import AutoLabelingConfig
  7. class AutoLabelingConfigSerializer(serializers.ModelSerializer):
  8. class Meta:
  9. model = AutoLabelingConfig
  10. fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default', 'task_type')
  11. read_only_fields = ('created_at', 'updated_at')
  12. def validate_model_name(self, value):
  13. try:
  14. RequestModelFactory.find(value)
  15. except NameError:
  16. raise serializers.ValidationError(f'The specified model name {value} does not exist.')
  17. return value
  18. def valid_label_mapping(self, value):
  19. if isinstance(value, dict):
  20. return value
  21. else:
  22. raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
  23. def validate(self, data):
  24. try:
  25. RequestModelFactory.create(data['model_name'], data['model_attrs'])
  26. except Exception:
  27. model = RequestModelFactory.find(data['model_name'])
  28. schema = model.schema()
  29. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  30. raise serializers.ValidationError(
  31. 'The attributes does not match the model.'
  32. 'You need to correctly specify the required fields: {}'.format(required_fields)
  33. )
  34. return data
  35. def get_annotation_serializer(task: str):
  36. mapping = {
  37. DOCUMENT_CLASSIFICATION: CategorySerializer,
  38. SEQUENCE_LABELING: SpanSerializer,
  39. SEQ2SEQ: TextLabelSerializer,
  40. SPEECH2TEXT: TextLabelSerializer,
  41. IMAGE_CLASSIFICATION: CategorySerializer,
  42. }
  43. try:
  44. return mapping[task]
  45. except KeyError:
  46. raise ValueError(f'{task} is not implemented.')