|
|
@ -4,6 +4,7 @@ import tempfile |
|
|
|
import unittest |
|
|
|
|
|
|
|
from ...views.upload.dataset import CsvDataset |
|
|
|
from ...views.upload.label import CategoryLabel |
|
|
|
|
|
|
|
|
|
|
|
class TestCsvDataset(unittest.TestCase): |
|
|
@ -21,7 +22,7 @@ class TestCsvDataset(unittest.TestCase): |
|
|
|
|
|
|
|
def assert_record(self, content, dataset, data='Text', label=None): |
|
|
|
if label is None: |
|
|
|
label = ['Label'] |
|
|
|
label = [CategoryLabel(label='Label')] |
|
|
|
self.create_file(content) |
|
|
|
record = next(dataset.load(self.test_file)) |
|
|
|
self.assertEqual(record.data, data) |
|
|
@ -29,25 +30,25 @@ class TestCsvDataset(unittest.TestCase): |
|
|
|
|
|
|
|
def test_can_load_default_column_names(self): |
|
|
|
content = 'label,text\nLabel,Text' |
|
|
|
dataset = CsvDataset(filenames=[]) |
|
|
|
dataset = CsvDataset(filenames=[], label_class=CategoryLabel) |
|
|
|
self.assert_record(content, dataset) |
|
|
|
|
|
|
|
def test_can_change_delimiter(self): |
|
|
|
content = 'label\ttext\nLabel\tText' |
|
|
|
dataset = CsvDataset(filenames=[], delimiter='\t') |
|
|
|
dataset = CsvDataset(filenames=[], label_class=CategoryLabel, delimiter='\t') |
|
|
|
self.assert_record(content, dataset) |
|
|
|
|
|
|
|
def test_can_specify_column_name(self): |
|
|
|
content = 'star,body\nLabel,Text' |
|
|
|
dataset = CsvDataset(filenames=[], column_data='body', column_label='star') |
|
|
|
dataset = CsvDataset(filenames=[], label_class=CategoryLabel, column_data='body', column_label='star') |
|
|
|
self.assert_record(content, dataset) |
|
|
|
|
|
|
|
def test_can_load_only_text_column(self): |
|
|
|
content = 'star,text\nLabel,Text' |
|
|
|
dataset = CsvDataset(filenames=[]) |
|
|
|
dataset = CsvDataset(filenames=[], label_class=CategoryLabel) |
|
|
|
self.assert_record(content, dataset, label=[]) |
|
|
|
|
|
|
|
def test_does_not_match_column_and_row(self): |
|
|
|
content = 'text,label\nText' |
|
|
|
dataset = CsvDataset(filenames=[]) |
|
|
|
dataset = CsvDataset(filenames=[], label_class=CategoryLabel) |
|
|
|
self.assert_record(content, dataset, label=[]) |