You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
4.2 KiB

  1. from typing import List, Optional, Type
  2. import pandas as pd
  3. from .data import BaseData
  4. from .exceptions import FileParseException
  5. from .label import Label
  6. from .readers import (
  7. DEFAULT_TEXT_COLUMN,
  8. LINE_NUMBER_COLUMN,
  9. UPLOAD_NAME_COLUMN,
  10. UUID_COLUMN,
  11. )
  12. from examples.models import Example
  13. from projects.models import Project
  14. class ExampleMaker:
  15. def __init__(
  16. self,
  17. project: Project,
  18. data_class: Type[BaseData],
  19. column_data: str = DEFAULT_TEXT_COLUMN,
  20. exclude_columns: Optional[List[str]] = None,
  21. ):
  22. self.project = project
  23. self.data_class = data_class
  24. self.column_data = column_data
  25. self.exclude_columns = exclude_columns or []
  26. self._errors: List[FileParseException] = []
  27. def make(self, df: pd.DataFrame) -> List[Example]:
  28. if not self.check_column_existence(df):
  29. return []
  30. self.check_value_existence(df)
  31. # make dataframe without exclude columns and missing data
  32. df_with_data_column = df.loc[:, ~df.columns.isin(self.exclude_columns)]
  33. df_with_data_column = df_with_data_column.dropna(subset=[self.column_data])
  34. examples = []
  35. for row in df_with_data_column.to_dict(orient="records"):
  36. line_num = row.pop(LINE_NUMBER_COLUMN, 0)
  37. row[DEFAULT_TEXT_COLUMN] = row.pop(self.column_data) # Rename column for parsing
  38. try:
  39. data = self.data_class.parse(**row)
  40. example = data.create(self.project)
  41. examples.append(example)
  42. except ValueError:
  43. message = f"Invalid data in line {line_num}"
  44. error = FileParseException(row[UPLOAD_NAME_COLUMN], line_num, message)
  45. self._errors.append(error)
  46. return examples
  47. def check_column_existence(self, df: pd.DataFrame) -> bool:
  48. message = f"Column {self.column_data} not found in the file"
  49. if self.column_data not in df.columns:
  50. for filename in df[UPLOAD_NAME_COLUMN].unique():
  51. self._errors.append(FileParseException(filename, 0, message))
  52. return False
  53. return True
  54. def check_value_existence(self, df: pd.DataFrame):
  55. df_without_data_column = df[df[self.column_data].isnull()]
  56. for row in df_without_data_column.to_dict(orient="records"):
  57. message = f"Column {self.column_data} not found in record"
  58. error = FileParseException(row[UPLOAD_NAME_COLUMN], row.get(LINE_NUMBER_COLUMN, 0), message)
  59. self._errors.append(error)
  60. @property
  61. def errors(self) -> List[FileParseException]:
  62. self._errors.sort(key=lambda error: error.line_num)
  63. return self._errors
  64. class BinaryExampleMaker(ExampleMaker):
  65. def make(self, df: pd.DataFrame) -> List[Example]:
  66. examples = []
  67. for row in df.to_dict(orient="records"):
  68. data = self.data_class.parse(**row)
  69. example = data.create(self.project)
  70. examples.append(example)
  71. return examples
  72. class LabelMaker:
  73. def __init__(self, column: str, label_class: Type[Label]):
  74. self.column = column
  75. self.label_class = label_class
  76. self._errors: List[FileParseException] = []
  77. def make(self, df: pd.DataFrame) -> List[Label]:
  78. if not self.check_column_existence(df):
  79. return []
  80. df_label = df.explode(self.column)
  81. df_label = df_label[[UUID_COLUMN, self.column]]
  82. df_label.dropna(subset=[self.column], inplace=True)
  83. labels = []
  84. for row in df_label.to_dict(orient="records"):
  85. try:
  86. label = self.label_class.parse(row[UUID_COLUMN], row[self.column])
  87. labels.append(label)
  88. except ValueError:
  89. pass
  90. return labels
  91. def check_column_existence(self, df: pd.DataFrame) -> bool:
  92. message = f"Column {self.column} not found in the file"
  93. if self.column not in df.columns:
  94. for filename in df[UPLOAD_NAME_COLUMN].unique():
  95. self._errors.append(FileParseException(filename, 0, message))
  96. return False
  97. return True
  98. @property
  99. def errors(self) -> List[FileParseException]:
  100. self._errors.sort(key=lambda error: error.line_num)
  101. return self._errors