You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

90 lines
2.6 KiB

2 years ago
2 years ago
2 years ago
2 years ago
  1. import abc
  2. import collections.abc
  3. import dataclasses
  4. import uuid
  5. from typing import Any, Dict, Iterator, List
  6. import pandas as pd
  7. from .exceptions import FileParseException
  8. DEFAULT_TEXT_COLUMN = "text"
  9. DEFAULT_LABEL_COLUMN = "label"
  10. FILE_NAME_COLUMN = "filename"
  11. UPLOAD_NAME_COLUMN = "upload_name"
  12. UUID_COLUMN = "example_uuid"
  13. LINE_NUMBER_COLUMN = "#line_number"
  14. class BaseReader(collections.abc.Iterable):
  15. """Reader has a role to parse files and return a Record iterator."""
  16. @abc.abstractmethod
  17. def __iter__(self) -> Iterator[Dict[Any, Any]]:
  18. """Creates an iterator for elements of this dataset.
  19. Returns:
  20. A `dict` for the elements of this dataset.
  21. """
  22. raise NotImplementedError("Please implement this method in the subclass.")
  23. @property
  24. @abc.abstractmethod
  25. def errors(self):
  26. raise NotImplementedError("Please implement this method in the subclass.")
  27. @abc.abstractmethod
  28. def batch(self, batch_size: int) -> Iterator[pd.DataFrame]:
  29. raise NotImplementedError("Please implement this method in the subclass.")
  30. class Parser(abc.ABC):
  31. """The abstract file parser."""
  32. @abc.abstractmethod
  33. def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
  34. """Parses the file and returns the dictionary."""
  35. raise NotImplementedError("Please implement this method in the subclass.")
  36. @property
  37. def errors(self) -> List[FileParseException]:
  38. """Returns parsing errors."""
  39. return []
  40. @dataclasses.dataclass
  41. class FileName:
  42. full_path: str
  43. generated_name: str
  44. upload_name: str
  45. class Reader(BaseReader):
  46. def __init__(self, filenames: List[FileName], parser: Parser):
  47. self.filenames = filenames
  48. self.parser = parser
  49. def __iter__(self) -> Iterator[Dict[Any, Any]]:
  50. for filename in self.filenames:
  51. rows = self.parser.parse(filename.full_path)
  52. for row in rows:
  53. yield {
  54. UUID_COLUMN: uuid.uuid4(),
  55. FILE_NAME_COLUMN: filename.generated_name,
  56. UPLOAD_NAME_COLUMN: filename.upload_name,
  57. **row,
  58. }
  59. def batch(self, batch_size: int) -> Iterator[pd.DataFrame]:
  60. batch = []
  61. for record in self:
  62. batch.append(record)
  63. if len(batch) == batch_size:
  64. yield pd.DataFrame(batch)
  65. batch = []
  66. if batch:
  67. yield pd.DataFrame(batch)
  68. @property
  69. def errors(self) -> List[FileParseException]:
  70. return self.parser.errors