mirror of https://github.com/doccano/doccano.git
2 changed files with 57 additions and 232 deletions
Split View
Diff Options
@ -1,180 +1,39 @@ |
|||
import abc |
|||
import csv |
|||
import itertools |
|||
import json |
|||
import os |
|||
import uuid |
|||
import zipfile |
|||
from collections import defaultdict |
|||
from typing import Dict, Iterable, Iterator, List |
|||
|
|||
from .data import Record |
|||
import pandas as pd |
|||
|
|||
|
|||
class BaseWriter: |
|||
def __init__(self, tmpdir: str): |
|||
self.tmpdir = tmpdir |
|||
|
|||
@abc.abstractmethod |
|||
def write(self, records: Iterator[Record]) -> str: |
|||
raise NotImplementedError() |
|||
|
|||
def write_zip(self, filenames: Iterable): |
|||
save_file = "{}.zip".format(os.path.join(self.tmpdir, str(uuid.uuid4()))) |
|||
with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf: |
|||
for file in filenames: |
|||
zf.write(filename=file, arcname=os.path.basename(file)) |
|||
return save_file |
|||
|
|||
|
|||
class LineWriter(BaseWriter): |
|||
extension = "txt" |
|||
|
|||
def write(self, records: Iterator[Record]) -> str: |
|||
files = {} |
|||
for record in records: |
|||
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}") |
|||
if filename not in files: |
|||
f = open(filename, mode="a", encoding="utf-8") |
|||
files[filename] = f |
|||
f = files[filename] |
|||
line = self.create_line(record) |
|||
f.write(f"{line}\n") |
|||
for f in files.values(): |
|||
f.close() |
|||
save_file = self.write_zip(files) |
|||
def zip_files(files): |
|||
save_file = f"{uuid.uuid4()}.zip" |
|||
with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf: |
|||
for file in files: |
|||
os.remove(file) |
|||
return save_file |
|||
|
|||
@abc.abstractmethod |
|||
def create_line(self, record) -> str: |
|||
raise NotImplementedError() |
|||
|
|||
|
|||
class CsvWriter(BaseWriter): |
|||
extension = "csv" |
|||
|
|||
def write(self, records: Iterator[Record]) -> str: |
|||
writers = {} |
|||
file_handlers = set() |
|||
record_list = list(records) |
|||
header = self.create_header(record_list) |
|||
for record in record_list: |
|||
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}") |
|||
if filename not in writers: |
|||
f = open(filename, mode="a", encoding="utf-8") |
|||
writer = csv.DictWriter(f, header) |
|||
writer.writeheader() |
|||
writers[filename] = writer |
|||
file_handlers.add(f) |
|||
writer = writers[filename] |
|||
line = self.create_line(record) |
|||
writer.writerow(line) |
|||
|
|||
for f in file_handlers: |
|||
f.close() |
|||
save_file = self.write_zip(writers) |
|||
for file in writers: |
|||
os.remove(file) |
|||
return save_file |
|||
|
|||
def create_line(self, record) -> Dict: |
|||
return {"id": record.id, "data": record.data, "label": "#".join(sorted(record.label)), **record.metadata} |
|||
zf.write(filename=file, arcname=os.path.basename(file)) |
|||
return save_file |
|||
|
|||
def create_header(self, records: List[Record]) -> List[str]: |
|||
header = ["id", "data", "label"] |
|||
header += sorted(set(itertools.chain(*[r.metadata.keys() for r in records]))) |
|||
return header |
|||
|
|||
|
|||
class JSONWriter(BaseWriter): |
|||
extension = "json" |
|||
|
|||
def write(self, records: Iterator[Record]) -> str: |
|||
writers = {} |
|||
contents = defaultdict(list) |
|||
for record in records: |
|||
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}") |
|||
if filename not in writers: |
|||
f = open(filename, mode="a", encoding="utf-8") |
|||
writers[filename] = f |
|||
line = self.create_line(record) |
|||
contents[filename].append(line) |
|||
|
|||
for filename, f in writers.items(): |
|||
content = contents[filename] |
|||
json.dump(content, f, ensure_ascii=False) |
|||
f.close() |
|||
|
|||
save_file = self.write_zip(writers) |
|||
for file in writers: |
|||
os.remove(file) |
|||
return save_file |
|||
|
|||
def create_line(self, record) -> Dict: |
|||
return {"id": record.id, "data": record.data, "label": record.label, **record.metadata} |
|||
|
|||
|
|||
class JSONLWriter(LineWriter): |
|||
extension = "jsonl" |
|||
|
|||
def create_line(self, record): |
|||
return json.dumps( |
|||
{"id": record.id, "data": record.data, "label": record.label, **record.metadata}, ensure_ascii=False |
|||
) |
|||
|
|||
|
|||
class FastTextWriter(LineWriter): |
|||
extension = "txt" |
|||
|
|||
def create_line(self, record): |
|||
line = [f"__label__{label}" for label in record.label] |
|||
line.sort() |
|||
line.append(record.data) |
|||
line = " ".join(line) |
|||
return line |
|||
class Writer(abc.ABC): |
|||
@staticmethod |
|||
@abc.abstractmethod |
|||
def write(file, dataset: pd.DataFrame): |
|||
raise NotImplementedError("Please implement this method in the subclass.") |
|||
|
|||
|
|||
class IntentAndSlotWriter(LineWriter): |
|||
extension = "jsonl" |
|||
class CsvWriter(Writer): |
|||
@staticmethod |
|||
def write(file, dataset: pd.DataFrame): |
|||
dataset.to_csv(file, index=False, encoding="utf-8") |
|||
|
|||
def create_line(self, record): |
|||
if isinstance(record.label, dict): |
|||
return json.dumps( |
|||
{ |
|||
"id": record.id, |
|||
"text": record.data, |
|||
"cats": record.label.get("cats", []), |
|||
"entities": record.label.get("entities", []), |
|||
**record.metadata, |
|||
}, |
|||
ensure_ascii=False, |
|||
) |
|||
else: |
|||
return json.dumps( |
|||
{"id": record.id, "text": record.data, "cats": [], "entities": [], **record.metadata}, |
|||
ensure_ascii=False, |
|||
) |
|||
|
|||
class JsonWriter(Writer): |
|||
@staticmethod |
|||
def write(file, dataset: pd.DataFrame): |
|||
dataset.to_json(file, orient="records", force_ascii=False) |
|||
|
|||
class EntityAndRelationWriter(LineWriter): |
|||
extension = "jsonl" |
|||
|
|||
def create_line(self, record): |
|||
if isinstance(record.label, dict): |
|||
return json.dumps( |
|||
{ |
|||
"id": record.id, |
|||
"text": record.data, |
|||
"relations": record.label.get("relations", []), |
|||
"entities": record.label.get("entities", []), |
|||
**record.metadata, |
|||
}, |
|||
ensure_ascii=False, |
|||
) |
|||
else: |
|||
return json.dumps( |
|||
{"id": record.id, "text": record.data, "relations": [], "entities": [], **record.metadata}, |
|||
ensure_ascii=False, |
|||
) |
|||
class JsonlWriter(Writer): |
|||
@staticmethod |
|||
def write(file, dataset: pd.DataFrame): |
|||
dataset.to_json(file, orient="records", force_ascii=False, lines=True) |
@ -1,80 +1,46 @@ |
|||
import json |
|||
import os |
|||
import unittest |
|||
from unittest.mock import call, patch |
|||
|
|||
from ..pipeline.data import Record |
|||
from ..pipeline.writers import CsvWriter, IntentAndSlotWriter |
|||
import pandas as pd |
|||
from pandas.testing import assert_frame_equal |
|||
|
|||
from ..pipeline.writers import CsvWriter, JsonlWriter, JsonWriter |
|||
|
|||
class TestCSVWriter(unittest.TestCase): |
|||
def setUp(self): |
|||
self.records = [ |
|||
Record(data_id=0, data="exampleA", label=["labelA"], user="admin", metadata={"hidden": "secretA"}), |
|||
Record(data_id=1, data="exampleB", label=["labelB"], user="admin", metadata={"hidden": "secretB"}), |
|||
Record(data_id=2, data="exampleC", label=["labelC"], user="admin", metadata={"meta": "secretC"}), |
|||
] |
|||
|
|||
def test_create_header(self): |
|||
writer = CsvWriter(".") |
|||
header = writer.create_header(self.records) |
|||
expected = ["id", "data", "label", "hidden", "meta"] |
|||
self.assertEqual(header, expected) |
|||
|
|||
def test_create_line(self): |
|||
writer = CsvWriter(".") |
|||
record = self.records[0] |
|||
line = writer.create_line(record) |
|||
expected = {"id": record.id, "data": record.data, "label": record.label[0], "hidden": "secretA"} |
|||
self.assertEqual(line, expected) |
|||
class TestWriter(unittest.TestCase): |
|||
def setUp(self): |
|||
self.dataset = pd.DataFrame( |
|||
[ |
|||
{"id": 0, "text": "A"}, |
|||
{"id": 1, "text": "B"}, |
|||
{"id": 2, "text": "C"}, |
|||
] |
|||
) |
|||
self.file = "tmp.csv" |
|||
|
|||
def test_label_order(self): |
|||
writer = CsvWriter(".") |
|||
record1 = Record(data_id=0, data="", label=["labelA", "labelB"], user="", metadata={}) |
|||
record2 = Record(data_id=0, data="", label=["labelB", "labelA"], user="", metadata={}) |
|||
line1 = writer.create_line(record1) |
|||
line2 = writer.create_line(record2) |
|||
expected = "labelA#labelB" |
|||
self.assertEqual(line1["label"], expected) |
|||
self.assertEqual(line2["label"], expected) |
|||
def tearDown(self): |
|||
os.remove(self.file) |
|||
|
|||
@patch("os.remove") |
|||
@patch("zipfile.ZipFile") |
|||
@patch("csv.DictWriter.writerow") |
|||
@patch("builtins.open") |
|||
def test_dump(self, mock_open_file, csv_io, zip_io, mock_remove_file): |
|||
writer = CsvWriter(".") |
|||
writer.write(self.records) |
|||
|
|||
self.assertEqual(mock_open_file.call_count, 1) |
|||
mock_open_file.assert_called_with("./admin.csv", mode="a", encoding="utf-8") |
|||
class TestCSVWriter(TestWriter): |
|||
def test_write(self): |
|||
writer = CsvWriter() |
|||
writer.write(self.file, self.dataset) |
|||
loaded_dataset = pd.read_csv(self.file) |
|||
assert_frame_equal(self.dataset, loaded_dataset) |
|||
|
|||
self.assertEqual(csv_io.call_count, len(self.records) + 1) # +1 is for a header |
|||
calls = [ |
|||
call({"id": "id", "data": "data", "label": "label", "hidden": "hidden", "meta": "meta"}), |
|||
call({"id": 0, "data": "exampleA", "label": "labelA", "hidden": "secretA"}), |
|||
call({"id": 1, "data": "exampleB", "label": "labelB", "hidden": "secretB"}), |
|||
call({"id": 2, "data": "exampleC", "label": "labelC", "meta": "secretC"}), |
|||
] |
|||
csv_io.assert_has_calls(calls) |
|||
|
|||
class TestJsonWriter(TestWriter): |
|||
def test_write(self): |
|||
writer = JsonWriter() |
|||
writer.write(self.file, self.dataset) |
|||
loaded_dataset = pd.read_json(self.file) |
|||
assert_frame_equal(self.dataset, loaded_dataset) |
|||
|
|||
class TestIntentWriter(unittest.TestCase): |
|||
def setUp(self): |
|||
self.record = Record( |
|||
data_id=0, |
|||
data="exampleA", |
|||
label={"cats": ["positive"], "entities": [(0, 1, "LOC")]}, |
|||
user="admin", |
|||
metadata={}, |
|||
) |
|||
|
|||
def test_create_line(self): |
|||
writer = IntentAndSlotWriter(".") |
|||
actual = writer.create_line(self.record) |
|||
expected = { |
|||
"id": self.record.id, |
|||
"text": self.record.data, |
|||
"cats": ["positive"], |
|||
"entities": [[0, 1, "LOC"]], |
|||
} |
|||
self.assertEqual(json.loads(actual), expected) |
|||
class TestJsonlWriter(TestWriter): |
|||
def test_write(self): |
|||
writer = JsonlWriter() |
|||
writer.write(self.file, self.dataset) |
|||
loaded_dataset = pd.read_json(self.file, lines=True) |
|||
assert_frame_equal(self.dataset, loaded_dataset) |
Write
Preview
Loading…
Cancel
Save