You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

84 lines
3.6 KiB

import unittest
from typing import List
from data_import.pipeline import builders
from data_import.pipeline.data import TextData
from data_import.pipeline.exceptions import FileParseException
from data_import.pipeline.labels import CategoryLabel, SpanLabel
class TestColumnBuilder(unittest.TestCase):
def assert_record(self, actual, expected):
self.assertEqual(actual.data.text, expected['data'])
self.assertEqual(actual.label, expected['label'])
def create_record(self, row, data_column: builders.DataColumn, label_columns: List[builders.LabelColumn]):
builder = builders.ColumnBuilder(
data_column=data_column,
label_columns=label_columns
)
return builder.build(row, filename='', line_num=1)
def test_can_load_default_column_names(self):
row = {'text': 'Text', 'label': 'Label'}
data_column = builders.DataColumn('text', TextData)
label_columns = [builders.LabelColumn('label', CategoryLabel)]
actual = self.create_record(row, data_column, label_columns)
expected = {'data': 'Text', 'label': [{'label': 'Label'}]}
self.assert_record(actual, expected)
def test_can_specify_any_column_names(self):
row = {'body': 'Text', 'star': 5}
data_column = builders.DataColumn('body', TextData)
label_columns = [builders.LabelColumn('star', CategoryLabel)]
actual = self.create_record(row, data_column, label_columns)
expected = {'data': 'Text', 'label': [{'label': '5'}]}
self.assert_record(actual, expected)
def test_can_load_only_text_column(self):
row = {'text': 'Text', 'label': None}
data_column = builders.DataColumn('text', TextData)
label_columns = [builders.LabelColumn('label', CategoryLabel)]
actual = self.create_record(row, data_column, label_columns)
expected = {'data': 'Text', 'label': []}
self.assert_record(actual, expected)
def test_disallow_no_data_column(self):
row = {'label': 'Label'}
data_column = builders.DataColumn('text', TextData)
label_columns = [builders.LabelColumn('label', CategoryLabel)]
with self.assertRaises(FileParseException):
self.create_record(row, data_column, label_columns)
def test_disallow_empty_text(self):
row = {'text': '', 'label': 'Label'}
data_column = builders.DataColumn('text', TextData)
label_columns = [builders.LabelColumn('label', CategoryLabel)]
with self.assertRaises(FileParseException):
self.create_record(row, data_column, label_columns)
def test_can_load_int_as_text(self):
row = {'text': 5, 'label': 'Label'}
data_column = builders.DataColumn('text', TextData)
label_columns = [builders.LabelColumn('label', CategoryLabel)]
actual = self.create_record(row, data_column, label_columns)
expected = {'data': '5', 'label': [{'label': 'Label'}]}
self.assert_record(actual, expected)
def test_can_build_multiple_labels(self):
row = {'text': 'Text', 'cats': ['Label'], 'entities': [(0, 1, 'LOC')]}
data_column = builders.DataColumn('text', TextData)
label_columns = [
builders.LabelColumn('cats', CategoryLabel),
builders.LabelColumn('entities', SpanLabel)
]
actual = self.create_record(row, data_column, label_columns)
expected = {
'data': 'Text',
'label': [
{'label': 'Label'},
{'label': 'LOC', 'start_offset': 0, 'end_offset': 1}
]
}
self.assert_record(actual, expected)