You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

88 lines
3.2 KiB

  1. import uuid
  2. import pandas as pd
  3. from django.test import TestCase
  4. from data_import.pipeline.data import TextData
  5. from data_import.pipeline.label import CategoryLabel
  6. from data_import.pipeline.makers import ExampleMaker, LabelMaker
  7. from data_import.pipeline.readers import (
  8. FILE_NAME_COLUMN,
  9. LINE_NUMBER_COLUMN,
  10. UPLOAD_NAME_COLUMN,
  11. UUID_COLUMN,
  12. )
  13. from projects.tests.utils import prepare_project
  14. class TestExamplesMaker(TestCase):
  15. def setUp(self):
  16. self.project = prepare_project()
  17. self.label_column = "label"
  18. self.text_column = "text"
  19. self.record = {
  20. LINE_NUMBER_COLUMN: 1,
  21. UUID_COLUMN: uuid.uuid4(),
  22. FILE_NAME_COLUMN: "file1",
  23. UPLOAD_NAME_COLUMN: "upload1",
  24. self.text_column: "text1",
  25. self.label_column: ["A"],
  26. }
  27. self.maker = ExampleMaker(self.project.item, TextData, self.text_column, [self.label_column])
  28. def test_make_examples(self):
  29. df = pd.DataFrame([self.record])
  30. examples = self.maker.make(df)
  31. self.assertEqual(len(examples), 1)
  32. def test_check_column_existence(self):
  33. self.record.pop(self.text_column)
  34. df = pd.DataFrame([self.record])
  35. examples = self.maker.make(df)
  36. self.assertEqual(len(examples), 0)
  37. self.assertEqual(len(self.maker.errors), 1)
  38. def test_empty_text_raises_error(self):
  39. self.record[self.text_column] = ""
  40. df = pd.DataFrame([self.record])
  41. examples = self.maker.make(df)
  42. self.assertEqual(len(examples), 0)
  43. self.assertEqual(len(self.maker.errors), 1)
  44. class TestLabelFormatter(TestCase):
  45. def setUp(self):
  46. self.label_column = "label"
  47. self.label_class = CategoryLabel
  48. self.df = pd.DataFrame(
  49. [
  50. {LINE_NUMBER_COLUMN: 1, UUID_COLUMN: uuid.uuid4(), self.label_column: ["A"]},
  51. {LINE_NUMBER_COLUMN: 2, UUID_COLUMN: uuid.uuid4(), self.label_column: ["B", "C"]},
  52. ]
  53. )
  54. def test_make(self):
  55. label_maker = LabelMaker(column=self.label_column, label_class=self.label_class)
  56. labels = label_maker.make(self.df)
  57. self.assertEqual(len(labels), 3)
  58. with self.subTest():
  59. for label, expected in zip(labels, ["A", "B", "C"]):
  60. self.assertEqual(getattr(label, "label"), expected)
  61. def test_format_without_specified_column(self):
  62. label_maker = LabelMaker(column="invalid_column", label_class=self.label_class)
  63. with self.assertRaises(KeyError):
  64. label_maker.make(self.df)
  65. def test_format_with_partially_correct_column(self):
  66. label_maker = LabelMaker(column=self.label_column, label_class=self.label_class)
  67. df = pd.DataFrame(
  68. [
  69. {LINE_NUMBER_COLUMN: 1, UUID_COLUMN: uuid.uuid4(), self.label_column: ["A"]},
  70. {LINE_NUMBER_COLUMN: 2, UUID_COLUMN: uuid.uuid4(), "invalid_column": ["B"]},
  71. {LINE_NUMBER_COLUMN: 3, UUID_COLUMN: uuid.uuid4()},
  72. {LINE_NUMBER_COLUMN: 3, UUID_COLUMN: uuid.uuid4(), self.label_column: [{}]},
  73. ]
  74. )
  75. labels = label_maker.make(df)
  76. self.assertEqual(len(labels), 1)