from typing import Dict, List, Type

from django.db.models import QuerySet

from . import catalog, formatters, labels, writers
from .labels import Labels
from data_export.models import ExportedExample
from projects.models import (
    DOCUMENT_CLASSIFICATION,
    IMAGE_CLASSIFICATION,
    INTENT_DETECTION_AND_SLOT_FILLING,
    SEQ2SEQ,
    SEQUENCE_LABELING,
    SPEECH2TEXT,
)


def select_writer(file_format: str) -> Type[writers.Writer]:
    mapping = {
        catalog.CSV.name: writers.CsvWriter,
        catalog.JSON.name: writers.JsonWriter,
        catalog.JSONL.name: writers.JsonlWriter,
        # catalog.FastText.name: writers.FastTextWriter,
    }
    if file_format not in mapping:
        ValueError(f"Invalid format: {file_format}")
    return mapping[file_format]


def select_formatter(project, file_format: str) -> List[Type[formatters.Formatter]]:
    use_relation = getattr(project, "use_relation", False)
    mapping: Dict[str, Dict[str, List[Type[formatters.Formatter]]]] = {
        DOCUMENT_CLASSIFICATION: {
            catalog.CSV.name: [formatters.JoinedCategoryFormatter],
            catalog.JSON.name: [formatters.ListedCategoryFormatter],
            catalog.JSONL.name: [formatters.ListedCategoryFormatter],
        },
        SEQUENCE_LABELING: {
            catalog.JSONL.name: [formatters.DictFormatter, formatters.DictFormatter]
            if use_relation
            else [formatters.TupledSpanFormatter]
        },
        SEQ2SEQ: {
            catalog.CSV.name: [formatters.JoinedCategoryFormatter],
            catalog.JSON.name: [formatters.ListedCategoryFormatter],
            catalog.JSONL.name: [formatters.ListedCategoryFormatter],
        },
        IMAGE_CLASSIFICATION: {
            catalog.JSONL.name: [formatters.ListedCategoryFormatter],
        },
        SPEECH2TEXT: {
            catalog.JSONL.name: [formatters.ListedCategoryFormatter],
        },
        INTENT_DETECTION_AND_SLOT_FILLING: {
            catalog.JSONL.name: [formatters.ListedCategoryFormatter, formatters.TupledSpanFormatter]
        },
    }
    return mapping[project.project_type][file_format]


def select_label_collection(project):
    use_relation = getattr(project, "use_relation", False)
    mapping = {
        DOCUMENT_CLASSIFICATION: [labels.Categories],
        SEQUENCE_LABELING: [labels.Spans, labels.Relations] if use_relation else [labels.Spans],
        SEQ2SEQ: [labels.Texts],
        IMAGE_CLASSIFICATION: [labels.Categories],
        SPEECH2TEXT: [labels.Texts],
        INTENT_DETECTION_AND_SLOT_FILLING: [labels.Categories, labels.Spans],
    }
    return mapping[project.project_type]


def create_labels(label_collection_class: Type[Labels], examples: QuerySet[ExportedExample], user=None):
    return label_collection_class(examples, user)