diff --git a/app/api/tests/upload/__init__.py b/app/api/tests/upload/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/api/tests/upload/test_csv.py b/app/api/tests/upload/test_csv.py new file mode 100644 index 00000000..e3f65fd9 --- /dev/null +++ b/app/api/tests/upload/test_csv.py @@ -0,0 +1,53 @@ +import os +import shutil +import tempfile +import unittest + +from ...views.upload.dataset import CsvDataset + + +class TestCsvDataset(unittest.TestCase): + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.test_file = os.path.join(self.test_dir, 'test_file.csv') + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def create_file(self, content): + with open(self.test_file, 'w') as f: + f.write(content) + + def assert_record(self, content, dataset, data='Text', label=None): + if label is None: + label = ['Label'] + self.create_file(content) + record = next(dataset.load(self.test_file)) + self.assertEqual(record.data, data) + self.assertEqual(record.label, label) + + def test_can_load_default_column_names(self): + content = 'label,text\nLabel,Text' + dataset = CsvDataset(filenames=[]) + self.assert_record(content, dataset) + + def test_can_change_delimiter(self): + content = 'label\ttext\nLabel\tText' + dataset = CsvDataset(filenames=[], delimiter='\t') + self.assert_record(content, dataset) + + def test_can_specify_column_name(self): + content = 'star,body\nLabel,Text' + dataset = CsvDataset(filenames=[], column_data='body', column_label='star') + self.assert_record(content, dataset) + + def test_can_load_only_text_column(self): + content = 'star,text\nLabel,Text' + dataset = CsvDataset(filenames=[]) + self.assert_record(content, dataset, label=[]) + + def test_does_not_match_column_and_row(self): + content = 'text,label\nText' + dataset = CsvDataset(filenames=[]) + self.assert_record(content, dataset, label=[]) diff --git a/app/api/tests/upload/test_dataset.py b/app/api/tests/upload/test_dataset.py new file mode 100644 index 00000000..662114b7 --- /dev/null +++ b/app/api/tests/upload/test_dataset.py @@ -0,0 +1,39 @@ +import os +import shutil +import tempfile +import unittest + +from ...views.upload.dataset import Dataset + + +class TestDataset(unittest.TestCase): + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.test_file = os.path.join(self.test_dir, 'test_file.txt') + self.content = 'こんにちは、世界!' + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def create_file(self, encoding=None): + with open(self.test_file, 'w', encoding=encoding) as f: + f.write(self.content) + + def test_can_load_utf8(self): + self.create_file() + dataset = Dataset(filenames=[]) + record = next(dataset.load(self.test_file)) + self.assertEqual(record.filename, self.test_file) + + def test_cannot_load_shiftjis_without_specifying_encoding(self): + self.create_file('shift_jis') + dataset = Dataset(filenames=[]) + with self.assertRaises(UnicodeDecodeError): + next(dataset.load(self.test_file)) + + def test_can_load_shiftjis_with_specifying_encoding(self): + self.create_file('shift_jis') + dataset = Dataset(filenames=[], encoding='shift_jis') + record = next(dataset.load(self.test_file)) + self.assertEqual(record.filename, self.test_file) diff --git a/app/api/tests/upload/test_fasttext.py b/app/api/tests/upload/test_fasttext.py new file mode 100644 index 00000000..84b2176e --- /dev/null +++ b/app/api/tests/upload/test_fasttext.py @@ -0,0 +1,33 @@ +import os +import shutil +import tempfile +import unittest + +from ...views.upload.dataset import FastTextDataset + + +class TestFastTextDataset(unittest.TestCase): + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.test_file = os.path.join(self.test_dir, 'test_file.txt') + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def create_file(self, content): + with open(self.test_file, 'w') as f: + f.write(content) + + def assert_record(self, content, dataset, data='Text', label=None): + if label is None: + label = ['Label'] + self.create_file(content) + record = next(dataset.load(self.test_file)) + self.assertEqual(record.data, data) + self.assertEqual(record.label, label) + + def test_can_load_default_column_names(self): + content = '__label__sauce __label__cheese Text' + dataset = FastTextDataset(filenames=[]) + self.assert_record(content, dataset, label=['sauce', 'cheese'])