You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

86 lines
3.1 KiB

2 years ago
2 years ago
2 years ago
  1. """
  2. Convert a dataset to the specified format.
  3. """
  4. import abc
  5. import pandas as pd
  6. from data_export.models import DATA
  7. class Formatter(abc.ABC):
  8. def __init__(self, target_column: str = "labels", **kwargs):
  9. self.target_column = target_column
  10. self.mapper = kwargs
  11. def format(self, dataset: pd.DataFrame) -> pd.DataFrame:
  12. if self.target_column not in dataset.columns:
  13. return dataset
  14. return self.apply(dataset)
  15. @abc.abstractmethod
  16. def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
  17. raise NotImplementedError("Please implement this method in the subclass.")
  18. class JoinedCategoryFormatter(Formatter):
  19. def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
  20. """Format the label column to `LabelA#LabelB` format."""
  21. dataset[self.target_column] = dataset[self.target_column].apply(
  22. lambda labels: "#".join(sorted(label.to_string() for label in labels))
  23. )
  24. return dataset
  25. class ListedCategoryFormatter(Formatter):
  26. def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
  27. """Format the label column to `['LabelA', 'LabelB']` format."""
  28. dataset[self.target_column] = dataset[self.target_column].apply(
  29. lambda labels: sorted([label.to_string() for label in labels])
  30. )
  31. return dataset
  32. class FastTextCategoryFormatter(Formatter):
  33. def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
  34. """Format the label column to `__label__LabelA __label__LabelB` format.
  35. Also, drop the columns except for `data` and `self.target_column`.
  36. """
  37. dataset = dataset[[DATA, self.target_column, "Comments"]]
  38. dataset[self.target_column] = dataset[self.target_column].apply(
  39. lambda labels: " ".join(sorted(f"__label__{label.to_string()}" for label in labels))
  40. )
  41. dataset[self.target_column] = dataset[self.target_column].fillna("")
  42. dataset["Comments"] = dataset["Comments"].apply(
  43. lambda comments: " ".join(f"__comment__{comment.to_string()}" for comment in comments)
  44. )
  45. dataset = dataset[self.target_column] + " " + dataset[DATA] + " " + dataset["Comments"]
  46. return dataset
  47. class TupledSpanFormatter(Formatter):
  48. def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
  49. """Format the span column to `(start_offset, end_offset, label)` format"""
  50. dataset[self.target_column] = dataset[self.target_column].apply(
  51. lambda spans: sorted(span.to_tuple() for span in spans)
  52. )
  53. return dataset
  54. class DictFormatter(Formatter):
  55. def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
  56. """Format the column to `{key: value}` format"""
  57. dataset[self.target_column] = dataset[self.target_column].apply(
  58. lambda labels: [label.to_dict() for label in labels]
  59. )
  60. return dataset
  61. class RenameFormatter(Formatter):
  62. def format(self, dataset: pd.DataFrame) -> pd.DataFrame:
  63. return self.apply(dataset)
  64. def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
  65. """Rename columns"""
  66. dataset.rename(columns=self.mapper, inplace=True)
  67. return dataset