From 5b67cd3322de4744d1d4486c6453b34f4f9b0bc3 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Tue, 6 Apr 2021 10:07:25 +0900 Subject: [PATCH] Update test cases --- app/api/tests/upload/test_csv.py | 13 +++++++------ app/api/tests/upload/test_dataset.py | 7 ++++--- app/api/tests/upload/test_fasttext.py | 8 +++++--- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/app/api/tests/upload/test_csv.py b/app/api/tests/upload/test_csv.py index e3f65fd9..53ab82e7 100644 --- a/app/api/tests/upload/test_csv.py +++ b/app/api/tests/upload/test_csv.py @@ -4,6 +4,7 @@ import tempfile import unittest from ...views.upload.dataset import CsvDataset +from ...views.upload.label import CategoryLabel class TestCsvDataset(unittest.TestCase): @@ -21,7 +22,7 @@ class TestCsvDataset(unittest.TestCase): def assert_record(self, content, dataset, data='Text', label=None): if label is None: - label = ['Label'] + label = [CategoryLabel(label='Label')] self.create_file(content) record = next(dataset.load(self.test_file)) self.assertEqual(record.data, data) @@ -29,25 +30,25 @@ class TestCsvDataset(unittest.TestCase): def test_can_load_default_column_names(self): content = 'label,text\nLabel,Text' - dataset = CsvDataset(filenames=[]) + dataset = CsvDataset(filenames=[], label_class=CategoryLabel) self.assert_record(content, dataset) def test_can_change_delimiter(self): content = 'label\ttext\nLabel\tText' - dataset = CsvDataset(filenames=[], delimiter='\t') + dataset = CsvDataset(filenames=[], label_class=CategoryLabel, delimiter='\t') self.assert_record(content, dataset) def test_can_specify_column_name(self): content = 'star,body\nLabel,Text' - dataset = CsvDataset(filenames=[], column_data='body', column_label='star') + dataset = CsvDataset(filenames=[], label_class=CategoryLabel, column_data='body', column_label='star') self.assert_record(content, dataset) def test_can_load_only_text_column(self): content = 'star,text\nLabel,Text' - dataset = CsvDataset(filenames=[]) + dataset = CsvDataset(filenames=[], label_class=CategoryLabel) self.assert_record(content, dataset, label=[]) def test_does_not_match_column_and_row(self): content = 'text,label\nText' - dataset = CsvDataset(filenames=[]) + dataset = CsvDataset(filenames=[], label_class=CategoryLabel) self.assert_record(content, dataset, label=[]) diff --git a/app/api/tests/upload/test_dataset.py b/app/api/tests/upload/test_dataset.py index 662114b7..46f78e1e 100644 --- a/app/api/tests/upload/test_dataset.py +++ b/app/api/tests/upload/test_dataset.py @@ -4,6 +4,7 @@ import tempfile import unittest from ...views.upload.dataset import Dataset +from ...views.upload.label import Label class TestDataset(unittest.TestCase): @@ -22,18 +23,18 @@ class TestDataset(unittest.TestCase): def test_can_load_utf8(self): self.create_file() - dataset = Dataset(filenames=[]) + dataset = Dataset(filenames=[], label_class=Label) record = next(dataset.load(self.test_file)) self.assertEqual(record.filename, self.test_file) def test_cannot_load_shiftjis_without_specifying_encoding(self): self.create_file('shift_jis') - dataset = Dataset(filenames=[]) + dataset = Dataset(filenames=[], label_class=Label) with self.assertRaises(UnicodeDecodeError): next(dataset.load(self.test_file)) def test_can_load_shiftjis_with_specifying_encoding(self): self.create_file('shift_jis') - dataset = Dataset(filenames=[], encoding='shift_jis') + dataset = Dataset(filenames=[], label_class=Label, encoding='shift_jis') record = next(dataset.load(self.test_file)) self.assertEqual(record.filename, self.test_file) diff --git a/app/api/tests/upload/test_fasttext.py b/app/api/tests/upload/test_fasttext.py index 84b2176e..766e69ad 100644 --- a/app/api/tests/upload/test_fasttext.py +++ b/app/api/tests/upload/test_fasttext.py @@ -4,6 +4,7 @@ import tempfile import unittest from ...views.upload.dataset import FastTextDataset +from ...views.upload.label import CategoryLabel class TestFastTextDataset(unittest.TestCase): @@ -21,7 +22,7 @@ class TestFastTextDataset(unittest.TestCase): def assert_record(self, content, dataset, data='Text', label=None): if label is None: - label = ['Label'] + label = [CategoryLabel(label='Label')] self.create_file(content) record = next(dataset.load(self.test_file)) self.assertEqual(record.data, data) @@ -29,5 +30,6 @@ class TestFastTextDataset(unittest.TestCase): def test_can_load_default_column_names(self): content = '__label__sauce __label__cheese Text' - dataset = FastTextDataset(filenames=[]) - self.assert_record(content, dataset, label=['sauce', 'cheese']) + dataset = FastTextDataset(filenames=[], label_class=CategoryLabel) + label = [CategoryLabel(label='sauce'), CategoryLabel(label='cheese')] + self.assert_record(content, dataset, label=label)