import os
import pathlib
import shutil

from django.core.files import File
from django.test import TestCase, override_settings
from django_drf_filepond.models import StoredUpload, TemporaryUpload
from django_drf_filepond.utils import _get_file_id

from data_import.celery_tasks import import_dataset
from examples.models import Example
from label_types.models import CategoryType, SpanType
from labels.models import Category, Span
from projects.models import (
    DOCUMENT_CLASSIFICATION,
    IMAGE_CLASSIFICATION,
    INTENT_DETECTION_AND_SLOT_FILLING,
    SEQ2SEQ,
    SEQUENCE_LABELING,
)
from projects.tests.utils import prepare_project


@override_settings(MEDIA_ROOT=os.path.join(os.path.dirname(__file__), "data"))
class TestImportData(TestCase):
    task = "Any"
    annotation_class = Category

    def setUp(self):
        self.project = prepare_project(self.task)
        self.user = self.project.admin
        self.data_path = pathlib.Path(__file__).parent / "data"
        self.upload_id = _get_file_id()

    def tearDown(self):
        try:
            su = StoredUpload.objects.get(upload_id=self.upload_id)
            directory = pathlib.Path(su.get_absolute_file_path()).parent
            shutil.rmtree(directory)
        except StoredUpload.DoesNotExist:
            pass

    def import_dataset(self, filename, file_format, kwargs=None):
        file_path = str(self.data_path / filename)
        TemporaryUpload.objects.create(
            upload_id=self.upload_id,
            file_id="1",
            file=File(open(file_path, mode="rb"), filename.split("/")[-1]),
            upload_name=filename,
            upload_type="F",
        )
        upload_ids = [self.upload_id]
        kwargs = kwargs or {}
        return import_dataset(self.user.id, self.project.item.id, file_format, upload_ids, **kwargs)


@override_settings(MAX_UPLOAD_SIZE=0)
class TestMaxFileSize(TestImportData):
    task = DOCUMENT_CLASSIFICATION

    def test_jsonl(self):
        filename = "text_classification/example.jsonl"
        file_format = "JSONL"
        kwargs = {"column_label": "labels"}
        response = self.import_dataset(filename, file_format, kwargs)
        self.assertEqual(len(response["error"]), 1)
        self.assertIn("maximum file size", response["error"][0]["message"])


