Browse Source

Remove duplication from formatter

pull/1799/head
Hironsan 3 years ago
parent
commit
d2c6f0f4ac
1 changed files with 11 additions and 21 deletions
  1. 32
      backend/data_export/pipeline/formatters.py

32
backend/data_export/pipeline/formatters.py

@ -10,17 +10,19 @@ class Formatter(abc.ABC):
def __init__(self, target_column: str):
self.target_column = target_column
@abc.abstractmethod
def format(self, dataset: pd.DataFrame) -> pd.DataFrame:
if self.target_column not in dataset.columns:
return dataset
return self.apply(dataset)
@abc.abstractmethod
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
raise NotImplementedError("Please implement this method in the subclass.")
class JoinedCategoryFormatter(Formatter):
def format(self, dataset: pd.DataFrame) -> pd.DataFrame:
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
"""Format the label column to `LabelA#LabelB` format."""
if self.target_column not in dataset.columns:
return dataset
dataset[self.target_column] = dataset[self.target_column].apply(
lambda labels: "#".join(sorted(label.to_string() for label in labels))
)
@ -28,11 +30,8 @@ class JoinedCategoryFormatter(Formatter):
class ListedCategoryFormatter(Formatter):
def format(self, dataset: pd.DataFrame) -> pd.DataFrame:
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
"""Format the label column to `['LabelA', 'LabelB']` format."""
if self.target_column not in dataset.columns:
return dataset
dataset[self.target_column] = dataset[self.target_column].apply(
lambda labels: sorted([label.to_string() for label in labels])
)
@ -40,13 +39,10 @@ class ListedCategoryFormatter(Formatter):
class FastTextCategoryFormatter(Formatter):
def format(self, dataset: pd.DataFrame) -> pd.DataFrame:
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
"""Format the label column to `__label__LabelA __label__LabelB` format.
Also, drop the columns except for `data` and `self.target_column`.
"""
if self.target_column not in dataset.columns:
return dataset
dataset = dataset[["data", self.target_column]]
dataset[self.target_column] = dataset[self.target_column].apply(
lambda labels: sorted(f"__label__{label.to_string()}" for label in labels)
@ -55,11 +51,8 @@ class FastTextCategoryFormatter(Formatter):
class TupledSpanFormatter(Formatter):
def format(self, dataset: pd.DataFrame) -> pd.DataFrame:
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
"""Format the span column to `(start_offset, end_offset, label)` format"""
if self.target_column not in dataset.columns:
return dataset
dataset[self.target_column] = dataset[self.target_column].apply(
lambda spans: sorted(span.to_tuple() for span in spans)
)
@ -67,10 +60,7 @@ class TupledSpanFormatter(Formatter):
class DictFormatter(Formatter):
def format(self, dataset: pd.DataFrame) -> pd.DataFrame:
if self.target_column not in dataset.columns:
return dataset
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
dataset[self.target_column] = dataset[self.target_column].apply(
lambda labels: [label.to_dict() for label in labels]
)

Loading…
Cancel
Save