diff --git a/backend/data_export/pipeline/writers.py b/backend/data_export/pipeline/writers.py index 61441d1a..013eb0b6 100644 --- a/backend/data_export/pipeline/writers.py +++ b/backend/data_export/pipeline/writers.py @@ -1,180 +1,39 @@ import abc -import csv -import itertools -import json import os import uuid import zipfile -from collections import defaultdict -from typing import Dict, Iterable, Iterator, List -from .data import Record +import pandas as pd -class BaseWriter: - def __init__(self, tmpdir: str): - self.tmpdir = tmpdir - - @abc.abstractmethod - def write(self, records: Iterator[Record]) -> str: - raise NotImplementedError() - - def write_zip(self, filenames: Iterable): - save_file = "{}.zip".format(os.path.join(self.tmpdir, str(uuid.uuid4()))) - with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf: - for file in filenames: - zf.write(filename=file, arcname=os.path.basename(file)) - return save_file - - -class LineWriter(BaseWriter): - extension = "txt" - - def write(self, records: Iterator[Record]) -> str: - files = {} - for record in records: - filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}") - if filename not in files: - f = open(filename, mode="a", encoding="utf-8") - files[filename] = f - f = files[filename] - line = self.create_line(record) - f.write(f"{line}\n") - for f in files.values(): - f.close() - save_file = self.write_zip(files) +def zip_files(files): + save_file = f"{uuid.uuid4()}.zip" + with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf: for file in files: - os.remove(file) - return save_file - - @abc.abstractmethod - def create_line(self, record) -> str: - raise NotImplementedError() - - -class CsvWriter(BaseWriter): - extension = "csv" - - def write(self, records: Iterator[Record]) -> str: - writers = {} - file_handlers = set() - record_list = list(records) - header = self.create_header(record_list) - for record in record_list: - filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}") - if filename not in writers: - f = open(filename, mode="a", encoding="utf-8") - writer = csv.DictWriter(f, header) - writer.writeheader() - writers[filename] = writer - file_handlers.add(f) - writer = writers[filename] - line = self.create_line(record) - writer.writerow(line) - - for f in file_handlers: - f.close() - save_file = self.write_zip(writers) - for file in writers: - os.remove(file) - return save_file - - def create_line(self, record) -> Dict: - return {"id": record.id, "data": record.data, "label": "#".join(sorted(record.label)), **record.metadata} + zf.write(filename=file, arcname=os.path.basename(file)) + return save_file - def create_header(self, records: List[Record]) -> List[str]: - header = ["id", "data", "label"] - header += sorted(set(itertools.chain(*[r.metadata.keys() for r in records]))) - return header - -class JSONWriter(BaseWriter): - extension = "json" - - def write(self, records: Iterator[Record]) -> str: - writers = {} - contents = defaultdict(list) - for record in records: - filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}") - if filename not in writers: - f = open(filename, mode="a", encoding="utf-8") - writers[filename] = f - line = self.create_line(record) - contents[filename].append(line) - - for filename, f in writers.items(): - content = contents[filename] - json.dump(content, f, ensure_ascii=False) - f.close() - - save_file = self.write_zip(writers) - for file in writers: - os.remove(file) - return save_file - - def create_line(self, record) -> Dict: - return {"id": record.id, "data": record.data, "label": record.label, **record.metadata} - - -class JSONLWriter(LineWriter): - extension = "jsonl" - - def create_line(self, record): - return json.dumps( - {"id": record.id, "data": record.data, "label": record.label, **record.metadata}, ensure_ascii=False - ) - - -class FastTextWriter(LineWriter): - extension = "txt" - - def create_line(self, record): - line = [f"__label__{label}" for label in record.label] - line.sort() - line.append(record.data) - line = " ".join(line) - return line +class Writer(abc.ABC): + @staticmethod + @abc.abstractmethod + def write(file, dataset: pd.DataFrame): + raise NotImplementedError("Please implement this method in the subclass.") -class IntentAndSlotWriter(LineWriter): - extension = "jsonl" +class CsvWriter(Writer): + @staticmethod + def write(file, dataset: pd.DataFrame): + dataset.to_csv(file, index=False, encoding="utf-8") - def create_line(self, record): - if isinstance(record.label, dict): - return json.dumps( - { - "id": record.id, - "text": record.data, - "cats": record.label.get("cats", []), - "entities": record.label.get("entities", []), - **record.metadata, - }, - ensure_ascii=False, - ) - else: - return json.dumps( - {"id": record.id, "text": record.data, "cats": [], "entities": [], **record.metadata}, - ensure_ascii=False, - ) +class JsonWriter(Writer): + @staticmethod + def write(file, dataset: pd.DataFrame): + dataset.to_json(file, orient="records", force_ascii=False) -class EntityAndRelationWriter(LineWriter): - extension = "jsonl" - def create_line(self, record): - if isinstance(record.label, dict): - return json.dumps( - { - "id": record.id, - "text": record.data, - "relations": record.label.get("relations", []), - "entities": record.label.get("entities", []), - **record.metadata, - }, - ensure_ascii=False, - ) - else: - return json.dumps( - {"id": record.id, "text": record.data, "relations": [], "entities": [], **record.metadata}, - ensure_ascii=False, - ) +class JsonlWriter(Writer): + @staticmethod + def write(file, dataset: pd.DataFrame): + dataset.to_json(file, orient="records", force_ascii=False, lines=True) diff --git a/backend/data_export/tests/test_writer.py b/backend/data_export/tests/test_writer.py index 8997fbe4..2d1b2195 100644 --- a/backend/data_export/tests/test_writer.py +++ b/backend/data_export/tests/test_writer.py @@ -1,80 +1,46 @@ -import json +import os import unittest -from unittest.mock import call, patch -from ..pipeline.data import Record -from ..pipeline.writers import CsvWriter, IntentAndSlotWriter +import pandas as pd +from pandas.testing import assert_frame_equal +from ..pipeline.writers import CsvWriter, JsonlWriter, JsonWriter -class TestCSVWriter(unittest.TestCase): - def setUp(self): - self.records = [ - Record(data_id=0, data="exampleA", label=["labelA"], user="admin", metadata={"hidden": "secretA"}), - Record(data_id=1, data="exampleB", label=["labelB"], user="admin", metadata={"hidden": "secretB"}), - Record(data_id=2, data="exampleC", label=["labelC"], user="admin", metadata={"meta": "secretC"}), - ] - - def test_create_header(self): - writer = CsvWriter(".") - header = writer.create_header(self.records) - expected = ["id", "data", "label", "hidden", "meta"] - self.assertEqual(header, expected) - def test_create_line(self): - writer = CsvWriter(".") - record = self.records[0] - line = writer.create_line(record) - expected = {"id": record.id, "data": record.data, "label": record.label[0], "hidden": "secretA"} - self.assertEqual(line, expected) +class TestWriter(unittest.TestCase): + def setUp(self): + self.dataset = pd.DataFrame( + [ + {"id": 0, "text": "A"}, + {"id": 1, "text": "B"}, + {"id": 2, "text": "C"}, + ] + ) + self.file = "tmp.csv" - def test_label_order(self): - writer = CsvWriter(".") - record1 = Record(data_id=0, data="", label=["labelA", "labelB"], user="", metadata={}) - record2 = Record(data_id=0, data="", label=["labelB", "labelA"], user="", metadata={}) - line1 = writer.create_line(record1) - line2 = writer.create_line(record2) - expected = "labelA#labelB" - self.assertEqual(line1["label"], expected) - self.assertEqual(line2["label"], expected) + def tearDown(self): + os.remove(self.file) - @patch("os.remove") - @patch("zipfile.ZipFile") - @patch("csv.DictWriter.writerow") - @patch("builtins.open") - def test_dump(self, mock_open_file, csv_io, zip_io, mock_remove_file): - writer = CsvWriter(".") - writer.write(self.records) - self.assertEqual(mock_open_file.call_count, 1) - mock_open_file.assert_called_with("./admin.csv", mode="a", encoding="utf-8") +class TestCSVWriter(TestWriter): + def test_write(self): + writer = CsvWriter() + writer.write(self.file, self.dataset) + loaded_dataset = pd.read_csv(self.file) + assert_frame_equal(self.dataset, loaded_dataset) - self.assertEqual(csv_io.call_count, len(self.records) + 1) # +1 is for a header - calls = [ - call({"id": "id", "data": "data", "label": "label", "hidden": "hidden", "meta": "meta"}), - call({"id": 0, "data": "exampleA", "label": "labelA", "hidden": "secretA"}), - call({"id": 1, "data": "exampleB", "label": "labelB", "hidden": "secretB"}), - call({"id": 2, "data": "exampleC", "label": "labelC", "meta": "secretC"}), - ] - csv_io.assert_has_calls(calls) +class TestJsonWriter(TestWriter): + def test_write(self): + writer = JsonWriter() + writer.write(self.file, self.dataset) + loaded_dataset = pd.read_json(self.file) + assert_frame_equal(self.dataset, loaded_dataset) -class TestIntentWriter(unittest.TestCase): - def setUp(self): - self.record = Record( - data_id=0, - data="exampleA", - label={"cats": ["positive"], "entities": [(0, 1, "LOC")]}, - user="admin", - metadata={}, - ) - def test_create_line(self): - writer = IntentAndSlotWriter(".") - actual = writer.create_line(self.record) - expected = { - "id": self.record.id, - "text": self.record.data, - "cats": ["positive"], - "entities": [[0, 1, "LOC"]], - } - self.assertEqual(json.loads(actual), expected) +class TestJsonlWriter(TestWriter): + def test_write(self): + writer = JsonlWriter() + writer.write(self.file, self.dataset) + loaded_dataset = pd.read_json(self.file, lines=True) + assert_frame_equal(self.dataset, loaded_dataset)