diff --git a/app/api/views/download/factory.py b/app/api/views/download/factory.py new file mode 100644 index 00000000..5974cc99 --- /dev/null +++ b/app/api/views/download/factory.py @@ -0,0 +1,14 @@ +from ...models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING +from . import repositories + + +def create_repository(project) -> repositories.BaseRepository: + mapping = { + DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository, + SEQUENCE_LABELING: repositories.SequenceLabelingRepository, + SEQ2SEQ: repositories.Seq2seqRepository, + } + if project.project_type not in mapping: + ValueError(f'Invalid project type: {project.project_type}') + repository = mapping.get(project.project_type)(project) + return repository