diff --git a/backend/data_import/celery_tasks.py b/backend/data_import/celery_tasks.py index af1f45be..e5109164 100644 --- a/backend/data_import/celery_tasks.py +++ b/backend/data_import/celery_tasks.py @@ -61,12 +61,12 @@ def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], parser = create_parser(file_format, **kwargs) builder = create_builder(project, **kwargs) - reader = Reader(filenames=filenames, parser=parser, builder=builder) cleaner = create_cleaner(project) + reader = Reader(filenames=filenames, parser=parser, builder=builder, cleaner=cleaner) writer = Writer(batch_size=settings.IMPORT_BATCH_SIZE) - writer.save(reader, project, user, cleaner) + writer.save(reader, project, user) upload_to_store(temporary_uploads) - return {"error": writer.errors + errors} + return {"error": reader.errors + errors} def upload_to_store(temporary_uploads): diff --git a/backend/data_import/pipeline/readers.py b/backend/data_import/pipeline/readers.py index 40303f74..e9c7d9cd 100644 --- a/backend/data_import/pipeline/readers.py +++ b/backend/data_import/pipeline/readers.py @@ -35,7 +35,7 @@ class Record: changed = len(label) != len(self.label) self._label = label if changed: - raise FileParseException(filename=self._data.filename, line_num=self._line_num, message=cleaner.message) + return FileParseException(filename=self._data.filename, line_num=self._line_num, message=cleaner.message) @property def data(self): @@ -104,10 +104,11 @@ class Builder(abc.ABC): class Reader(BaseReader): - def __init__(self, filenames: List[FileName], parser: Parser, builder: Builder): + def __init__(self, filenames: List[FileName], parser: Parser, builder: Builder, cleaner: Cleaner): self.filenames = filenames self.parser = parser self.builder = builder + self.cleaner = cleaner self._errors: List[FileParseException] = [] def __iter__(self) -> Iterator[Record]: @@ -115,7 +116,11 @@ class Reader(BaseReader): rows = self.parser.parse(filename.full_path) for line_num, row in enumerate(rows, start=1): try: - yield self.builder.build(row, filename, line_num) + record = self.builder.build(row, filename, line_num) + maybe_error = record.clean(self.cleaner) + if maybe_error: + self._errors.append(maybe_error) + yield record except FileParseException as e: self._errors.append(e) @@ -123,4 +128,5 @@ class Reader(BaseReader): def errors(self) -> List[FileParseException]: """Aggregates parser and builder errors.""" errors = self.parser.errors + self._errors - return errors + errors.sort(key=lambda error: error.line_num) + return [error.dict() for error in errors] diff --git a/backend/data_import/pipeline/writers.py b/backend/data_import/pipeline/writers.py index 5f4f8b53..4850fdfe 100644 --- a/backend/data_import/pipeline/writers.py +++ b/backend/data_import/pipeline/writers.py @@ -1,10 +1,9 @@ import itertools from collections import defaultdict -from typing import Any, Dict, List, Type +from typing import List, Type from django.conf import settings -from .exceptions import FileParseException from .readers import BaseReader, Record from examples.models import Example from label_types.models import CategoryType, LabelType, SpanType @@ -45,19 +44,14 @@ class Examples: class Writer: def __init__(self, batch_size: int): self.examples = Examples(batch_size) - self._errors: List[FileParseException] = [] - def save(self, reader: BaseReader, project: Project, user, cleaner): + def save(self, reader: BaseReader, project: Project, user): it = iter(reader) while True: try: example = next(it) except StopIteration: break - try: - example.clean(cleaner) - except FileParseException as err: - self._errors.append(err) self.examples.add(example) if self.examples.is_full(): @@ -66,36 +60,30 @@ class Writer: if not self.examples.is_empty(): self.create(project, user) self.examples.clear() - self._errors.extend(reader.errors) - - @property - def errors(self) -> List[Dict[Any, Any]]: - self._errors.sort(key=lambda e: e.line_num) - return [error.dict() for error in self._errors] def create(self, project: Project, user): - self.save_label(project) - ids = self.save_data(project) - self.save_annotation(project, user, ids) + self.save_label_type(project) + ids = self.save_example(project) + self.save_label(project, user, ids) - def save_label(self, project: Project): + def save_label_type(self, project: Project): labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.examples])) labels = list(filter(None, labels)) groups = group_by_class(labels) for klass, instances in groups.items(): klass.objects.bulk_create(instances, ignore_conflicts=True) - def save_data(self, project: Project) -> List[Example]: + def save_example(self, project: Project) -> List[Example]: examples = [example.create_data(project) for example in self.examples] return Example.objects.bulk_create(examples) - def save_annotation(self, project: Project, user, examples): - # Todo: move annotation class + def save_label(self, project: Project, user, examples): mapping = {} label_types: List[Type[LabelType]] = [CategoryType, SpanType] for model in label_types: for label in model.objects.filter(project=project): mapping[label.text] = label + annotations = list( itertools.chain.from_iterable( [data.create_annotation(user, example, mapping) for data, example in zip(self.examples, examples)]