import abc
import csv
import itertools
import json
import os
import uuid
import zipfile
from collections import defaultdict
from typing import Dict, Iterable, Iterator, List

from .data import Record


class BaseWriter:
    def __init__(self, tmpdir: str):
        self.tmpdir = tmpdir

    @abc.abstractmethod
    def write(self, records: Iterator[Record]) -> str:
        raise NotImplementedError()

    def write_zip(self, filenames: Iterable):
        save_file = "{}.zip".format(os.path.join(self.tmpdir, str(uuid.uuid4())))
        with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf:
            for file in filenames:
                zf.write(filename=file, arcname=os.path.basename(file))
        return save_file


class LineWriter(BaseWriter):
    extension = "txt"

    def write(self, records: Iterator[Record]) -> str:
        files = {}
        for record in records:
            filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
            if filename not in files:
                f = open(filename, mode="a")
                files[filename] = f
            f = files[filename]
            line = self.create_line(record)
            f.write(f"{line}\n")
        for f in files.values():
            f.close()
        save_file = self.write_zip(files)
        for file in files:
            os.remove(file)
        return save_file

    @abc.abstractmethod
    def create_line(self, record) -> str:
        raise NotImplementedError()


class CsvWriter(BaseWriter):
    extension = "csv"

    def write(self, records: Iterator[Record]) -> str:
        writers = {}
        file_handlers = set()
        record_list = list(records)
        header = self.create_header(record_list)
        for record in record_list:
            filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
            if filename not in writers:
                f = open(filename, mode="a", encoding="utf-8")
                writer = csv.DictWriter(f, header)
                writer.writeheader()
                writers[filename] = writer
                file_handlers.add(f)
            writer = writers[filename]
            line = self.create_line(record)
            writer.writerow(line)

        for f in file_handlers:
            f.close()
        save_file = self.write_zip(writers)
        for file in writers:
            os.remove(file)
        return save_file

    def create_line(self, record) -> Dict:
        return {"id": record.id, "data": record.data, "label": "#".join(sorted(record.label)), **record.metadata}

    def create_header(self, records: List[Record]) -> List[str]:
        header = ["id", "data", "label"]
        header += sorted(set(itertools.chain(*[r.metadata.keys() for r in records])))
        return header


class JSONWriter(BaseWriter):
    extension = "json"

    def write(self, records: Iterator[Record]) -> str:
        writers = {}
        contents = defaultdict(list)
        for record in records:
            filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
            if filename not in writers:
                f = open(filename, mode="a", encoding="utf-8")
                writers[filename] = f
            line = self.create_line(record)
            contents[filename].append(line)

        for filename, f in writers.items():
            content = contents[filename]
            json.dump(content, f, ensure_ascii=False)
            f.close()

        save_file = self.write_zip(writers)
        for file in writers:
            os.remove(file)
        return save_file

    def create_line(self, record) -> Dict:
        return {"id": record.id, "data": record.data, "label": record.label, **record.metadata}


class JSONLWriter(LineWriter):
    extension = "jsonl"

    def create_line(self, record):
        return json.dumps(
            {"id": record.id, "data": record.data, "label": record.label, **record.metadata}, ensure_ascii=False
        )


class FastTextWriter(LineWriter):
    extension = "txt"

    def create_line(self, record):
        line = [f"__label__{label}" for label in record.label]
        line.sort()
        line.append(record.data)
        line = " ".join(line)
        return line


class IntentAndSlotWriter(LineWriter):
    extension = "jsonl"

    def create_line(self, record):
        if isinstance(record.label, dict):
            return json.dumps(
                {
                    "id": record.id,
                    "text": record.data,
                    "cats": record.label.get("cats", []),
                    "entities": record.label.get("entities", []),
                    **record.metadata,
                },
                ensure_ascii=False,
            )
        else:
            return json.dumps(
                {"id": record.id, "text": record.data, "cats": [], "entities": [], **record.metadata},
                ensure_ascii=False,
            )


class EntityAndRelationWriter(LineWriter):
    extension = "jsonl"

    def create_line(self, record):
        if isinstance(record.label, dict):
            return json.dumps(
                {
                    "id": record.id,
                    "text": record.data,
                    "relations": record.label.get("relations", []),
                    "entities": record.label.get("entities", []),
                    **record.metadata,
                },
                ensure_ascii=False,
            )
        else:
            return json.dumps(
                {"id": record.id, "text": record.data, "relations": [], "entities": [], **record.metadata},
                ensure_ascii=False,
            )