You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

34 lines
1.4 KiB

  1. from typing import Type
  2. from api.models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION,
  3. INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
  4. SEQUENCE_LABELING, SPEECH2TEXT)
  5. from . import catalog, repositories, writers
  6. def create_repository(project) -> repositories.BaseRepository:
  7. mapping = {
  8. DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository,
  9. SEQUENCE_LABELING: repositories.SequenceLabelingRepository,
  10. SEQ2SEQ: repositories.Seq2seqRepository,
  11. IMAGE_CLASSIFICATION: repositories.FileRepository,
  12. SPEECH2TEXT: repositories.Speech2TextRepository,
  13. INTENT_DETECTION_AND_SLOT_FILLING: repositories.IntentDetectionSlotFillingRepository,
  14. }
  15. if project.project_type not in mapping:
  16. ValueError(f'Invalid project type: {project.project_type}')
  17. repository = mapping.get(project.project_type)(project)
  18. return repository
  19. def create_writer(file_format: str) -> Type[writers.BaseWriter]:
  20. mapping = {
  21. catalog.CSV.name: writers.CsvWriter,
  22. catalog.JSON.name: writers.JSONWriter,
  23. catalog.JSONL.name: writers.JSONLWriter,
  24. catalog.FastText.name: writers.FastTextWriter,
  25. catalog.IntentAndSlot.name: writers.IntentAndSlotWriter
  26. }
  27. if file_format not in mapping:
  28. ValueError(f'Invalid format: {file_format}')
  29. return mapping[file_format]