You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
3.5 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. import unittest
  2. from typing import List, Optional
  3. from data_import.pipeline import builders
  4. from data_import.pipeline.data import TextData
  5. from data_import.pipeline.exceptions import FileParseException
  6. from data_import.pipeline.labels import CategoryLabel, SpanLabel
  7. from data_import.pipeline.readers import FileName
  8. class TestColumnBuilder(unittest.TestCase):
  9. def assert_record(self, actual, expected):
  10. self.assertEqual(actual.data.text, expected["data"])
  11. self.assertEqual(actual.label, expected["label"])
  12. def create_record(self, row, data_column: builders.DataColumn, label_columns: Optional[List[builders.Column]]):
  13. builder = builders.ColumnBuilder(data_column=data_column, label_columns=label_columns)
  14. filename = FileName("", "", "")
  15. return builder.build(row, filename=filename, line_num=1)
  16. def test_can_load_default_column_names(self):
  17. row = {"text": "Text", "label": "Label"}
  18. data_column = builders.DataColumn("text", TextData)
  19. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  20. actual = self.create_record(row, data_column, label_columns)
  21. expected = {"data": "Text", "label": [{"label": "Label"}]}
  22. self.assert_record(actual, expected)
  23. def test_can_specify_any_column_names(self):
  24. row = {"body": "Text", "star": 5}
  25. data_column = builders.DataColumn("body", TextData)
  26. label_columns = [builders.LabelColumn("star", CategoryLabel)]
  27. actual = self.create_record(row, data_column, label_columns)
  28. expected = {"data": "Text", "label": [{"label": "5"}]}
  29. self.assert_record(actual, expected)
  30. def test_can_load_only_text_column(self):
  31. row = {"text": "Text", "label": None}
  32. data_column = builders.DataColumn("text", TextData)
  33. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  34. actual = self.create_record(row, data_column, label_columns)
  35. expected = {"data": "Text", "label": []}
  36. self.assert_record(actual, expected)
  37. def test_denies_no_data_column(self):
  38. row = {"label": "Label"}
  39. data_column = builders.DataColumn("text", TextData)
  40. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  41. with self.assertRaises(FileParseException):
  42. self.create_record(row, data_column, label_columns)
  43. def test_denies_empty_text(self):
  44. row = {"text": "", "label": "Label"}
  45. data_column = builders.DataColumn("text", TextData)
  46. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  47. with self.assertRaises(FileParseException):
  48. self.create_record(row, data_column, label_columns)
  49. def test_can_load_int_as_text(self):
  50. row = {"text": 5, "label": "Label"}
  51. data_column = builders.DataColumn("text", TextData)
  52. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  53. actual = self.create_record(row, data_column, label_columns)
  54. expected = {"data": "5", "label": [{"label": "Label"}]}
  55. self.assert_record(actual, expected)
  56. def test_can_build_multiple_labels(self):
  57. row = {"text": "Text", "cats": ["Label"], "entities": [(0, 1, "LOC")]}
  58. data_column = builders.DataColumn("text", TextData)
  59. label_columns = [builders.LabelColumn("cats", CategoryLabel), builders.LabelColumn("entities", SpanLabel)]
  60. actual = self.create_record(row, data_column, label_columns)
  61. expected = {"data": "Text", "label": [{"label": "Label"}, {"label": "LOC", "start_offset": 0, "end_offset": 1}]}
  62. self.assert_record(actual, expected)