You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

80 lines
2.9 KiB

2 years ago
2 years ago
  1. from typing import List
  2. import filetype
  3. from celery import shared_task
  4. from django.conf import settings
  5. from django.contrib.auth import get_user_model
  6. from django.shortcuts import get_object_or_404
  7. from django_drf_filepond.api import store_upload
  8. from django_drf_filepond.models import TemporaryUpload
  9. from .pipeline.catalog import AudioFile, ImageFile
  10. from .pipeline.exceptions import FileTypeException, MaximumFileSizeException
  11. from .pipeline.factories import (
  12. create_builder,
  13. create_cleaner,
  14. create_parser,
  15. select_examples,
  16. )
  17. from .pipeline.readers import FileName, Reader
  18. from .pipeline.writers import Writer
  19. from projects.models import Project
  20. def check_file_type(filename, file_format: str, filepath: str):
  21. if not settings.ENABLE_FILE_TYPE_CHECK:
  22. return
  23. kind = filetype.guess(filepath)
  24. if file_format == ImageFile.name:
  25. accept_types = ImageFile.accept_types.replace(" ", "").split(",")
  26. elif file_format == AudioFile.name:
  27. accept_types = AudioFile.accept_types.replace(" ", "").split(",")
  28. else:
  29. return
  30. if kind.mime not in accept_types:
  31. raise FileTypeException(filename, kind.mime, accept_types)
  32. def check_uploaded_files(upload_ids: List[str], file_format: str):
  33. errors = []
  34. cleaned_ids = []
  35. temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
  36. for tu in temporary_uploads:
  37. if tu.file.size > settings.MAX_UPLOAD_SIZE:
  38. errors.append(MaximumFileSizeException(tu.upload_name, settings.MAX_UPLOAD_SIZE).dict())
  39. tu.delete()
  40. continue
  41. try:
  42. check_file_type(tu.upload_name, file_format, tu.get_file_path())
  43. except FileTypeException as e:
  44. errors.append(e.dict())
  45. cleaned_ids.append(tu.upload_id)
  46. return cleaned_ids, errors
  47. @shared_task
  48. def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], **kwargs):
  49. project = get_object_or_404(Project, pk=project_id)
  50. user = get_object_or_404(get_user_model(), pk=user_id)
  51. upload_ids, errors = check_uploaded_files(upload_ids, file_format)
  52. temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
  53. filenames = [
  54. FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
  55. for tu in temporary_uploads
  56. ]
  57. parser = create_parser(file_format, **kwargs)
  58. builder = create_builder(project, **kwargs)
  59. cleaner = create_cleaner(project)
  60. reader = Reader(filenames=filenames, parser=parser, builder=builder, cleaner=cleaner)
  61. writer = Writer(batch_size=settings.IMPORT_BATCH_SIZE)
  62. examples = select_examples(project)
  63. writer.save(reader, project, user, examples)
  64. upload_to_store(temporary_uploads)
  65. return {"error": reader.errors + errors}
  66. def upload_to_store(temporary_uploads):
  67. for tu in temporary_uploads:
  68. store_upload(tu.upload_id, destination_file_path=tu.file.name)