|
|
@ -1,34 +1,36 @@ |
|
|
|
from . import data, dataset, label |
|
|
|
from . import data, dataset, label, catalog |
|
|
|
from ...models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ |
|
|
|
|
|
|
|
|
|
|
|
def get_data_class(project_type: str): |
|
|
|
if project_type in ['DocumentClassification', 'SequenceLabeling', 'Seq2seq']: |
|
|
|
text_projects = [DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ] |
|
|
|
if project_type in text_projects: |
|
|
|
return data.TextData |
|
|
|
else: |
|
|
|
return data.FileData |
|
|
|
|
|
|
|
|
|
|
|
def get_dataset_class(format: str): |
|
|
|
if format == 'csv': |
|
|
|
if format == catalog.CSV: |
|
|
|
return dataset.CsvDataset |
|
|
|
elif format == 'jsonl': |
|
|
|
elif format == catalog.JSONL: |
|
|
|
return dataset.JSONLDataset |
|
|
|
elif format == 'json': |
|
|
|
elif format == catalog.JSONL: |
|
|
|
return dataset.JSONDataset |
|
|
|
elif format == 'fasttext': |
|
|
|
elif format == catalog.FastText: |
|
|
|
return dataset.FastTextDataset |
|
|
|
elif format == 'excel': |
|
|
|
elif format == catalog.EXCEL: |
|
|
|
return dataset.ExcelDataset |
|
|
|
else: |
|
|
|
ValueError(f'Invalid format: {format}') |
|
|
|
|
|
|
|
|
|
|
|
def get_label_class(project_type: str): |
|
|
|
if project_type == 'DocumentClassification': |
|
|
|
if project_type == DOCUMENT_CLASSIFICATION: |
|
|
|
return label.CategoryLabel |
|
|
|
elif project_type == 'SequenceLabeling': |
|
|
|
elif project_type == SEQUENCE_LABELING: |
|
|
|
return label.OffsetLabel |
|
|
|
elif project_type == 'Seq2seq': |
|
|
|
elif project_type == SEQ2SEQ: |
|
|
|
return label.TextLabel |
|
|
|
else: |
|
|
|
ValueError(f'Invalid project type: {project_type}') |