Browse Source

Merge pull request #1832 from doccano/refactoring/dataImport

[Refactoring] data import
pull/1837/head
Hiroki Nakayama 2 years ago
committed by GitHub
parent
commit
249926b0e6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 112 additions and 65 deletions
  1. 43
      backend/data_import/celery_tasks.py
  2. 8
      backend/data_import/datasets.py
  3. 29
      backend/data_import/pipeline/catalog.py
  4. 9
      backend/data_import/pipeline/exceptions.py
  5. 7
      backend/data_import/pipeline/factories.py
  6. 45
      backend/data_import/pipeline/label.py
  7. 6
      backend/data_import/pipeline/label_types.py
  8. 16
      backend/data_import/tests/test_label.py
  9. 4
      backend/data_import/tests/test_label_types.py
  10. 10
      backend/data_import/tests/test_tasks.py

43
backend/data_import/celery_tasks.py

@ -9,7 +9,7 @@ from django_drf_filepond.api import store_upload
from django_drf_filepond.models import TemporaryUpload from django_drf_filepond.models import TemporaryUpload
from .datasets import load_dataset from .datasets import load_dataset
from .pipeline.catalog import AudioFile, ImageFile
from .pipeline.catalog import Format, create_file_format
from .pipeline.exceptions import ( from .pipeline.exceptions import (
FileImportException, FileImportException,
FileTypeException, FileTypeException,
@ -19,21 +19,15 @@ from .pipeline.readers import FileName
from projects.models import Project from projects.models import Project
def check_file_type(filename, file_format: str, filepath: str):
def check_file_type(filename, file_format: Format, filepath: str):
if not settings.ENABLE_FILE_TYPE_CHECK: if not settings.ENABLE_FILE_TYPE_CHECK:
return return
kind = filetype.guess(filepath) kind = filetype.guess(filepath)
if file_format == ImageFile.name:
accept_types = ImageFile.accept_types.replace(" ", "").split(",")
elif file_format == AudioFile.name:
accept_types = AudioFile.accept_types.replace(" ", "").split(",")
else:
return
if kind.mime not in accept_types:
raise FileTypeException(filename, kind.mime, accept_types)
if not file_format.validate_mime(kind.mime):
raise FileTypeException(filename, kind.mime, file_format.accept_types)
def check_uploaded_files(upload_ids: List[str], file_format: str):
def check_uploaded_files(upload_ids: List[str], file_format: Format):
errors: List[FileImportException] = [] errors: List[FileImportException] = []
cleaned_ids = [] cleaned_ids = []
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids) temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
@ -56,19 +50,22 @@ def check_uploaded_files(upload_ids: List[str], file_format: str):
def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], task: str, **kwargs): def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], task: str, **kwargs):
project = get_object_or_404(Project, pk=project_id) project = get_object_or_404(Project, pk=project_id)
user = get_object_or_404(get_user_model(), pk=user_id) user = get_object_or_404(get_user_model(), pk=user_id)
try:
fmt = create_file_format(file_format)
upload_ids, errors = check_uploaded_files(upload_ids, fmt)
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
filenames = [
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
for tu in temporary_uploads
]
upload_ids, errors = check_uploaded_files(upload_ids, file_format)
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
filenames = [
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
for tu in temporary_uploads
]
dataset = load_dataset(task, file_format, filenames, project, **kwargs)
dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE)
upload_to_store(temporary_uploads)
errors.extend(dataset.errors)
return {"error": [e.dict() for e in errors]}
dataset = load_dataset(task, fmt, filenames, project, **kwargs)
dataset.save(user, batch_size=settings.IMPORT_BATCH_SIZE)
upload_to_store(temporary_uploads)
errors.extend(dataset.errors)
return {"error": [e.dict() for e in errors]}
except FileImportException as e:
return {"error": [e.dict()]}
def upload_to_store(temporary_uploads): def upload_to_store(temporary_uploads):

8
backend/data_import/datasets.py

