You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

116 lines
3.4 KiB

import abc
import itertools
from typing import List
from django.conf import settings
from ...models import Example, Label, Project
from .exception import FileParseException, FileParseExceptions
from .readers import BaseReader
class Writer(abc.ABC):
@abc.abstractmethod
def save(self, reader: BaseReader):
"""Save the read contents to DB."""
raise NotImplementedError('Please implement this method in the subclass.')
class BulkWriterOld(Writer):
def __init__(self, batch_size: int):
self.batch_size = batch_size
def save(self, reader: BaseReader):
"""Bulk save the read contents."""
pass
def group_by_class(instances):
from collections import defaultdict
groups = defaultdict(list)
for instance in instances:
groups[instance.__class__].append(instance)
return groups
class Examples:
def __init__(self, buffer_size=settings.IMPORT_BATCH_SIZE):
self.buffer_size = buffer_size
self.buffer = []
def __len__(self):
return len(self.buffer)
@property
def data(self):
return self.buffer
def add(self, data):
self.buffer.append(data)
def clear(self):
self.buffer = []
def is_full(self):
return len(self) >= self.buffer_size
def is_empty(self):
return len(self) == 0
def save_label(self, project: Project):
labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.buffer]))
labels = list(filter(None, labels))
Label.objects.bulk_create(labels, ignore_conflicts=True)
def save_data(self, project: Project) -> List[Example]:
examples = [example.create_data(project) for example in self.buffer]
return Example.objects.bulk_create(examples)
def save_annotation(self, project, user, examples):
mapping = {(label.text, label.task_type): label for label in project.labels.all()}
annotations = list(itertools.chain.from_iterable([
data.create_annotation(user, example, mapping) for data, example in zip(self.buffer, examples)
]))
groups = group_by_class(annotations)
for klass, instances in groups.items():
klass.objects.bulk_create(instances)
class BulkWriter:
def __init__(self, batch_size):
self.examples = Examples(batch_size)
self.errors = []
def save(self, dataset, project, user, cleaner):
while True:
try:
example = next(dataset)
except StopIteration:
break
except FileParseException as err:
self.errors.append(err.dict())
continue
except FileParseExceptions as err:
self.errors.append(list(err))
continue
try:
example.clean(cleaner)
except FileParseException as err:
self.errors.append(err.dict())
self.examples.add(example)
if self.examples.is_full():
self.create(project, user)
self.examples.clear()
if not self.examples.is_empty():
self.create(project, user)
self.examples.clear()
def create(self, project, user):
self.examples.save_label(project)
ids = self.examples.save_data(project)
self.examples.save_annotation(project, user, ids)