import abc
from typing import List, Type

from django.contrib.auth.models import User

from .models import DummyLabelType
from .pipeline.catalog import RELATION_EXTRACTION, Format
from .pipeline.data import BaseData, BinaryData, TextData
from .pipeline.exceptions import FileParseException
from .pipeline.factories import create_parser
from .pipeline.label import CategoryLabel, Label, RelationLabel, SpanLabel, TextLabel
from .pipeline.label_types import LabelTypes
from .pipeline.labels import Categories, Labels, Relations, Spans, Texts
from .pipeline.makers import BinaryExampleMaker, ExampleMaker, LabelMaker
from .pipeline.readers import (
    DEFAULT_LABEL_COLUMN,
    DEFAULT_TEXT_COLUMN,
    FileName,
    Reader,
)
from examples.models import Example
from label_types.models import CategoryType, LabelType, RelationType, SpanType
from projects.models import (
    DOCUMENT_CLASSIFICATION,
    IMAGE_CLASSIFICATION,
    INTENT_DETECTION_AND_SLOT_FILLING,
    SEQ2SEQ,
    SEQUENCE_LABELING,
    SPEECH2TEXT,
    Project,
)


class Dataset(abc.ABC):
    def __init__(self, reader: Reader, project: Project, **kwargs):
        self.reader = reader
        self.project = project
        self.kwargs = kwargs

    def save(self, user: User, batch_size: int = 1000):
        raise NotImplementedError()

    @property
    def errors(self) -> List[FileParseException]:
        raise NotImplementedError()


class PlainDataset(Dataset):
    def __init__(self, reader: Reader, project: Project, **kwargs):
        super().__init__(reader, project, **kwargs)
        self.example_maker = ExampleMaker(project=project, data_class=TextData)

    def save(self, user: User, batch_size: int = 1000):
        for records in self.reader.batch(batch_size):
            examples = self.example_maker.make(records)
            Example.objects.bulk_create(examples)

    @property
    def errors(self) -> List[FileParseException]:
        return self.reader.errors + self.example_maker.errors


class DatasetWithSingleLabelType(Dataset):
    data_class: Type[BaseData]
    label_class: Type[Label]
    label_type = LabelType
    labels_class = Labels

    def __init__(self, reader: Reader, project: Project, **kwargs):
        super().__init__(reader, project, **kwargs)
        self.types = LabelTypes(self.label_type)
        self.example_maker = ExampleMaker(
            project=project,
            data_class=self.data_class,
            column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
            exclude_columns=[kwargs.get("column_label") or DEFAULT_LABEL_COLUMN],
        )
        self.label_maker = LabelMaker(
            column=kwargs.get("column_label") or DEFAULT_LABEL_COLUMN, label_class=self.label_class
        )

    def save(self, user: User, batch_size: int = 1000):
        for records in self.reader.batch(batch_size):
            # create examples
            examples = self.example_maker.make(records)
            Example.objects.bulk_create(examples)

            # create label types
            labels = self.labels_class(self.label_maker.make(records), self.types)
            labels.clean(self.project)
            labels.save_types(self.project)

            # create Labels
            labels.save(user)

    @property
    def errors(self) -> List[FileParseException]:
        return self.reader.errors + self.example_maker.errors + self.label_maker.errors


class BinaryDataset(Dataset):
    def __init__(self, reader: Reader, project: Project, **kwargs):
        super().__init__(reader, project, **kwargs)
        self.example_maker = BinaryExampleMaker(project=project, data_class=BinaryData)

    def save(self, user: User, batch_size: int = 1000):
        for records in self.reader.batch(batch_size):
            examples = self.example_maker.make(records)
            Example.objects.bulk_create(examples)

    @property
    def errors(self) -> List[FileParseException]:
        return self.reader.errors + self.example_maker.errors


class TextClassificationDataset(DatasetWithSingleLabelType):
    data_class = TextData
    label_class = CategoryLabel
    label_type = CategoryType
    labels_class = Categories


class SequenceLabelingDataset(DatasetWithSingleLabelType):
    data_class = TextData
    label_class = SpanLabel
    label_type = SpanType
    labels_class = Spans


class Seq2seqDataset(DatasetWithSingleLabelType):
    data_class = TextData
    label_class = TextLabel
    label_type = DummyLabelType
    labels_class = Texts


class RelationExtractionDataset(Dataset):
    def __init__(self, reader: Reader, project: Project, **kwargs):
        super().__init__(reader, project, **kwargs)
        self.span_types = LabelTypes(SpanType)
        self.relation_types = LabelTypes(RelationType)
        self.example_maker = ExampleMaker(
            project=project,
            data_class=TextData,
            column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
            exclude_columns=["entities", "relations"],
        )
        self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)
        self.relation_maker = LabelMaker(column="relations", label_class=RelationLabel)

    def save(self, user: User, batch_size: int = 1000):
        for records in self.reader.batch(batch_size):
            # create examples
            examples = self.example_maker.make(records)
            Example.objects.bulk_create(examples)

            # create label types
            spans = Spans(self.span_maker.make(records), self.span_types)
            spans.clean(self.project)
            spans.save_types(self.project)

            relations = Relations(self.relation_maker.make(records), self.relation_types)
            relations.clean(self.project)
            relations.save_types(self.project)

            # create Labels
            spans.save(user)
            relations.save(user, spans=spans)

    @property
    def errors(self) -> List[FileParseException]:
        return self.reader.errors + self.example_maker.errors + self.span_maker.errors + self.relation_maker.errors


class CategoryAndSpanDataset(Dataset):
    def __init__(self, reader: Reader, project: Project, **kwargs):
        super().__init__(reader, project, **kwargs)
        self.category_types = LabelTypes(CategoryType)
        self.span_types = LabelTypes(SpanType)
        self.example_maker = ExampleMaker(
            project=project,
            data_class=TextData,
            column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
            exclude_columns=["cats", "entities"],
        )
        self.category_maker = LabelMaker(column="cats", label_class=CategoryLabel)
        self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)

    def save(self, user: User, batch_size: int = 1000):
        for records in self.reader.batch(batch_size):
            # create examples
            examples = self.example_maker.make(records)
            Example.objects.bulk_create(examples)

            # create label types
            categories = Categories(self.category_maker.make(records), self.category_types)
            categories.clean(self.project)
            categories.save_types(self.project)

            spans = Spans(self.span_maker.make(records), self.span_types)
            spans.clean(self.project)
            spans.save_types(self.project)

            # create Labels
            categories.save(user)
            spans.save(user)

    @property
    def errors(self) -> List[FileParseException]:
        return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors


def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]:
    mapping = {
        DOCUMENT_CLASSIFICATION: TextClassificationDataset,
        SEQUENCE_LABELING: SequenceLabelingDataset,
        RELATION_EXTRACTION: RelationExtractionDataset,
        SEQ2SEQ: Seq2seqDataset,
        INTENT_DETECTION_AND_SLOT_FILLING: CategoryAndSpanDataset,
        IMAGE_CLASSIFICATION: BinaryDataset,
        SPEECH2TEXT: BinaryDataset,
    }
    if task not in mapping:
        task = project.project_type
    if project.is_text_project and file_format.is_plain_text():
        return PlainDataset
    return mapping[task]


def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
    parser = create_parser(file_format, **kwargs)
    reader = Reader(data_files, parser)
    dataset_class = select_dataset(project, task, file_format)
    return dataset_class(reader, project, **kwargs)