From b6f369e3c2071de642f7a63b0de86242dac64089 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Wed, 11 May 2022 08:41:09 +0900 Subject: [PATCH] Move examples methods to BulkWriter --- backend/data_import/pipeline/writers.py | 65 ++++++++++++------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/backend/data_import/pipeline/writers.py b/backend/data_import/pipeline/writers.py index cafc9bc7..9ee9fc99 100644 --- a/backend/data_import/pipeline/writers.py +++ b/backend/data_import/pipeline/writers.py @@ -38,9 +38,8 @@ class Examples: def __len__(self): return len(self.buffer) - @property - def data(self): - return self.buffer + def __getitem__(self, item): + return self.buffer[item] def add(self, data): self.buffer.append(data) @@ -54,33 +53,6 @@ class Examples: def is_empty(self): return len(self) == 0 - def save_label(self, project: Project): - labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.buffer])) - labels = list(filter(None, labels)) - groups = group_by_class(labels) - for klass, instances in groups.items(): - klass.objects.bulk_create(instances, ignore_conflicts=True) - - def save_data(self, project: Project) -> List[Example]: - examples = [example.create_data(project) for example in self.buffer] - return Example.objects.bulk_create(examples) - - def save_annotation(self, project: Project, user, examples): - # Todo: move annotation class - mapping = {} - label_types: List[Type[LabelType]] = [CategoryType, SpanType] - for model in label_types: - for label in model.objects.filter(project=project): - mapping[label.text] = label - annotations = list( - itertools.chain.from_iterable( - [data.create_annotation(user, example, mapping) for data, example in zip(self.buffer, examples)] - ) - ) - groups = group_by_class(annotations) - for klass, instances in groups.items(): - klass.objects.bulk_create(instances) - class BulkWriter(Writer): def __init__(self, batch_size: int): @@ -114,6 +86,33 @@ class BulkWriter(Writer): return [error.dict() for error in self._errors] def create(self, project: Project, user): - self.examples.save_label(project) - ids = self.examples.save_data(project) - self.examples.save_annotation(project, user, ids) + self.save_label(project) + ids = self.save_data(project) + self.save_annotation(project, user, ids) + + def save_label(self, project: Project): + labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.examples])) + labels = list(filter(None, labels)) + groups = group_by_class(labels) + for klass, instances in groups.items(): + klass.objects.bulk_create(instances, ignore_conflicts=True) + + def save_data(self, project: Project) -> List[Example]: + examples = [example.create_data(project) for example in self.examples] + return Example.objects.bulk_create(examples) + + def save_annotation(self, project: Project, user, examples): + # Todo: move annotation class + mapping = {} + label_types: List[Type[LabelType]] = [CategoryType, SpanType] + for model in label_types: + for label in model.objects.filter(project=project): + mapping[label.text] = label + annotations = list( + itertools.chain.from_iterable( + [data.create_annotation(user, example, mapping) for data, example in zip(self.examples, examples)] + ) + ) + groups = group_by_class(annotations) + for klass, instances in groups.items(): + klass.objects.bulk_create(instances)