You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

180 lines
5.6 KiB

import abc
import csv
import itertools
import json
import os
import uuid
import zipfile
from collections import defaultdict
from typing import Dict, Iterable, Iterator, List
from .data import Record
class BaseWriter:
def __init__(self, tmpdir: str):
self.tmpdir = tmpdir
@abc.abstractmethod
def write(self, records: Iterator[Record]) -> str:
raise NotImplementedError()
def write_zip(self, filenames: Iterable):
save_file = "{}.zip".format(os.path.join(self.tmpdir, str(uuid.uuid4())))
with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for file in filenames:
zf.write(filename=file, arcname=os.path.basename(file))
return save_file
class LineWriter(BaseWriter):
extension = "txt"
def write(self, records: Iterator[Record]) -> str:
files = {}
for record in records:
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
if filename not in files:
f = open(filename, mode="a", encoding="utf-8")
files[filename] = f
f = files[filename]
line = self.create_line(record)
f.write(f"{line}\n")
for f in files.values():
f.close()
save_file = self.write_zip(files)
for file in files:
os.remove(file)
return save_file
@abc.abstractmethod
def create_line(self, record) -> str:
raise NotImplementedError()
class CsvWriter(BaseWriter):
extension = "csv"
def write(self, records: Iterator[Record]) -> str:
writers = {}
file_handlers = set()
record_list = list(records)
header = self.create_header(record_list)
for record in record_list:
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
if filename not in writers:
f = open(filename, mode="a", encoding="utf-8")
writer = csv.DictWriter(f, header)
writer.writeheader()
writers[filename] = writer
file_handlers.add(f)
writer = writers[filename]
line = self.create_line(record)
writer.writerow(line)
for f in file_handlers:
f.close()
save_file = self.write_zip(writers)
for file in writers:
os.remove(file)
return save_file
def create_line(self, record) -> Dict:
return {"id": record.id, "data": record.data, "label": "#".join(sorted(record.label)), **record.metadata}
def create_header(self, records: List[Record]) -> List[str]:
header = ["id", "data", "label"]
header += sorted(set(itertools.chain(*[r.metadata.keys() for r in records])))
return header
class JSONWriter(BaseWriter):
extension = "json"
def write(self, records: Iterator[Record]) -> str:
writers = {}
contents = defaultdict(list)
for record in records:
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
if filename not in writers:
f = open(filename, mode="a", encoding="utf-8")
writers[filename] = f
line = self.create_line(record)
contents[filename].append(line)
for filename, f in writers.items():
content = contents[filename]
json.dump(content, f, ensure_ascii=False)
f.close()
save_file = self.write_zip(writers)
for file in writers:
os.remove(file)
return save_file
def create_line(self, record) -> Dict:
return {"id": record.id, "data": record.data, "label": record.label, **record.metadata}
class JSONLWriter(LineWriter):
extension = "jsonl"
def create_line(self, record):
return json.dumps(
{"id": record.id, "data": record.data, "label": record.label, **record.metadata}, ensure_ascii=False
)
class FastTextWriter(LineWriter):
extension = "txt"
def create_line(self, record):
line = [f"__label__{label}" for label in record.label]
line.sort()
line.append(record.data)
line = " ".join(line)
return line
class IntentAndSlotWriter(LineWriter):
extension = "jsonl"
def create_line(self, record):
if isinstance(record.label, dict):
return json.dumps(
{
"id": record.id,
"text": record.data,
"cats": record.label.get("cats", []),
"entities": record.label.get("entities", []),
**record.metadata,
},
ensure_ascii=False,
)
else:
return json.dumps(
{"id": record.id, "text": record.data, "cats": [], "entities": [], **record.metadata},
ensure_ascii=False,
)
class EntityAndRelationWriter(LineWriter):
extension = "jsonl"
def create_line(self, record):
if isinstance(record.label, dict):
return json.dumps(
{
"id": record.id,
"text": record.data,
"relations": record.label.get("relations", []),
"entities": record.label.get("entities", []),
**record.metadata,
},
ensure_ascii=False,
)
else:
return json.dumps(
{"id": record.id, "text": record.data, "relations": [], "entities": [], **record.metadata},
ensure_ascii=False,
)