Browse Source

Merge pull request #1845 from doccano/enhancement/refactorDataImport

[Enhancement] refactor data import
release-1.8.0
Hiroki Nakayama 2 years ago
committed by GitHub
parent
commit
bf9b62f716
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 87 additions and 38 deletions
  1. 32
      backend/data_import/datasets.py
  2. 21
      backend/data_import/pipeline/examples.py
  3. 21
      backend/data_import/pipeline/labels.py
  4. 25
      backend/data_import/tests/test_examples.py
  5. 26
      backend/data_import/tests/test_labels.py

32
backend/data_import/datasets.py

@ -6,6 +6,7 @@ from django.contrib.auth.models import User
from .models import DummyLabelType
from .pipeline.catalog import RELATION_EXTRACTION, Format
from .pipeline.data import BaseData, BinaryData, TextData
from .pipeline.examples import Examples
from .pipeline.exceptions import FileParseException
from .pipeline.factories import create_parser
from .pipeline.label import CategoryLabel, Label, RelationLabel, SpanLabel, TextLabel
@ -18,7 +19,6 @@ from .pipeline.readers import (
FileName,
Reader,
)
from examples.models import Example
from label_types.models import CategoryType, LabelType, RelationType, SpanType
from projects.models import (
DOCUMENT_CLASSIFICATION,
@ -52,8 +52,8 @@ class PlainDataset(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
@property
def errors(self) -> List[FileParseException]:
@ -82,8 +82,8 @@ class DatasetWithSingleLabelType(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
# create examples
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
# create label types
labels = self.labels_class(self.label_maker.make(records), self.types)
@ -91,7 +91,7 @@ class DatasetWithSingleLabelType(Dataset):
labels.save_types(self.project)
# create Labels
labels.save(user)
labels.save(user, examples)
@property
def errors(self) -> List[FileParseException]:
@ -105,8 +105,8 @@ class BinaryDataset(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
@property
def errors(self) -> List[FileParseException]:
@ -151,8 +151,8 @@ class RelationExtractionDataset(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
# create examples
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
# create label types
spans = Spans(self.span_maker.make(records), self.span_types)
@ -164,8 +164,8 @@ class RelationExtractionDataset(Dataset):
relations.save_types(self.project)
# create Labels
spans.save(user)
relations.save(user, spans=spans)
spans.save(user, examples)
relations.save(user, examples, spans=spans)
@property
def errors(self) -> List[FileParseException]:
@ -189,8 +189,8 @@ class CategoryAndSpanDataset(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
# create examples
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
# create label types
categories = Categories(self.category_maker.make(records), self.category_types)
@ -202,8 +202,8 @@ class CategoryAndSpanDataset(Dataset):
spans.save_types(self.project)
# create Labels
categories.save(user)
spans.save(user)
categories.save(user, examples)
spans.save(user, examples)
@property
def errors(self) -> List[FileParseException]:

21
backend/data_import/pipeline/examples.py

@ -0,0 +1,21 @@
from typing import Dict, List
from pydantic import UUID4
from examples.models import Example
class Examples:
def __init__(self, examples: List[Example]):
self.examples = examples
self.uuid_to_example: Dict[UUID4, Example] = {}
def __getitem__(self, uuid: UUID4) -> Example:
return self.uuid_to_example[uuid]
def __contains__(self, uuid: UUID4) -> bool:
return uuid in self.uuid_to_example
def save(self):
examples = Example.objects.bulk_create(self.examples)
self.uuid_to_example = {example.uuid: example for example in examples}

21
backend/data_import/pipeline/labels.py

@ -2,11 +2,9 @@ import abc
from itertools import groupby
from typing import Dict, List
from pydantic import UUID4
from .examples import Examples
from .label import Label
from .label_types import LabelTypes
from examples.models import Example
from labels.models import Category as CategoryModel
from labels.models import Label as LabelModel
from labels.models import Relation as RelationModel
@ -34,18 +32,11 @@ class Labels(abc.ABC):
self.types.save(filtered_types)
self.types.update(project)
@property
def uuid_to_example(self) -> Dict[UUID4, Example]:
example_uuids = {str(label.example_uuid) for label in self.labels}
examples = Example.objects.filter(uuid__in=example_uuids)
return {example.uuid: example for example in examples}
def save(self, user, **kwargs):
uuid_to_example = self.uuid_to_example
def save(self, user, examples: Examples, **kwargs):
labels = [
label.create(user, uuid_to_example[label.example_uuid], self.types, **kwargs)
label.create(user, examples[label.example_uuid], self.types, **kwargs)
for label in self.labels
if label.example_uuid in uuid_to_example
if label.example_uuid in examples
]
self.label_model.objects.bulk_create(labels)
@ -93,6 +84,6 @@ class Texts(Labels):
class Relations(Labels):
label_model = RelationModel
def save(self, user, **kwargs):
def save(self, user, examples: Examples, **kwargs):
id_to_span = kwargs["spans"].id_to_span
super().save(user, id_to_span=id_to_span)
super().save(user, examples, id_to_span=id_to_span)

25
backend/data_import/tests/test_examples.py

@ -0,0 +1,25 @@
import uuid
from django.test import TestCase
from data_import.pipeline.examples import Examples
from examples.models import Example
from projects.models import DOCUMENT_CLASSIFICATION
from projects.tests.utils import prepare_project
class TestExamples(TestCase):
def setUp(self):
self.project = prepare_project(DOCUMENT_CLASSIFICATION)
self.example_uuid = uuid.uuid4()
example = Example(uuid=self.example_uuid, text="A", project=self.project.item)
self.examples = Examples([example])
def test_save(self):
self.examples.save()
self.assertEqual(Example.objects.count(), 1)
def test_getitem(self):
self.examples.save()
example = self.examples[self.example_uuid]
self.assertEqual(example.uuid, self.example_uuid)

26
backend/data_import/tests/test_labels.py

@ -30,7 +30,10 @@ class TestCategories(TestCase):
CategoryLabel(example_uuid=example_uuid, label="A"),
CategoryLabel(example_uuid=example_uuid, label="B"),
]
mommy.make("Example", project=self.project.item, uuid=example_uuid)
example = mommy.make("Example", project=self.project.item, uuid=example_uuid)
self.examples = MagicMock()
self.examples.__getitem__.return_value = example
self.examples.__contains__.return_value = True
self.categories = Categories(labels, self.types)
def test_clean(self):
@ -45,7 +48,7 @@ class TestCategories(TestCase):
def test_save(self):
self.categories.save_types(self.project.item)
self.categories.save(self.user)
self.categories.save(self.user, self.examples)
self.assertEqual(Category.objects.count(), 2)
def test_save_types(self):
@ -64,7 +67,10 @@ class TestSpans(TestCase):
SpanLabel(example_uuid=example_uuid, label="B", start_offset=0, end_offset=3),
SpanLabel(example_uuid=example_uuid, label="B", start_offset=3, end_offset=4),
]
mommy.make("Example", project=self.project.item, uuid=example_uuid)
example = mommy.make("Example", project=self.project.item, uuid=example_uuid)
self.examples = MagicMock()
self.examples.__getitem__.return_value = example
self.examples.__contains__.return_value = True
self.spans = Spans(labels, self.types)
def disable_overlapping(self):
@ -96,7 +102,7 @@ class TestSpans(TestCase):
def test_save(self):
self.spans.save_types(self.project.item)
self.spans.save(self.user)
self.spans.save(self.user, self.examples)
self.assertEqual(Span.objects.count(), 3)
def test_save_types(self):
@ -114,7 +120,10 @@ class TestTexts(TestCase):
TextLabel(example_uuid=example_uuid, text="A"),
TextLabel(example_uuid=example_uuid, text="B"),
]
mommy.make("Example", project=self.project.item, uuid=example_uuid)
example = mommy.make("Example", project=self.project.item, uuid=example_uuid)
self.examples = MagicMock()
self.examples.__getitem__.return_value = example
self.examples.__contains__.return_value = True
self.texts = Texts(labels, self.types)
def test_clean(self):
@ -123,7 +132,7 @@ class TestTexts(TestCase):
def test_save(self):
self.texts.save_types(self.project.item)
self.texts.save(self.user)
self.texts.save(self.user, self.examples)
self.assertEqual(TextLabelModel.objects.count(), 2)
def test_save_types(self):
@ -146,6 +155,9 @@ class TestRelations(TestCase):
self.relations = Relations(labels, self.types)
self.spans = MagicMock()
self.spans.id_to_span = {from_span.id: from_span, to_span.id: to_span}
self.examples = MagicMock()
self.examples.__getitem__.return_value = example
self.examples.__contains__.return_value = True
def test_clean(self):
self.relations.clean(self.project.item)
@ -153,7 +165,7 @@ class TestRelations(TestCase):
def test_save(self):
self.relations.save_types(self.project.item)
self.relations.save(self.user, spans=self.spans)
self.relations.save(self.user, self.examples, spans=self.spans)
self.assertEqual(Relation.objects.count(), 1)
def test_save_types(self):

Loading…
Cancel
Save