Browse Source

Pass save_names parameter to Writer

pull/1779/head
Hironsan 2 years ago
parent
commit
f4c1af9891
1 changed files with 6 additions and 3 deletions
  1. 9
      backend/data_import/pipeline/writers.py

9
backend/data_import/pipeline/writers.py

@ -61,8 +61,10 @@ class Examples:
for klass, instances in groups.items():
klass.objects.bulk_create(instances, ignore_conflicts=True)
def save_data(self, project: Project) -> List[Example]:
def save_data(self, project: Project, save_names: Dict[str, str]) -> List[Example]:
examples = [example.create_data(project) for example in self.buffer]
for example in examples:
example.filename = save_names.get(example.filename, example.filename)
return Example.objects.bulk_create(examples)
def save_annotation(self, project: Project, user, examples):
@ -83,9 +85,10 @@ class Examples:
class BulkWriter(Writer):
def __init__(self, batch_size: int):
def __init__(self, batch_size: int, save_names: Dict[str, str]):
self.examples = Examples(batch_size)
self._errors: List[FileParseException] = []
self.save_names = save_names
def save(self, reader: BaseReader, project: Project, user, cleaner):
it = iter(reader)
@ -115,5 +118,5 @@ class BulkWriter(Writer):
def create(self, project: Project, user):
self.examples.save_label(project)
ids = self.examples.save_data(project)
ids = self.examples.save_data(project, self.save_names)
self.examples.save_annotation(project, user, ids)
Loading…
Cancel
Save