You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

80 lines
3.1 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import json
  2. import unittest
  3. from unittest.mock import call, patch
  4. from ..pipeline.data import Record
  5. from ..pipeline.writers import CsvWriter, IntentAndSlotWriter
  6. class TestCSVWriter(unittest.TestCase):
  7. def setUp(self):
  8. self.records = [
  9. Record(data_id=0, data="exampleA", label=["labelA"], user="admin", metadata={"hidden": "secretA"}),
  10. Record(data_id=1, data="exampleB", label=["labelB"], user="admin", metadata={"hidden": "secretB"}),
  11. Record(data_id=2, data="exampleC", label=["labelC"], user="admin", metadata={"meta": "secretC"}),
  12. ]
  13. def test_create_header(self):
  14. writer = CsvWriter(".")
  15. header = writer.create_header(self.records)
  16. expected = ["id", "data", "label", "hidden", "meta"]
  17. self.assertEqual(header, expected)
  18. def test_create_line(self):
  19. writer = CsvWriter(".")
  20. record = self.records[0]
  21. line = writer.create_line(record)
  22. expected = {"id": record.id, "data": record.data, "label": record.label[0], "hidden": "secretA"}
  23. self.assertEqual(line, expected)
  24. def test_label_order(self):
  25. writer = CsvWriter(".")
  26. record1 = Record(data_id=0, data="", label=["labelA", "labelB"], user="", metadata={})
  27. record2 = Record(data_id=0, data="", label=["labelB", "labelA"], user="", metadata={})
  28. line1 = writer.create_line(record1)
  29. line2 = writer.create_line(record2)
  30. expected = "labelA#labelB"
  31. self.assertEqual(line1["label"], expected)
  32. self.assertEqual(line2["label"], expected)
  33. @patch("os.remove")
  34. @patch("zipfile.ZipFile")
  35. @patch("csv.DictWriter.writerow")
  36. @patch("builtins.open")
  37. def test_dump(self, mock_open_file, csv_io, zip_io, mock_remove_file):
  38. writer = CsvWriter(".")
  39. writer.write(self.records)
  40. self.assertEqual(mock_open_file.call_count, 1)
  41. mock_open_file.assert_called_with("./admin.csv", mode="a", encoding="utf-8")
  42. self.assertEqual(csv_io.call_count, len(self.records) + 1) # +1 is for a header
  43. calls = [
  44. call({"id": "id", "data": "data", "label": "label", "hidden": "hidden", "meta": "meta"}),
  45. call({"id": 0, "data": "exampleA", "label": "labelA", "hidden": "secretA"}),
  46. call({"id": 1, "data": "exampleB", "label": "labelB", "hidden": "secretB"}),
  47. call({"id": 2, "data": "exampleC", "label": "labelC", "meta": "secretC"}),
  48. ]
  49. csv_io.assert_has_calls(calls)
  50. class TestIntentWriter(unittest.TestCase):
  51. def setUp(self):
  52. self.record = Record(
  53. data_id=0,
  54. data="exampleA",
  55. label={"cats": ["positive"], "entities": [(0, 1, "LOC")]},
  56. user="admin",
  57. metadata={},
  58. )
  59. def test_create_line(self):
  60. writer = IntentAndSlotWriter(".")
  61. actual = writer.create_line(self.record)
  62. expected = {
  63. "id": self.record.id,
  64. "text": self.record.data,
  65. "cats": ["positive"],
  66. "entities": [[0, 1, "LOC"]],
  67. }
  68. self.assertEqual(json.loads(actual), expected)