Browse Source

Change the location of applying cleaner from writer to reader

pull/1823/head
Hironsan 2 years ago
parent
commit
59b4e13111
3 changed files with 22 additions and 28 deletions
  1. 6
      backend/data_import/celery_tasks.py
  2. 14
      backend/data_import/pipeline/readers.py
  3. 30
      backend/data_import/pipeline/writers.py

6
backend/data_import/celery_tasks.py

@ -61,12 +61,12 @@ def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str],
parser = create_parser(file_format, **kwargs)
builder = create_builder(project, **kwargs)
reader = Reader(filenames=filenames, parser=parser, builder=builder)
cleaner = create_cleaner(project)
reader = Reader(filenames=filenames, parser=parser, builder=builder, cleaner=cleaner)
writer = Writer(batch_size=settings.IMPORT_BATCH_SIZE)
writer.save(reader, project, user, cleaner)
writer.save(reader, project, user)
upload_to_store(temporary_uploads)
return {"error": writer.errors + errors}
return {"error": reader.errors + errors}
def upload_to_store(temporary_uploads):

14
backend/data_import/pipeline/readers.py

@ -35,7 +35,7 @@ class Record:
changed = len(label) != len(self.label)
self._label = label
if changed:
raise FileParseException(filename=self._data.filename, line_num=self._line_num, message=cleaner.message)
return FileParseException(filename=self._data.filename, line_num=self._line_num, message=cleaner.message)
@property
def data(self):
@ -104,10 +104,11 @@ class Builder(abc.ABC):
class Reader(BaseReader):
def __init__(self, filenames: List[FileName], parser: Parser, builder: Builder):
def __init__(self, filenames: List[FileName], parser: Parser, builder: Builder, cleaner: Cleaner):
self.filenames = filenames
self.parser = parser
self.builder = builder
self.cleaner = cleaner
self._errors: List[FileParseException] = []
def __iter__(self) -> Iterator[Record]:
@ -115,7 +116,11 @@ class Reader(BaseReader):
rows = self.parser.parse(filename.full_path)
for line_num, row in enumerate(rows, start=1):
try:
yield self.builder.build(row, filename, line_num)
record = self.builder.build(row, filename, line_num)
maybe_error = record.clean(self.cleaner)
if maybe_error:
self._errors.append(maybe_error)
yield record
except FileParseException as e:
self._errors.append(e)
@ -123,4 +128,5 @@ class Reader(BaseReader):
def errors(self) -> List[FileParseException]:
"""Aggregates parser and builder errors."""
errors = self.parser.errors + self._errors
return errors
errors.sort(key=lambda error: error.line_num)
return [error.dict() for error in errors]

30
backend/data_import/pipeline/writers.py

@ -1,10 +1,9 @@
import itertools
from collections import defaultdict
from typing import Any, Dict, List, Type
from typing import List, Type
from django.conf import settings
from .exceptions import FileParseException
from .readers import BaseReader, Record
from examples.models import Example
from label_types.models import CategoryType, LabelType, SpanType
@ -45,19 +44,14 @@ class Examples:
class Writer:
def __init__(self, batch_size: int):
self.examples = Examples(batch_size)
self._errors: List[FileParseException] = []
def save(self, reader: BaseReader, project: Project, user, cleaner):
def save(self, reader: BaseReader, project: Project, user):
it = iter(reader)
while True:
try:
example = next(it)
except StopIteration:
break
try:
example.clean(cleaner)
except FileParseException as err:
self._errors.append(err)
self.examples.add(example)
if self.examples.is_full():
@ -66,36 +60,30 @@ class Writer:
if not self.examples.is_empty():
self.create(project, user)
self.examples.clear()
self._errors.extend(reader.errors)
@property
def errors(self) -> List[Dict[Any, Any]]:
self._errors.sort(key=lambda e: e.line_num)
return [error.dict() for error in self._errors]
def create(self, project: Project, user):
self.save_label(project)
ids = self.save_data(project)
self.save_annotation(project, user, ids)
self.save_label_type(project)
ids = self.save_example(project)
self.save_label(project, user, ids)
def save_label(self, project: Project):
def save_label_type(self, project: Project):
labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.examples]))
labels = list(filter(None, labels))
groups = group_by_class(labels)
for klass, instances in groups.items():
klass.objects.bulk_create(instances, ignore_conflicts=True)
def save_data(self, project: Project) -> List[Example]:
def save_example(self, project: Project) -> List[Example]:
examples = [example.create_data(project) for example in self.examples]
return Example.objects.bulk_create(examples)
def save_annotation(self, project: Project, user, examples):
# Todo: move annotation class
def save_label(self, project: Project, user, examples):
mapping = {}
label_types: List[Type[LabelType]] = [CategoryType, SpanType]
for model in label_types:
for label in model.objects.filter(project=project):
mapping[label.text] = label
annotations = list(
itertools.chain.from_iterable(
[data.create_annotation(user, example, mapping) for data, example in zip(self.examples, examples)]

Loading…
Cancel
Save