|
|
@ -10,17 +10,19 @@ class Formatter(abc.ABC): |
|
|
|
def __init__(self, target_column: str): |
|
|
|
self.target_column = target_column |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def format(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
if self.target_column not in dataset.columns: |
|
|
|
return dataset |
|
|
|
return self.apply(dataset) |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
raise NotImplementedError("Please implement this method in the subclass.") |
|
|
|
|
|
|
|
|
|
|
|
class JoinedCategoryFormatter(Formatter): |
|
|
|
def format(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
"""Format the label column to `LabelA#LabelB` format.""" |
|
|
|
if self.target_column not in dataset.columns: |
|
|
|
return dataset |
|
|
|
|
|
|
|
dataset[self.target_column] = dataset[self.target_column].apply( |
|
|
|
lambda labels: "#".join(sorted(label.to_string() for label in labels)) |
|
|
|
) |
|
|
@ -28,11 +30,8 @@ class JoinedCategoryFormatter(Formatter): |
|
|
|
|
|
|
|
|
|
|
|
class ListedCategoryFormatter(Formatter): |
|
|
|
def format(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
"""Format the label column to `['LabelA', 'LabelB']` format.""" |
|
|
|
if self.target_column not in dataset.columns: |
|
|
|
return dataset |
|
|
|
|
|
|
|
dataset[self.target_column] = dataset[self.target_column].apply( |
|
|
|
lambda labels: sorted([label.to_string() for label in labels]) |
|
|
|
) |
|
|
@ -40,13 +39,10 @@ class ListedCategoryFormatter(Formatter): |
|
|
|
|
|
|
|
|
|
|
|
class FastTextCategoryFormatter(Formatter): |
|
|
|
def format(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
"""Format the label column to `__label__LabelA __label__LabelB` format. |
|
|
|
Also, drop the columns except for `data` and `self.target_column`. |
|
|
|
""" |
|
|
|
if self.target_column not in dataset.columns: |
|
|
|
return dataset |
|
|
|
|
|
|
|
dataset = dataset[["data", self.target_column]] |
|
|
|
dataset[self.target_column] = dataset[self.target_column].apply( |
|
|
|
lambda labels: sorted(f"__label__{label.to_string()}" for label in labels) |
|
|
@ -55,11 +51,8 @@ class FastTextCategoryFormatter(Formatter): |
|
|
|
|
|
|
|
|
|
|
|
class TupledSpanFormatter(Formatter): |
|
|
|
def format(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
"""Format the span column to `(start_offset, end_offset, label)` format""" |
|
|
|
if self.target_column not in dataset.columns: |
|
|
|
return dataset |
|
|
|
|
|
|
|
dataset[self.target_column] = dataset[self.target_column].apply( |
|
|
|
lambda spans: sorted(span.to_tuple() for span in spans) |
|
|
|
) |
|
|
@ -67,10 +60,7 @@ class TupledSpanFormatter(Formatter): |
|
|
|
|
|
|
|
|
|
|
|
class DictFormatter(Formatter): |
|
|
|
def format(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
if self.target_column not in dataset.columns: |
|
|
|
return dataset |
|
|
|
|
|
|
|
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: |
|
|
|
dataset[self.target_column] = dataset[self.target_column].apply( |
|
|
|
lambda labels: [label.to_dict() for label in labels] |
|
|
|
) |
|
|
|