import csv
import io
import itertools
import json
import re
from collections import defaultdict
from random import Random

from django.db import transaction
from rest_framework.renderers import JSONRenderer
from seqeval.metrics.sequence_labeling import get_entities

from app.settings import IMPORT_BATCH_SIZE
from .exceptions import FileParseException
from .models import Label
from .serializers import DocumentSerializer, LabelSerializer


def extract_label(tag):
    ptn = re.compile(r'(B|I|E|S)-(.+)')
    m = ptn.match(tag)
    if m:
        return m.groups()[1]
    else:
        return tag


class BaseStorage(object):

    def __init__(self, data, project):
        self.data = data
        self.project = project

    @transaction.atomic
    def save(self, user):
        raise NotImplementedError()

    def save_doc(self, data):
        serializer = DocumentSerializer(data=data, many=True)
        serializer.is_valid(raise_exception=True)
        doc = serializer.save(project=self.project)
        return doc

    def save_label(self, data):
        serializer = LabelSerializer(data=data, many=True)
        serializer.is_valid(raise_exception=True)
        label = serializer.save(project=self.project)
        return label

    def save_annotation(self, data, user):
        annotation_serializer = self.project.get_annotation_serializer()
        serializer = annotation_serializer(data=data, many=True)
        serializer.is_valid(raise_exception=True)
        annotation = serializer.save(user=user)
        return annotation

    @classmethod
    def extract_label(cls, data):
        return [d.get('labels', []) for d in data]

    @classmethod
    def exclude_created_labels(cls, labels, created):
        return [label for label in labels if label not in created]

    @classmethod
    def to_serializer_format(cls, labels, created, random_seed=None):
        existing_shortkeys = {(label.suffix_key, label.prefix_key)
                              for label in created.values()}

        serializer_labels = []

        for label in sorted(labels):
            serializer_label = {'text': label}

            shortkey = cls.get_shortkey(label, existing_shortkeys)
            if shortkey:
                serializer_label['suffix_key'] = shortkey[0]
                serializer_label['prefix_key'] = shortkey[1]
                existing_shortkeys.add(shortkey)

            color = Color.random(seed=random_seed)
            serializer_label['background_color'] = color.hex
            serializer_label['text_color'] = color.contrast_color.hex

            serializer_labels.append(serializer_label)

        return serializer_labels

    @classmethod
    def get_shortkey(cls, label, existing_shortkeys):
        model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS]
        prefix_keys = [None] + model_prefix_keys

        model_suffix_keys = {key for (key, _) in Label.SUFFIX_KEYS}
        suffix_keys = [key for key in label.lower() if key in model_suffix_keys]

        for shortkey in itertools.product(suffix_keys, prefix_keys):
            if shortkey not in existing_shortkeys:
                return shortkey

        return None

    @classmethod
    def update_saved_labels(cls, saved, new):
        for label in new:
            saved[label.text] = label
        return saved


class PlainStorage(BaseStorage):

    @transaction.atomic
    def save(self, user):
        for text in self.data:
            self.save_doc(text)


class ClassificationStorage(BaseStorage):
    """Store json for text classification.

    The format is as follows:
    {"text": "Python is awesome!", "labels": ["positive"]}
    ...
    """
    @transaction.atomic
    def save(self, user):
        saved_labels = {label.text: label for label in self.project.labels.all()}
        for data in self.data:
            docs = self.save_doc(data)
            labels = self.extract_label(data)
            unique_labels = self.extract_unique_labels(labels)
            unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
            unique_labels = self.to_serializer_format(unique_labels, saved_labels)
            new_labels = self.save_label(unique_labels)
            saved_labels = self.update_saved_labels(saved_labels, new_labels)
            annotations = self.make_annotations(docs, labels, saved_labels)
            self.save_annotation(annotations, user)

    @classmethod
    def extract_unique_labels(cls, labels):
        return set(itertools.chain(*labels))

    @classmethod
    def make_annotations(cls, docs, labels, saved_labels):
        annotations = []
        for doc, label in zip(docs, labels):
            for name in label:
                label = saved_labels[name]
                annotations.append({'document': doc.id, 'label': label.id})
        return annotations


