Browse Source

Move examples methods to BulkWriter

pull/1823/head
Hironsan 3 years ago
parent
commit
b6f369e3c2
1 changed files with 32 additions and 33 deletions
  1. 65
      backend/data_import/pipeline/writers.py

65
backend/data_import/pipeline/writers.py

@ -38,9 +38,8 @@ class Examples:
def __len__(self):
return len(self.buffer)
@property
def data(self):
return self.buffer
def __getitem__(self, item):
return self.buffer[item]
def add(self, data):
self.buffer.append(data)
@ -54,33 +53,6 @@ class Examples:
def is_empty(self):
return len(self) == 0
def save_label(self, project: Project):
labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.buffer]))
labels = list(filter(None, labels))
groups = group_by_class(labels)
for klass, instances in groups.items():
klass.objects.bulk_create(instances, ignore_conflicts=True)
def save_data(self, project: Project) -> List[Example]:
examples = [example.create_data(project) for example in self.buffer]
return Example.objects.bulk_create(examples)
def save_annotation(self, project: Project, user, examples):
# Todo: move annotation class
mapping = {}
label_types: List[Type[LabelType]] = [CategoryType, SpanType]
for model in label_types:
for label in model.objects.filter(project=project):
mapping[label.text] = label
annotations = list(
itertools.chain.from_iterable(
[data.create_annotation(user, example, mapping) for data, example in zip(self.buffer, examples)]
)
)
groups = group_by_class(annotations)
for klass, instances in groups.items():
klass.objects.bulk_create(instances)
class BulkWriter(Writer):
def __init__(self, batch_size: int):
@ -114,6 +86,33 @@ class BulkWriter(Writer):
return [error.dict() for error in self._errors]
def create(self, project: Project, user):
self.examples.save_label(project)
ids = self.examples.save_data(project)
self.examples.save_annotation(project, user, ids)
self.save_label(project)
ids = self.save_data(project)
self.save_annotation(project, user, ids)
def save_label(self, project: Project):
labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.examples]))
labels = list(filter(None, labels))
groups = group_by_class(labels)
for klass, instances in groups.items():
klass.objects.bulk_create(instances, ignore_conflicts=True)
def save_data(self, project: Project) -> List[Example]:
examples = [example.create_data(project) for example in self.examples]
return Example.objects.bulk_create(examples)
def save_annotation(self, project: Project, user, examples):
# Todo: move annotation class
mapping = {}
label_types: List[Type[LabelType]] = [CategoryType, SpanType]
for model in label_types:
for label in model.objects.filter(project=project):
mapping[label.text] = label
annotations = list(
itertools.chain.from_iterable(
[data.create_annotation(user, example, mapping) for data, example in zip(self.examples, examples)]
)
)
groups = group_by_class(annotations)
for klass, instances in groups.items():
klass.objects.bulk_create(instances)
Loading…
Cancel
Save