mirror of https://github.com/doccano/doccano.git
pythondatasetsactive-learningtext-annotationdatasetnatural-language-processingdata-labelingmachine-learningannotation-tool
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
3.0 KiB
76 lines
3.0 KiB
import json
|
|
import unittest
|
|
from unittest.mock import call, patch
|
|
|
|
from ..pipeline.data import Record
|
|
from ..pipeline.writers import CsvWriter, IntentAndSlotWriter
|
|
|
|
|
|
class TestCSVWriter(unittest.TestCase):
|
|
def setUp(self):
|
|
self.records = [
|
|
Record(id=0, data="exampleA", label=["labelA"], user="admin", metadata={"hidden": "secretA"}),
|
|
Record(id=1, data="exampleB", label=["labelB"], user="admin", metadata={"hidden": "secretB"}),
|
|
Record(id=2, data="exampleC", label=["labelC"], user="admin", metadata={"meta": "secretC"}),
|
|
]
|
|
|
|
def test_create_header(self):
|
|
writer = CsvWriter(".")
|
|
header = writer.create_header(self.records)
|
|
expected = ["id", "data", "label", "hidden", "meta"]
|
|
self.assertEqual(header, expected)
|
|
|
|
def test_create_line(self):
|
|
writer = CsvWriter(".")
|
|
record = self.records[0]
|
|
line = writer.create_line(record)
|
|
expected = {"id": record.id, "data": record.data, "label": record.label[0], "hidden": "secretA"}
|
|
self.assertEqual(line, expected)
|
|
|
|
def test_label_order(self):
|
|
writer = CsvWriter(".")
|
|
record1 = Record(id=0, data="", label=["labelA", "labelB"], user="", metadata={})
|
|
record2 = Record(id=0, data="", label=["labelB", "labelA"], user="", metadata={})
|
|
line1 = writer.create_line(record1)
|
|
line2 = writer.create_line(record2)
|
|
expected = "labelA#labelB"
|
|
self.assertEqual(line1["label"], expected)
|
|
self.assertEqual(line2["label"], expected)
|
|
|
|
@patch("os.remove")
|
|
@patch("zipfile.ZipFile")
|
|
@patch("csv.DictWriter.writerow")
|
|
@patch("builtins.open")
|
|
def test_dump(self, mock_open_file, csv_io, zip_io, mock_remove_file):
|
|
writer = CsvWriter(".")
|
|
writer.write(self.records)
|
|
|
|
self.assertEqual(mock_open_file.call_count, 1)
|
|
mock_open_file.assert_called_with("./admin.csv", mode="a", encoding="utf-8")
|
|
|
|
self.assertEqual(csv_io.call_count, len(self.records) + 1) # +1 is for a header
|
|
calls = [
|
|
call({"id": "id", "data": "data", "label": "label", "hidden": "hidden", "meta": "meta"}),
|
|
call({"id": 0, "data": "exampleA", "label": "labelA", "hidden": "secretA"}),
|
|
call({"id": 1, "data": "exampleB", "label": "labelB", "hidden": "secretB"}),
|
|
call({"id": 2, "data": "exampleC", "label": "labelC", "meta": "secretC"}),
|
|
]
|
|
csv_io.assert_has_calls(calls)
|
|
|
|
|
|
class TestIntentWriter(unittest.TestCase):
|
|
def setUp(self):
|
|
self.record = Record(
|
|
id=0, data="exampleA", label={"cats": ["positive"], "entities": [(0, 1, "LOC")]}, user="admin", metadata={}
|
|
)
|
|
|
|
def test_create_line(self):
|
|
writer = IntentAndSlotWriter(".")
|
|
actual = writer.create_line(self.record)
|
|
expected = {
|
|
"id": self.record.id,
|
|
"text": self.record.data,
|
|
"cats": ["positive"],
|
|
"entities": [[0, 1, "LOC"]],
|
|
}
|
|
self.assertEqual(json.loads(actual), expected)
|