|
|
import json import unittest from unittest.mock import call, patch
from ..pipeline.data import Record from ..pipeline.writers import CsvWriter, IntentAndSlotWriter
class TestCSVWriter(unittest.TestCase): def setUp(self): self.records = [ Record(id=0, data="exampleA", label=["labelA"], user="admin", metadata={"hidden": "secretA"}), Record(id=1, data="exampleB", label=["labelB"], user="admin", metadata={"hidden": "secretB"}), Record(id=2, data="exampleC", label=["labelC"], user="admin", metadata={"meta": "secretC"}), ]
def test_create_header(self): writer = CsvWriter(".") header = writer.create_header(self.records) expected = ["id", "data", "label", "hidden", "meta"] self.assertEqual(header, expected)
def test_create_line(self): writer = CsvWriter(".") record = self.records[0] line = writer.create_line(record) expected = {"id": record.id, "data": record.data, "label": record.label[0], "hidden": "secretA"} self.assertEqual(line, expected)
def test_label_order(self): writer = CsvWriter(".") record1 = Record(id=0, data="", label=["labelA", "labelB"], user="", metadata={}) record2 = Record(id=0, data="", label=["labelB", "labelA"], user="", metadata={}) line1 = writer.create_line(record1) line2 = writer.create_line(record2) expected = "labelA#labelB" self.assertEqual(line1["label"], expected) self.assertEqual(line2["label"], expected)
@patch("os.remove") @patch("zipfile.ZipFile") @patch("csv.DictWriter.writerow") @patch("builtins.open") def test_dump(self, mock_open_file, csv_io, zip_io, mock_remove_file): writer = CsvWriter(".") writer.write(self.records)
self.assertEqual(mock_open_file.call_count, 1) mock_open_file.assert_called_with("./admin.csv", mode="a", encoding="utf-8")
self.assertEqual(csv_io.call_count, len(self.records) + 1) # +1 is for a header calls = [ call({"id": "id", "data": "data", "label": "label", "hidden": "hidden", "meta": "meta"}), call({"id": 0, "data": "exampleA", "label": "labelA", "hidden": "secretA"}), call({"id": 1, "data": "exampleB", "label": "labelB", "hidden": "secretB"}), call({"id": 2, "data": "exampleC", "label": "labelC", "meta": "secretC"}), ] csv_io.assert_has_calls(calls)
class TestIntentWriter(unittest.TestCase): def setUp(self): self.record = Record( id=0, data="exampleA", label={"cats": ["positive"], "entities": [(0, 1, "LOC")]}, user="admin", metadata={} )
def test_create_line(self): writer = IntentAndSlotWriter(".") actual = writer.create_line(self.record) expected = { "id": self.record.id, "text": self.record.data, "cats": ["positive"], "entities": [[0, 1, "LOC"]], } self.assertEqual(json.loads(actual), expected)
|