|
|
@ -7,14 +7,13 @@ from typing import Any, Dict, Iterator, List |
|
|
|
import pandas as pd |
|
|
|
|
|
|
|
from .exceptions import FileParseException |
|
|
|
from .labeled_examples import Record |
|
|
|
|
|
|
|
DEFAULT_TEXT_COLUMN = "text" |
|
|
|
DEFAULT_LABEL_COLUMN = "label" |
|
|
|
LINE_NUM_COLUMN = "#line_num" |
|
|
|
FILE_NAME_COLUMN = "filename" |
|
|
|
UPLOAD_NAME_COLUMN = "upload_name" |
|
|
|
UUID_COLUMN = "uuid" |
|
|
|
UUID_COLUMN = "example_uuid" |
|
|
|
LINE_NUMBER_COLUMN = "#line_number" |
|
|
|
|
|
|
|
|
|
|
|
class BaseReader(collections.abc.Iterable): |
|
|
@ -60,35 +59,21 @@ class FileName: |
|
|
|
upload_name: str |
|
|
|
|
|
|
|
|
|
|
|
class Builder(abc.ABC): |
|
|
|
"""The abstract Record builder.""" |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def build(self, row: Dict[Any, Any], filename: FileName, line_num: int) -> Record: |
|
|
|
"""Builds the record from the dictionary.""" |
|
|
|
raise NotImplementedError("Please implement this method in the subclass.") |
|
|
|
|
|
|
|
|
|
|
|
class Reader(BaseReader): |
|
|
|
def __init__(self, filenames: List[FileName], parser: Parser): |
|
|
|
self.filenames = filenames |
|
|
|
self.parser = parser |
|
|
|
self._errors: List[FileParseException] = [] |
|
|
|
|
|
|
|
def __iter__(self) -> Iterator[Dict[Any, Any]]: |
|
|
|
for filename in self.filenames: |
|
|
|
rows = self.parser.parse(filename.full_path) |
|
|
|
for line_num, row in enumerate(rows, start=1): |
|
|
|
try: |
|
|
|
yield { |
|
|
|
LINE_NUM_COLUMN: line_num, |
|
|
|
UUID_COLUMN: uuid.uuid4(), |
|
|
|
FILE_NAME_COLUMN: filename.generated_name, |
|
|
|
UPLOAD_NAME_COLUMN: filename.upload_name, |
|
|
|
**row, |
|
|
|
} |
|
|
|
except FileParseException as e: |
|
|
|
self._errors.append(e) |
|
|
|
for row in rows: |
|
|
|
yield { |
|
|
|
UUID_COLUMN: uuid.uuid4(), |
|
|
|
FILE_NAME_COLUMN: filename.generated_name, |
|
|
|
UPLOAD_NAME_COLUMN: filename.upload_name, |
|
|
|
**row, |
|
|
|
} |
|
|
|
|
|
|
|
def batch(self, batch_size: int) -> Iterator[pd.DataFrame]: |
|
|
|
batch = [] |
|
|
@ -102,7 +87,4 @@ class Reader(BaseReader): |
|
|
|
|
|
|
|
@property |
|
|
|
def errors(self) -> List[FileParseException]: |
|
|
|
"""Aggregates parser and builder errors.""" |
|
|
|
errors = self.parser.errors + self._errors |
|
|
|
errors.sort(key=lambda error: error.line_num) |
|
|
|
return [error.dict() for error in errors] |
|
|
|
return self.parser.errors |