diff --git a/backend/api/views/upload/factory.py b/backend/api/views/upload/factory.py index b09bd716..e7020444 100644 --- a/backend/api/views/upload/factory.py +++ b/backend/api/views/upload/factory.py @@ -1,6 +1,6 @@ from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, SPEECH2TEXT) -from . import catalog, cleaners, data, dataset, label +from . import catalog, cleaners, data, dataset, label, parsers def get_data_class(project_type: str): @@ -29,6 +29,24 @@ def get_dataset_class(format: str): return mapping[format] +def get_parser(file_format: str): + mapping = { + catalog.TextFile.name: parsers.TextFileParser, + catalog.TextLine.name: parsers.LineParser, + catalog.CSV.name: parsers.CSVParser, + catalog.JSONL.name: parsers.JSONLParser, + catalog.JSON.name: parsers.JSONParser, + catalog.FastText.name: parsers.FastTextParser, + catalog.Excel.name: parsers.ExcelParser, + catalog.CoNLL.name: parsers.CoNLLParser, + catalog.ImageFile.name: parsers.PlainParser, + catalog.AudioFile.name: parsers.PlainParser, + } + if file_format not in mapping: + raise ValueError(f'Invalid format: {file_format}') + return mapping[file_format] + + def get_label_class(project_type: str): mapping = { DOCUMENT_CLASSIFICATION: label.CategoryLabel,