|
|
@ -1,5 +1,6 @@ |
|
|
|
from typing import List |
|
|
|
|
|
|
|
import filetype |
|
|
|
from celery import shared_task |
|
|
|
from django.conf import settings |
|
|
|
from django.contrib.auth import get_user_model |
|
|
@ -7,14 +8,29 @@ from django.shortcuts import get_object_or_404 |
|
|
|
from django_drf_filepond.api import store_upload |
|
|
|
from django_drf_filepond.models import TemporaryUpload |
|
|
|
|
|
|
|
from .pipeline.exceptions import MaximumFileSizeException |
|
|
|
from .pipeline.catalog import AudioFile, ImageFile |
|
|
|
from .pipeline.exceptions import FileTypeException, MaximumFileSizeException |
|
|
|
from .pipeline.factories import create_builder, create_cleaner, create_parser |
|
|
|
from .pipeline.readers import Reader |
|
|
|
from .pipeline.writers import BulkWriter |
|
|
|
from projects.models import Project |
|
|
|
|
|
|
|
|
|
|
|
def check_uploaded_files(upload_ids: List[str]): |
|
|
|
def check_file_type(filename, file_format: str, filepath: str): |
|
|
|
if not settings.ENABLE_FILE_TYPE_CHECK: |
|
|
|
return |
|
|
|
kind = filetype.guess(filepath) |
|
|
|
if file_format == ImageFile.name: |
|
|
|
accept_types = ImageFile.accept_types.replace(" ", "").split(",") |
|
|
|
elif file_format == AudioFile.name: |
|
|
|
accept_types = AudioFile.accept_types.replace(" ", "").split(",") |
|
|
|
else: |
|
|
|
return |
|
|
|
if kind.mime not in accept_types: |
|
|
|
raise FileTypeException(filename, kind.mime, accept_types) |
|
|
|
|
|
|
|
|
|
|
|
def check_uploaded_files(upload_ids: List[str], file_format: str): |
|
|
|
errors = [] |
|
|
|
cleaned_ids = [] |
|
|
|
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) |
|
|
@ -23,6 +39,10 @@ def check_uploaded_files(upload_ids: List[str]): |
|
|
|
errors.append(MaximumFileSizeException(tu.upload_name, settings.MAX_UPLOAD_SIZE).dict()) |
|
|
|
tu.delete() |
|
|
|
continue |
|
|
|
try: |
|
|
|
check_file_type(tu.upload_name, file_format, tu.get_file_path()) |
|
|
|
except FileTypeException as e: |
|
|
|
errors.append(e.dict()) |
|
|
|
cleaned_ids.append(tu.upload_id) |
|
|
|
return cleaned_ids, errors |
|
|
|
|
|
|
@ -32,7 +52,7 @@ def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], |
|
|
|
project = get_object_or_404(Project, pk=project_id) |
|
|
|
user = get_object_or_404(get_user_model(), pk=user_id) |
|
|
|
|
|
|
|
upload_ids, errors = check_uploaded_files(upload_ids) |
|
|
|
upload_ids, errors = check_uploaded_files(upload_ids, file_format) |
|
|
|
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) |
|
|
|
file_names = [tu.get_file_path() for tu in temporary_uploads] |
|
|
|
save_names = {tu.get_file_path(): tu.file.name for tu in temporary_uploads} |
|
|
|