You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

180 lines
5.5 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import abc
  2. import csv
  3. import itertools
  4. import json
  5. import os
  6. import uuid
  7. import zipfile
  8. from collections import defaultdict
  9. from typing import Dict, Iterable, Iterator, List
  10. from .data import Record
  11. class BaseWriter:
  12. def __init__(self, tmpdir: str):
  13. self.tmpdir = tmpdir
  14. @abc.abstractmethod
  15. def write(self, records: Iterator[Record]) -> str:
  16. raise NotImplementedError()
  17. def write_zip(self, filenames: Iterable):
  18. save_file = "{}.zip".format(os.path.join(self.tmpdir, str(uuid.uuid4())))
  19. with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf:
  20. for file in filenames:
  21. zf.write(filename=file, arcname=os.path.basename(file))
  22. return save_file
  23. class LineWriter(BaseWriter):
  24. extension = "txt"
  25. def write(self, records: Iterator[Record]) -> str:
  26. files = {}
  27. for record in records:
  28. filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
  29. if filename not in files:
  30. f = open(filename, mode="a")
  31. files[filename] = f
  32. f = files[filename]
  33. line = self.create_line(record)
  34. f.write(f"{line}\n")
  35. for f in files.values():
  36. f.close()
  37. save_file = self.write_zip(files)
  38. for file in files:
  39. os.remove(file)
  40. return save_file
  41. @abc.abstractmethod
  42. def create_line(self, record) -> str:
  43. raise NotImplementedError()
  44. class CsvWriter(BaseWriter):
  45. extension = "csv"
  46. def write(self, records: Iterator[Record]) -> str:
  47. writers = {}
  48. file_handlers = set()
  49. record_list = list(records)
  50. header = self.create_header(record_list)
  51. for record in record_list:
  52. filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
  53. if filename not in writers:
  54. f = open(filename, mode="a", encoding="utf-8")
  55. writer = csv.DictWriter(f, header)
  56. writer.writeheader()
  57. writers[filename] = writer
  58. file_handlers.add(f)
  59. writer = writers[filename]
  60. line = self.create_line(record)
  61. writer.writerow(line)
  62. for f in file_handlers:
  63. f.close()
  64. save_file = self.write_zip(writers)
  65. for file in writers:
  66. os.remove(file)
  67. return save_file
  68. def create_line(self, record) -> Dict:
  69. return {"id": record.id, "data": record.data, "label": "#".join(sorted(record.label)), **record.metadata}
  70. def create_header(self, records: List[Record]) -> List[str]:
  71. header = ["id", "data", "label"]
  72. header += sorted(set(itertools.chain(*[r.metadata.keys() for r in records])))
  73. return header
  74. class JSONWriter(BaseWriter):
  75. extension = "json"
  76. def write(self, records: Iterator[Record]) -> str:
  77. writers = {}
  78. contents = defaultdict(list)
  79. for record in records:
  80. filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
  81. if filename not in writers:
  82. f = open(filename, mode="a", encoding="utf-8")
  83. writers[filename] = f
  84. line = self.create_line(record)
  85. contents[filename].append(line)
  86. for filename, f in writers.items():
  87. content = contents[filename]
  88. json.dump(content, f, ensure_ascii=False)
  89. f.close()
  90. save_file = self.write_zip(writers)
  91. for file in writers:
  92. os.remove(file)
  93. return save_file
  94. def create_line(self, record) -> Dict:
  95. return {"id": record.id, "data": record.data, "label": record.label, **record.metadata}
  96. class JSONLWriter(LineWriter):
  97. extension = "jsonl"
  98. def create_line(self, record):
  99. return json.dumps(
  100. {"id": record.id, "data": record.data, "label": record.label, **record.metadata}, ensure_ascii=False
  101. )
  102. class FastTextWriter(LineWriter):
  103. extension = "txt"
  104. def create_line(self, record):
  105. line = [f"__label__{label}" for label in record.label]
  106. line.sort()
  107. line.append(record.data)
  108. line = " ".join(line)
  109. return line
  110. class IntentAndSlotWriter(LineWriter):
  111. extension = "jsonl"
  112. def create_line(self, record):
  113. if isinstance(record.label, dict):
  114. return json.dumps(
  115. {
  116. "id": record.id,
  117. "text": record.data,
  118. "cats": record.label.get("cats", []),
  119. "entities": record.label.get("entities", []),
  120. **record.metadata,
  121. },
  122. ensure_ascii=False,
  123. )
  124. else:
  125. return json.dumps(
  126. {"id": record.id, "text": record.data, "cats": [], "entities": [], **record.metadata},
  127. ensure_ascii=False,
  128. )
  129. class EntityAndRelationWriter(LineWriter):
  130. extension = "jsonl"
  131. def create_line(self, record):
  132. if isinstance(record.label, dict):
  133. return json.dumps(
  134. {
  135. "id": record.id,
  136. "text": record.data,
  137. "relations": record.label.get("relations", []),
  138. "entities": record.label.get("entities", []),
  139. **record.metadata,
  140. },
  141. ensure_ascii=False,
  142. )
  143. else:
  144. return json.dumps(
  145. {"id": record.id, "text": record.data, "relations": [], "entities": [], **record.metadata},
  146. ensure_ascii=False,
  147. )