diff --git a/backend/data_import/celery_tasks.py b/backend/data_import/celery_tasks.py index 7f3c4249..c51b7557 100644 --- a/backend/data_import/celery_tasks.py +++ b/backend/data_import/celery_tasks.py @@ -9,7 +9,7 @@ from django_drf_filepond.api import store_upload from django_drf_filepond.models import TemporaryUpload from .datasets import load_dataset -from .pipeline.catalog import AudioFile, ImageFile +from .pipeline.catalog import Format, create_file_format from .pipeline.exceptions import ( FileImportException, FileTypeException, @@ -19,21 +19,15 @@ from .pipeline.readers import FileName from projects.models import Project -def check_file_type(filename, file_format: str, filepath: str): +def check_file_type(filename, file_format: Format, filepath: str): if not settings.ENABLE_FILE_TYPE_CHECK: return kind = filetype.guess(filepath) - if file_format == ImageFile.name: - accept_types = ImageFile.accept_types.replace(" ", "").split(",") - elif file_format == AudioFile.name: - accept_types = AudioFile.accept_types.replace(" ", "").split(",") - else: - return - if kind.mime not in accept_types: - raise FileTypeException(filename, kind.mime, accept_types) + if not file_format.validate_mime(kind.mime): + raise FileTypeException(filename, kind.mime, file_format.accept_types) -def check_uploaded_files(upload_ids: List[str], file_format: str): +def check_uploaded_files(upload_ids: List[str], file_format: Format): errors: List[FileImportException] = [] cleaned_ids = [] temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) @@ -56,19 +50,22 @@ def check_uploaded_files(upload_ids: List[str], file_format: str): def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], task: str, **kwargs): project = get_object_or_404(Project, pk=project_id) user = get_object_or_404(get_user_model(), pk=user_id) + try: + fmt = create_file_format(file_format) + upload_ids, errors = check_uploaded_files(upload_ids, fmt) + temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) + filenames = [ + FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name) + for tu in temporary_uploads + ] - upload_ids, errors = check_uploaded_files(upload_ids, file_format) - temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) - filenames = [ - FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name) - for tu in temporary_uploads - ] - - dataset = load_dataset(task, file_format, filenames, project, **kwargs) - dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE) - upload_to_store(temporary_uploads) - errors.extend(dataset.errors) - return {"error": [e.dict() for e in errors]} + dataset = load_dataset(task, fmt, filenames, project, **kwargs) + dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE) + upload_to_store(temporary_uploads) + errors.extend(dataset.errors) + return {"error": [e.dict() for e in errors]} + except FileImportException as e: + return {"error": [e.dict()]} def upload_to_store(temporary_uploads): diff --git a/backend/data_import/datasets.py b/backend/data_import/datasets.py index 76e8b3e2..e1512dd5 100644 --- a/backend/data_import/datasets.py +++ b/backend/data_import/datasets.py @@ -4,7 +4,7 @@ from typing import List, Type from django.contrib.auth.models import User from .models import DummyLabelType -from .pipeline.catalog import RELATION_EXTRACTION, TextFile, TextLine +from .pipeline.catalog import RELATION_EXTRACTION, Format from .pipeline.data import BaseData, BinaryData, TextData from .pipeline.exceptions import FileParseException from .pipeline.factories import create_parser @@ -210,7 +210,7 @@ class CategoryAndSpanDataset(Dataset): return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors -def select_dataset(project: Project, task: str, file_format: str) -> Type[Dataset]: +def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]: mapping = { DOCUMENT_CLASSIFICATION: TextClassificationDataset, SEQUENCE_LABELING: SequenceLabelingDataset, @@ -222,12 +222,12 @@ def select_dataset(project: Project, task: str, file_format: str) -> Type[Datase } if task not in mapping: task = project.project_type - if project.is_text_project and file_format in [TextLine.name, TextFile.name]: + if project.is_text_project and file_format.is_plain_text(): return PlainDataset return mapping[task] -def load_dataset(task: str, file_format: str, data_files: List[FileName], project: Project, **kwargs) -> Dataset: +def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset: parser = create_parser(file_format, **kwargs) reader = Reader(data_files, parser) dataset_class = select_dataset(project, task, file_format) diff --git a/backend/data_import/pipeline/catalog.py b/backend/data_import/pipeline/catalog.py index ba1330ef..1feb24dc 100644 --- a/backend/data_import/pipeline/catalog.py +++ b/backend/data_import/pipeline/catalog.py @@ -6,6 +6,7 @@ from typing import Dict, List, Type from pydantic import BaseModel from typing_extensions import Literal +from .exceptions import FileFormatException from projects.models import ( DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, @@ -140,6 +141,13 @@ class Format: def dict(cls): return {"name": cls.name, "accept_types": cls.accept_types} + def validate_mime(self, mime: str): + return True + + @staticmethod + def is_plain_text(): + return False + class CSV(Format): name = "CSV" @@ -170,11 +178,19 @@ class TextFile(Format): name = "TextFile" accept_types = "text/*" + @staticmethod + def is_plain_text(): + return True + class TextLine(Format): name = "TextLine" accept_types = "text/*" + @staticmethod + def is_plain_text(): + return True + class CoNLL(Format): name = "CoNLL" @@ -185,11 +201,17 @@ class ImageFile(Format): name = "ImageFile" accept_types = "image/png, image/jpeg, image/bmp, image/gif" + def validate_mime(self, mime: str): + return mime in self.accept_types + class AudioFile(Format): name = "AudioFile" accept_types = "audio/ogg, audio/aac, audio/mpeg, audio/wav" + def validate_mime(self, mime: str): + return mime in self.accept_types + class ArgColumn(BaseModel): encoding: encodings = "utf_8" @@ -239,6 +261,13 @@ class Option: } +def create_file_format(file_format: str) -> Format: + for format_class in Format.__subclasses__(): + if format_class.name == file_format: + return format_class() + raise FileFormatException(file_format) + + class Options: options: Dict[str, List] = defaultdict(list) diff --git a/backend/data_import/pipeline/exceptions.py b/backend/data_import/pipeline/exceptions.py index efab9ca7..93fd157c 100644 --- a/backend/data_import/pipeline/exceptions.py +++ b/backend/data_import/pipeline/exceptions.py @@ -42,3 +42,12 @@ class FileTypeException(FileImportException): def dict(self): return {"filename": self.filename, "line": -1, "message": str(self)} + + +class FileFormatException(FileImportException): + def __init__(self, file_format: str): + self.file_format = file_format + + def dict(self): + message = f"Unknown file format: {self.file_format}" + return {"message": message} diff --git a/backend/data_import/pipeline/factories.py b/backend/data_import/pipeline/factories.py index 05d8b0a1..b9a68bac 100644 --- a/backend/data_import/pipeline/factories.py +++ b/backend/data_import/pipeline/factories.py @@ -6,6 +6,7 @@ from .catalog import ( CoNLL, Excel, FastText, + Format, ImageFile, TextFile, TextLine, @@ -23,7 +24,7 @@ from .parsers import ( ) -def create_parser(file_format: str, **kwargs): +def create_parser(file_format: Format, **kwargs): mapping = { TextFile.name: TextFileParser, TextLine.name: LineParser, @@ -36,6 +37,4 @@ def create_parser(file_format: str, **kwargs): ImageFile.name: PlainParser, AudioFile.name: PlainParser, } - if file_format not in mapping: - raise ValueError(f"Invalid format: {file_format}") - return mapping[file_format](**kwargs) + return mapping[file_format.name](**kwargs) diff --git a/backend/data_import/pipeline/label.py b/backend/data_import/pipeline/label.py index a70cfd9a..bdff2d43 100644 --- a/backend/data_import/pipeline/label.py +++ b/backend/data_import/pipeline/label.py @@ -2,7 +2,7 @@ import abc import uuid from typing import Any, Optional -from pydantic import UUID4, BaseModel, validator +from pydantic import UUID4, BaseModel, ConstrainedStr, NonNegativeInt, root_validator from .label_types import LabelTypes from examples.models import Example @@ -15,6 +15,10 @@ from labels.models import TextLabel as TextLabelModel from projects.models import Project +class NonEmptyStr(ConstrainedStr): + min_length = 1 + + class Label(BaseModel, abc.ABC): id: int = -1 uuid: UUID4 @@ -45,18 +49,11 @@ class Label(BaseModel, abc.ABC): class CategoryLabel(Label): - label: str + label: NonEmptyStr def __lt__(self, other): return self.label < other.label - @validator("label") - def label_is_not_empty(cls, value: str): - if value: - return value - else: - raise ValueError("is not empty.") - @classmethod def parse(cls, example_uuid: UUID4, obj: Any): return cls(example_uuid=example_uuid, label=obj) @@ -65,17 +62,24 @@ class CategoryLabel(Label): return CategoryType(text=self.label, project=project) def create(self, user, example: Example, types: LabelTypes, **kwargs): - return CategoryModel(uuid=self.uuid, user=user, example=example, label=types.get_by_text(self.label)) + return CategoryModel(uuid=self.uuid, user=user, example=example, label=types[self.label]) class SpanLabel(Label): - label: str - start_offset: int - end_offset: int + label: NonEmptyStr + start_offset: NonNegativeInt + end_offset: NonNegativeInt def __lt__(self, other): return self.start_offset < other.start_offset + @root_validator + def check_start_offset_is_less_than_end_offset(cls, values): + start_offset, end_offset = values.get("start_offset"), values.get("end_offset") + if start_offset >= end_offset: + raise ValueError("start_offset must be less than end_offset.") + return values + @classmethod def parse(cls, example_uuid: UUID4, obj: Any): if isinstance(obj, list) or isinstance(obj, tuple): @@ -96,23 +100,16 @@ class SpanLabel(Label): example=example, start_offset=self.start_offset, end_offset=self.end_offset, - label=types.get_by_text(self.label), + label=types[self.label], ) class TextLabel(Label): - text: str + text: NonEmptyStr def __lt__(self, other): return self.text < other.text - @validator("text") - def text_is_not_empty(cls, value: str): - if value: - return value - else: - raise ValueError("is not empty.") - @classmethod def parse(cls, example_uuid: UUID4, obj: Any): return cls(example_uuid=example_uuid, text=obj) @@ -127,7 +124,7 @@ class TextLabel(Label): class RelationLabel(Label): from_id: int to_id: int - type: str + type: NonEmptyStr def __lt__(self, other): return self.from_id < other.from_id @@ -144,7 +141,7 @@ class RelationLabel(Label): uuid=self.uuid, user=user, example=example, - type=types.get_by_text(self.type), + type=types[self.type], from_id=kwargs["id_to_span"][self.from_id], to_id=kwargs["id_to_span"][self.to_id], ) diff --git a/backend/data_import/pipeline/label_types.py b/backend/data_import/pipeline/label_types.py index a83f1ef5..037df5ba 100644 --- a/backend/data_import/pipeline/label_types.py +++ b/backend/data_import/pipeline/label_types.py @@ -12,12 +12,12 @@ class LabelTypes: def __contains__(self, text: str) -> bool: return text in self.types + def __getitem__(self, text: str) -> LabelType: + return self.types[text] + def save(self, label_types: List[LabelType]): self.label_type_class.objects.bulk_create(label_types, ignore_conflicts=True) def update(self, project: Project): types = self.label_type_class.objects.filter(project=project) self.types = {label_type.text: label_type for label_type in types} - - def get_by_text(self, text: str) -> LabelType: - return self.types[text] diff --git a/backend/data_import/tests/test_label.py b/backend/data_import/tests/test_label.py index 010a5d48..da79ca49 100644 --- a/backend/data_import/tests/test_label.py +++ b/backend/data_import/tests/test_label.py @@ -55,7 +55,7 @@ class TestCategoryLabel(TestLabel): def test_create(self): category = CategoryLabel(label="A", example_uuid=uuid.uuid4()) types = MagicMock() - types.get_by_text.return_value = mommy.make(CategoryType, project=self.project.item) + types.__getitem__.return_value = mommy.make(CategoryType, project=self.project.item) category_model = category.create(self.user, self.example, types) self.assertIsInstance(category_model, CategoryModel) @@ -65,7 +65,7 @@ class TestSpanLabel(TestLabel): def test_comparison(self): span1 = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) - span2 = SpanLabel(label="A", start_offset=1, end_offset=1, example_uuid=uuid.uuid4()) + span2 = SpanLabel(label="A", start_offset=1, end_offset=2, example_uuid=uuid.uuid4()) self.assertLess(span1, span2) def test_parse_tuple(self): @@ -82,6 +82,14 @@ class TestSpanLabel(TestLabel): self.assertEqual(span.start_offset, 0) self.assertEqual(span.end_offset, 1) + def test_invalid_negative_offset(self): + with self.assertRaises(ValueError): + SpanLabel(label="A", start_offset=-1, end_offset=1, example_uuid=uuid.uuid4()) + + def test_invalid_offset(self): + with self.assertRaises(ValueError): + SpanLabel(label="A", start_offset=1, end_offset=0, example_uuid=uuid.uuid4()) + def test_parse_invalid_dict(self): example_uuid = uuid.uuid4() with self.assertRaises(ValueError): @@ -96,7 +104,7 @@ class TestSpanLabel(TestLabel): def test_create(self): span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) types = MagicMock() - types.get_by_text.return_value = mommy.make(SpanType, project=self.project.item) + types.__getitem__.return_value = mommy.make(SpanType, project=self.project.item) span_model = span.create(self.user, self.example, types) self.assertIsInstance(span_model, SpanModel) @@ -160,7 +168,7 @@ class TestRelationLabel(TestLabel): def test_create(self): relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) types = MagicMock() - types.get_by_text.return_value = mommy.make(RelationType, project=self.project.item) + types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item) id_to_span = { 0: mommy.make(SpanModel, start_offset=0, end_offset=1), 1: mommy.make(SpanModel, start_offset=2, end_offset=3), diff --git a/backend/data_import/tests/test_label_types.py b/backend/data_import/tests/test_label_types.py index 099f08c6..d1195fdf 100644 --- a/backend/data_import/tests/test_label_types.py +++ b/backend/data_import/tests/test_label_types.py @@ -22,10 +22,8 @@ class TestCategoryLabel(TestCase): def test_update(self): label_types = LabelTypes(CategoryType) - with self.assertRaises(KeyError): - label_types.get_by_text("A") category_types = [CategoryType(text="A", project=self.project.item)] label_types.save(category_types) label_types.update(self.project.item) - category_type = label_types.get_by_text("A") + category_type = label_types["A"] self.assertEqual(category_type.text, "A") diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index 888a0bd8..d6696916 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -68,6 +68,16 @@ class TestMaxFileSize(TestImportData): self.assertIn("maximum file size", response["error"][0]["message"]) +class TestInvalidFileFormat(TestImportData): + task = DOCUMENT_CLASSIFICATION + + def test_invalid_file_format(self): + filename = "text_classification/example.csv" + file_format = "INVALID_FORMAT" + response = self.import_dataset(filename, file_format, self.task) + self.assertEqual(len(response["error"]), 1) + + class TestImportClassificationData(TestImportData): task = DOCUMENT_CLASSIFICATION