You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
3.5 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import unittest
  2. from typing import List, Optional
  3. from data_import.pipeline import builders
  4. from data_import.pipeline.data import TextData
  5. from data_import.pipeline.exceptions import FileParseException
  6. from data_import.pipeline.labels import CategoryLabel, SpanLabel
  7. class TestColumnBuilder(unittest.TestCase):
  8. def assert_record(self, actual, expected):
  9. self.assertEqual(actual.data.text, expected["data"])
  10. self.assertEqual(actual.label, expected["label"])
  11. def create_record(self, row, data_column: builders.DataColumn, label_columns: Optional[List[builders.Column]]):
  12. builder = builders.ColumnBuilder(data_column=data_column, label_columns=label_columns)
  13. return builder.build(row, filename="", line_num=1)
  14. def test_can_load_default_column_names(self):
  15. row = {"text": "Text", "label": "Label"}
  16. data_column = builders.DataColumn("text", TextData)
  17. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  18. actual = self.create_record(row, data_column, label_columns)
  19. expected = {"data": "Text", "label": [{"label": "Label"}]}
  20. self.assert_record(actual, expected)
  21. def test_can_specify_any_column_names(self):
  22. row = {"body": "Text", "star": 5}
  23. data_column = builders.DataColumn("body", TextData)
  24. label_columns = [builders.LabelColumn("star", CategoryLabel)]
  25. actual = self.create_record(row, data_column, label_columns)
  26. expected = {"data": "Text", "label": [{"label": "5"}]}
  27. self.assert_record(actual, expected)
  28. def test_can_load_only_text_column(self):
  29. row = {"text": "Text", "label": None}
  30. data_column = builders.DataColumn("text", TextData)
  31. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  32. actual = self.create_record(row, data_column, label_columns)
  33. expected = {"data": "Text", "label": []}
  34. self.assert_record(actual, expected)
  35. def test_denies_no_data_column(self):
  36. row = {"label": "Label"}
  37. data_column = builders.DataColumn("text", TextData)
  38. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  39. with self.assertRaises(FileParseException):
  40. self.create_record(row, data_column, label_columns)
  41. def test_denies_empty_text(self):
  42. row = {"text": "", "label": "Label"}
  43. data_column = builders.DataColumn("text", TextData)
  44. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  45. with self.assertRaises(FileParseException):
  46. self.create_record(row, data_column, label_columns)
  47. def test_can_load_int_as_text(self):
  48. row = {"text": 5, "label": "Label"}
  49. data_column = builders.DataColumn("text", TextData)
  50. label_columns = [builders.LabelColumn("label", CategoryLabel)]
  51. actual = self.create_record(row, data_column, label_columns)
  52. expected = {"data": "5", "label": [{"label": "Label"}]}
  53. self.assert_record(actual, expected)
  54. def test_can_build_multiple_labels(self):
  55. row = {"text": "Text", "cats": ["Label"], "entities": [(0, 1, "LOC")]}
  56. data_column = builders.DataColumn("text", TextData)
  57. label_columns = [builders.LabelColumn("cats", CategoryLabel), builders.LabelColumn("entities", SpanLabel)]
  58. actual = self.create_record(row, data_column, label_columns)
  59. expected = {"data": "Text", "label": [{"label": "Label"}, {"label": "LOC", "start_offset": 0, "end_offset": 1}]}
  60. self.assert_record(actual, expected)