diff --git a/backend/api/tests/test_tasks.py b/backend/api/tests/test_tasks.py index 66852436..851019a3 100644 --- a/backend/api/tests/test_tasks.py +++ b/backend/api/tests/test_tasks.py @@ -32,6 +32,12 @@ class TestIngestClassificationData(TestIngestData): labels = set(cat.label.text for cat in example.categories.all()) self.assertEqual(labels, set(expected_labels)) + def assert_parse_error(self, response): + self.assertGreaterEqual(len(response['error']), 1) + self.assertEqual(Example.objects.count(), 0) + self.assertEqual(Label.objects.count(), 0) + self.assertEqual(Category.objects.count(), 0) + def test_jsonl(self): filename = 'text_classification/example.jsonl' file_format = 'JSONL' @@ -116,14 +122,17 @@ class TestIngestClassificationData(TestIngestData): self.ingest_data(filename, file_format) self.assert_examples(dataset) - def test_wong_jsonl(self): + def test_wrong_jsonl(self): filename = 'text_classification/example.json' file_format = 'JSONL' response = self.ingest_data(filename, file_format) - self.assertGreaterEqual(len(response['error']), 1) - self.assertEqual(Example.objects.count(), 0) - self.assertEqual(Label.objects.count(), 0) - self.assertEqual(Category.objects.count(), 0) + self.assert_parse_error(response) + + def test_wrong_json(self): + filename = 'text_classification/example.jsonl' + file_format = 'JSON' + response = self.ingest_data(filename, file_format) + self.assert_parse_error(response) class TestIngestSequenceLabelingData(TestIngestData): diff --git a/backend/api/views/upload/dataset.py b/backend/api/views/upload/dataset.py index 09283352..e5bb89af 100644 --- a/backend/api/views/upload/dataset.py +++ b/backend/api/views/upload/dataset.py @@ -172,9 +172,13 @@ class JSONDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: encoding = self.detect_encoding(filename) with open(filename, encoding=encoding) as f: - dataset = json.load(f) - for line_num, row in enumerate(dataset, start=1): - yield self.from_row(filename, row, line_num) + try: + dataset = json.load(f) + for line_num, row in enumerate(dataset, start=1): + yield self.from_row(filename, row, line_num) + except json.decoder.JSONDecodeError: + message = 'Failed to decode the json file.' + raise FileParseException(filename, line_num=-1, message=message) class JSONLDataset(Dataset): @@ -187,7 +191,7 @@ class JSONLDataset(Dataset): row = json.loads(line) yield self.from_row(filename, row, line_num) except json.decoder.JSONDecodeError: - message = 'Failed to encode the line.' + message = 'Failed to decode the line.' raise FileParseException(filename, line_num, message)