Browse Source

Add FastTextFormatter and FastTextWriter

pull/1799/head
Hironsan 2 years ago
parent
commit
389e482193
5 changed files with 44 additions and 3 deletions
  1. 3
      backend/data_export/pipeline/factories.py
  2. 5
      backend/data_export/pipeline/formatters.py
  3. 8
      backend/data_export/pipeline/writers.py
  4. 16
      backend/data_export/tests/test_formatters.py
  5. 15
      backend/data_export/tests/test_writer.py

3
backend/data_export/pipeline/factories.py

@ -20,7 +20,7 @@ def select_writer(file_format: str) -> Type[writers.Writer]:
catalog.CSV.name: writers.CsvWriter,
catalog.JSON.name: writers.JsonWriter,
catalog.JSONL.name: writers.JsonlWriter,
# catalog.FastText.name: writers.FastTextWriter,
catalog.FastText.name: writers.FastTextWriter,
}
if file_format not in mapping:
ValueError(f"Invalid format: {file_format}")
@ -34,6 +34,7 @@ def select_formatter(project, file_format: str) -> List[Type[formatters.Formatte
catalog.CSV.name: [formatters.JoinedCategoryFormatter],
catalog.JSON.name: [formatters.ListedCategoryFormatter],
catalog.JSONL.name: [formatters.ListedCategoryFormatter],
catalog.FastText.name: [formatters.FastTextCategoryFormatter],
},
SEQUENCE_LABELING: {
catalog.JSONL.name: [formatters.DictFormatter, formatters.DictFormatter]

5
backend/data_export/pipeline/formatters.py

@ -47,8 +47,10 @@ class FastTextCategoryFormatter(Formatter):
"""
dataset = dataset[[DATA, self.target_column]]
dataset[self.target_column] = dataset[self.target_column].apply(
lambda labels: sorted(f"__label__{label.to_string()}" for label in labels)
lambda labels: " ".join(sorted(f"__label__{label.to_string()}" for label in labels))
)
dataset[self.target_column] = dataset[self.target_column].fillna("")
dataset = dataset[self.target_column] + " " + dataset[DATA]
return dataset
@ -63,6 +65,7 @@ class TupledSpanFormatter(Formatter):
class DictFormatter(Formatter):
def apply(self, dataset: pd.DataFrame) -> pd.DataFrame:
"""Format the column to `{key: value}` format"""
dataset[self.target_column] = dataset[self.target_column].apply(
lambda labels: [label.to_dict() for label in labels]
)

8
backend/data_export/pipeline/writers.py

@ -50,3 +50,11 @@ class JsonlWriter(Writer):
@staticmethod
def write(file, dataset: pd.DataFrame):
dataset.to_json(file, orient="records", force_ascii=False, lines=True)
class FastTextWriter(Writer):
extension = "txt"
@staticmethod
def write(file, dataset: pd.DataFrame):
dataset.to_csv(file, index=False, encoding="utf-8", header=False)

16
backend/data_export/tests/test_formatters.py

@ -4,8 +4,10 @@ from unittest.mock import MagicMock
import pandas as pd
from pandas.testing import assert_frame_equal
from data_export.models import DATA
from data_export.pipeline.formatters import (
DictFormatter,
FastTextCategoryFormatter,
JoinedCategoryFormatter,
ListedCategoryFormatter,
TupledSpanFormatter,
@ -68,3 +70,17 @@ class TestTupledSpanFormatter(unittest.TestCase):
dataset = formatter.format(self.dataset)
expected_dataset = pd.DataFrame([{TARGET_COLUMN: [self.return_value]}])
assert_frame_equal(dataset, expected_dataset)
class TestFastTextFormatter(unittest.TestCase):
def setUp(self):
self.return_value = "Label"
label = MagicMock()
label.to_string.return_value = self.return_value
self.dataset = pd.DataFrame([{TARGET_COLUMN: [label], DATA: "example"}])
def test_format(self):
formatter = FastTextCategoryFormatter(TARGET_COLUMN)
dataset = formatter.format(self.dataset)
expected_dataset = pd.DataFrame([f"__label__{self.return_value} example"])
self.assertEqual(dataset.to_csv(index=False, header=None), expected_dataset.to_csv(index=False, header=None))

15
backend/data_export/tests/test_writer.py

@ -4,7 +4,7 @@ import unittest
import pandas as pd
from pandas.testing import assert_frame_equal
from ..pipeline.writers import CsvWriter, JsonlWriter, JsonWriter
from ..pipeline.writers import CsvWriter, FastTextWriter, JsonlWriter, JsonWriter
class TestWriter(unittest.TestCase):
@ -44,3 +44,16 @@ class TestJsonlWriter(TestWriter):
writer.write(self.file, self.dataset)
loaded_dataset = pd.read_json(self.file, lines=True)
assert_frame_equal(self.dataset, loaded_dataset)
class TestFastText(unittest.TestCase):
def setUp(self):
self.expected = "__label__A exampleA\n__label__B exampleB"
self.dataset = pd.DataFrame([*zip(self.expected.split("\n"))])
def test_write(self):
file = "tmp.txt"
writer = FastTextWriter()
writer.write(file, self.dataset)
loaded_dataset = open(file, encoding="utf-8").read().strip()
self.assertEqual(loaded_dataset, self.expected)
Loading…
Cancel
Save