|
|
@ -4,7 +4,7 @@ import unittest |
|
|
|
import pandas as pd |
|
|
|
from pandas.testing import assert_frame_equal |
|
|
|
|
|
|
|
from ..pipeline.writers import CsvWriter, JsonlWriter, JsonWriter |
|
|
|
from ..pipeline.writers import CsvWriter, FastTextWriter, JsonlWriter, JsonWriter |
|
|
|
|
|
|
|
|
|
|
|
class TestWriter(unittest.TestCase): |
|
|
@ -44,3 +44,16 @@ class TestJsonlWriter(TestWriter): |
|
|
|
writer.write(self.file, self.dataset) |
|
|
|
loaded_dataset = pd.read_json(self.file, lines=True) |
|
|
|
assert_frame_equal(self.dataset, loaded_dataset) |
|
|
|
|
|
|
|
|
|
|
|
class TestFastText(unittest.TestCase): |
|
|
|
def setUp(self): |
|
|
|
self.expected = "__label__A exampleA\n__label__B exampleB" |
|
|
|
self.dataset = pd.DataFrame([*zip(self.expected.split("\n"))]) |
|
|
|
|
|
|
|
def test_write(self): |
|
|
|
file = "tmp.txt" |
|
|
|
writer = FastTextWriter() |
|
|
|
writer.write(file, self.dataset) |
|
|
|
loaded_dataset = open(file, encoding="utf-8").read().strip() |
|
|
|
self.assertEqual(loaded_dataset, self.expected) |