You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 lines
5.0 KiB

3 years ago
  1. import abc
  2. import csv
  3. import itertools
  4. import json
  5. import os
  6. import uuid
  7. import zipfile
  8. from collections import defaultdict
  9. from typing import Dict, Iterable, Iterator, List
  10. from .data import Record
  11. class BaseWriter:
  12. def __init__(self, tmpdir: str):
  13. self.tmpdir = tmpdir
  14. @abc.abstractmethod
  15. def write(self, records: Iterator[Record]) -> str:
  16. raise NotImplementedError()
  17. def write_zip(self, filenames: Iterable):
  18. save_file = '{}.zip'.format(os.path.join(self.tmpdir, str(uuid.uuid4())))
  19. with zipfile.ZipFile(save_file, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
  20. for file in filenames:
  21. zf.write(filename=file, arcname=os.path.basename(file))
  22. return save_file
  23. class LineWriter(BaseWriter):
  24. extension = 'txt'
  25. def write(self, records: Iterator[Record]) -> str:
  26. files = {}
  27. for record in records:
  28. filename = os.path.join(self.tmpdir, f'{record.user}.{self.extension}')
  29. if filename not in files:
  30. f = open(filename, mode='a')
  31. files[filename] = f
  32. f = files[filename]
  33. line = self.create_line(record)
  34. f.write(f'{line}\n')
  35. for f in files.values():
  36. f.close()
  37. save_file = self.write_zip(files)
  38. for file in files:
  39. os.remove(file)
  40. return save_file
  41. @abc.abstractmethod
  42. def create_line(self, record) -> str:
  43. raise NotImplementedError()
  44. class CsvWriter(BaseWriter):
  45. extension = 'csv'
  46. def write(self, records: Iterator[Record]) -> str:
  47. writers = {}
  48. file_handlers = set()
  49. records = list(records)
  50. header = self.create_header(records)
  51. for record in records:
  52. filename = os.path.join(self.tmpdir, f'{record.user}.{self.extension}')
  53. if filename not in writers:
  54. f = open(filename, mode='a', encoding='utf-8')
  55. writer = csv.DictWriter(f, header)
  56. writer.writeheader()
  57. writers[filename] = writer
  58. file_handlers.add(f)
  59. writer = writers[filename]
  60. line = self.create_line(record)
  61. writer.writerow(line)
  62. for f in file_handlers:
  63. f.close()
  64. save_file = self.write_zip(writers)
  65. for file in writers:
  66. os.remove(file)
  67. return save_file
  68. def create_line(self, record) -> Dict:
  69. return {
  70. 'id': record.id,
  71. 'data': record.data,
  72. 'label': '#'.join(sorted(record.label)),
  73. **record.metadata
  74. }
  75. def create_header(self, records: List[Record]) -> Iterable[str]:
  76. header = ['id', 'data', 'label']
  77. header += sorted(set(itertools.chain(*[r.metadata.keys() for r in records])))
  78. return header
  79. class JSONWriter(BaseWriter):
  80. extension = 'json'
  81. def write(self, records: Iterator[Record]) -> str:
  82. writers = {}
  83. contents = defaultdict(list)
  84. for record in records:
  85. filename = os.path.join(self.tmpdir, f'{record.user}.{self.extension}')
  86. if filename not in writers:
  87. f = open(filename, mode='a', encoding='utf-8')
  88. writers[filename] = f
  89. line = self.create_line(record)
  90. contents[filename].append(line)
  91. for filename, f in writers.items():
  92. content = contents[filename]
  93. json.dump(content, f, ensure_ascii=False)
  94. f.close()
  95. save_file = self.write_zip(writers)
  96. for file in writers:
  97. os.remove(file)
  98. return save_file
  99. def create_line(self, record) -> Dict:
  100. return {
  101. 'id': record.id,
  102. 'data': record.data,
  103. 'label': record.label,
  104. **record.metadata
  105. }
  106. class JSONLWriter(LineWriter):
  107. extension = 'jsonl'
  108. def create_line(self, record):
  109. return json.dumps({
  110. 'id': record.id,
  111. 'data': record.data,
  112. 'label': record.label,
  113. **record.metadata
  114. }, ensure_ascii=False)
  115. class FastTextWriter(LineWriter):
  116. extension = 'txt'
  117. def create_line(self, record):
  118. line = [f'__label__{label}' for label in record.label]
  119. line.sort()
  120. line.append(record.data)
  121. line = ' '.join(line)
  122. return line
  123. class IntentAndSlotWriter(LineWriter):
  124. extension = 'jsonl'
  125. def create_line(self, record):
  126. if isinstance(record.label, dict):
  127. return json.dumps({
  128. 'id': record.id,
  129. 'text': record.data,
  130. 'cats': record.label.get('cats', []),
  131. 'entities': record.label.get('entities', []),
  132. **record.metadata
  133. }, ensure_ascii=False)
  134. else:
  135. return json.dumps({
  136. 'id': record.id,
  137. 'text': record.data,
  138. 'cats': [],
  139. 'entities': [],
  140. **record.metadata
  141. }, ensure_ascii=False)