diff --git a/backend/api/views/upload/factory.py b/backend/api/views/upload/factory.py index 22d1ec02..52f6b739 100644 --- a/backend/api/views/upload/factory.py +++ b/backend/api/views/upload/factory.py @@ -1,4 +1,5 @@ -from ...models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING +from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, + SEQUENCE_LABELING) from . import catalog, data, dataset, label @@ -19,7 +20,8 @@ def get_dataset_class(format: str): catalog.JSON.name: dataset.JSONDataset, catalog.FastText.name: dataset.FastTextDataset, catalog.Excel.name: dataset.ExcelDataset, - catalog.CoNLL.name: dataset.CoNLLDataset + catalog.CoNLL.name: dataset.CoNLLDataset, + catalog.ImageFile.name: dataset.FileBaseDataset, } if format not in mapping: ValueError(f'Invalid format: {format}') @@ -30,7 +32,8 @@ def get_label_class(project_type: str): mapping = { DOCUMENT_CLASSIFICATION: label.CategoryLabel, SEQUENCE_LABELING: label.OffsetLabel, - SEQ2SEQ: label.TextLabel + SEQ2SEQ: label.TextLabel, + IMAGE_CLASSIFICATION: label.CategoryLabel, } if project_type not in mapping: ValueError(f'Invalid project type: {project_type}')