import os
import shutil
import tempfile
import unittest

from ...views.upload.data import TextData
from ...views.upload.dataset import FastTextDataset
from ...views.upload.label import CategoryLabel


class TestFastTextDataset(unittest.TestCase):

    def setUp(self):
        self.test_dir = tempfile.mkdtemp()
        self.test_file = os.path.join(self.test_dir, 'test_file.txt')

    def tearDown(self):
        shutil.rmtree(self.test_dir)

    def create_file(self, content):
        with open(self.test_file, 'w') as f:
            f.write(content)

    def assert_record(self, content, dataset, data='Text', label=None):
        if label is None:
            label = [{'text': 'Label'}]
        self.create_file(content)
        record = next(dataset.load(self.test_file))
        self.assertEqual(record.data['text'], data)
        self.assertEqual(record.label, label)

    def test_can_load_default_column_names(self):
        content = '__label__sauce __label__cheese Text'
        dataset = FastTextDataset(filenames=[], label_class=CategoryLabel, data_class=TextData)
        label = [{'text': 'sauce'}, {'text': 'cheese'}]
        self.assert_record(content, dataset, label=label)