diff --git a/backend/api/tests/test_tasks.py b/backend/api/tests/test_tasks.py index a0d460c0..40011bd3 100644 --- a/backend/api/tests/test_tasks.py +++ b/backend/api/tests/test_tasks.py @@ -7,10 +7,12 @@ from ..tasks import injest_data from .api.utils import prepare_project -class TestIngestClassificationData(TestCase): +class TestIngestData(TestCase): + task = 'Any' + annotation_class = Category def setUp(self): - self.project = prepare_project(task='DocumentClassification') + self.project = prepare_project(self.task) self.user = self.project.users[0] self.data_path = pathlib.Path(__file__).parent / 'data' @@ -26,7 +28,12 @@ class TestIngestClassificationData(TestCase): injest_data(self.user.id, self.project.item.id, filenames, file_format, **kwargs) self.assertEqual(Example.objects.count(), expected_example) self.assertEqual(Label.objects.count(), expected_label) - self.assertEqual(Category.objects.count(), expected_annotation) + self.assertEqual(self.annotation_class.objects.count(), expected_annotation) + + +class TestIngestClassificationData(TestIngestData): + task = 'DocumentClassification' + annotation_class = Category def test_jsonl(self): filename = 'text_classification/example.jsonl' @@ -65,26 +72,9 @@ class TestIngestClassificationData(TestCase): self.assert_count(filename, file_format, expected_example=3, expected_label=0, expected_annotation=0) -class TestIngestSequenceLabelingData(TestCase): - - def setUp(self): - self.project = prepare_project(task='SequenceLabeling') - self.user = self.project.users[0] - self.data_path = pathlib.Path(__file__).parent / 'data' - - def assert_count(self, - filename, - file_format, - kwargs=None, - expected_example=0, - expected_label=0, - expected_annotation=0): - filenames = [str(self.data_path / filename)] - kwargs = kwargs or {} - injest_data(self.user.id, self.project.item.id, filenames, file_format, **kwargs) - self.assertEqual(Example.objects.count(), expected_example) - self.assertEqual(Label.objects.count(), expected_label) - self.assertEqual(Span.objects.count(), expected_annotation) +class TestIngestSequenceLabelingData(TestIngestData): + task = 'SequenceLabeling' + annotation_class = Span def test_jsonl(self): filename = 'sequence_labeling/example.jsonl'