diff --git a/app/api/views/download/factory.py b/app/api/views/download/factory.py index eb750be9..acf205d1 100644 --- a/app/api/views/download/factory.py +++ b/app/api/views/download/factory.py @@ -19,6 +19,7 @@ def create_repository(project) -> repositories.BaseRepository: def create_writer(format: str) -> Type[writer.BaseWriter]: mapping = { catalog.CSV.name: writer.CsvWriter, + catalog.JSON.name: writer.JSONWriter, catalog.JSONL.name: writer.JSONLWriter, catalog.FastText.name: writer.FastTextWriter, } diff --git a/app/api/views/download/writer.py b/app/api/views/download/writer.py index c774a221..bcd7f463 100644 --- a/app/api/views/download/writer.py +++ b/app/api/views/download/writer.py @@ -5,6 +5,7 @@ import json import os import uuid import zipfile +from collections import defaultdict from typing import Dict, Iterable, Iterator, List from .data import Record @@ -80,17 +81,51 @@ class CsvWriter(BaseWriter): def create_line(self, record) -> Dict: return { + 'id': record.id, 'data': record.data, 'label': '#'.join(record.label), **record.metadata } def create_header(self, records: List[Record]) -> Iterable[str]: - header = ['data', 'label'] + header = ['id', 'data', 'label'] header += list(itertools.chain(*[r.metadata.keys() for r in records])) return header +class JSONWriter(BaseWriter): + extension = 'json' + + def write(self, records: Iterator[Record]) -> str: + writers = {} + contents = defaultdict(list) + for record in records: + filename = os.path.join(self.tmpdir, f'{record.user}.{self.extension}') + if filename not in writers: + f = open(filename, mode='a') + writers[filename] = f + line = self.create_line(record) + contents[filename].append(line) + + for filename, f in writers.items(): + content = contents[filename] + json.dump(content, f) + f.close() + + save_file = self.write_zip(writers) + for file in writers: + os.remove(file) + return save_file + + def create_line(self, record) -> Dict: + return { + 'id': record.id, + 'data': record.data, + 'label': record.label, + **record.metadata + } + + class JSONLWriter(LineWriter): extension = 'jsonl'