diff --git a/app/api/views/upload/factory.py b/app/api/views/upload/factory.py index 14c3132c..2d1a8652 100644 --- a/app/api/views/upload/factory.py +++ b/app/api/views/upload/factory.py @@ -1,7 +1,15 @@ +from . import data from . import dataset from . import label +def get_data_class(project_type: str): + if project_type in ['DocumentClassification', 'SequenceLabeling', 'Seq2seq']: + return data.TextData + else: + return data.FileData + + def get_dataset_class(format: str): if format == 'csv': return dataset.CsvDataset