Browse Source

Add Examples class

pull/1844/head
Hironsan 2 years ago
parent
commit
22f897d2b7
3 changed files with 54 additions and 11 deletions
  1. 22
      backend/data_import/datasets.py
  2. 18
      backend/data_import/pipeline/examples.py
  3. 25
      backend/data_import/tests/test_examples.py

22
backend/data_import/datasets.py

@ -6,6 +6,7 @@ from django.contrib.auth.models import User
from .models import DummyLabelType
from .pipeline.catalog import RELATION_EXTRACTION, Format
from .pipeline.data import BaseData, BinaryData, TextData
from .pipeline.examples import Examples
from .pipeline.exceptions import FileParseException
from .pipeline.factories import create_parser
from .pipeline.label import CategoryLabel, Label, RelationLabel, SpanLabel, TextLabel
@ -18,7 +19,6 @@ from .pipeline.readers import (
FileName,
Reader,
)
from examples.models import Example
from label_types.models import CategoryType, LabelType, RelationType, SpanType
from projects.models import (
DOCUMENT_CLASSIFICATION,
@ -52,8 +52,8 @@ class PlainDataset(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
@property
def errors(self) -> List[FileParseException]:
@ -82,8 +82,8 @@ class DatasetWithSingleLabelType(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
# create examples
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
# create label types
labels = self.labels_class(self.label_maker.make(records), self.types)
@ -105,8 +105,8 @@ class BinaryDataset(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
@property
def errors(self) -> List[FileParseException]:
@ -151,8 +151,8 @@ class RelationExtractionDataset(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
# create examples
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
# create label types
spans = Spans(self.span_maker.make(records), self.span_types)
@ -189,8 +189,8 @@ class CategoryAndSpanDataset(Dataset):
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
# create examples
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
examples = Examples(self.example_maker.make(records))
examples.save()
# create label types
categories = Categories(self.category_maker.make(records), self.category_types)

18
backend/data_import/pipeline/examples.py

@ -0,0 +1,18 @@
from typing import Dict, List
from pydantic import UUID4
from examples.models import Example
class Examples:
def __init__(self, examples: List[Example]):
self.examples = examples
self.uuid_to_example: Dict[UUID4, Example] = {}
def __getitem__(self, uuid: UUID4):
return self.uuid_to_example[uuid]
def save(self):
examples = Example.objects.bulk_create(self.examples)
self.uuid_to_example = {example.uuid: example for example in examples}

25
backend/data_import/tests/test_examples.py

@ -0,0 +1,25 @@
import uuid
from django.test import TestCase
from data_import.pipeline.examples import Examples
from examples.models import Example
from projects.models import DOCUMENT_CLASSIFICATION
from projects.tests.utils import prepare_project
class TestCategories(TestCase):
def setUp(self):
self.project = prepare_project(DOCUMENT_CLASSIFICATION)
self.example_uuid = uuid.uuid4()
self.example = Example(uuid=self.example_uuid, text="A", project=self.project.item)
self.examples = Examples([self.example])
def test_save(self):
self.examples.save()
self.assertEqual(Example.objects.count(), 1)
def test_getitem(self):
self.examples.save()
example = self.examples[self.example_uuid]
self.assertEqual(example, self.example)
Loading…
Cancel
Save