Browse Source

Simplify readers

pull/1823/head
Hironsan 2 years ago
parent
commit
73b3be3d19
2 changed files with 10 additions and 31 deletions
  1. 38
      backend/data_import/pipeline/readers.py
  2. 3
      backend/data_import/tests/test_reader.py

38
backend/data_import/pipeline/readers.py

@ -7,14 +7,13 @@ from typing import Any, Dict, Iterator, List
import pandas as pd
from .exceptions import FileParseException
from .labeled_examples import Record
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_LABEL_COLUMN = "label"
LINE_NUM_COLUMN = "#line_num"
FILE_NAME_COLUMN = "filename"
UPLOAD_NAME_COLUMN = "upload_name"
UUID_COLUMN = "uuid"
UUID_COLUMN = "example_uuid"
LINE_NUMBER_COLUMN = "#line_number"
class BaseReader(collections.abc.Iterable):
@ -60,35 +59,21 @@ class FileName:
upload_name: str
class Builder(abc.ABC):
"""The abstract Record builder."""
@abc.abstractmethod
def build(self, row: Dict[Any, Any], filename: FileName, line_num: int) -> Record:
"""Builds the record from the dictionary."""
raise NotImplementedError("Please implement this method in the subclass.")
class Reader(BaseReader):
def __init__(self, filenames: List[FileName], parser: Parser):
self.filenames = filenames
self.parser = parser
self._errors: List[FileParseException] = []
def __iter__(self) -> Iterator[Dict[Any, Any]]:
for filename in self.filenames:
rows = self.parser.parse(filename.full_path)
for line_num, row in enumerate(rows, start=1):
try:
yield {
LINE_NUM_COLUMN: line_num,
UUID_COLUMN: uuid.uuid4(),
FILE_NAME_COLUMN: filename.generated_name,
UPLOAD_NAME_COLUMN: filename.upload_name,
**row,
}
except FileParseException as e:
self._errors.append(e)
for row in rows:
yield {
UUID_COLUMN: uuid.uuid4(),
FILE_NAME_COLUMN: filename.generated_name,
UPLOAD_NAME_COLUMN: filename.upload_name,
**row,
}
def batch(self, batch_size: int) -> Iterator[pd.DataFrame]:
batch = []
@ -102,7 +87,4 @@ class Reader(BaseReader):
@property
def errors(self) -> List[FileParseException]:
"""Aggregates parser and builder errors."""
errors = self.parser.errors + self._errors
errors.sort(key=lambda error: error.line_num)
return [error.dict() for error in errors]
return self.parser.errors

3
backend/data_import/tests/test_reader.py

@ -6,7 +6,6 @@ from pandas.testing import assert_frame_equal
from data_import.pipeline.readers import (
FILE_NAME_COLUMN,
LINE_NUM_COLUMN,
UPLOAD_NAME_COLUMN,
UUID_COLUMN,
Reader,
@ -24,14 +23,12 @@ class TestReader(unittest.TestCase):
self.filenames.__iter__.return_value = [filename]
self.rows = [
{
LINE_NUM_COLUMN: 1,
UUID_COLUMN: "uuid",
FILE_NAME_COLUMN: filename.generated_name,
UPLOAD_NAME_COLUMN: filename.upload_name,
"a": 1,
},
{
LINE_NUM_COLUMN: 2,
UUID_COLUMN: "uuid",
FILE_NAME_COLUMN: filename.generated_name,
UPLOAD_NAME_COLUMN: filename.upload_name,

Loading…
Cancel
Save