From 389e482193f2c71a7ed34ff90c6ecc356c7af7ff Mon Sep 17 00:00:00 2001 From: Hironsan Date: Sun, 24 Apr 2022 20:14:54 +0900 Subject: [PATCH] Add FastTextFormatter and FastTextWriter --- backend/data_export/pipeline/factories.py | 3 ++- backend/data_export/pipeline/formatters.py | 5 ++++- backend/data_export/pipeline/writers.py | 8 ++++++++ backend/data_export/tests/test_formatters.py | 16 ++++++++++++++++ backend/data_export/tests/test_writer.py | 15 ++++++++++++++- 5 files changed, 44 insertions(+), 3 deletions(-) diff --git a/backend/data_export/pipeline/factories.py b/backend/data_export/pipeline/factories.py index 239c0c83..134733f9 100644 --- a/backend/data_export/pipeline/factories.py +++ b/backend/data_export/pipeline/factories.py @@ -20,7 +20,7 @@ def select_writer(file_format: str) -> Type[writers.Writer]: catalog.CSV.name: writers.CsvWriter, catalog.JSON.name: writers.JsonWriter, catalog.JSONL.name: writers.JsonlWriter, - # catalog.FastText.name: writers.FastTextWriter, + catalog.FastText.name: writers.FastTextWriter, } if file_format not in mapping: ValueError(f"Invalid format: {file_format}") @@ -34,6 +34,7 @@ def select_formatter(project, file_format: str) -> List[Type[formatters.Formatte catalog.CSV.name: [formatters.JoinedCategoryFormatter], catalog.JSON.name: [formatters.ListedCategoryFormatter], catalog.JSONL.name: [formatters.ListedCategoryFormatter], + catalog.FastText.name: [formatters.FastTextCategoryFormatter], }, SEQUENCE_LABELING: { catalog.JSONL.name: [formatters.DictFormatter, formatters.DictFormatter] diff --git a/backend/data_export/pipeline/formatters.py b/backend/data_export/pipeline/formatters.py index 5c4d758f..9c195b0a 100644 --- a/backend/data_export/pipeline/formatters.py +++ b/backend/data_export/pipeline/formatters.py @@ -47,8 +47,10 @@ class FastTextCategoryFormatter(Formatter): """ dataset = dataset[[DATA, self.target_column]] dataset[self.target_column] = dataset[self.target_column].apply( - lambda labels: sorted(f"__label__{label.to_string()}" for label in labels) + lambda labels: " ".join(sorted(f"__label__{label.to_string()}" for label in labels)) ) + dataset[self.target_column] = dataset[self.target_column].fillna("") + dataset = dataset[self.target_column] + " " + dataset[DATA] return dataset @@ -63,6 +65,7 @@ class TupledSpanFormatter(Formatter): class DictFormatter(Formatter): def apply(self, dataset: pd.DataFrame) -> pd.DataFrame: + """Format the column to `{key: value}` format""" dataset[self.target_column] = dataset[self.target_column].apply( lambda labels: [label.to_dict() for label in labels] ) diff --git a/backend/data_export/pipeline/writers.py b/backend/data_export/pipeline/writers.py index 8df24d4f..31ff499f 100644 --- a/backend/data_export/pipeline/writers.py +++ b/backend/data_export/pipeline/writers.py @@ -50,3 +50,11 @@ class JsonlWriter(Writer): @staticmethod def write(file, dataset: pd.DataFrame): dataset.to_json(file, orient="records", force_ascii=False, lines=True) + + +class FastTextWriter(Writer): + extension = "txt" + + @staticmethod + def write(file, dataset: pd.DataFrame): + dataset.to_csv(file, index=False, encoding="utf-8", header=False) diff --git a/backend/data_export/tests/test_formatters.py b/backend/data_export/tests/test_formatters.py index 98d40d22..28a1d643 100644 --- a/backend/data_export/tests/test_formatters.py +++ b/backend/data_export/tests/test_formatters.py @@ -4,8 +4,10 @@ from unittest.mock import MagicMock import pandas as pd from pandas.testing import assert_frame_equal +from data_export.models import DATA from data_export.pipeline.formatters import ( DictFormatter, + FastTextCategoryFormatter, JoinedCategoryFormatter, ListedCategoryFormatter, TupledSpanFormatter, @@ -68,3 +70,17 @@ class TestTupledSpanFormatter(unittest.TestCase): dataset = formatter.format(self.dataset) expected_dataset = pd.DataFrame([{TARGET_COLUMN: [self.return_value]}]) assert_frame_equal(dataset, expected_dataset) + + +class TestFastTextFormatter(unittest.TestCase): + def setUp(self): + self.return_value = "Label" + label = MagicMock() + label.to_string.return_value = self.return_value + self.dataset = pd.DataFrame([{TARGET_COLUMN: [label], DATA: "example"}]) + + def test_format(self): + formatter = FastTextCategoryFormatter(TARGET_COLUMN) + dataset = formatter.format(self.dataset) + expected_dataset = pd.DataFrame([f"__label__{self.return_value} example"]) + self.assertEqual(dataset.to_csv(index=False, header=None), expected_dataset.to_csv(index=False, header=None)) diff --git a/backend/data_export/tests/test_writer.py b/backend/data_export/tests/test_writer.py index 2d1b2195..3ef11497 100644 --- a/backend/data_export/tests/test_writer.py +++ b/backend/data_export/tests/test_writer.py @@ -4,7 +4,7 @@ import unittest import pandas as pd from pandas.testing import assert_frame_equal -from ..pipeline.writers import CsvWriter, JsonlWriter, JsonWriter +from ..pipeline.writers import CsvWriter, FastTextWriter, JsonlWriter, JsonWriter class TestWriter(unittest.TestCase): @@ -44,3 +44,16 @@ class TestJsonlWriter(TestWriter): writer.write(self.file, self.dataset) loaded_dataset = pd.read_json(self.file, lines=True) assert_frame_equal(self.dataset, loaded_dataset) + + +class TestFastText(unittest.TestCase): + def setUp(self): + self.expected = "__label__A exampleA\n__label__B exampleB" + self.dataset = pd.DataFrame([*zip(self.expected.split("\n"))]) + + def test_write(self): + file = "tmp.txt" + writer = FastTextWriter() + writer.write(file, self.dataset) + loaded_dataset = open(file, encoding="utf-8").read().strip() + self.assertEqual(loaded_dataset, self.expected)