You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

42 lines
1.5 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from typing import Type
  2. from . import catalog, repositories, writers
  3. from projects.models import (
  4. DOCUMENT_CLASSIFICATION,
  5. IMAGE_CLASSIFICATION,
  6. INTENT_DETECTION_AND_SLOT_FILLING,
  7. SEQ2SEQ,
  8. SEQUENCE_LABELING,
  9. SPEECH2TEXT,
  10. )
  11. def create_repository(project):
  12. if getattr(project, "use_relation", False):
  13. return repositories.RelationExtractionRepository(project)
  14. mapping = {
  15. DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository,
  16. SEQUENCE_LABELING: repositories.SequenceLabelingRepository,
  17. SEQ2SEQ: repositories.Seq2seqRepository,
  18. IMAGE_CLASSIFICATION: repositories.FileRepository,
  19. SPEECH2TEXT: repositories.Speech2TextRepository,
  20. INTENT_DETECTION_AND_SLOT_FILLING: repositories.IntentDetectionSlotFillingRepository,
  21. }
  22. if project.project_type not in mapping:
  23. ValueError(f"Invalid project type: {project.project_type}")
  24. repository = mapping[project.project_type](project)
  25. return repository
  26. def create_writer(file_format: str) -> Type[writers.BaseWriter]:
  27. mapping = {
  28. catalog.CSV.name: writers.CsvWriter,
  29. catalog.JSON.name: writers.JSONWriter,
  30. catalog.JSONL.name: writers.JSONLWriter,
  31. catalog.FastText.name: writers.FastTextWriter,
  32. catalog.IntentAndSlot.name: writers.IntentAndSlotWriter,
  33. catalog.JSONLRelation.name: writers.EntityAndRelationWriter,
  34. }
  35. if file_format not in mapping:
  36. ValueError(f"Invalid format: {file_format}")
  37. return mapping[file_format]