You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
3.8 KiB

  1. import json
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from data_import.pipeline import parsers
  7. class TestParser(unittest.TestCase):
  8. def setUp(self):
  9. self.test_dir = tempfile.mkdtemp()
  10. self.test_file = os.path.join(self.test_dir, 'test_file.csv')
  11. def tearDown(self):
  12. shutil.rmtree(self.test_dir)
  13. def create_file(self, content):
  14. with open(self.test_file, 'w') as f:
  15. f.write(content)
  16. def assert_record(self, content, parser, expected):
  17. self.create_file(content)
  18. it = parser.parse(self.test_file)
  19. for expect in expected:
  20. row = next(it)
  21. self.assertEqual(row, expect)
  22. with self.assertRaises(StopIteration):
  23. next(it)
  24. class TestPlainParser(TestParser):
  25. def test_read(self):
  26. content = 'example'
  27. parser = parsers.PlainParser()
  28. expected = [{}]
  29. self.assert_record(content, parser, expected)
  30. class TestLineParser(TestParser):
  31. def test_read(self):
  32. content = 'Hello, World!\nこんにちは'
  33. parser = parsers.LineParser()
  34. expected = [{'text': 'Hello, World!'}, {'text': 'こんにちは'}]
  35. self.assert_record(content, parser, expected)
  36. class TestTextFileParser(TestParser):
  37. def test_read(self):
  38. content = 'Hello, World!\nこんにちは'
  39. parser = parsers.TextFileParser()
  40. expected = [{'text': content}]
  41. self.assert_record(content, parser, expected)
  42. class TestCsvParser(TestParser):
  43. def test_read(self):
  44. content = 'label,text\nLabel,Text'
  45. parser = parsers.CSVParser(delimiter=',')
  46. expected = [{'label': 'Label', 'text': 'Text'}]
  47. self.assert_record(content, parser, expected)
  48. def test_can_change_delimiter(self):
  49. content = 'label\ttext\nLabel\tText'
  50. parser = parsers.CSVParser(delimiter='\t')
  51. expected = [{'label': 'Label', 'text': 'Text'}]
  52. self.assert_record(content, parser, expected)
  53. def test_can_read_null_value(self):
  54. content = 'text,label\nText'
  55. parser = parsers.CSVParser(delimiter=',')
  56. expected = [{'text': 'Text', 'label': None}]
  57. self.assert_record(content, parser, expected)
  58. class TestJSONParser(TestParser):
  59. def test_read(self):
  60. content = json.dumps([
  61. {'text': 'line1', 'labels': 'Label1'},
  62. {'text': 'line2', 'labels': 'Label2'}
  63. ])
  64. parser = parsers.JSONParser()
  65. expected = json.loads(content)
  66. self.assert_record(content, parser, expected)
  67. class TestJSONLParser(TestParser):
  68. def test_read(self):
  69. line1 = json.dumps({'text': 'line1', 'labels': 'Label1'})
  70. line2 = json.dumps({'text': 'line2', 'labels': 'Label2'})
  71. content = f"{line1}\n{line2}"
  72. parser = parsers.JSONLParser()
  73. expected = [json.loads(line1), json.loads(line2)]
  74. self.assert_record(content, parser, expected)
  75. class TestFastTextParser(TestParser):
  76. def test_read(self):
  77. content = '__label__sauce __label__cheese Text'
  78. parser = parsers.FastTextParser()
  79. expected = [{'text': 'Text', 'label': ['sauce', 'cheese']}]
  80. self.assert_record(content, parser, expected)
  81. class TestCoNLLParser(TestParser):
  82. def test_can_read(self):
  83. content = """EU\tB-ORG
  84. rejects\tO
  85. German\tB-MISC
  86. call\tO
  87. to\tO
  88. boycott\tO
  89. British\tB-MISC
  90. lamb\tO
  91. .\tO
  92. Peter\tB-PER
  93. Blackburn\tI-PER
  94. """
  95. parser = parsers.CoNLLParser()
  96. expected = [
  97. {
  98. 'text': 'EU rejects German call to boycott British lamb .',
  99. 'label': [(0, 2, 'ORG'), (11, 17, 'MISC'), (34, 41, 'MISC')]
  100. },
  101. {
  102. 'text': 'Peter Blackburn',
  103. 'label': [(0, 15, 'PER')]
  104. }
  105. ]
  106. self.assert_record(content, parser, expected)