class SequenceLabelingStorage(BaseStorage):
    """Upload jsonl for sequence labeling.

    The format is as follows:
    {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
    ...
    """
    @transaction.atomic
    def save(self, user):
        saved_labels = {label.text: label for label in self.project.labels.all()}
        for data in self.data:
            docs = self.save_doc(data)
            labels = self.extract_label(data)
            unique_labels = self.extract_unique_labels(labels)
            unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
            unique_labels = self.to_serializer_format(unique_labels, saved_labels)
            new_labels = self.save_label(unique_labels)
            saved_labels = self.update_saved_labels(saved_labels, new_labels)
            annotations = self.make_annotations(docs, labels, saved_labels)
            self.save_annotation(annotations, user)

    @classmethod
    def extract_unique_labels(cls, labels):
        return set([label for _, _, label in itertools.chain(*labels)])

    @classmethod
    def make_annotations(cls, docs, labels, saved_labels):
        annotations = []
        for doc, spans in zip(docs, labels):
            for span in spans:
                start_offset, end_offset, name = span
                label = saved_labels[name]
                annotations.append({'document': doc.id,
                                    'label': label.id,
                                    'start_offset': start_offset,
                                    'end_offset': end_offset})
        return annotations


class Seq2seqStorage(BaseStorage):
    """Store json for seq2seq.

    The format is as follows:
    {"text": "Hello, World!", "labels": ["こんにちは、世界!"]}
    ...
    """
    @transaction.atomic
    def save(self, user):
        for data in self.data:
            doc = self.save_doc(data)
            labels = self.extract_label(data)
            annotations = self.make_annotations(doc, labels)
            self.save_annotation(annotations, user)

    @classmethod
    def make_annotations(cls, docs, labels):
        annotations = []
        for doc, texts in zip(docs, labels):
            for text in texts:
                annotations.append({'document': doc.id, 'text': text})
        return annotations


class FileParser(object):

    def parse(self, file):
        raise NotImplementedError()


class CoNLLParser(FileParser):
    """Uploads CoNLL format file.

    The file format is tab-separated values.
    A blank line is required at the end of a sentence.
    For example:
    ```
    EU	B-ORG
    rejects	O
    German	B-MISC
    call	O
    to	O
    boycott	O
    British	B-MISC
    lamb	O
    .	O

    Peter	B-PER
    Blackburn	I-PER
    ...
    ```
    """
    def parse(self, file):
        """Store json for seq2seq.

        Return format:
        {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
        ...
        """
        words, tags = [], []
        data = []
        for i, line in enumerate(file, start=1):
            if len(data) >= IMPORT_BATCH_SIZE:
                yield data
                data = []
            line = line.decode('utf-8')
            line = line.strip()
            if line:
                try:
                    word, tag = line.split('\t')
                except ValueError:
                    raise FileParseException(line_num=i, line=line)
                words.append(word)
                tags.append(tag)
            else:
                j = self.calc_char_offset(words, tags)
                data.append(j)
                words, tags = [], []
        if len(words) > 0:
            j = self.calc_char_offset(words, tags)
            data.append(j)
            yield data

    @classmethod
    def calc_char_offset(cls, words, tags):
        doc = ' '.join(words)
        j = {'text': ' '.join(words), 'labels': []}
        pos = defaultdict(int)
        for label, start_offset, end_offset in get_entities(tags):
            entity = ' '.join(words[start_offset: end_offset + 1])
            char_left = doc.index(entity, pos[entity])
            char_right = char_left + len(entity)
            span = [char_left, char_right, label]
            j['labels'].append(span)
            pos[entity] = char_right
        return j


