Browse Source

Pass upload_ids to import_dataset

pull/1779/head
Hironsan 2 years ago
parent
commit
62a1a6228e
3 changed files with 22 additions and 14 deletions
  1. 10
      backend/data_import/celery_tasks.py
  2. 18
      backend/data_import/tests/test_tasks.py
  3. 8
      backend/data_import/views.py

10
backend/data_import/celery_tasks.py

@ -14,15 +14,17 @@ from projects.models import Project
@shared_task
def import_dataset(user_id, project_id, filenames, file_format: str, save_names=None, **kwargs):
if not save_names:
save_names = {filename: filename for filename in filenames}
def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], **kwargs):
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
file_names = [tu.get_file_path() for tu in temporary_uploads]
save_names = {tu.get_file_path(): tu.file.name for tu in temporary_uploads}
project = get_object_or_404(Project, pk=project_id)
user = get_object_or_404(get_user_model(), pk=user_id)
parser = create_parser(file_format, **kwargs)
builder = create_builder(project, **kwargs)
reader = Reader(filenames=filenames, parser=parser, builder=builder)
reader = Reader(filenames=file_names, parser=parser, builder=builder)
cleaner = create_cleaner(project)
writer = BulkWriter(batch_size=settings.IMPORT_BATCH_SIZE, save_names=save_names)
writer.save(reader, project, user, cleaner)

18
backend/data_import/tests/test_tasks.py

@ -1,6 +1,9 @@
import os
import pathlib
from django.test import TestCase
from django.core.files import File
from django.test import TestCase, override_settings
from django_drf_filepond.models import TemporaryUpload
from data_import.celery_tasks import import_dataset
from examples.models import Example
@ -16,6 +19,7 @@ from projects.models import (
from projects.tests.utils import prepare_project
@override_settings(MEDIA_ROOT=os.path.join(os.path.dirname(__file__), "data"))
class TestImportData(TestCase):
task = "Any"
annotation_class = Category
@ -26,9 +30,17 @@ class TestImportData(TestCase):
self.data_path = pathlib.Path(__file__).parent / "data"
def import_dataset(self, filename, file_format, kwargs=None):
filenames = [str(self.data_path / filename)]
file_path = str(self.data_path / filename)
TemporaryUpload.objects.create(
upload_id="1",
file_id="1",
file=File(open(file_path, mode="rb"), filename),
upload_name=filename,
upload_type="F",
)
upload_ids = ["1"]
kwargs = kwargs or {}
return import_dataset(self.user.id, self.project.item.id, filenames, file_format, **kwargs)
return import_dataset(self.user.id, self.project.item.id, file_format, upload_ids, **kwargs)
class TestImportClassificationData(TestImportData):

8
backend/data_import/views.py

@ -1,5 +1,4 @@
from django.shortcuts import get_object_or_404
from django_drf_filepond.models import TemporaryUpload
from rest_framework import status
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
@ -29,16 +28,11 @@ class DatasetImportAPI(APIView):
upload_ids = request.data.pop("uploadIds")
file_format = request.data.pop("format")
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
file_paths = [tu.get_file_path() for tu in temporary_uploads]
save_names = {tu.get_file_path(): tu.file.name for tu in temporary_uploads}
task = import_dataset.delay(
user_id=request.user.id,
project_id=project_id,
filenames=file_paths,
file_format=file_format,
save_names=save_names,
upload_ids=upload_ids,
**request.data,
)
upload_task = upload_to_store.delay(upload_ids)

Loading…
Cancel
Save