From ff711124c4a362c6417b919a9a65369171210f6f Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 8 Apr 2021 09:41:08 +0900 Subject: [PATCH] Update factory --- app/api/views/upload/factory.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/app/api/views/upload/factory.py b/app/api/views/upload/factory.py index d87cef20..ca714000 100644 --- a/app/api/views/upload/factory.py +++ b/app/api/views/upload/factory.py @@ -1,34 +1,36 @@ -from . import data, dataset, label +from . import data, dataset, label, catalog +from ...models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ def get_data_class(project_type: str): - if project_type in ['DocumentClassification', 'SequenceLabeling', 'Seq2seq']: + text_projects = [DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ] + if project_type in text_projects: return data.TextData else: return data.FileData def get_dataset_class(format: str): - if format == 'csv': + if format == catalog.CSV: return dataset.CsvDataset - elif format == 'jsonl': + elif format == catalog.JSONL: return dataset.JSONLDataset - elif format == 'json': + elif format == catalog.JSONL: return dataset.JSONDataset - elif format == 'fasttext': + elif format == catalog.FastText: return dataset.FastTextDataset - elif format == 'excel': + elif format == catalog.EXCEL: return dataset.ExcelDataset else: ValueError(f'Invalid format: {format}') def get_label_class(project_type: str): - if project_type == 'DocumentClassification': + if project_type == DOCUMENT_CLASSIFICATION: return label.CategoryLabel - elif project_type == 'SequenceLabeling': + elif project_type == SEQUENCE_LABELING: return label.OffsetLabel - elif project_type == 'Seq2seq': + elif project_type == SEQ2SEQ: return label.TextLabel else: ValueError(f'Invalid project type: {project_type}')