You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.7 KiB

  1. import os
  2. import unittest
  3. import pandas as pd
  4. from pandas.testing import assert_frame_equal
  5. from ..pipeline.writers import CsvWriter, FastTextWriter, JsonlWriter, JsonWriter
  6. class TestWriter(unittest.TestCase):
  7. def setUp(self):
  8. self.dataset = pd.DataFrame(
  9. [
  10. {"id": 0, "text": "A"},
  11. {"id": 1, "text": "B"},
  12. {"id": 2, "text": "C"},
  13. ]
  14. )
  15. self.file = "tmp.csv"
  16. def tearDown(self):
  17. os.remove(self.file)
  18. class TestCSVWriter(TestWriter):
  19. def test_write(self):
  20. writer = CsvWriter()
  21. writer.write(self.file, self.dataset)
  22. loaded_dataset = pd.read_csv(self.file)
  23. assert_frame_equal(self.dataset, loaded_dataset)
  24. class TestJsonWriter(TestWriter):
  25. def test_write(self):
  26. writer = JsonWriter()
  27. writer.write(self.file, self.dataset)
  28. loaded_dataset = pd.read_json(self.file)
  29. assert_frame_equal(self.dataset, loaded_dataset)
  30. class TestJsonlWriter(TestWriter):
  31. def test_write(self):
  32. writer = JsonlWriter()
  33. writer.write(self.file, self.dataset)
  34. loaded_dataset = pd.read_json(self.file, lines=True)
  35. assert_frame_equal(self.dataset, loaded_dataset)
  36. class TestFastText(unittest.TestCase):
  37. def setUp(self):
  38. self.expected = "__label__A exampleA\n__label__B exampleB"
  39. self.dataset = pd.DataFrame([*zip(self.expected.split("\n"))])
  40. def test_write(self):
  41. file = "tmp.txt"
  42. writer = FastTextWriter()
  43. writer.write(file, self.dataset)
  44. loaded_dataset = open(file, encoding="utf-8").read().strip()
  45. self.assertEqual(loaded_dataset, self.expected)