You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
3.6 KiB

  1. import unittest
  2. from unittest.mock import MagicMock
  3. import pandas as pd
  4. from pandas.testing import assert_frame_equal
  5. from data_export.models import DATA
  6. from data_export.pipeline.formatters import (
  7. DictFormatter,
  8. FastTextCategoryFormatter,
  9. JoinedCategoryFormatter,
  10. ListedCategoryFormatter,
  11. RenameFormatter,
  12. TupledSpanFormatter,
  13. )
  14. TARGET_COLUMN = "labels"
  15. class TestDictFormatter(unittest.TestCase):
  16. def setUp(self):
  17. self.return_value = {"label": "Label"}
  18. label = MagicMock()
  19. label.to_dict.return_value = self.return_value
  20. self.dataset = pd.DataFrame([{TARGET_COLUMN: [label]}])
  21. def test_format(self):
  22. formatter = DictFormatter(TARGET_COLUMN)
  23. dataset = formatter.format(self.dataset)
  24. expected_dataset = pd.DataFrame([{TARGET_COLUMN: [self.return_value]}])
  25. assert_frame_equal(dataset, expected_dataset)
  26. class TestJoinedCategoryFormatter(unittest.TestCase):
  27. def setUp(self):
  28. self.return_value = "Label"
  29. label = MagicMock()
  30. label.to_string.return_value = self.return_value
  31. self.dataset = pd.DataFrame([{TARGET_COLUMN: [label]}])
  32. def test_format(self):
  33. formatter = JoinedCategoryFormatter(TARGET_COLUMN)
  34. dataset = formatter.format(self.dataset)
  35. expected_dataset = pd.DataFrame([{TARGET_COLUMN: self.return_value}])
  36. assert_frame_equal(dataset, expected_dataset)
  37. class TestListedCategoryFormatter(unittest.TestCase):
  38. def setUp(self):
  39. self.return_value = "Label"
  40. label = MagicMock()
  41. label.to_string.return_value = self.return_value
  42. self.dataset = pd.DataFrame([{TARGET_COLUMN: [label]}])
  43. def test_format(self):
  44. formatter = ListedCategoryFormatter(TARGET_COLUMN)
  45. dataset = formatter.format(self.dataset)
  46. expected_dataset = pd.DataFrame([{TARGET_COLUMN: [self.return_value]}])
  47. assert_frame_equal(dataset, expected_dataset)
  48. class TestTupledSpanFormatter(unittest.TestCase):
  49. def setUp(self):
  50. self.return_value = (0, 1, "Label")
  51. label = MagicMock()
  52. label.to_tuple.return_value = self.return_value
  53. self.dataset = pd.DataFrame([{TARGET_COLUMN: [label]}])
  54. def test_format(self):
  55. formatter = TupledSpanFormatter(TARGET_COLUMN)
  56. dataset = formatter.format(self.dataset)
  57. expected_dataset = pd.DataFrame([{TARGET_COLUMN: [self.return_value]}])
  58. assert_frame_equal(dataset, expected_dataset)
  59. class TestFastTextFormatter(unittest.TestCase):
  60. def setUp(self):
  61. self.return_value_label = "Label"
  62. self.return_value_comment = "Comment"
  63. label = MagicMock()
  64. comment = MagicMock()
  65. label.to_string.return_value = self.return_value_label
  66. comment.to_string.return_value = self.return_value_comment
  67. self.dataset = pd.DataFrame([{TARGET_COLUMN: [label], DATA: "example", "Comments": [comment]}])
  68. def test_format(self):
  69. formatter = FastTextCategoryFormatter(TARGET_COLUMN)
  70. dataset = formatter.format(self.dataset)
  71. expected_dataset = pd.DataFrame(
  72. [f"__label__{self.return_value_label} example __comment__{self.return_value_comment}"]
  73. )
  74. self.assertEqual(dataset.to_csv(index=False, header=None), expected_dataset.to_csv(index=False, header=None))
  75. class TestRenameFormatter(unittest.TestCase):
  76. def test_format(self):
  77. dataset = pd.DataFrame([{"data": "example"}])
  78. formatter = RenameFormatter(**{"data": "text"})
  79. dataset = formatter.format(dataset)
  80. expected_dataset = pd.DataFrame([{"text": "example"}])
  81. assert_frame_equal(dataset, expected_dataset)