@ -4,7 +4,7 @@ from typing import List, Type
from django.contrib.auth.models import User from django.contrib.auth.models import User
from .models import DummyLabelType from .models import DummyLabelType
from .pipeline.catalog import RELATION_EXTRACTION, TextFile, TextLine
from .pipeline.catalog import RELATION_EXTRACTION, Format
from .pipeline.data import BaseData, BinaryData, TextData from .pipeline.data import BaseData, BinaryData, TextData
from .pipeline.exceptions import FileParseException from .pipeline.exceptions import FileParseException
from .pipeline.factories import create_parser from .pipeline.factories import create_parser
@ -210,7 +210,7 @@ class CategoryAndSpanDataset(Dataset):
return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors
def select_dataset(project: Project, task: str, file_format: str) -> Type[Dataset]:
def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]:
mapping = { mapping = {
DOCUMENT_CLASSIFICATION: TextClassificationDataset, DOCUMENT_CLASSIFICATION: TextClassificationDataset,
SEQUENCE_LABELING: SequenceLabelingDataset, SEQUENCE_LABELING: SequenceLabelingDataset,
@ -222,12 +222,12 @@ def select_dataset(project: Project, task: str, file_format: str) -> Type[Datase
} }
if task not in mapping: if task not in mapping:
task = project.project_type task = project.project_type
if project.is_text_project and file_format in [TextLine.name, TextFile.name]:
if project.is_text_project and file_format.is_plain_text():
return PlainDataset return PlainDataset
return mapping[task] return mapping[task]
def load_dataset(task: str, file_format: str, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
parser = create_parser(file_format, **kwargs) parser = create_parser(file_format, **kwargs)
reader = Reader(data_files, parser) reader = Reader(data_files, parser)
dataset_class = select_dataset(project, task, file_format) dataset_class = select_dataset(project, task, file_format)

29
backend/data_import/pipeline/catalog.py

@ -6,6 +6,7 @@ from typing import Dict, List, Type
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Literal from typing_extensions import Literal
from .exceptions import FileFormatException
from projects.models import ( from projects.models import (
DOCUMENT_CLASSIFICATION, DOCUMENT_CLASSIFICATION,
IMAGE_CLASSIFICATION, IMAGE_CLASSIFICATION,
@ -140,6 +141,13 @@ class Format:
def dict(cls): def dict(cls):
return {"name": cls.name, "accept_types": cls.accept_types} return {"name": cls.name, "accept_types": cls.accept_types}
def validate_mime(self, mime: str):
return True
@staticmethod
def is_plain_text():
return False
class CSV(Format): class CSV(Format):
name = "CSV" name = "CSV"
@ -170,11 +178,19 @@ class TextFile(Format):
name = "TextFile" name = "TextFile"
accept_types = "text/*" accept_types = "text/*"
@staticmethod
def is_plain_text():
return True
class TextLine(Format): class TextLine(Format):
name = "TextLine" name = "TextLine"
accept_types = "text/*" accept_types = "text/*"
@staticmethod
def is_plain_text():
return True
class CoNLL(Format): class CoNLL(Format):
name = "CoNLL" name = "CoNLL"
@ -185,11 +201,17 @@ class ImageFile(Format):
name = "ImageFile" name = "ImageFile"
accept_types = "image/png, image/jpeg, image/bmp, image/gif" accept_types = "image/png, image/jpeg, image/bmp, image/gif"
def validate_mime(self, mime: str):
return mime in self.accept_types
class AudioFile(Format): class AudioFile(Format):
name = "AudioFile" name = "AudioFile"
accept_types = "audio/ogg, audio/aac, audio/mpeg, audio/wav" accept_types = "audio/ogg, audio/aac, audio/mpeg, audio/wav"
def validate_mime(self, mime: str):
return mime in self.accept_types
class ArgColumn(BaseModel): class ArgColumn(BaseModel):
encoding: encodings = "utf_8" encoding: encodings = "utf_8"
@ -239,6 +261,13 @@ class Option:
} }
def create_file_format(file_format: str) -> Format:
for format_class in Format.__subclasses__():
if format_class.name == file_format:
return format_class()
raise FileFormatException(file_format)
class Options: class Options:
options: Dict[str, List] = defaultdict(list) options: Dict[str, List] = defaultdict(list)

9
backend/data_import/pipeline/exceptions.py

@ -42,3 +42,12 @@ class FileTypeException(FileImportException):
def dict(self): def dict(self):
return {"filename": self.filename, "line": -1, "message": str(self)} return {"filename": self.filename, "line": -1, "message": str(self)}
class FileFormatException(FileImportException):
def __init__(self, file_format: str):
self.file_format = file_format
def dict(self):
message = f"Unknown file format: {self.file_format}"
return {"message": message}

7
backend/data_import/pipeline/factories.py

@ -6,6 +6,7 @@ from .catalog import (
CoNLL, CoNLL,
Excel, Excel,
FastText, FastText,
Format,
ImageFile, ImageFile,
TextFile, TextFile,
TextLine, TextLine,
@ -23,7 +24,7 @@ from .parsers import (
) )
def create_parser(file_format: str, **kwargs):
def create_parser(file_format: Format, **kwargs):
mapping = { mapping = {
TextFile.name: TextFileParser, TextFile.name: TextFileParser,
TextLine.name: LineParser, TextLine.name: LineParser,
@ -36,6 +37,4 @@ def create_parser(file_format: str, **kwargs):
ImageFile.name: PlainParser, ImageFile.name: PlainParser,
AudioFile.name: PlainParser, AudioFile.name: PlainParser,
} }
if file_format not in mapping:
raise ValueError(f"Invalid format: {file_format}")
return mapping[file_format](**kwargs)
return mapping[file_format.name](**kwargs)

45
backend/data_import/pipeline/label.py

@ -2,7 +2,7 @@ import abc
import uuid import uuid
from typing import Any, Optional from typing import Any, Optional
from pydantic import UUID4, BaseModel, validator
from pydantic import UUID4, BaseModel, ConstrainedStr, NonNegativeInt, root_validator
from .label_types import LabelTypes from .label_types import LabelTypes
from examples.models import Example from examples.models import Example
@ -15,6 +15,10 @@ from labels.models import TextLabel as TextLabelModel
from projects.models import Project from projects.models import Project
class NonEmptyStr(ConstrainedStr):
min_length = 1
class Label(BaseModel, abc.ABC): class Label(BaseModel, abc.ABC):
id: int = -1 id: int = -1
uuid: UUID4 uuid: UUID4
@ -45,18 +49,11 @@ class Label(BaseModel, abc.ABC):
class CategoryLabel(Label): class CategoryLabel(Label):
label: str
label: NonEmptyStr
def __lt__(self, other): def __lt__(self, other):
return self.label < other.label return self.label < other.label
@validator("label")
def label_is_not_empty(cls, value: str):
if value:
return value
else:
raise ValueError("is not empty.")
@classmethod @classmethod
def parse(cls, example_uuid: UUID4, obj: Any): def parse(cls, example_uuid: UUID4, obj: Any):
return cls(example_uuid=example_uuid, label=obj) return cls(example_uuid=example_uuid, label=obj)
@ -65,17 +62,24 @@ class CategoryLabel(Label):
return CategoryType(text=self.label, project=project) return CategoryType(text=self.label, project=project)
def create(self, user, example: Example, types: LabelTypes, **kwargs): def create(self, user, example: Example, types: LabelTypes, **kwargs):
return CategoryModel(uuid=self.uuid, user=user, example=example, label=types.get_by_text(self.label))
return CategoryModel(uuid=self.uuid, user=user, example=example, label=types[self.label])
class SpanLabel(Label): class SpanLabel(Label):
label: str
start_offset: int
end_offset: int
label: NonEmptyStr
start_offset: NonNegativeInt
end_offset: NonNegativeInt
def __lt__(self, other): def __lt__(self, other):
return self.start_offset < other.start_offset return self.start_offset < other.start_offset
@root_validator
def check_start_offset_is_less_than_end_offset(cls, values):
start_offset, end_offset = values.get("start_offset"), values.get("end_offset")
if start_offset >= end_offset:
raise ValueError("start_offset must be less than end_offset.")
return values
@classmethod @classmethod
def parse(cls, example_uuid: UUID4, obj: Any): def parse(cls, example_uuid: UUID4, obj: Any):
if isinstance(obj, list) or isinstance(obj, tuple): if isinstance(obj, list) or isinstance(obj, tuple):
@ -96,23 +100,16 @@ class SpanLabel(Label):
example=example, example=example,
start_offset=self.start_offset, start_offset=self.start_offset,
end_offset=self.end_offset, end_offset=self.end_offset,
label=types.get_by_text(self.label),
label=types[self.label],
) )
class TextLabel(Label): class TextLabel(Label):
text: str
text: NonEmptyStr
def __lt__(self, other): def __lt__(self, other):
return self.text < other.text return self.text < other.text
@validator("text")
def text_is_not_empty(cls, value: str):
if value:
return value
else:
raise ValueError("is not empty.")
@classmethod @classmethod
def parse(cls, example_uuid: UUID4, obj: Any): def parse(cls, example_uuid: UUID4, obj: Any):
return cls(example_uuid=example_uuid, text=obj) return cls(example_uuid=example_uuid, text=obj)
@ -127,7 +124,7 @@ class TextLabel(Label):
class RelationLabel(Label): class RelationLabel(Label):
from_id: int from_id: int
to_id: int to_id: int
type: str
type: NonEmptyStr
def __lt__(self, other): def __lt__(self, other):
return self.from_id < other.from_id return self.from_id < other.from_id
@ -144,7 +141,7 @@ class RelationLabel(Label):
uuid=self.uuid, uuid=self.uuid,
user=user, user=user,
example=example, example=example,
type=types.get_by_text(self.type),
type=types[self.type],
from_id=kwargs["id_to_span"][self.from_id], from_id=kwargs["id_to_span"][self.from_id],
to_id=kwargs["id_to_span"][self.to_id], to_id=kwargs["id_to_span"][self.to_id],
) )

