Browse Source

Replace file_format string with Format class

pull/1832/head
Hironsan 2 years ago
parent
commit
f3b40ebcb3
6 changed files with 75 additions and 31 deletions
  1. 43
      backend/data_import/celery_tasks.py
  2. 8
      backend/data_import/datasets.py
  3. 29
      backend/data_import/pipeline/catalog.py
  4. 9
      backend/data_import/pipeline/exceptions.py
  5. 7
      backend/data_import/pipeline/factories.py
  6. 10
      backend/data_import/tests/test_tasks.py

43
backend/data_import/celery_tasks.py

@ -9,7 +9,7 @@ from django_drf_filepond.api import store_upload
from django_drf_filepond.models import TemporaryUpload
from .datasets import load_dataset
from .pipeline.catalog import AudioFile, ImageFile
from .pipeline.catalog import Format, create_file_format
from .pipeline.exceptions import (
FileImportException,
FileTypeException,
@ -19,21 +19,15 @@ from .pipeline.readers import FileName
from projects.models import Project
def check_file_type(filename, file_format: str, filepath: str):
def check_file_type(filename, file_format: Format, filepath: str):
if not settings.ENABLE_FILE_TYPE_CHECK:
return
kind = filetype.guess(filepath)
if file_format == ImageFile.name:
accept_types = ImageFile.accept_types.replace(" ", "").split(",")
elif file_format == AudioFile.name:
accept_types = AudioFile.accept_types.replace(" ", "").split(",")
else:
return
if kind.mime not in accept_types:
raise FileTypeException(filename, kind.mime, accept_types)
if not file_format.validate_mime(kind.mime):
raise FileTypeException(filename, kind.mime, file_format.accept_types)
def check_uploaded_files(upload_ids: List[str], file_format: str):
def check_uploaded_files(upload_ids: List[str], file_format: Format):
errors: List[FileImportException] = []
cleaned_ids = []
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
@ -56,19 +50,22 @@ def check_uploaded_files(upload_ids: List[str], file_format: str):
def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], task: str, **kwargs):
project = get_object_or_404(Project, pk=project_id)
user = get_object_or_404(get_user_model(), pk=user_id)
try:
fmt = create_file_format(file_format)
upload_ids, errors = check_uploaded_files(upload_ids, fmt)
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
filenames = [
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
for tu in temporary_uploads
]
upload_ids, errors = check_uploaded_files(upload_ids, file_format)
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
filenames = [
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
for tu in temporary_uploads
]
dataset = load_dataset(task, file_format, filenames, project, **kwargs)
dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE)
upload_to_store(temporary_uploads)
errors.extend(dataset.errors)
return {"error": [e.dict() for e in errors]}
dataset = load_dataset(task, fmt, filenames, project, **kwargs)
dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE)
upload_to_store(temporary_uploads)
errors.extend(dataset.errors)
return {"error": [e.dict() for e in errors]}
except FileImportException as e:
return {"error": [e.dict()]}
def upload_to_store(temporary_uploads):

8
backend/data_import/datasets.py