class PlainTextParser(FileParser):
    """Uploads plain text.

    The file format is as follows:
    ```
    EU rejects German call to boycott British lamb.
    President Obama is speaking at the White House.
    ...
    ```
    """
    def parse(self, file):
        file = io.TextIOWrapper(file, encoding='utf-8')
        while True:
            batch = list(itertools.islice(file, IMPORT_BATCH_SIZE))
            if not batch:
                break
            yield [{'text': line.strip()} for line in batch]


class CSVParser(FileParser):
    """Uploads csv file.

    The file format is comma separated values.
    Column names are required at the top of a file.
    For example:
    ```
    text, label
    "EU rejects German call to boycott British lamb.",Politics
    "President Obama is speaking at the White House.",Politics
    "He lives in Newark, Ohio.",Other
    ...
    ```
    """
    def parse(self, file):
        file = io.TextIOWrapper(file, encoding='utf-8')
        reader = csv.reader(file)
        columns = next(reader)
        data = []
        for i, row in enumerate(reader, start=2):
            if len(data) >= IMPORT_BATCH_SIZE:
                yield data
                data = []
            if len(row) == len(columns) and len(row) >= 2:
                text, label = row[:2]
                meta = json.dumps(dict(zip(columns[2:], row[2:])))
                j = {'text': text, 'labels': [label], 'meta': meta}
                data.append(j)
            else:
                raise FileParseException(line_num=i, line=row)
        if data:
            yield data


class JSONParser(FileParser):

    def parse(self, file):
        data = []
        for i, line in enumerate(file, start=1):
            if len(data) >= IMPORT_BATCH_SIZE:
                yield data
                data = []
            try:
                j = json.loads(line)
                j['meta'] = json.dumps(j.get('meta', {}))
                data.append(j)
            except json.decoder.JSONDecodeError:
                raise FileParseException(line_num=i, line=line)
        if data:
            yield data


class JSONLRenderer(JSONRenderer):

    def render(self, data, accepted_media_type=None, renderer_context=None):
        """
        Render `data` into JSON, returning a bytestring.
        """
        if data is None:
            return bytes()

        if not isinstance(data, list):
            data = [data]

        for d in data:
            yield json.dumps(d,
                             cls=self.encoder_class,
                             ensure_ascii=self.ensure_ascii,
                             allow_nan=not self.strict) + '\n'


class JSONPainter(object):

    def paint(self, documents):
        serializer = DocumentSerializer(documents, many=True)
        data = []
        for d in serializer.data:
            d['meta'] = json.loads(d['meta'])
            for a in d['annotations']:
                a.pop('id')
                a.pop('prob')
                a.pop('document')
            data.append(d)
        return data


class CSVPainter(JSONPainter):

    def paint(self, documents):
        data = super().paint(documents)
        res = []
        for d in data:
            annotations = d.pop('annotations')
            for a in annotations:
                res.append({**d, **a})
        return res


class Color:
    def __init__(self, red, green, blue):
        self.red = red
        self.green = green
        self.blue = blue

    @property
    def contrast_color(self):
        """Generate black or white color.

        Ensure that text and background color combinations provide
        sufficient contrast when viewed by someone having color deficits or
        when viewed on a black and white screen.

        Algorithm from w3c:
        * https://www.w3.org/TR/AERT/#color-contrast
        """
        return Color.white() if self.brightness < 128 else Color.black()

    @property
    def brightness(self):
        return ((self.red * 299) + (self.green * 587) + (self.blue * 114)) / 1000

    @property
    def hex(self):
        return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)

    @classmethod
    def white(cls):
        return cls(red=255, green=255, blue=255)

    @classmethod
    def black(cls):
        return cls(red=0, green=0, blue=0)

    @classmethod
    def random(cls, seed=None):
        rgb = Random(seed).choices(range(256), k=3)
        return cls(*rgb)