diff --git a/backend/config/settings/base.py b/backend/config/settings/base.py index 743220a9..08ade658 100644 --- a/backend/config/settings/base.py +++ b/backend/config/settings/base.py @@ -248,7 +248,8 @@ DJANGO_DRF_FILEPOND_UPLOAD_TMP = path.join(BASE_DIR, "filepond-temp-uploads") DJANGO_DRF_FILEPOND_FILE_STORE_PATH = MEDIA_ROOT # File upload setting -MAX_UPLOAD_SIZE = env("MAX_UPLOAD_SIZE", pow(1024, 3)) # default: 1GB per a file +MAX_UPLOAD_SIZE = env.int("MAX_UPLOAD_SIZE", pow(1024, 3)) # default: 1GB per a file +ENABLE_FILE_TYPE_CHECK = env.bool("ENABLE_FILE_TYPE_CHECK", False) # Celery settings DJANGO_CELERY_RESULTS_TASK_ID_MAX_LENGTH = 191 diff --git a/backend/data_import/celery_tasks.py b/backend/data_import/celery_tasks.py index a095be86..8aacd16a 100644 --- a/backend/data_import/celery_tasks.py +++ b/backend/data_import/celery_tasks.py @@ -1,5 +1,6 @@ from typing import List +import filetype from celery import shared_task from django.conf import settings from django.contrib.auth import get_user_model @@ -7,14 +8,29 @@ from django.shortcuts import get_object_or_404 from django_drf_filepond.api import store_upload from django_drf_filepond.models import TemporaryUpload -from .pipeline.exceptions import MaximumFileSizeException +from .pipeline.catalog import AudioFile, ImageFile +from .pipeline.exceptions import FileTypeException, MaximumFileSizeException from .pipeline.factories import create_builder, create_cleaner, create_parser from .pipeline.readers import Reader from .pipeline.writers import BulkWriter from projects.models import Project -def check_uploaded_files(upload_ids: List[str]): +def check_file_type(filename, file_format: str, filepath: str): + if not settings.ENABLE_FILE_TYPE_CHECK: + return + kind = filetype.guess(filepath) + if file_format == ImageFile.name: + accept_types = ImageFile.accept_types.replace(" ", "").split(",") + elif file_format == AudioFile.name: + accept_types = AudioFile.accept_types.replace(" ", "").split(",") + else: + return + if kind.mime not in accept_types: + raise FileTypeException(filename, kind.mime, accept_types) + + +def check_uploaded_files(upload_ids: List[str], file_format: str): errors = [] cleaned_ids = [] temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) @@ -23,6 +39,10 @@ def check_uploaded_files(upload_ids: List[str]): errors.append(MaximumFileSizeException(tu.upload_name, settings.MAX_UPLOAD_SIZE).dict()) tu.delete() continue + try: + check_file_type(tu.upload_name, file_format, tu.get_file_path()) + except FileTypeException as e: + errors.append(e.dict()) cleaned_ids.append(tu.upload_id) return cleaned_ids, errors @@ -32,7 +52,7 @@ def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], project = get_object_or_404(Project, pk=project_id) user = get_object_or_404(get_user_model(), pk=user_id) - upload_ids, errors = check_uploaded_files(upload_ids) + upload_ids, errors = check_uploaded_files(upload_ids, file_format) temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) file_names = [tu.get_file_path() for tu in temporary_uploads] save_names = {tu.get_file_path(): tu.file.name for tu in temporary_uploads} diff --git a/backend/data_import/pipeline/exceptions.py b/backend/data_import/pipeline/exceptions.py index da5dfc7f..d9790ff0 100644 --- a/backend/data_import/pipeline/exceptions.py +++ b/backend/data_import/pipeline/exceptions.py @@ -21,3 +21,16 @@ class MaximumFileSizeException(Exception): def dict(self): return {"filename": self.filename, "line": -1, "message": str(self)} + + +class FileTypeException(Exception): + def __init__(self, filename: str, filetype: str, allowed_types=None): + self.filename = filename + self.filetype = filetype + self.allowed_types = allowed_types + + def __str__(self): + return f"The file type {self.filetype} is unexpected. Expected: {self.allowed_types}" + + def dict(self): + return {"filename": self.filename, "line": -1, "message": str(self)} diff --git a/backend/data_import/tests/data/images/1500x500.jpeg b/backend/data_import/tests/data/images/1500x500.jpeg new file mode 100644 index 00000000..bbada87e Binary files /dev/null and b/backend/data_import/tests/data/images/1500x500.jpeg differ diff --git a/backend/data_import/tests/data/images/example.ico b/backend/data_import/tests/data/images/example.ico new file mode 100644 index 00000000..b47d22c2 Binary files /dev/null and b/backend/data_import/tests/data/images/example.ico differ diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index 81e29788..7e957702 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -273,3 +273,15 @@ class TestImportImageClassificationData(TestImportData): file_format = "ImageFile" self.import_dataset(filename, file_format) self.assertEqual(Example.objects.count(), 1) + + +@override_settings(ENABLE_FILE_TYPE_CHECK=True) +class TestFileTypeChecking(TestImportData): + task = IMAGE_CLASSIFICATION + + def test_example(self): + filename = "images/example.ico" + file_format = "ImageFile" + response = self.import_dataset(filename, file_format) + self.assertEqual(len(response["error"]), 1) + self.assertIn("unexpected", response["error"][0]["message"]) diff --git a/backend/poetry.lock b/backend/poetry.lock index c6cb038f..5700973b 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -535,6 +535,14 @@ category = "main" optional = false python-versions = ">=3.6" +[[package]] +name = "filetype" +version = "1.0.10" +description = "Infer file type and MIME type of any file/buffer. No external dependencies." +category = "main" +optional = false +python-versions = "*" + [[package]] name = "flake8" version = "4.0.1" @@ -1478,7 +1486,7 @@ postgresql = [] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "2b7d21429a29cfe0f07be53971292f593d955bad855f76e0d05f22cbfd60baf9" +content-hash = "aba19bd73c1d181ff5a54ea860e255055b0f9ed7ad9788605f989f03715bad64" [metadata.files] amqp = [ @@ -1714,6 +1722,10 @@ et-xmlfile = [ {file = "et_xmlfile-1.1.0-py3-none-any.whl", hash = "sha256:a2ba85d1d6a74ef63837eed693bcb89c3f752169b0e3e7ae5b16ca5e1b3deada"}, {file = "et_xmlfile-1.1.0.tar.gz", hash = "sha256:8eb9e2bc2f8c97e37a2dc85a09ecdcdec9d8a396530a6d5a33b30b9a92da0c5c"}, ] +filetype = [ + {file = "filetype-1.0.10-py2.py3-none-any.whl", hash = "sha256:63fbe6e818a3d1cfac1d62b196574a7a4b7fc8e06a6c500d53577c018ef127d9"}, + {file = "filetype-1.0.10.tar.gz", hash = "sha256:323a13500731b6c65a253bc3930bbce9a56dfba71e90b60ffd968ab69d9ae937"}, +] flake8 = [ {file = "flake8-4.0.1-py2.py3-none-any.whl", hash = "sha256:479b1304f72536a55948cb40a32dce8bb0ffe3501e26eaf292c7e60eb5e0428d"}, {file = "flake8-4.0.1.tar.gz", hash = "sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d"}, diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d0f8f1a0..a315a7be 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -62,6 +62,7 @@ django-health-check = "^3.16.5" djangorestframework-xml = "^2.0.0" django-storages = {extras = ["google"], version = "^1.12.3"} django-cleanup = "^6.0.0" +filetype = "^1.0.10" [tool.poetry.dev-dependencies] model-mommy = "^2.0.0"