diff --git a/backend/data_import/celery_tasks.py b/backend/data_import/celery_tasks.py index 7f3c4249..c51b7557 100644 --- a/backend/data_import/celery_tasks.py +++ b/backend/data_import/celery_tasks.py @@ -9,7 +9,7 @@ from django_drf_filepond.api import store_upload from django_drf_filepond.models import TemporaryUpload from .datasets import load_dataset -from .pipeline.catalog import AudioFile, ImageFile +from .pipeline.catalog import Format, create_file_format from .pipeline.exceptions import ( FileImportException, FileTypeException, @@ -19,21 +19,15 @@ from .pipeline.readers import FileName from projects.models import Project -def check_file_type(filename, file_format: str, filepath: str): +def check_file_type(filename, file_format: Format, filepath: str): if not settings.ENABLE_FILE_TYPE_CHECK: return kind = filetype.guess(filepath) - if file_format == ImageFile.name: - accept_types = ImageFile.accept_types.replace(" ", "").split(",") - elif file_format == AudioFile.name: - accept_types = AudioFile.accept_types.replace(" ", "").split(",") - else: - return - if kind.mime not in accept_types: - raise FileTypeException(filename, kind.mime, accept_types) + if not file_format.validate_mime(kind.mime): + raise FileTypeException(filename, kind.mime, file_format.accept_types) -def check_uploaded_files(upload_ids: List[str], file_format: str): +def check_uploaded_files(upload_ids: List[str], file_format: Format): errors: List[FileImportException] = [] cleaned_ids = [] temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) @@ -56,19 +50,22 @@ def check_uploaded_files(upload_ids: List[str], file_format: str): def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], task: str, **kwargs): project = get_object_or_404(Project, pk=project_id) user = get_object_or_404(get_user_model(), pk=user_id) + try: + fmt = create_file_format(file_format) + upload_ids, errors = check_uploaded_files(upload_ids, fmt) + temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) + filenames = [ + FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name) + for tu in temporary_uploads + ] - upload_ids, errors = check_uploaded_files(upload_ids, file_format) - temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) - filenames = [ - FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name) - for tu in temporary_uploads - ] - - dataset = load_dataset(task, file_format, filenames, project, **kwargs) - dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE) - upload_to_store(temporary_uploads) - errors.extend(dataset.errors) - return {"error": [e.dict() for e in errors]} + dataset = load_dataset(task, fmt, filenames, project, **kwargs) + dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE) + upload_to_store(temporary_uploads) + errors.extend(dataset.errors) + return {"error": [e.dict() for e in errors]} + except FileImportException as e: + return {"error": [e.dict()]} def upload_to_store(temporary_uploads): diff --git a/backend/data_import/datasets.py b/backend/data_import/datasets.py index 76e8b3e2..e1512dd5 100644 --- a/backend/data_import/datasets.py +++ b/backend/data_import/datasets.py @@ -4,7 +4,7 @@ from typing import List, Type from django.contrib.auth.models import User from .models import DummyLabelType -from .pipeline.catalog import RELATION_EXTRACTION, TextFile, TextLine +from .pipeline.catalog import RELATION_EXTRACTION, Format from .pipeline.data import BaseData, BinaryData, TextData from .pipeline.exceptions import FileParseException from .pipeline.factories import create_parser @@ -210,7 +210,7 @@ class CategoryAndSpanDataset(Dataset): return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors -def select_dataset(project: Project, task: str, file_format: str) -> Type[Dataset]: +def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]: mapping = { DOCUMENT_CLASSIFICATION: TextClassificationDataset, SEQUENCE_LABELING: SequenceLabelingDataset, @@ -222,12 +222,12 @@ def select_dataset(project: Project, task: str, file_format: str) -> Type[Datase } if task not in mapping: task = project.project_type - if project.is_text_project and file_format in [TextLine.name, TextFile.name]: + if project.is_text_project and file_format.is_plain_text(): return PlainDataset return mapping[task] -def load_dataset(task: str, file_format: str, data_files: List[FileName], project: Project, **kwargs) -> Dataset: +def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset: parser = create_parser(file_format, **kwargs) reader = Reader(data_files, parser) dataset_class = select_dataset(project, task, file_format) diff --git a/backend/data_import/pipeline/catalog.py b/backend/data_import/pipeline/catalog.py index ba1330ef..1feb24dc 100644 --- a/backend/data_import/pipeline/catalog.py +++ b/backend/data_import/pipeline/catalog.py @@ -6,6 +6,7 @@ from typing import Dict, List, Type from pydantic import BaseModel from typing_extensions import Literal +from .exceptions import FileFormatException from projects.models import ( DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, @@ -140,6 +141,13 @@ class Format: def dict(cls): return {"name": cls.name, "accept_types": cls.accept_types} + def validate_mime(self, mime: str): + return True + + @staticmethod + def is_plain_text(): + return False + class CSV(Format): name = "CSV" @@ -170,11 +178,19 @@ class TextFile(Format): name = "TextFile" accept_types = "text/*" + @staticmethod + def is_plain_text(): + return True + class TextLine(Format): name = "TextLine" accept_types = "text/*" + @staticmethod + def is_plain_text(): + return True + class CoNLL(Format): name = "CoNLL" @@ -185,11 +201,17 @@ class ImageFile(Format): name = "ImageFile" accept_types = "image/png, image/jpeg, image/bmp, image/gif" + def validate_mime(self, mime: str): + return mime in self.accept_types + class AudioFile(Format): name = "AudioFile" accept_types = "audio/ogg, audio/aac, audio/mpeg, audio/wav" + def validate_mime(self, mime: str): + return mime in self.accept_types + class ArgColumn(BaseModel): encoding: encodings = "utf_8" @@ -239,6 +261,13 @@ class Option: } +def create_file_format(file_format: str) -> Format: + for format_class in Format.__subclasses__(): + if format_class.name == file_format: + return format_class() + raise FileFormatException(file_format) + + class Options: options: Dict[str, List] = defaultdict(list) diff --git a/backend/data_import/pipeline/exceptions.py b/backend/data_import/pipeline/exceptions.py index efab9ca7..93fd157c 100644 --- a/backend/data_import/pipeline/exceptions.py +++ b/backend/data_import/pipeline/exceptions.py @@ -42,3 +42,12 @@ class FileTypeException(FileImportException): def dict(self): return {"filename": self.filename, "line": -1, "message": str(self)} + + +class FileFormatException(FileImportException): + def __init__(self, file_format: str): + self.file_format = file_format + + def dict(self): + message = f"Unknown file format: {self.file_format}" + return {"message": message} diff --git a/backend/data_import/pipeline/factories.py b/backend/data_import/pipeline/factories.py index 05d8b0a1..b9a68bac 100644 --- a/backend/data_import/pipeline/factories.py +++ b/backend/data_import/pipeline/factories.py @@ -6,6 +6,7 @@ from .catalog import ( CoNLL, Excel, FastText, + Format, ImageFile, TextFile, TextLine, @@ -23,7 +24,7 @@ from .parsers import ( ) -def create_parser(file_format: str, **kwargs): +def create_parser(file_format: Format, **kwargs): mapping = { TextFile.name: TextFileParser, TextLine.name: LineParser, @@ -36,6 +37,4 @@ def create_parser(file_format: str, **kwargs): ImageFile.name: PlainParser, AudioFile.name: PlainParser, } - if file_format not in mapping: - raise ValueError(f"Invalid format: {file_format}") - return mapping[file_format](**kwargs) + return mapping[file_format.name](**kwargs) diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index 888a0bd8..d6696916 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -68,6 +68,16 @@ class TestMaxFileSize(TestImportData): self.assertIn("maximum file size", response["error"][0]["message"]) +class TestInvalidFileFormat(TestImportData): + task = DOCUMENT_CLASSIFICATION + + def test_invalid_file_format(self): + filename = "text_classification/example.csv" + file_format = "INVALID_FORMAT" + response = self.import_dataset(filename, file_format, self.task) + self.assertEqual(len(response["error"]), 1) + + class TestImportClassificationData(TestImportData): task = DOCUMENT_CLASSIFICATION