diff --git a/backend/data_import/celery_tasks.py b/backend/data_import/celery_tasks.py index 18cc42b5..0c86ff36 100644 --- a/backend/data_import/celery_tasks.py +++ b/backend/data_import/celery_tasks.py @@ -4,7 +4,7 @@ from django.contrib.auth import get_user_model from django.shortcuts import get_object_or_404 from api.models import Project -from .pipeline.factories import create_parser, create_bulder, create_cleaner +from .pipeline.factories import create_parser, create_builder, create_cleaner from .pipeline.readers import Reader from .pipeline.writers import BulkWriter @@ -15,7 +15,7 @@ def import_dataset(user_id, project_id, filenames, file_format: str, **kwargs): user = get_object_or_404(get_user_model(), pk=user_id) parser = create_parser(file_format, **kwargs) - builder = create_bulder(project, **kwargs) + builder = create_builder(project, **kwargs) reader = Reader(filenames=filenames, parser=parser, builder=builder) cleaner = create_cleaner(project) writer = BulkWriter(batch_size=settings.IMPORT_BATCH_SIZE) diff --git a/backend/data_import/pipeline/builders.py b/backend/data_import/pipeline/builders.py index dad1de98..3449acc9 100644 --- a/backend/data_import/pipeline/builders.py +++ b/backend/data_import/pipeline/builders.py @@ -16,11 +16,12 @@ T = TypeVar('T') class PlainBuilder(Builder): def __init__(self, data_class: Type[BaseData]): + print(data_class) self.data_class = data_class def build(self, row: Dict[Any, Any], filename: str, line_num: int) -> Record: data = self.data_class.parse(filename=filename) - yield Record(data=data) + return Record(data=data) def build_label(row: Dict[Any, Any], name: str, label_class: Type[Label]) -> List[Label]: diff --git a/backend/data_import/pipeline/factories.py b/backend/data_import/pipeline/factories.py index ab61eae2..5eccea89 100644 --- a/backend/data_import/pipeline/factories.py +++ b/backend/data_import/pipeline/factories.py @@ -55,12 +55,15 @@ def create_cleaner(project): IMAGE_CLASSIFICATION: cleaners.CategoryCleaner } if project.project_type not in mapping: - ValueError(f'Invalid project type: {project.project_type}') + return cleaners.Cleaner(project) cleaner_class = mapping.get(project.project_type, cleaners.Cleaner) return cleaner_class(project) -def create_bulder(project, **kwargs): +def create_builder(project, **kwargs): + if not project.is_text_project: + return builders.PlainBuilder(data_class=get_data_class(project.project_type)) + data_column = builders.DataColumn( name=kwargs.get('column_data') or readers.DEFAULT_TEXT_COLUMN, value_class=get_data_class(project.project_type) diff --git a/backend/data_import/pipeline/parsers.py b/backend/data_import/pipeline/parsers.py index f2d43bf4..80ae8754 100644 --- a/backend/data_import/pipeline/parsers.py +++ b/backend/data_import/pipeline/parsers.py @@ -96,6 +96,9 @@ class PlainParser(Parser): This is for a task without any text. """ + def __init__(self, **kwargs): + self.kwargs = kwargs + def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: yield {} diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index 91488d1a..71628695 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -5,7 +5,7 @@ from django.test import TestCase from data_import.celery_tasks import import_dataset from api.models import (DOCUMENT_CLASSIFICATION, INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ, - SEQUENCE_LABELING) + SEQUENCE_LABELING, IMAGE_CLASSIFICATION) from examples.models import Example from label_types.models import CategoryType, SpanType from labels.models import Category, Span @@ -242,7 +242,7 @@ class TestImportSeq2seqData(TestImportData): self.assert_examples(dataset) -class TextImportIntentDetectionAndSlotFillingData(TestImportData): +class TestImportIntentDetectionAndSlotFillingData(TestImportData): task = INTENT_DETECTION_AND_SLOT_FILLING def assert_examples(self, dataset): @@ -265,3 +265,13 @@ class TextImportIntentDetectionAndSlotFillingData(TestImportData): ] self.import_dataset(filename, file_format) self.assert_examples(dataset) + + +class TestImportImageClassificationData(TestImportData): + task = IMAGE_CLASSIFICATION + + def test_example(self): + filename = 'images/1500x500.jpeg' + file_format = 'ImageFile' + self.import_dataset(filename, file_format) + self.assertEqual(Example.objects.count(), 1)