|
|
@ -1,5 +1,7 @@ |
|
|
|
from typing import Type |
|
|
|
|
|
|
|
from ...models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING |
|
|
|
from . import repositories |
|
|
|
from . import catalog, repositories, writer |
|
|
|
|
|
|
|
|
|
|
|
def create_repository(project) -> repositories.BaseRepository: |
|
|
@ -12,3 +14,14 @@ def create_repository(project) -> repositories.BaseRepository: |
|
|
|
ValueError(f'Invalid project type: {project.project_type}') |
|
|
|
repository = mapping.get(project.project_type)(project) |
|
|
|
return repository |
|
|
|
|
|
|
|
|
|
|
|
def create_writer(format: str) -> Type[writer.BaseWriter]: |
|
|
|
mapping = { |
|
|
|
catalog.CSV.name: writer.CsvWriter, |
|
|
|
catalog.JSONL.name: writer.JSONLWriter, |
|
|
|
catalog.FastText.name: writer.FastTextWriter, |
|
|
|
} |
|
|
|
if format not in mapping: |
|
|
|
ValueError(f'Invalid format: {format}') |
|
|
|
return mapping[format] |