From d4e216b188a657f9cd3efe6225407227e51223c1 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Mon, 20 Dec 2021 09:07:26 +0900 Subject: [PATCH] Update ingest_task function --- backend/api/tasks.py | 17 ++---- backend/api/tests/upload/test_builder.py | 2 +- backend/api/tests/upload/test_parser.py | 6 +- backend/api/views/upload/builders.py | 12 ++-- backend/api/views/upload/factory.py | 32 +++-------- backend/api/views/upload/parsers.py | 73 ++++++++++++++++++------ backend/api/views/upload/readers.py | 42 ++++++++++++-- backend/api/views/upload/writers.py | 39 +++++-------- 8 files changed, 131 insertions(+), 92 deletions(-) diff --git a/backend/api/tasks.py b/backend/api/tasks.py index 5a294fa5..8e2cc020 100644 --- a/backend/api/tasks.py +++ b/backend/api/tasks.py @@ -7,8 +7,8 @@ from django.shortcuts import get_object_or_404 from .models import Project from .views.download.factory import create_repository, create_writer from .views.download.service import ExportApplicationService -from .views.upload.factory import (create_cleaner, get_data_class, - get_dataset_class, get_label_class) +from .views.upload.factory import create_bulder, create_cleaner, create_parser +from .views.upload.readers import Reader from .views.upload.writers import BulkWriter logger = get_task_logger(__name__) @@ -19,17 +19,12 @@ def ingest_data(user_id, project_id, filenames, format: str, **kwargs): project = get_object_or_404(Project, pk=project_id) user = get_object_or_404(get_user_model(), pk=user_id) - dataset_class = get_dataset_class(format) - dataset = dataset_class( - filenames=filenames, - label_class=get_label_class(project.project_type), - data_class=get_data_class(project.project_type), - **kwargs - ) - it = iter(dataset) + parser = create_parser(format, **kwargs) + builder = create_bulder(project, **kwargs) + reader = Reader(filenames=filenames, parser=parser, builder=builder) cleaner = create_cleaner(project) writer = BulkWriter(batch_size=settings.IMPORT_BATCH_SIZE) - writer.save(it, project, user, cleaner) + writer.save(reader, project, user, cleaner) return {'error': writer.errors} diff --git a/backend/api/tests/upload/test_builder.py b/backend/api/tests/upload/test_builder.py index bd42f3b5..a7b9a645 100644 --- a/backend/api/tests/upload/test_builder.py +++ b/backend/api/tests/upload/test_builder.py @@ -8,7 +8,7 @@ from ...views.upload.label import CategoryLabel class TestColumnBuilder(unittest.TestCase): def assert_record(self, actual, expected): - self.assertEqual(actual.data['text'], expected['data']) + self.assertEqual(actual.data.text, expected['data']) self.assertEqual(actual.label, expected['label']) def test_can_load_default_column_names(self): diff --git a/backend/api/tests/upload/test_parser.py b/backend/api/tests/upload/test_parser.py index 403653eb..c73dc239 100644 --- a/backend/api/tests/upload/test_parser.py +++ b/backend/api/tests/upload/test_parser.py @@ -105,7 +105,7 @@ class TestFastTextParser(TestParser): def test_read(self): content = '__label__sauce __label__cheese Text' parser = parsers.FastTextParser() - expected = [{'text': 'Text', 'labels': ['sauce', 'cheese']}] + expected = [{'text': 'Text', 'label': ['sauce', 'cheese']}] self.assert_record(content, parser, expected) @@ -130,11 +130,11 @@ Blackburn\tI-PER expected = [ { 'text': 'EU rejects German call to boycott British lamb .', - 'labels': [(0, 2, 'ORG'), (11, 17, 'MISC'), (34, 41, 'MISC')] + 'label': [(0, 2, 'ORG'), (11, 17, 'MISC'), (34, 41, 'MISC')] }, { 'text': 'Peter Blackburn', - 'labels': [(0, 15, 'PER')] + 'label': [(0, 15, 'PER')] } ] self.assert_record(content, parser, expected) diff --git a/backend/api/views/upload/builders.py b/backend/api/views/upload/builders.py index 8eea86f1..7be2ec2f 100644 --- a/backend/api/views/upload/builders.py +++ b/backend/api/views/upload/builders.py @@ -52,10 +52,7 @@ class DataColumn(Column): class LabelColumn(Column): def __call__(self, row: Dict[Any, Any], filename: str) -> List[Label]: - try: - return build_label(row, self.name, self.value_class) - except (KeyError, ValidationError, TypeError): - return [] + return build_label(row, self.name, self.value_class) class ColumnBuilder(Builder): @@ -77,7 +74,10 @@ class ColumnBuilder(Builder): labels = [] for column in self.label_columns: - labels.extend(column(row, filename)) - row.pop(column.name) + try: + labels.extend(column(row, filename)) + row.pop(column.name) + except (KeyError, ValidationError, TypeError): + pass return Record(data=data, label=labels, line_num=line_num, meta=row) diff --git a/backend/api/views/upload/factory.py b/backend/api/views/upload/factory.py index a33d368d..25484957 100644 --- a/backend/api/views/upload/factory.py +++ b/backend/api/views/upload/factory.py @@ -1,6 +1,6 @@ from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, SPEECH2TEXT) -from . import builders, catalog, cleaners, data, dataset, label, parsers +from . import builders, catalog, cleaners, data, label, parsers, readers def get_data_class(project_type: str): @@ -11,25 +11,7 @@ def get_data_class(project_type: str): return data.FileData -def get_dataset_class(format: str): - mapping = { - catalog.TextFile.name: dataset.TextFileDataset, - catalog.TextLine.name: dataset.TextLineDataset, - catalog.CSV.name: dataset.CsvDataset, - catalog.JSONL.name: dataset.JSONLDataset, - catalog.JSON.name: dataset.JSONDataset, - catalog.FastText.name: dataset.FastTextDataset, - catalog.Excel.name: dataset.ExcelDataset, - catalog.CoNLL.name: dataset.CoNLLDataset, - catalog.ImageFile.name: dataset.FileBaseDataset, - catalog.AudioFile.name: dataset.FileBaseDataset, - } - if format not in mapping: - ValueError(f'Invalid format: {format}') - return mapping[format] - - -def get_parser(file_format: str): +def create_parser(file_format: str, **kwargs): mapping = { catalog.TextFile.name: parsers.TextFileParser, catalog.TextLine.name: parsers.LineParser, @@ -44,7 +26,7 @@ def get_parser(file_format: str): } if file_format not in mapping: raise ValueError(f'Invalid format: {file_format}') - return mapping[file_format] + return mapping[file_format](**kwargs) def get_label_class(project_type: str): @@ -74,14 +56,14 @@ def create_cleaner(project): def create_bulder(project, **kwargs): data_column = builders.DataColumn( - name=kwargs.get('column_data', 'text'), + name=kwargs.get('column_data', readers.DEFAULT_TEXT_COLUMN), value_class=get_data_class(project.project_type) ) # Todo: If project is EntityClassification, # column names are fixed: entities, cats - label_column = builders.DataColumn( - name=kwargs.get('column_label', 'label'), - value_class=get_data_class(project.project_type) + label_column = builders.LabelColumn( + name=kwargs.get('column_label', readers.DEFAULT_LABEL_COLUMN), + value_class=get_label_class(project.project_type) ) builder = builders.ColumnBuilder( data_column=data_column, diff --git a/backend/api/views/upload/parsers.py b/backend/api/views/upload/parsers.py index ddd067ed..10f30049 100644 --- a/backend/api/views/upload/parsers.py +++ b/backend/api/views/upload/parsers.py @@ -79,7 +79,7 @@ class PlainParser(Parser): class LineParser(Parser): - def __init__(self, encoding: str = DEFAULT_ENCODING): + def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs): self.encoding = encoding def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: @@ -90,7 +90,7 @@ class LineParser(Parser): class TextFileParser(Parser): - def __init__(self, encoding: str = DEFAULT_ENCODING): + def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs): self.encoding = encoding def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: @@ -101,7 +101,7 @@ class TextFileParser(Parser): class CSVParser(Parser): - def __init__(self, encoding: str = DEFAULT_ENCODING, delimiter: str = ','): + def __init__(self, encoding: str = DEFAULT_ENCODING, delimiter: str = ',', **kwargs): self.encoding = encoding self.delimiter = delimiter @@ -115,39 +115,68 @@ class CSVParser(Parser): class JSONParser(Parser): - def __init__(self, encoding: str = DEFAULT_ENCODING): + def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs): self.encoding = encoding + self._errors = [] def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: encoding = decide_encoding(filename, self.encoding) with open(filename, encoding=encoding) as f: - rows = json.load(f) - for line_num, row in enumerate(rows, start=1): - yield row + try: + rows = json.load(f) + for line_num, row in enumerate(rows, start=1): + yield row + except json.decoder.JSONDecodeError as e: + error = FileParseException(filename, line_num=1, message=str(e)) + self._errors.append(error) + + @property + def errors(self) -> List[FileParseException]: + return self._errors class JSONLParser(Parser): - def __init__(self, encoding: str = DEFAULT_ENCODING): + def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs): self.encoding = encoding + self._errors = [] def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: reader = LineReader(filename, self.encoding) - for line in reader: - yield json.loads(line) + for line_num, line in enumerate(reader, start=1): + try: + yield json.loads(line) + except json.decoder.JSONDecodeError as e: + error = FileParseException(filename, line_num, str(e)) + self._errors.append(error) + + @property + def errors(self) -> List[FileParseException]: + return self._errors class ExcelParser(Parser): + def __init__(self, **kwargs): + self._errors = [] + def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: rows = pyexcel.iget_records(file_name=filename) - for row in rows: - yield row + try: + for line_num, row in enumerate(rows, start=1): + yield row + except pyexcel.exceptions.FileTypeNotSupported as e: + error = FileParseException(filename, line_num=1, message=str(e)) + self._errors.append(error) + + @property + def errors(self) -> List[FileParseException]: + return self._errors class FastTextParser(Parser): - def __init__(self, encoding: str = DEFAULT_ENCODING, label: str = '__label__'): + def __init__(self, encoding: str = DEFAULT_ENCODING, label: str = '__label__', **kwargs): self.encoding = encoding self.label = label @@ -168,7 +197,7 @@ class FastTextParser(Parser): class CoNLLParser(Parser): - def __init__(self, encoding: str = DEFAULT_ENCODING, delimiter: str = ' ', scheme: str = 'IOB2'): + def __init__(self, encoding: str = DEFAULT_ENCODING, delimiter: str = ' ', scheme: str = 'IOB2', **kwargs): self.encoding = encoding self.delimiter = delimiter mapping = { @@ -177,12 +206,23 @@ class CoNLLParser(Parser): 'IOBES': IOBES, 'BILOU': BILOU } + self._errors = [] if scheme in mapping: self.scheme = mapping[scheme] else: - raise Exception('The scheme is not supported.') + self.scheme = None + + @property + def errors(self) -> List[FileParseException]: + return self._errors def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: + if not self.scheme: + message = 'The specified scheme is not supported.' + error = FileParseException(filename, line_num=1, message=message) + self._errors.append(error) + return + reader = LineReader(filename, self.encoding) words, tags = [], [] for line_num, line in enumerate(reader, start=1): @@ -191,7 +231,8 @@ class CoNLLParser(Parser): tokens = line.split('\t') if len(tokens) != 2: message = 'A line must be separated by tab and has two columns.' - raise FileParseException(filename, line_num, message) + self._errors.append(FileParseException(filename, line_num, message)) + return word, tag = tokens words.append(word) tags.append(tag) diff --git a/backend/api/views/upload/readers.py b/backend/api/views/upload/readers.py index 45b87f0a..c8e9786b 100644 --- a/backend/api/views/upload/readers.py +++ b/backend/api/views/upload/readers.py @@ -2,11 +2,13 @@ import abc import collections.abc from typing import Any, Dict, Iterator, List, Type +from .cleaners import Cleaner from .data import BaseData +from .exception import FileParseException from .label import Label DEFAULT_TEXT_COLUMN = 'text' -DEFAULT_LABEL_COLUMN = 'labels' +DEFAULT_LABEL_COLUMN = 'label' class Record: @@ -28,9 +30,29 @@ class Record: def __str__(self): return f'{self._data}\t{self._label}' + def clean(self, cleaner: Cleaner): + label = cleaner.clean(self._label) + changed = len(label) != len(self.label) + self._label = label + if changed: + raise FileParseException( + filename=self._data.filename, + line_num=self._line_num, + message=cleaner.message + ) + @property def data(self): - return self._data.dict() + return self._data + + def create_data(self, project): + return self._data.create(project, self._meta) + + def create_label(self, project): + return [label.create(project) for label in self._label] + + def create_annotation(self, user, example, mapping): + return [label.create_annotation(user, example, mapping) for label in self._label] @property def label(self): @@ -66,6 +88,11 @@ class Parser(abc.ABC): """Parses the file and returns the dictionary.""" raise NotImplementedError('Please implement this method in the subclass.') + @property + def errors(self) -> List[FileParseException]: + """Returns parsing errors.""" + return [] + class Builder(abc.ABC): @@ -87,10 +114,13 @@ class Reader(BaseReader): for filename in self.filenames: rows = self.parser.parse(filename) for line_num, row in enumerate(rows, start=1): - record = self.builder.build(row, filename, line_num) - yield record + try: + yield self.builder.build(row, filename, line_num) + except FileParseException as e: + self._errors.append(e) @property - def errors(self): + def errors(self) -> List[FileParseException]: """Aggregates parser and builder errors.""" - return self._errors + errors = self.parser.errors + self._errors + return errors diff --git a/backend/api/views/upload/writers.py b/backend/api/views/upload/writers.py index 13b90d8d..9b6785e9 100644 --- a/backend/api/views/upload/writers.py +++ b/backend/api/views/upload/writers.py @@ -1,34 +1,24 @@ import abc import itertools +from collections import defaultdict from typing import List from django.conf import settings from ...models import Example, Label, Project -from .exception import FileParseException, FileParseExceptions +from .exception import FileParseException from .readers import BaseReader class Writer(abc.ABC): @abc.abstractmethod - def save(self, reader: BaseReader): + def save(self, reader: BaseReader, project: Project, user, cleaner): """Save the read contents to DB.""" raise NotImplementedError('Please implement this method in the subclass.') -class BulkWriterOld(Writer): - - def __init__(self, batch_size: int): - self.batch_size = batch_size - - def save(self, reader: BaseReader): - """Bulk save the read contents.""" - pass - - def group_by_class(instances): - from collections import defaultdict groups = defaultdict(list) for instance in instances: groups[instance.__class__].append(instance) @@ -79,28 +69,23 @@ class Examples: klass.objects.bulk_create(instances) -class BulkWriter: +class BulkWriter(Writer): def __init__(self, batch_size): self.examples = Examples(batch_size) - self.errors = [] + self._errors = [] - def save(self, dataset, project, user, cleaner): + def save(self, reader: BaseReader, project, user, cleaner): + it = iter(reader) while True: try: - example = next(dataset) + example = next(it) except StopIteration: break - except FileParseException as err: - self.errors.append(err.dict()) - continue - except FileParseExceptions as err: - self.errors.append(list(err)) - continue try: example.clean(cleaner) except FileParseException as err: - self.errors.append(err.dict()) + self._errors.append(err) self.examples.add(example) if self.examples.is_full(): @@ -109,6 +94,12 @@ class BulkWriter: if not self.examples.is_empty(): self.create(project, user) self.examples.clear() + self._errors.extend(reader.errors) + + @property + def errors(self) -> List[FileParseException]: + self._errors.sort(key=lambda e: e.line_num) + return self._errors def create(self, project, user): self.examples.save_label(project)