from typing import Type from . import catalog, repositories, writers from projects.models import ( DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ, SEQUENCE_LABELING, SPEECH2TEXT, ) def create_repository(project): if getattr(project, "use_relation", False): return repositories.RelationExtractionRepository(project) mapping = { DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository, SEQUENCE_LABELING: repositories.SequenceLabelingRepository, SEQ2SEQ: repositories.Seq2seqRepository, IMAGE_CLASSIFICATION: repositories.FileRepository, SPEECH2TEXT: repositories.Speech2TextRepository, INTENT_DETECTION_AND_SLOT_FILLING: repositories.IntentDetectionSlotFillingRepository, } if project.project_type not in mapping: ValueError(f"Invalid project type: {project.project_type}") repository = mapping[project.project_type](project) return repository def create_writer(file_format: str) -> Type[writers.BaseWriter]: mapping = { catalog.CSV.name: writers.CsvWriter, catalog.JSON.name: writers.JSONWriter, catalog.JSONL.name: writers.JSONLWriter, catalog.FastText.name: writers.FastTextWriter, catalog.IntentAndSlot.name: writers.IntentAndSlotWriter, catalog.JSONLRelation.name: writers.EntityAndRelationWriter, } if file_format not in mapping: ValueError(f"Invalid format: {file_format}") return mapping[file_format]