diff --git a/app/api/views/download/factory.py b/app/api/views/download/factory.py index 5974cc99..eb750be9 100644 --- a/app/api/views/download/factory.py +++ b/app/api/views/download/factory.py @@ -1,5 +1,7 @@ +from typing import Type + from ...models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING -from . import repositories +from . import catalog, repositories, writer def create_repository(project) -> repositories.BaseRepository: @@ -12,3 +14,14 @@ def create_repository(project) -> repositories.BaseRepository: ValueError(f'Invalid project type: {project.project_type}') repository = mapping.get(project.project_type)(project) return repository + + +def create_writer(format: str) -> Type[writer.BaseWriter]: + mapping = { + catalog.CSV.name: writer.CsvWriter, + catalog.JSONL.name: writer.JSONLWriter, + catalog.FastText.name: writer.FastTextWriter, + } + if format not in mapping: + ValueError(f'Invalid format: {format}') + return mapping[format]