You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

102 lines
3.3 KiB

import abc
import collections.abc
import dataclasses
from typing import Any, Dict, Iterator, List
from .cleaners import Cleaner
from .exceptions import FileParseException
from .labeled_examples import LabeledExamples, Record
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_LABEL_COLUMN = "label"
class BaseReader(collections.abc.Iterable):
"""Reader has a role to parse files and return a Record iterator."""
@abc.abstractmethod
def __iter__(self) -> Iterator[Record]:
"""Creates an iterator for elements of this dataset.
Returns:
A `Record` for the elements of this dataset.
"""
raise NotImplementedError("Please implement this method in the subclass.")
@property
@abc.abstractmethod
def errors(self):
raise NotImplementedError("Please implement this method in the subclass.")
@abc.abstractmethod
def batch(self, batch_size: int) -> Iterator[LabeledExamples]:
raise NotImplementedError("Please implement this method in the subclass.")
class Parser(abc.ABC):
"""The abstract file parser."""
@abc.abstractmethod
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
"""Parses the file and returns the dictionary."""
raise NotImplementedError("Please implement this method in the subclass.")
@property
def errors(self) -> List[FileParseException]:
"""Returns parsing errors."""
return []
@dataclasses.dataclass
class FileName:
full_path: str
generated_name: str
upload_name: str
class Builder(abc.ABC):
"""The abstract Record builder."""
@abc.abstractmethod
def build(self, row: Dict[Any, Any], filename: FileName, line_num: int) -> Record:
"""Builds the record from the dictionary."""
raise NotImplementedError("Please implement this method in the subclass.")
class Reader(BaseReader):
def __init__(self, filenames: List[FileName], parser: Parser, builder: Builder, cleaner: Cleaner):
self.filenames = filenames
self.parser = parser
self.builder = builder
self.cleaner = cleaner
self._errors: List[FileParseException] = []
def __iter__(self) -> Iterator[Record]:
for filename in self.filenames:
rows = self.parser.parse(filename.full_path)
for line_num, row in enumerate(rows, start=1):
try:
record = self.builder.build(row, filename, line_num)
maybe_error = record.clean(self.cleaner)
if maybe_error:
self._errors.append(maybe_error)
yield record
except FileParseException as e:
self._errors.append(e)
def batch(self, batch_size: int) -> Iterator[LabeledExamples]:
batch = []
for record in self:
batch.append(record)
if len(batch) == batch_size:
yield LabeledExamples(batch)
batch = []
if batch:
yield LabeledExamples(batch)
@property
def errors(self) -> List[FileParseException]:
"""Aggregates parser and builder errors."""
errors = self.parser.errors + self._errors
errors.sort(key=lambda error: error.line_num)
return [error.dict() for error in errors]