diff --git a/app/api/views/download/writer.py b/app/api/views/download/writer.py new file mode 100644 index 00000000..ef061ffb --- /dev/null +++ b/app/api/views/download/writer.py @@ -0,0 +1,113 @@ +import abc +import csv +import itertools +import json +import os +import uuid +import zipfile +from typing import Dict, Iterable, Iterator, List + +from .data import Record + + +class BaseWriter: + + def __init__(self, tmpdir: str): + self.tmpdir = tmpdir + + @abc.abstractmethod + def write(self, records: Iterator[Record]) -> str: + raise NotImplementedError() + + def write_zip(self, filenames: Iterable): + save_file = '{}.zip'.format(os.path.join(self.tmpdir, str(uuid.uuid4()))) + with zipfile.ZipFile(save_file, 'w', compression=zipfile.ZIP_DEFLATED) as zf: + for file in filenames: + zf.write(file) + return save_file + + +class LineWriter(BaseWriter): + extension = 'txt' + + def write(self, records: Iterator[Record]) -> str: + files = {} + for record in records: + filename = os.path.join(self.tmpdir, f'{record.user}.{self.extension}') + if filename not in files: + f = open(filename, mode='a') + files[filename] = f + f = files[filename] + line = self.create_line(record) + f.write(f'{line}\n') + for f in files.values(): + f.close() + save_file = self.write_zip(files) + for file in files: + os.remove(file) + return save_file + + @abc.abstractmethod + def create_line(self, record) -> str: + raise NotImplementedError() + + +class CsvWriter(BaseWriter): + extension = 'csv' + + def write(self, records: Iterator[Record]) -> str: + writers = {} + file_handlers = set() + records = list(records) + header = self.create_header(records) + for record in records: + filename = os.path.join(self.tmpdir, f'{record.user}.{self.extension}') + if filename not in writers: + f = open(filename, mode='a') + writer = csv.DictWriter(f, header) + writer.writeheader() + writers[filename] = writer + file_handlers.add(f) + writer = writers[filename] + line = self.create_line(record) + writer.writerow(line) + save_file = self.write_zip(writers) + for file in writers: + os.remove(file) + for f in file_handlers: + f.close() + return save_file + + def create_line(self, record) -> Dict: + return { + 'data': record.data, + 'label': '#'.join(record.label), + **record.metadata + } + + def create_header(self, records: List[Record]) -> Iterable[str]: + header = ['data', 'label'] + header += list(itertools.chain(*[r.metadata.keys() for r in records])) + return header + + +class JSONLWriter(LineWriter): + extension = 'jsonl' + + def create_line(self, record): + return json.dumps({ + 'id': record.id, + 'data': record.data, + 'label': record.label, + **record.metadata + }) + + +class FastTextWriter(LineWriter): + extension = 'txt' + + def create_line(self, record): + line = [f'__label__{label}' for label in record.label] + line.append(record.data) + line = ' '.join(line) + return line