@ -4,7 +4,7 @@ from typing import List, Type
from django.contrib.auth.models import User
from .models import DummyLabelType
from .pipeline.catalog import RELATION_EXTRACTION, TextFile, TextLine
from .pipeline.catalog import RELATION_EXTRACTION, Format
from .pipeline.data import BaseData, BinaryData, TextData
from .pipeline.exceptions import FileParseException
from .pipeline.factories import create_parser
@ -210,7 +210,7 @@ class CategoryAndSpanDataset(Dataset):
return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors
def select_dataset(project: Project, task: str, file_format: str) -> Type[Dataset]:
def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]:
mapping = {
DOCUMENT_CLASSIFICATION: TextClassificationDataset,
SEQUENCE_LABELING: SequenceLabelingDataset,
@ -222,12 +222,12 @@ def select_dataset(project: Project, task: str, file_format: str) -> Type[Datase
}
if task not in mapping:
task = project.project_type
if project.is_text_project and file_format in [TextLine.name, TextFile.name]:
if project.is_text_project and file_format.is_plain_text():
return PlainDataset
return mapping[task]
def load_dataset(task: str, file_format: str, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
parser = create_parser(file_format, **kwargs)
reader = Reader(data_files, parser)
dataset_class = select_dataset(project, task, file_format)

29
backend/data_import/pipeline/catalog.py

@ -6,6 +6,7 @@ from typing import Dict, List, Type
from pydantic import BaseModel
from typing_extensions import Literal
from .exceptions import FileFormatException
from projects.models import (
DOCUMENT_CLASSIFICATION,
IMAGE_CLASSIFICATION,
@ -140,6 +141,13 @@ class Format:
def dict(cls):
return {"name": cls.name, "accept_types": cls.accept_types}
def validate_mime(self, mime: str):
return True
@staticmethod
def is_plain_text():
return False
class CSV(Format):
name = "CSV"
@ -170,11 +178,19 @@ class TextFile(Format):
name = "TextFile"
accept_types = "text/*"
@staticmethod
def is_plain_text():
return True
class TextLine(Format):
name = "TextLine"
accept_types = "text/*"
@staticmethod
def is_plain_text():
return True
class CoNLL(Format):
name = "CoNLL"
@ -185,11 +201,17 @@ class ImageFile(Format):
name = "ImageFile"
accept_types = "image/png, image/jpeg, image/bmp, image/gif"
def validate_mime(self, mime: str):
return mime in self.accept_types
class AudioFile(Format):
name = "AudioFile"
accept_types = "audio/ogg, audio/aac, audio/mpeg, audio/wav"
def validate_mime(self, mime: str):
return mime in self.accept_types
class ArgColumn(BaseModel):
encoding: encodings = "utf_8"
@ -239,6 +261,13 @@ class Option:
}
def create_file_format(file_format: str) -> Format:
for format_class in Format.__subclasses__():
if format_class.name == file_format:
return format_class()
raise FileFormatException(file_format)
class Options:
options: Dict[str, List] = defaultdict(list)

9
backend/data_import/pipeline/exceptions.py

@ -42,3 +42,12 @@ class FileTypeException(FileImportException):
def dict(self):
return {"filename": self.filename, "line": -1, "message": str(self)}
class FileFormatException(FileImportException):
def __init__(self, file_format: str):
self.file_format = file_format
def dict(self):
message = f"Unknown file format: {self.file_format}"
return {"message": message}

7
backend/data_import/pipeline/factories.py

@ -6,6 +6,7 @@ from .catalog import (
CoNLL,
Excel,
FastText,
Format,
ImageFile,
TextFile,
TextLine,
@ -23,7 +24,7 @@ from .parsers import (
)
def create_parser(file_format: str, **kwargs):
def create_parser(file_format: Format, **kwargs):
mapping = {
TextFile.name: TextFileParser,
TextLine.name: LineParser,
@ -36,6 +37,4 @@ def create_parser(file_format: str, **kwargs):
ImageFile.name: PlainParser,
AudioFile.name: PlainParser,
}
if file_format not in mapping:
raise ValueError(f"Invalid format: {file_format}")
return mapping[file_format](**kwargs)
return mapping[file_format.name](**kwargs)

10
backend/data_import/tests/test_tasks.py

@ -68,6 +68,16 @@ class TestMaxFileSize(TestImportData):
self.assertIn("maximum file size", response["error"][0]["message"])
class TestInvalidFileFormat(TestImportData):
task = DOCUMENT_CLASSIFICATION
def test_invalid_file_format(self):
filename = "text_classification/example.csv"
file_format = "INVALID_FORMAT"
response = self.import_dataset(filename, file_format, self.task)
self.assertEqual(len(response["error"]), 1)
class TestImportClassificationData(TestImportData):
task = DOCUMENT_CLASSIFICATION

Loading…
Cancel
Save