Browse Source

Update export writers

pull/1799/head
Hironsan 2 years ago
parent
commit
03981d391f
2 changed files with 57 additions and 232 deletions
  1. 187
      backend/data_export/pipeline/writers.py
  2. 102
      backend/data_export/tests/test_writer.py

187
backend/data_export/pipeline/writers.py

@ -1,180 +1,39 @@
import abc
import csv
import itertools
import json
import os
import uuid
import zipfile
from collections import defaultdict
from typing import Dict, Iterable, Iterator, List
from .data import Record
import pandas as pd
class BaseWriter:
def __init__(self, tmpdir: str):
self.tmpdir = tmpdir
@abc.abstractmethod
def write(self, records: Iterator[Record]) -> str:
raise NotImplementedError()
def write_zip(self, filenames: Iterable):
save_file = "{}.zip".format(os.path.join(self.tmpdir, str(uuid.uuid4())))
with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for file in filenames:
zf.write(filename=file, arcname=os.path.basename(file))
return save_file
class LineWriter(BaseWriter):
extension = "txt"
def write(self, records: Iterator[Record]) -> str:
files = {}
for record in records:
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
if filename not in files:
f = open(filename, mode="a", encoding="utf-8")
files[filename] = f
f = files[filename]
line = self.create_line(record)
f.write(f"{line}\n")
for f in files.values():
f.close()
save_file = self.write_zip(files)
def zip_files(files):
save_file = f"{uuid.uuid4()}.zip"
with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for file in files:
os.remove(file)
return save_file
@abc.abstractmethod
def create_line(self, record) -> str:
raise NotImplementedError()
class CsvWriter(BaseWriter):
extension = "csv"
def write(self, records: Iterator[Record]) -> str:
writers = {}
file_handlers = set()
record_list = list(records)
header = self.create_header(record_list)
for record in record_list:
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
if filename not in writers:
f = open(filename, mode="a", encoding="utf-8")
writer = csv.DictWriter(f, header)
writer.writeheader()
writers[filename] = writer
file_handlers.add(f)
writer = writers[filename]
line = self.create_line(record)
writer.writerow(line)
for f in file_handlers:
f.close()
save_file = self.write_zip(writers)
for file in writers:
os.remove(file)
return save_file
def create_line(self, record) -> Dict:
return {"id": record.id, "data": record.data, "label": "#".join(sorted(record.label)), **record.metadata}
zf.write(filename=file, arcname=os.path.basename(file))
return save_file
def create_header(self, records: List[Record]) -> List[str]:
header = ["id", "data", "label"]
header += sorted(set(itertools.chain(*[r.metadata.keys() for r in records])))
return header
class JSONWriter(BaseWriter):
extension = "json"
def write(self, records: Iterator[Record]) -> str:
writers = {}
contents = defaultdict(list)
for record in records:
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
if filename not in writers:
f = open(filename, mode="a", encoding="utf-8")
writers[filename] = f
line = self.create_line(record)
contents[filename].append(line)
for filename, f in writers.items():
content = contents[filename]
json.dump(content, f, ensure_ascii=False)
f.close()
save_file = self.write_zip(writers)
for file in writers:
os.remove(file)
return save_file
def create_line(self, record) -> Dict:
return {"id": record.id, "data": record.data, "label": record.label, **record.metadata}
class JSONLWriter(LineWriter):
extension = "jsonl"
def create_line(self, record):
return json.dumps(
{"id": record.id, "data": record.data, "label": record.label, **record.metadata}, ensure_ascii=False
)
class FastTextWriter(LineWriter):
extension = "txt"
def create_line(self, record):
line = [f"__label__{label}" for label in record.label]
line.sort()
line.append(record.data)
line = " ".join(line)
return line
class Writer(abc.ABC):
@staticmethod
@abc.abstractmethod
def write(file, dataset: pd.DataFrame):
raise NotImplementedError("Please implement this method in the subclass.")
class IntentAndSlotWriter(LineWriter):
extension = "jsonl"
class CsvWriter(Writer):
@staticmethod
def write(file, dataset: pd.DataFrame):
dataset.to_csv(file, index=False, encoding="utf-8")
def create_line(self, record):
if isinstance(record.label, dict):
return json.dumps(
{
"id": record.id,
"text": record.data,
"cats": record.label.get("cats", []),
"entities": record.label.get("entities", []),
**record.metadata,
},
ensure_ascii=False,
)
else:
return json.dumps(
{"id": record.id, "text": record.data, "cats": [], "entities": [], **record.metadata},
ensure_ascii=False,
)
class JsonWriter(Writer):
@staticmethod
def write(file, dataset: pd.DataFrame):
dataset.to_json(file, orient="records", force_ascii=False)
class EntityAndRelationWriter(LineWriter):
extension = "jsonl"
def create_line(self, record):
if isinstance(record.label, dict):
return json.dumps(
{
"id": record.id,
"text": record.data,
"relations": record.label.get("relations", []),
"entities": record.label.get("entities", []),
**record.metadata,
},
ensure_ascii=False,
)
else:
return json.dumps(
{"id": record.id, "text": record.data, "relations": [], "entities": [], **record.metadata},
ensure_ascii=False,
)
class JsonlWriter(Writer):
@staticmethod
def write(file, dataset: pd.DataFrame):
dataset.to_json(file, orient="records", force_ascii=False, lines=True)