6
backend/data_import/pipeline/label_types.py

@ -12,12 +12,12 @@ class LabelTypes:
def __contains__(self, text: str) -> bool: def __contains__(self, text: str) -> bool:
return text in self.types return text in self.types
def __getitem__(self, text: str) -> LabelType:
return self.types[text]
def save(self, label_types: List[LabelType]): def save(self, label_types: List[LabelType]):
self.label_type_class.objects.bulk_create(label_types, ignore_conflicts=True) self.label_type_class.objects.bulk_create(label_types, ignore_conflicts=True)
def update(self, project: Project): def update(self, project: Project):
types = self.label_type_class.objects.filter(project=project) types = self.label_type_class.objects.filter(project=project)
self.types = {label_type.text: label_type for label_type in types} self.types = {label_type.text: label_type for label_type in types}
def get_by_text(self, text: str) -> LabelType:
return self.types[text]

16
backend/data_import/tests/test_label.py

@ -55,7 +55,7 @@ class TestCategoryLabel(TestLabel):
def test_create(self): def test_create(self):
category = CategoryLabel(label="A", example_uuid=uuid.uuid4()) category = CategoryLabel(label="A", example_uuid=uuid.uuid4())
types = MagicMock() types = MagicMock()
types.get_by_text.return_value = mommy.make(CategoryType, project=self.project.item)
types.__getitem__.return_value = mommy.make(CategoryType, project=self.project.item)
category_model = category.create(self.user, self.example, types) category_model = category.create(self.user, self.example, types)
self.assertIsInstance(category_model, CategoryModel) self.assertIsInstance(category_model, CategoryModel)
@ -65,7 +65,7 @@ class TestSpanLabel(TestLabel):
def test_comparison(self): def test_comparison(self):
span1 = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) span1 = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
span2 = SpanLabel(label="A", start_offset=1, end_offset=1, example_uuid=uuid.uuid4())
span2 = SpanLabel(label="A", start_offset=1, end_offset=2, example_uuid=uuid.uuid4())
self.assertLess(span1, span2) self.assertLess(span1, span2)
def test_parse_tuple(self): def test_parse_tuple(self):
@ -82,6 +82,14 @@ class TestSpanLabel(TestLabel):
self.assertEqual(span.start_offset, 0) self.assertEqual(span.start_offset, 0)
self.assertEqual(span.end_offset, 1) self.assertEqual(span.end_offset, 1)
def test_invalid_negative_offset(self):
with self.assertRaises(ValueError):
SpanLabel(label="A", start_offset=-1, end_offset=1, example_uuid=uuid.uuid4())
def test_invalid_offset(self):
with self.assertRaises(ValueError):
SpanLabel(label="A", start_offset=1, end_offset=0, example_uuid=uuid.uuid4())
def test_parse_invalid_dict(self): def test_parse_invalid_dict(self):
example_uuid = uuid.uuid4() example_uuid = uuid.uuid4()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -96,7 +104,7 @@ class TestSpanLabel(TestLabel):
def test_create(self): def test_create(self):
span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
types = MagicMock() types = MagicMock()
types.get_by_text.return_value = mommy.make(SpanType, project=self.project.item)
types.__getitem__.return_value = mommy.make(SpanType, project=self.project.item)
span_model = span.create(self.user, self.example, types) span_model = span.create(self.user, self.example, types)
self.assertIsInstance(span_model, SpanModel) self.assertIsInstance(span_model, SpanModel)
@ -160,7 +168,7 @@ class TestRelationLabel(TestLabel):
def test_create(self): def test_create(self):
relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4())
types = MagicMock() types = MagicMock()
types.get_by_text.return_value = mommy.make(RelationType, project=self.project.item)
types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item)
id_to_span = { id_to_span = {
0: mommy.make(SpanModel, start_offset=0, end_offset=1), 0: mommy.make(SpanModel, start_offset=0, end_offset=1),
1: mommy.make(SpanModel, start_offset=2, end_offset=3), 1: mommy.make(SpanModel, start_offset=2, end_offset=3),

4
backend/data_import/tests/test_label_types.py

@ -22,10 +22,8 @@ class TestCategoryLabel(TestCase):
def test_update(self): def test_update(self):
label_types = LabelTypes(CategoryType) label_types = LabelTypes(CategoryType)
with self.assertRaises(KeyError):
label_types.get_by_text("A")
category_types = [CategoryType(text="A", project=self.project.item)] category_types = [CategoryType(text="A", project=self.project.item)]
label_types.save(category_types) label_types.save(category_types)
label_types.update(self.project.item) label_types.update(self.project.item)
category_type = label_types.get_by_text("A")
category_type = label_types["A"]
self.assertEqual(category_type.text, "A") self.assertEqual(category_type.text, "A")

10
backend/data_import/tests/test_tasks.py

@ -68,6 +68,16 @@ class TestMaxFileSize(TestImportData):
self.assertIn("maximum file size", response["error"][0]["message"]) self.assertIn("maximum file size", response["error"][0]["message"])
class TestInvalidFileFormat(TestImportData):
task = DOCUMENT_CLASSIFICATION
def test_invalid_file_format(self):
filename = "text_classification/example.csv"
file_format = "INVALID_FORMAT"
response = self.import_dataset(filename, file_format, self.task)
self.assertEqual(len(response["error"]), 1)
class TestImportClassificationData(TestImportData): class TestImportClassificationData(TestImportData):
task = DOCUMENT_CLASSIFICATION task = DOCUMENT_CLASSIFICATION

Loading…
Cancel
Save