@ -1,7 +1,15 @@
from . import data
from . import dataset
from . import label
def get_data_class(project_type: str):
if project_type in ['DocumentClassification', 'SequenceLabeling', 'Seq2seq']:
return data.TextData
else:
return data.FileData
def get_dataset_class(format: str):
if format == 'csv':
return dataset.CsvDataset