You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

31 lines
1.1 KiB

from typing import Type
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT)
from . import catalog, repositories, writer
def create_repository(project) -> repositories.BaseRepository:
mapping = {
DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository,
SEQUENCE_LABELING: repositories.SequenceLabelingRepository,
SEQ2SEQ: repositories.Seq2seqRepository,
IMAGE_CLASSIFICATION: repositories.FileRepository,
SPEECH2TEXT: repositories.Speech2TextRepository,
}
if project.project_type not in mapping:
ValueError(f'Invalid project type: {project.project_type}')
repository = mapping.get(project.project_type)(project)
return repository
def create_writer(format: str) -> Type[writer.BaseWriter]:
mapping = {
catalog.CSV.name: writer.CsvWriter,
catalog.JSON.name: writer.JSONWriter,
catalog.JSONL.name: writer.JSONLWriter,
catalog.FastText.name: writer.FastTextWriter,
}
if format not in mapping:
ValueError(f'Invalid format: {format}')
return mapping[format]