diff --git a/backend/data_import/pipeline/parsers.py b/backend/data_import/pipeline/parsers.py index dd73986b..151d3ca0 100644 --- a/backend/data_import/pipeline/parsers.py +++ b/backend/data_import/pipeline/parsers.py @@ -11,7 +11,12 @@ from chardet import UniversalDetector from seqeval.scheme import BILOU, IOB2, IOBES, IOE2, Tokens from .exceptions import FileParseException -from .readers import DEFAULT_LABEL_COLUMN, DEFAULT_TEXT_COLUMN, Parser +from .readers import ( + DEFAULT_LABEL_COLUMN, + DEFAULT_TEXT_COLUMN, + LINE_NUMBER_COLUMN, + Parser, +) DEFAULT_ENCODING = "Auto" @@ -115,8 +120,8 @@ class LineParser(Parser): def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: reader = LineReader(filename, self.encoding) - for line in reader: - yield {DEFAULT_TEXT_COLUMN: line} + for line_num, line in enumerate(reader, start=1): + yield {DEFAULT_TEXT_COLUMN: line, LINE_NUMBER_COLUMN: line_num} class TextFileParser(Parser): @@ -151,8 +156,8 @@ class CSVParser(Parser): encoding = decide_encoding(filename, self.encoding) with open(filename, encoding=encoding) as f: reader = csv.DictReader(f, delimiter=self.delimiter) - for row in reader: - yield row + for line_num, row in enumerate(reader, start=1): + yield {LINE_NUMBER_COLUMN: line_num, **row} class JSONParser(Parser): @@ -197,7 +202,8 @@ class JSONLParser(Parser): reader = LineReader(filename, self.encoding) for line_num, line in enumerate(reader, start=1): try: - yield json.loads(line) + row = json.loads(line) + yield {LINE_NUMBER_COLUMN: line_num, **row} except json.decoder.JSONDecodeError as e: error = FileParseException(filename, line_num, str(e)) self._errors.append(error) @@ -217,7 +223,7 @@ class ExcelParser(Parser): rows = pyexcel.iget_records(file_name=filename) try: for line_num, row in enumerate(rows, start=1): - yield row + yield {LINE_NUMBER_COLUMN: line_num, **row} except pyexcel.exceptions.FileTypeNotSupported as e: error = FileParseException(filename, line_num=1, message=str(e)) self._errors.append(error) @@ -246,7 +252,7 @@ class FastTextParser(Parser): def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: reader = LineReader(filename, self.encoding) - for line in reader: + for line_num, line in enumerate(reader, start=1): labels = [] tokens = [] for token in line.rstrip().split(" "): @@ -256,7 +262,7 @@ class FastTextParser(Parser): else: tokens.append(token) text = " ".join(tokens) - yield {DEFAULT_TEXT_COLUMN: text, DEFAULT_LABEL_COLUMN: labels} + yield {DEFAULT_TEXT_COLUMN: text, DEFAULT_LABEL_COLUMN: labels, LINE_NUMBER_COLUMN: line_num} class CoNLLParser(Parser): diff --git a/backend/data_import/tests/test_parser.py b/backend/data_import/tests/test_parser.py index 768908cb..9d762d87 100644 --- a/backend/data_import/tests/test_parser.py +++ b/backend/data_import/tests/test_parser.py @@ -5,6 +5,7 @@ import tempfile import unittest from data_import.pipeline import parsers +from data_import.pipeline.readers import LINE_NUMBER_COLUMN class TestParser(unittest.TestCase): @@ -24,6 +25,7 @@ class TestParser(unittest.TestCase): it = parser.parse(self.test_file) for expect in expected: row = next(it) + row.pop(LINE_NUMBER_COLUMN, None) self.assertEqual(row, expect) with self.assertRaises(StopIteration): next(it)