102
backend/data_export/tests/test_writer.py

@ -1,80 +1,46 @@
import json
import os
import unittest
from unittest.mock import call, patch
from ..pipeline.data import Record
from ..pipeline.writers import CsvWriter, IntentAndSlotWriter
import pandas as pd
from pandas.testing import assert_frame_equal
from ..pipeline.writers import CsvWriter, JsonlWriter, JsonWriter
class TestCSVWriter(unittest.TestCase):
def setUp(self):
self.records = [
Record(data_id=0, data="exampleA", label=["labelA"], user="admin", metadata={"hidden": "secretA"}),
Record(data_id=1, data="exampleB", label=["labelB"], user="admin", metadata={"hidden": "secretB"}),
Record(data_id=2, data="exampleC", label=["labelC"], user="admin", metadata={"meta": "secretC"}),
]
def test_create_header(self):
writer = CsvWriter(".")
header = writer.create_header(self.records)
expected = ["id", "data", "label", "hidden", "meta"]
self.assertEqual(header, expected)
def test_create_line(self):
writer = CsvWriter(".")
record = self.records[0]
line = writer.create_line(record)
expected = {"id": record.id, "data": record.data, "label": record.label[0], "hidden": "secretA"}
self.assertEqual(line, expected)
class TestWriter(unittest.TestCase):
def setUp(self):
self.dataset = pd.DataFrame(
[
{"id": 0, "text": "A"},
{"id": 1, "text": "B"},
{"id": 2, "text": "C"},
]
)
self.file = "tmp.csv"
def test_label_order(self):
writer = CsvWriter(".")
record1 = Record(data_id=0, data="", label=["labelA", "labelB"], user="", metadata={})
record2 = Record(data_id=0, data="", label=["labelB", "labelA"], user="", metadata={})
line1 = writer.create_line(record1)
line2 = writer.create_line(record2)
expected = "labelA#labelB"
self.assertEqual(line1["label"], expected)
self.assertEqual(line2["label"], expected)
def tearDown(self):
os.remove(self.file)
@patch("os.remove")
@patch("zipfile.ZipFile")
@patch("csv.DictWriter.writerow")
@patch("builtins.open")
def test_dump(self, mock_open_file, csv_io, zip_io, mock_remove_file):
writer = CsvWriter(".")
writer.write(self.records)
self.assertEqual(mock_open_file.call_count, 1)
mock_open_file.assert_called_with("./admin.csv", mode="a", encoding="utf-8")
class TestCSVWriter(TestWriter):
def test_write(self):
writer = CsvWriter()
writer.write(self.file, self.dataset)
loaded_dataset = pd.read_csv(self.file)
assert_frame_equal(self.dataset, loaded_dataset)
self.assertEqual(csv_io.call_count, len(self.records) + 1) # +1 is for a header
calls = [
call({"id": "id", "data": "data", "label": "label", "hidden": "hidden", "meta": "meta"}),
call({"id": 0, "data": "exampleA", "label": "labelA", "hidden": "secretA"}),
call({"id": 1, "data": "exampleB", "label": "labelB", "hidden": "secretB"}),
call({"id": 2, "data": "exampleC", "label": "labelC", "meta": "secretC"}),
]
csv_io.assert_has_calls(calls)
class TestJsonWriter(TestWriter):
def test_write(self):
writer = JsonWriter()
writer.write(self.file, self.dataset)
loaded_dataset = pd.read_json(self.file)
assert_frame_equal(self.dataset, loaded_dataset)
class TestIntentWriter(unittest.TestCase):
def setUp(self):
self.record = Record(
data_id=0,
data="exampleA",
label={"cats": ["positive"], "entities": [(0, 1, "LOC")]},
user="admin",
metadata={},
)
def test_create_line(self):
writer = IntentAndSlotWriter(".")
actual = writer.create_line(self.record)
expected = {
"id": self.record.id,
"text": self.record.data,
"cats": ["positive"],
"entities": [[0, 1, "LOC"]],
}
self.assertEqual(json.loads(actual), expected)
class TestJsonlWriter(TestWriter):
def test_write(self):
writer = JsonlWriter()
writer.write(self.file, self.dataset)
loaded_dataset = pd.read_json(self.file, lines=True)
assert_frame_equal(self.dataset, loaded_dataset)
Loading…
Cancel
Save