From 62a1a6228ea1dcd0840d66950c341fdb2265d975 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Mon, 4 Apr 2022 09:48:49 +0900 Subject: [PATCH] Pass upload_ids to import_dataset --- backend/data_import/celery_tasks.py | 10 ++++++---- backend/data_import/tests/test_tasks.py | 18 +++++++++++++++--- backend/data_import/views.py | 8 +------- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/backend/data_import/celery_tasks.py b/backend/data_import/celery_tasks.py index 7ff8703c..57f2bd90 100644 --- a/backend/data_import/celery_tasks.py +++ b/backend/data_import/celery_tasks.py @@ -14,15 +14,17 @@ from projects.models import Project @shared_task -def import_dataset(user_id, project_id, filenames, file_format: str, save_names=None, **kwargs): - if not save_names: - save_names = {filename: filename for filename in filenames} +def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], **kwargs): + temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) + file_names = [tu.get_file_path() for tu in temporary_uploads] + save_names = {tu.get_file_path(): tu.file.name for tu in temporary_uploads} + project = get_object_or_404(Project, pk=project_id) user = get_object_or_404(get_user_model(), pk=user_id) parser = create_parser(file_format, **kwargs) builder = create_builder(project, **kwargs) - reader = Reader(filenames=filenames, parser=parser, builder=builder) + reader = Reader(filenames=file_names, parser=parser, builder=builder) cleaner = create_cleaner(project) writer = BulkWriter(batch_size=settings.IMPORT_BATCH_SIZE, save_names=save_names) writer.save(reader, project, user, cleaner) diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index 3eba3cc8..5eb7b83f 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -1,6 +1,9 @@ +import os import pathlib -from django.test import TestCase +from django.core.files import File +from django.test import TestCase, override_settings +from django_drf_filepond.models import TemporaryUpload from data_import.celery_tasks import import_dataset from examples.models import Example @@ -16,6 +19,7 @@ from projects.models import ( from projects.tests.utils import prepare_project +@override_settings(MEDIA_ROOT=os.path.join(os.path.dirname(__file__), "data")) class TestImportData(TestCase): task = "Any" annotation_class = Category @@ -26,9 +30,17 @@ class TestImportData(TestCase): self.data_path = pathlib.Path(__file__).parent / "data" def import_dataset(self, filename, file_format, kwargs=None): - filenames = [str(self.data_path / filename)] + file_path = str(self.data_path / filename) + TemporaryUpload.objects.create( + upload_id="1", + file_id="1", + file=File(open(file_path, mode="rb"), filename), + upload_name=filename, + upload_type="F", + ) + upload_ids = ["1"] kwargs = kwargs or {} - return import_dataset(self.user.id, self.project.item.id, filenames, file_format, **kwargs) + return import_dataset(self.user.id, self.project.item.id, file_format, upload_ids, **kwargs) class TestImportClassificationData(TestImportData): diff --git a/backend/data_import/views.py b/backend/data_import/views.py index 280261db..fcb7b428 100644 --- a/backend/data_import/views.py +++ b/backend/data_import/views.py @@ -1,5 +1,4 @@ from django.shortcuts import get_object_or_404 -from django_drf_filepond.models import TemporaryUpload from rest_framework import status from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response @@ -29,16 +28,11 @@ class DatasetImportAPI(APIView): upload_ids = request.data.pop("uploadIds") file_format = request.data.pop("format") - temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) - file_paths = [tu.get_file_path() for tu in temporary_uploads] - save_names = {tu.get_file_path(): tu.file.name for tu in temporary_uploads} - task = import_dataset.delay( user_id=request.user.id, project_id=project_id, - filenames=file_paths, file_format=file_format, - save_names=save_names, + upload_ids=upload_ids, **request.data, ) upload_task = upload_to_store.delay(upload_ids)