diff --git a/backend/data_export/pipeline/formatters.py b/backend/data_export/pipeline/formatters.py index b72bf359..bfe6e460 100644 --- a/backend/data_export/pipeline/formatters.py +++ b/backend/data_export/pipeline/formatters.py @@ -10,17 +10,19 @@ class Formatter(abc.ABC): def __init__(self, target_column: str): self.target_column = target_column - @abc.abstractmethod def format(self, dataset: pd.DataFrame) -> pd.DataFrame: + if self.target_column not in dataset.columns: + return dataset + return self.apply(dataset) + + @abc.abstractmethod + def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: raise NotImplementedError("Please implement this method in the subclass.") class JoinedCategoryFormatter(Formatter): - def format(self, dataset: pd.DataFrame) -> pd.DataFrame: + def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: """Format the label column to `LabelA#LabelB` format.""" - if self.target_column not in dataset.columns: - return dataset - dataset[self.target_column] = dataset[self.target_column].apply( lambda labels: "#".join(sorted(label.to_string() for label in labels)) ) @@ -28,11 +30,8 @@ class JoinedCategoryFormatter(Formatter): class ListedCategoryFormatter(Formatter): - def format(self, dataset: pd.DataFrame) -> pd.DataFrame: + def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: """Format the label column to `['LabelA', 'LabelB']` format.""" - if self.target_column not in dataset.columns: - return dataset - dataset[self.target_column] = dataset[self.target_column].apply( lambda labels: sorted([label.to_string() for label in labels]) ) @@ -40,13 +39,10 @@ class ListedCategoryFormatter(Formatter): class FastTextCategoryFormatter(Formatter): - def format(self, dataset: pd.DataFrame) -> pd.DataFrame: + def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: """Format the label column to `__label__LabelA __label__LabelB` format. Also, drop the columns except for `data` and `self.target_column`. """ - if self.target_column not in dataset.columns: - return dataset - dataset = dataset[["data", self.target_column]] dataset[self.target_column] = dataset[self.target_column].apply( lambda labels: sorted(f"__label__{label.to_string()}" for label in labels) @@ -55,11 +51,8 @@ class FastTextCategoryFormatter(Formatter): class TupledSpanFormatter(Formatter): - def format(self, dataset: pd.DataFrame) -> pd.DataFrame: + def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: """Format the span column to `(start_offset, end_offset, label)` format""" - if self.target_column not in dataset.columns: - return dataset - dataset[self.target_column] = dataset[self.target_column].apply( lambda spans: sorted(span.to_tuple() for span in spans) ) @@ -67,10 +60,7 @@ class TupledSpanFormatter(Formatter): class DictFormatter(Formatter): - def format(self, dataset: pd.DataFrame) -> pd.DataFrame: - if self.target_column not in dataset.columns: - return dataset - + def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: dataset[self.target_column] = dataset[self.target_column].apply( lambda labels: [label.to_dict() for label in labels] )