You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

31 lines
1.1 KiB

  1. from typing import Type
  2. from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
  3. SEQUENCE_LABELING, SPEECH2TEXT)
  4. from . import catalog, repositories, writer
  5. def create_repository(project) -> repositories.BaseRepository:
  6. mapping = {
  7. DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository,
  8. SEQUENCE_LABELING: repositories.SequenceLabelingRepository,
  9. SEQ2SEQ: repositories.Seq2seqRepository,
  10. IMAGE_CLASSIFICATION: repositories.FileRepository,
  11. SPEECH2TEXT: repositories.Speech2TextRepository,
  12. }
  13. if project.project_type not in mapping:
  14. ValueError(f'Invalid project type: {project.project_type}')
  15. repository = mapping.get(project.project_type)(project)
  16. return repository
  17. def create_writer(format: str) -> Type[writer.BaseWriter]:
  18. mapping = {
  19. catalog.CSV.name: writer.CsvWriter,
  20. catalog.JSON.name: writer.JSONWriter,
  21. catalog.JSONL.name: writer.JSONLWriter,
  22. catalog.FastText.name: writer.FastTextWriter,
  23. }
  24. if format not in mapping:
  25. ValueError(f'Invalid format: {format}')
  26. return mapping[format]