class TestImportClassificationData(TestImportData):
    task = DOCUMENT_CLASSIFICATION

    def assert_examples(self, dataset):
        self.assertEqual(Example.objects.count(), len(dataset))
        for text, expected_labels in dataset:
            example = Example.objects.get(text=text)
            labels = set(cat.label.text for cat in example.categories.all())
            self.assertEqual(labels, set(expected_labels))

    def assert_parse_error(self, response):
        self.assertGreaterEqual(len(response["error"]), 1)
        self.assertEqual(Example.objects.count(), 0)
        self.assertEqual(CategoryType.objects.count(), 0)
        self.assertEqual(Category.objects.count(), 0)

    def test_jsonl(self):
        filename = "text_classification/example.jsonl"
        file_format = "JSONL"
        kwargs = {"column_label": "labels"}
        dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
        self.import_dataset(filename, file_format, kwargs)
        self.assert_examples(dataset)

    def test_csv(self):
        filename = "text_classification/example.csv"
        file_format = "CSV"
        dataset = [("exampleA", ["positive"]), ("exampleB", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_csv_out_of_order_columns(self):
        filename = "text_classification/example_out_of_order_columns.csv"
        file_format = "CSV"
        dataset = [("exampleA", ["positive"]), ("exampleB", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_fasttext(self):
        filename = "text_classification/example_fasttext.txt"
        file_format = "fastText"
        dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_excel(self):
        filename = "text_classification/example.xlsx"
        file_format = "Excel"
        dataset = [("exampleA", ["positive"]), ("exampleB", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_json(self):
        filename = "text_classification/example.json"
        file_format = "JSON"
        dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_textfile(self):
        filename = "example.txt"
        file_format = "TextFile"
        dataset = [("exampleA\nexampleB\n\nexampleC\n", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_textline(self):
        filename = "example.txt"
        file_format = "TextLine"
        dataset = [("exampleA", []), ("exampleB", []), ("exampleC", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_wrong_jsonl(self):
        filename = "text_classification/example.json"
        file_format = "JSONL"
        response = self.import_dataset(filename, file_format)
        self.assert_parse_error(response)

    def test_wrong_json(self):
        filename = "text_classification/example.jsonl"
        file_format = "JSON"
        response = self.import_dataset(filename, file_format)
        self.assert_parse_error(response)

    def test_wrong_excel(self):
        filename = "text_classification/example.jsonl"
        file_format = "Excel"
        response = self.import_dataset(filename, file_format)
        self.assert_parse_error(response)

    def test_wrong_csv(self):
        filename = "text_classification/example.jsonl"
        file_format = "CSV"
        response = self.import_dataset(filename, file_format)
        self.assert_parse_error(response)


class TestImportSequenceLabelingData(TestImportData):
    task = SEQUENCE_LABELING

    def assert_examples(self, dataset):
        self.assertEqual(Example.objects.count(), len(dataset))
        for text, expected_labels in dataset:
            example = Example.objects.get(text=text)
            labels = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
            self.assertEqual(labels, expected_labels)

    def assert_parse_error(self, response):
        self.assertGreaterEqual(len(response["error"]), 1)
        self.assertEqual(Example.objects.count(), 0)
        self.assertEqual(SpanType.objects.count(), 0)
        self.assertEqual(Span.objects.count(), 0)

    def test_jsonl(self):
        filename = "sequence_labeling/example.jsonl"
        file_format = "JSONL"
        dataset = [("exampleA", [[0, 1, "LOC"]]), ("exampleB", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_conll(self):
        filename = "sequence_labeling/example.conll"
        file_format = "CoNLL"
        dataset = [("JAPAN GET", [[0, 5, "LOC"]]), ("Nadim Ladki", [[0, 11, "PER"]])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_wrong_conll(self):
        filename = "sequence_labeling/example.jsonl"
        file_format = "CoNLL"
        response = self.import_dataset(filename, file_format)
        self.assert_parse_error(response)

    def test_jsonl_with_overlapping(self):
        filename = "sequence_labeling/example_overlapping.jsonl"
        file_format = "JSONL"
        response = self.import_dataset(filename, file_format)
        self.assertEqual(len(response["error"]), 1)


class TestImportSeq2seqData(TestImportData):
    task = SEQ2SEQ

    def assert_examples(self, dataset):
        self.assertEqual(Example.objects.count(), len(dataset))
        for text, expected_labels in dataset:
            example = Example.objects.get(text=text)
            labels = set(text_label.text for text_label in example.texts.all())
            self.assertEqual(labels, set(expected_labels))

    def test_jsonl(self):
        filename = "seq2seq/example.jsonl"
        file_format = "JSONL"
        dataset = [("exampleA", ["label1"]), ("exampleB", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_json(self):
        filename = "seq2seq/example.json"
        file_format = "JSON"
        dataset = [("exampleA", ["label1"]), ("exampleB", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)

    def test_csv(self):
        filename = "seq2seq/example.csv"
        file_format = "CSV"
        dataset = [("exampleA", ["label1"]), ("exampleB", [])]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)


class TestImportIntentDetectionAndSlotFillingData(TestImportData):
    task = INTENT_DETECTION_AND_SLOT_FILLING

    def assert_examples(self, dataset):
        self.assertEqual(Example.objects.count(), len(dataset))
        for text, expected_labels in dataset:
            example = Example.objects.get(text=text)
            cats = set(cat.label.text for cat in example.categories.all())
            entities = [(span.start_offset, span.end_offset, span.label.text) for span in example.spans.all()]
            self.assertEqual(cats, set(expected_labels["cats"]))
            self.assertEqual(entities, expected_labels["entities"])

    def test_entities_and_cats(self):
        filename = "intent/example.jsonl"
        file_format = "JSONL"
        dataset = [
            ("exampleA", {"cats": ["positive"], "entities": [(0, 1, "LOC")]}),
            ("exampleB", {"cats": ["positive"], "entities": []}),
            ("exampleC", {"cats": [], "entities": [(0, 1, "LOC")]}),
            ("exampleD", {"cats": [], "entities": []}),
        ]
        self.import_dataset(filename, file_format)
        self.assert_examples(dataset)


class TestImportImageClassificationData(TestImportData):
    task = IMAGE_CLASSIFICATION

    def test_example(self):
        filename = "images/1500x500.jpeg"
        file_format = "ImageFile"
        self.import_dataset(filename, file_format)
        self.assertEqual(Example.objects.count(), 1)


@override_settings(ENABLE_FILE_TYPE_CHECK=True)
class TestFileTypeChecking(TestImportData):
    task = IMAGE_CLASSIFICATION

    def test_example(self):
        filename = "images/example.ico"
        file_format = "ImageFile"
        response = self.import_dataset(filename, file_format)
        self.assertEqual(len(response["error"]), 1)
        self.assertIn("unexpected", response["error"][0]["message"])