36 lines
1.2 KiB

import os
import shutil
import tempfile
import unittest
from ...views.upload.data import TextData
from ...views.upload.dataset import FastTextDataset
from ...views.upload.label import CategoryLabel
class TestFastTextDataset(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.test_file = os.path.join(self.test_dir, 'test_file.txt')
def tearDown(self):
shutil.rmtree(self.test_dir)
def create_file(self, content):
with open(self.test_file, 'w') as f:
f.write(content)
def assert_record(self, content, dataset, data='Text', label=None):
if label is None:
label = [{'text': 'Label'}]
self.create_file(content)
record = next(dataset.load(self.test_file))
self.assertEqual(record.data['text'], data)
self.assertEqual(record.label, label)
def test_can_load_default_column_names(self):
content = '__label__sauce __label__cheese Text'
dataset = FastTextDataset(filenames=[], label_class=CategoryLabel, data_class=TextData)
label = [{'text': 'sauce'}, {'text': 'cheese'}]
self.assert_record(content, dataset, label=label)