You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
3.6 KiB

  1. import unittest
  2. from typing import List
  3. from data_import.pipeline import builders
  4. from data_import.pipeline.data import TextData
  5. from data_import.pipeline.exceptions import FileParseException
  6. from data_import.pipeline.labels import CategoryLabel, SpanLabel
  7. class TestColumnBuilder(unittest.TestCase):
  8. def assert_record(self, actual, expected):
  9. self.assertEqual(actual.data.text, expected['data'])
  10. self.assertEqual(actual.label, expected['label'])
  11. def create_record(self, row, data_column: builders.DataColumn, label_columns: List[builders.LabelColumn]):
  12. builder = builders.ColumnBuilder(
  13. data_column=data_column,
  14. label_columns=label_columns
  15. )
  16. return builder.build(row, filename='', line_num=1)
  17. def test_can_load_default_column_names(self):
  18. row = {'text': 'Text', 'label': 'Label'}
  19. data_column = builders.DataColumn('text', TextData)
  20. label_columns = [builders.LabelColumn('label', CategoryLabel)]
  21. actual = self.create_record(row, data_column, label_columns)
  22. expected = {'data': 'Text', 'label': [{'label': 'Label'}]}
  23. self.assert_record(actual, expected)
  24. def test_can_specify_any_column_names(self):
  25. row = {'body': 'Text', 'star': 5}
  26. data_column = builders.DataColumn('body', TextData)
  27. label_columns = [builders.LabelColumn('star', CategoryLabel)]
  28. actual = self.create_record(row, data_column, label_columns)
  29. expected = {'data': 'Text', 'label': [{'label': '5'}]}
  30. self.assert_record(actual, expected)
  31. def test_can_load_only_text_column(self):
  32. row = {'text': 'Text', 'label': None}
  33. data_column = builders.DataColumn('text', TextData)
  34. label_columns = [builders.LabelColumn('label', CategoryLabel)]
  35. actual = self.create_record(row, data_column, label_columns)
  36. expected = {'data': 'Text', 'label': []}
  37. self.assert_record(actual, expected)
  38. def test_denies_no_data_column(self):
  39. row = {'label': 'Label'}
  40. data_column = builders.DataColumn('text', TextData)
  41. label_columns = [builders.LabelColumn('label', CategoryLabel)]
  42. with self.assertRaises(FileParseException):
  43. self.create_record(row, data_column, label_columns)
  44. def test_denies_empty_text(self):
  45. row = {'text': '', 'label': 'Label'}
  46. data_column = builders.DataColumn('text', TextData)
  47. label_columns = [builders.LabelColumn('label', CategoryLabel)]
  48. with self.assertRaises(FileParseException):
  49. self.create_record(row, data_column, label_columns)
  50. def test_can_load_int_as_text(self):
  51. row = {'text': 5, 'label': 'Label'}
  52. data_column = builders.DataColumn('text', TextData)
  53. label_columns = [builders.LabelColumn('label', CategoryLabel)]
  54. actual = self.create_record(row, data_column, label_columns)
  55. expected = {'data': '5', 'label': [{'label': 'Label'}]}
  56. self.assert_record(actual, expected)
  57. def test_can_build_multiple_labels(self):
  58. row = {'text': 'Text', 'cats': ['Label'], 'entities': [(0, 1, 'LOC')]}
  59. data_column = builders.DataColumn('text', TextData)
  60. label_columns = [
  61. builders.LabelColumn('cats', CategoryLabel),
  62. builders.LabelColumn('entities', SpanLabel)
  63. ]
  64. actual = self.create_record(row, data_column, label_columns)
  65. expected = {
  66. 'data': 'Text',
  67. 'label': [
  68. {'label': 'Label'},
  69. {'label': 'LOC', 'start_offset': 0, 'end_offset': 1}
  70. ]
  71. }
  72. self.assert_record(actual, expected)