diff --git a/backend/config/settings/base.py b/backend/config/settings/base.py index 15ce8044..743220a9 100644 --- a/backend/config/settings/base.py +++ b/backend/config/settings/base.py @@ -247,6 +247,9 @@ MEDIA_URL = "/media/" DJANGO_DRF_FILEPOND_UPLOAD_TMP = path.join(BASE_DIR, "filepond-temp-uploads") DJANGO_DRF_FILEPOND_FILE_STORE_PATH = MEDIA_ROOT +# File upload setting +MAX_UPLOAD_SIZE = env("MAX_UPLOAD_SIZE", pow(1024, 3)) # default: 1GB per a file + # Celery settings DJANGO_CELERY_RESULTS_TASK_ID_MAX_LENGTH = 191 CELERY_RESULT_BACKEND = "django-db" diff --git a/backend/data_import/celery_tasks.py b/backend/data_import/celery_tasks.py index 57f2bd90..dc68071d 100644 --- a/backend/data_import/celery_tasks.py +++ b/backend/data_import/celery_tasks.py @@ -7,28 +7,43 @@ from django.shortcuts import get_object_or_404 from django_drf_filepond.api import store_upload from django_drf_filepond.models import TemporaryUpload +from .pipeline.exceptions import MaximumFileSizeException from .pipeline.factories import create_builder, create_cleaner, create_parser from .pipeline.readers import Reader from .pipeline.writers import BulkWriter from projects.models import Project +def check_uploaded_files(upload_ids: List[str]): + errors = [] + cleaned_ids = [] + temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) + for tu in temporary_uploads: + if tu.file.size > settings.MAX_UPLOAD_SIZE: + errors.append(MaximumFileSizeException(tu.upload_name, settings.MAX_UPLOAD_SIZE).dict()) + tu.delete() + continue + cleaned_ids.append(tu.upload_id) + return cleaned_ids, errors + + @shared_task def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], **kwargs): + project = get_object_or_404(Project, pk=project_id) + user = get_object_or_404(get_user_model(), pk=user_id) + + upload_ids, errors = check_uploaded_files(upload_ids) temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) file_names = [tu.get_file_path() for tu in temporary_uploads] save_names = {tu.get_file_path(): tu.file.name for tu in temporary_uploads} - project = get_object_or_404(Project, pk=project_id) - user = get_object_or_404(get_user_model(), pk=user_id) - parser = create_parser(file_format, **kwargs) builder = create_builder(project, **kwargs) reader = Reader(filenames=file_names, parser=parser, builder=builder) cleaner = create_cleaner(project) writer = BulkWriter(batch_size=settings.IMPORT_BATCH_SIZE, save_names=save_names) writer.save(reader, project, user, cleaner) - return {"error": writer.errors} + return {"error": writer.errors + errors} @shared_task diff --git a/backend/data_import/pipeline/exceptions.py b/backend/data_import/pipeline/exceptions.py index 1a0c957f..da5dfc7f 100644 --- a/backend/data_import/pipeline/exceptions.py +++ b/backend/data_import/pipeline/exceptions.py @@ -9,3 +9,15 @@ class FileParseException(Exception): def dict(self): return {"filename": self.filename, "line": self.line_num, "message": self.message} + + +class MaximumFileSizeException(Exception): + def __init__(self, filename: str, max_size: int): + self.filename = filename + self.max_size = max_size + + def __str__(self): + return f"The maximum file size that can be uploaded is {self.max_size/1024/1024} MB" + + def dict(self): + return {"filename": self.filename, "line": -1, "message": str(self)} diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index 5eb7b83f..684fefd8 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -4,6 +4,7 @@ import pathlib from django.core.files import File from django.test import TestCase, override_settings from django_drf_filepond.models import TemporaryUpload +from django_drf_filepond.utils import _get_file_id from data_import.celery_tasks import import_dataset from examples.models import Example @@ -31,18 +32,32 @@ class TestImportData(TestCase): def import_dataset(self, filename, file_format, kwargs=None): file_path = str(self.data_path / filename) + upload_id = _get_file_id() TemporaryUpload.objects.create( - upload_id="1", + upload_id=upload_id, file_id="1", - file=File(open(file_path, mode="rb"), filename), + file=File(open(file_path, mode="rb"), filename.split("/")[-1]), upload_name=filename, upload_type="F", ) - upload_ids = ["1"] + upload_ids = [upload_id] kwargs = kwargs or {} return import_dataset(self.user.id, self.project.item.id, file_format, upload_ids, **kwargs) +@override_settings(MAX_UPLOAD_SIZE=0) +class TestMaxFileSize(TestImportData): + task = DOCUMENT_CLASSIFICATION + + def test_jsonl(self): + filename = "text_classification/example.jsonl" + file_format = "JSONL" + kwargs = {"column_label": "labels"} + response = self.import_dataset(filename, file_format, kwargs) + self.assertEqual(len(response["error"]), 1) + self.assertIn("maximum file size", response["error"][0]["message"]) + + class TestImportClassificationData(TestImportData): task = DOCUMENT_CLASSIFICATION