|
|
import abc import itertools from collections import defaultdict from typing import Any, Dict, List
from django.conf import settings
from projects.models import Project from examples.models import Example from label_types.models import CategoryType, SpanType from .exceptions import FileParseException from .readers import BaseReader, Record
class Writer(abc.ABC): @abc.abstractmethod def save(self, reader: BaseReader, project: Project, user, cleaner): """Save the read contents to DB.""" raise NotImplementedError("Please implement this method in the subclass.")
def errors(self) -> List[Dict[Any, Any]]: """Return errors.""" raise NotImplementedError("Please implement this method in the subclass.")
def group_by_class(instances): groups = defaultdict(list) for instance in instances: groups[instance.__class__].append(instance) return groups
class Examples: def __init__(self, buffer_size: int = settings.IMPORT_BATCH_SIZE): self.buffer_size = buffer_size self.buffer: List[Record] = []
def __len__(self): return len(self.buffer)
@property def data(self): return self.buffer
def add(self, data): self.buffer.append(data)
def clear(self): self.buffer = []
def is_full(self): return len(self) >= self.buffer_size
def is_empty(self): return len(self) == 0
def save_label(self, project: Project): labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.buffer])) labels = list(filter(None, labels)) groups = group_by_class(labels) for klass, instances in groups.items(): klass.objects.bulk_create(instances, ignore_conflicts=True)
def save_data(self, project: Project) -> List[Example]: examples = [example.create_data(project) for example in self.buffer] return Example.objects.bulk_create(examples)
def save_annotation(self, project: Project, user, examples): # mapping = {label.text: label for label in project.labels.all()} # Todo: move annotation class mapping = {} for model in [CategoryType, SpanType]: for label in model.objects.all(): mapping[label.text] = label annotations = list( itertools.chain.from_iterable( [data.create_annotation(user, example, mapping) for data, example in zip(self.buffer, examples)] ) ) groups = group_by_class(annotations) for klass, instances in groups.items(): klass.objects.bulk_create(instances)
class BulkWriter(Writer): def __init__(self, batch_size: int): self.examples = Examples(batch_size) self._errors: List[FileParseException] = []
def save(self, reader: BaseReader, project: Project, user, cleaner): it = iter(reader) while True: try: example = next(it) except StopIteration: break try: example.clean(cleaner) except FileParseException as err: self._errors.append(err)
self.examples.add(example) if self.examples.is_full(): self.create(project, user) self.examples.clear() if not self.examples.is_empty(): self.create(project, user) self.examples.clear() self._errors.extend(reader.errors)
@property def errors(self) -> List[Dict[Any, Any]]: self._errors.sort(key=lambda e: e.line_num) return [error.dict() for error in self._errors]
def create(self, project: Project, user): self.examples.save_label(project) ids = self.examples.save_data(project) self.examples.save_annotation(project, user, ids)
|