Browse Source

Add writers for export dataset

pull/1310/head
Hironsan 3 years ago
parent
commit
d06365395c
1 changed files with 113 additions and 0 deletions
  1. 113
      app/api/views/download/writer.py

113
app/api/views/download/writer.py

@ -0,0 +1,113 @@
import abc
import csv
import itertools
import json
import os
import uuid
import zipfile
from typing import Dict, Iterable, Iterator, List
from .data import Record
class BaseWriter:
def __init__(self, tmpdir: str):
self.tmpdir = tmpdir
@abc.abstractmethod
def write(self, records: Iterator[Record]) -> str:
raise NotImplementedError()
def write_zip(self, filenames: Iterable):
save_file = '{}.zip'.format(os.path.join(self.tmpdir, str(uuid.uuid4())))
with zipfile.ZipFile(save_file, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
for file in filenames:
zf.write(file)
return save_file
class LineWriter(BaseWriter):
extension = 'txt'
def write(self, records: Iterator[Record]) -> str:
files = {}
for record in records:
filename = os.path.join(self.tmpdir, f'{record.user}.{self.extension}')
if filename not in files:
f = open(filename, mode='a')
files[filename] = f
f = files[filename]
line = self.create_line(record)
f.write(f'{line}\n')
for f in files.values():
f.close()
save_file = self.write_zip(files)
for file in files:
os.remove(file)
return save_file
@abc.abstractmethod
def create_line(self, record) -> str:
raise NotImplementedError()
class CsvWriter(BaseWriter):
extension = 'csv'
def write(self, records: Iterator[Record]) -> str:
writers = {}
file_handlers = set()
records = list(records)
header = self.create_header(records)
for record in records:
filename = os.path.join(self.tmpdir, f'{record.user}.{self.extension}')
if filename not in writers:
f = open(filename, mode='a')
writer = csv.DictWriter(f, header)
writer.writeheader()
writers[filename] = writer
file_handlers.add(f)
writer = writers[filename]
line = self.create_line(record)
writer.writerow(line)
save_file = self.write_zip(writers)
for file in writers:
os.remove(file)
for f in file_handlers:
f.close()
return save_file
def create_line(self, record) -> Dict:
return {
'data': record.data,
'label': '#'.join(record.label),
**record.metadata
}
def create_header(self, records: List[Record]) -> Iterable[str]:
header = ['data', 'label']
header += list(itertools.chain(*[r.metadata.keys() for r in records]))
return header
class JSONLWriter(LineWriter):
extension = 'jsonl'
def create_line(self, record):
return json.dumps({
'id': record.id,
'data': record.data,
'label': record.label,
**record.metadata
})
class FastTextWriter(LineWriter):
extension = 'txt'
def create_line(self, record):
line = [f'__label__{label}' for label in record.label]
line.append(record.data)
line = ' '.join(line)
return line
Loading…
Cancel
Save