You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

39 lines
1.3 KiB

2 years ago
2 years ago
2 years ago
2 years ago
  1. from typing import Type
  2. from projects.models import (
  3. DOCUMENT_CLASSIFICATION,
  4. SEQUENCE_LABELING,
  5. SEQ2SEQ,
  6. SPEECH2TEXT,
  7. IMAGE_CLASSIFICATION,
  8. INTENT_DETECTION_AND_SLOT_FILLING,
  9. )
  10. from . import catalog, repositories, writers
  11. def create_repository(project):
  12. mapping = {
  13. DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository,
  14. SEQUENCE_LABELING: repositories.SequenceLabelingRepository,
  15. SEQ2SEQ: repositories.Seq2seqRepository,
  16. IMAGE_CLASSIFICATION: repositories.FileRepository,
  17. SPEECH2TEXT: repositories.Speech2TextRepository,
  18. INTENT_DETECTION_AND_SLOT_FILLING: repositories.IntentDetectionSlotFillingRepository,
  19. }
  20. if project.project_type not in mapping:
  21. ValueError(f"Invalid project type: {project.project_type}")
  22. repository = mapping[project.project_type](project)
  23. return repository
  24. def create_writer(file_format: str) -> Type[writers.BaseWriter]:
  25. mapping = {
  26. catalog.CSV.name: writers.CsvWriter,
  27. catalog.JSON.name: writers.JSONWriter,
  28. catalog.JSONL.name: writers.JSONLWriter,
  29. catalog.FastText.name: writers.FastTextWriter,
  30. catalog.IntentAndSlot.name: writers.IntentAndSlotWriter,
  31. }
  32. if file_format not in mapping:
  33. ValueError(f"Invalid format: {file_format}")
  34. return mapping[file_format]