Browse Source

Add BulkWriter

pull/1619/head
Hironsan 2 years ago
parent
commit
1e372f2243
2 changed files with 121 additions and 95 deletions
  1. 100
      backend/api/tasks.py
  2. 116
      backend/api/views/upload/writers.py

100
backend/api/tasks.py

@ -1,89 +1,24 @@
import itertools
from typing import List
from celery import shared_task
from celery.utils.log import get_task_logger
from django.conf import settings
from django.contrib.auth import get_user_model
from django.shortcuts import get_object_or_404
from .models import Example, Label, Project
from .models import Project
from .views.download.factory import create_repository, create_writer
from .views.download.service import ExportApplicationService
from .views.upload.exception import FileParseException, FileParseExceptions
from .views.upload.factory import (create_cleaner, get_data_class,
get_dataset_class, get_label_class)
from .views.upload.writers import BulkWriter
logger = get_task_logger(__name__)
def group_by_class(instances):
from collections import defaultdict
groups = defaultdict(list)
for instance in instances:
groups[instance.__class__].append(instance)
return groups
class Examples:
def __init__(self, buffer_size=settings.IMPORT_BATCH_SIZE):
self.buffer_size = buffer_size
self.buffer = []
def __len__(self):
return len(self.buffer)
@property
def data(self):
return self.buffer
def add(self, data):
self.buffer.append(data)
def clear(self):
self.buffer = []
def is_full(self):
return len(self) >= self.buffer_size
def is_empty(self):
return len(self) == 0
def save_label(self, project: Project):
labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.buffer]))
labels = list(filter(None, labels))
Label.objects.bulk_create(labels, ignore_conflicts=True)
def save_data(self, project: Project) -> List[Example]:
examples = [example.create_data(project) for example in self.buffer]
return Example.objects.bulk_create(examples)
def save_annotation(self, project, user, examples):
mapping = {(label.text, label.task_type): label for label in project.labels.all()}
annotations = list(itertools.chain.from_iterable([
data.create_annotation(user, example, mapping) for data, example in zip(self.buffer, examples)
]))
groups = group_by_class(annotations)
for klass, instances in groups.items():
klass.objects.bulk_create(instances)
class DataFactory:
def create(self, examples, user, project):
examples.save_label(project)
ids = examples.save_data(project)
examples.save_annotation(project, user, ids)
@shared_task
def ingest_data(user_id, project_id, filenames, format: str, **kwargs):
project = get_object_or_404(Project, pk=project_id)
user = get_object_or_404(get_user_model(), pk=user_id)
response = {'error': []}
# Prepare dataset.
dataset_class = get_dataset_class(format)
dataset = dataset_class(
filenames=filenames,
@ -92,35 +27,10 @@ def ingest_data(user_id, project_id, filenames, format: str, **kwargs):
**kwargs
)
it = iter(dataset)
buffer = Examples()
factory = DataFactory()
cleaner = create_cleaner(project)
while True:
try:
example = next(it)
except StopIteration:
break
except FileParseException as err:
response['error'].append(err.dict())
continue
except FileParseExceptions as err:
response['error'].extend(list(err))
continue
try:
example.clean(cleaner)
except FileParseException as err:
response['error'].append(err.dict())
buffer.add(example)
if buffer.is_full():
factory.create(buffer, user, project)
buffer.clear()
if not buffer.is_empty():
logger.debug(f'BUFFER LEN {len(buffer)}')
factory.create(buffer, user, project)
buffer.clear()
return response
writer = BulkWriter(batch_size=settings.IMPORT_BATCH_SIZE)
writer.save(it, project, user, cleaner)
return {'error': writer.errors}
@shared_task

116
backend/api/views/upload/writers.py

@ -0,0 +1,116 @@
import abc
import itertools
from typing import List
from django.conf import settings
from ...models import Example, Label, Project
from .exception import FileParseException, FileParseExceptions
from .readers import BaseReader
class Writer(abc.ABC):
@abc.abstractmethod
def save(self, reader: BaseReader):
"""Save the read contents to DB."""
raise NotImplementedError('Please implement this method in the subclass.')
class BulkWriterOld(Writer):
def __init__(self, batch_size: int):
self.batch_size = batch_size
def save(self, reader: BaseReader):
"""Bulk save the read contents."""
pass
def group_by_class(instances):
from collections import defaultdict
groups = defaultdict(list)
for instance in instances:
groups[instance.__class__].append(instance)
return groups
class Examples:
def __init__(self, buffer_size=settings.IMPORT_BATCH_SIZE):
self.buffer_size = buffer_size
self.buffer = []
def __len__(self):
return len(self.buffer)
@property
def data(self):
return self.buffer
def add(self, data):
self.buffer.append(data)
def clear(self):
self.buffer = []
def is_full(self):
return len(self) >= self.buffer_size
def is_empty(self):
return len(self) == 0
def save_label(self, project: Project):
labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.buffer]))
labels = list(filter(None, labels))
Label.objects.bulk_create(labels, ignore_conflicts=True)
def save_data(self, project: Project) -> List[Example]:
examples = [example.create_data(project) for example in self.buffer]
return Example.objects.bulk_create(examples)
def save_annotation(self, project, user, examples):
mapping = {(label.text, label.task_type): label for label in project.labels.all()}
annotations = list(itertools.chain.from_iterable([
data.create_annotation(user, example, mapping) for data, example in zip(self.buffer, examples)
]))
groups = group_by_class(annotations)
for klass, instances in groups.items():
klass.objects.bulk_create(instances)
class BulkWriter:
def __init__(self, batch_size):
self.examples = Examples(batch_size)
self.errors = []
def save(self, dataset, project, user, cleaner):
while True:
try:
example = next(dataset)
except StopIteration:
break
except FileParseException as err:
self.errors.append(err.dict())
continue
except FileParseExceptions as err:
self.errors.append(list(err))
continue
try:
example.clean(cleaner)
except FileParseException as err:
self.errors.append(err.dict())
self.examples.add(example)
if self.examples.is_full():
self.create(project, user)
self.examples.clear()
if not self.examples.is_empty():
self.create(project, user)
self.examples.clear()
def create(self, project, user):
self.examples.save_label(project)
ids = self.examples.save_data(project)
self.examples.save_annotation(project, user, ids)
Loading…
Cancel
Save