You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
2.8 KiB

  1. import abc
  2. from logging import getLogger
  3. from typing import Any, Dict, List, Optional, Type, TypeVar
  4. from pydantic import ValidationError
  5. from .data import BaseData
  6. from .exceptions import FileParseException
  7. from .labels import Label
  8. from .readers import Builder, Record
  9. logger = getLogger(__name__)
  10. T = TypeVar('T')
  11. class PlainBuilder(Builder):
  12. def __init__(self, data_class: Type[BaseData]):
  13. self.data_class = data_class
  14. def build(self, row: Dict[Any, Any], filename: str, line_num: int) -> Record:
  15. data = self.data_class.parse(filename=filename)
  16. yield Record(data=data)
  17. def build_label(row: Dict[Any, Any], name: str, label_class: Type[Label]) -> List[Label]:
  18. labels = row[name]
  19. labels = [labels] if isinstance(labels, (str, int)) else labels
  20. return [label_class.parse(label) for label in labels]
  21. def build_data(row: Dict[Any, Any], name: str, data_class: Type[BaseData], filename: str) -> BaseData:
  22. data = row[name]
  23. return data_class.parse(text=data, filename=filename)
  24. class Column(abc.ABC):
  25. def __init__(self, name: str, value_class: Type[T]):
  26. self.name = name
  27. self.value_class = value_class
  28. @abc.abstractmethod
  29. def __call__(self, row: Dict[Any, Any], filename: str):
  30. raise NotImplementedError('Please implement this method in the subclass.')
  31. class DataColumn(Column):
  32. def __call__(self, row: Dict[Any, Any], filename: str) -> BaseData:
  33. return build_data(row, self.name, self.value_class, filename)
  34. class LabelColumn(Column):
  35. def __call__(self, row: Dict[Any, Any], filename: str) -> List[Label]:
  36. return build_label(row, self.name, self.value_class)
  37. class ColumnBuilder(Builder):
  38. def __init__(self, data_column: Column, label_columns: Optional[List[Column]] = None):
  39. self.data_column = data_column
  40. self.label_columns = label_columns or []
  41. def build(self, row: Dict[Any, Any], filename: str, line_num: int) -> Record:
  42. try:
  43. data = self.data_column(row, filename)
  44. row.pop(self.data_column.name)
  45. except KeyError:
  46. message = f'{self.data_column.name} field does not exist.'
  47. raise FileParseException(filename, line_num, message)
  48. except ValidationError:
  49. message = 'The empty text is not allowed.'
  50. raise FileParseException(filename, line_num, message)
  51. labels = []
  52. for column in self.label_columns:
  53. try:
  54. labels.extend(column(row, filename))
  55. row.pop(column.name)
  56. except (KeyError, ValidationError, TypeError) as e:
  57. logger.error('Filename: %s, Line: %s, Data: %s, Error: %s' % (filename, line_num, row, str(e)))
  58. return Record(data=data, label=labels, line_num=line_num, meta=row)