import os import shutil import tempfile import unittest from ...views.upload.data import TextData from ...views.upload.dataset import Dataset from ...views.upload.label import Label class TestDataset(unittest.TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() self.test_file = os.path.join(self.test_dir, 'test_file.txt') self.content = 'こんにちは、世界!' def tearDown(self): shutil.rmtree(self.test_dir) def create_file(self, encoding=None): with open(self.test_file, 'w', encoding=encoding) as f: f.write(self.content) def test_can_load_utf8(self): self.create_file() dataset = Dataset(filenames=[], label_class=Label, data_class=TextData) record = next(dataset.load(self.test_file)) self.assertEqual(record.data['filename'], self.test_file) def test_cannot_load_shiftjis_without_specifying_encoding(self): self.create_file('shift_jis') dataset = Dataset(filenames=[], label_class=Label, data_class=TextData) with self.assertRaises(UnicodeDecodeError): next(dataset.load(self.test_file)) def test_can_load_shiftjis_with_specifying_encoding(self): self.create_file('shift_jis') dataset = Dataset(filenames=[], label_class=Label, data_class=TextData, encoding='shift_jis') record = next(dataset.load(self.test_file)) self.assertEqual(record.data['filename'], self.test_file)