Browse Source

Refactor test_tasks.py

pull/1544/head
Hironsan 3 years ago
parent
commit
f01c5e4c4f
1 changed files with 13 additions and 23 deletions
  1. 36
      backend/api/tests/test_tasks.py

36
backend/api/tests/test_tasks.py

@ -7,10 +7,12 @@ from ..tasks import injest_data
from .api.utils import prepare_project
class TestIngestClassificationData(TestCase):
class TestIngestData(TestCase):
task = 'Any'
annotation_class = Category
def setUp(self):
self.project = prepare_project(task='DocumentClassification')
self.project = prepare_project(self.task)
self.user = self.project.users[0]
self.data_path = pathlib.Path(__file__).parent / 'data'
@ -26,7 +28,12 @@ class TestIngestClassificationData(TestCase):
injest_data(self.user.id, self.project.item.id, filenames, file_format, **kwargs)
self.assertEqual(Example.objects.count(), expected_example)
self.assertEqual(Label.objects.count(), expected_label)
self.assertEqual(Category.objects.count(), expected_annotation)
self.assertEqual(self.annotation_class.objects.count(), expected_annotation)
class TestIngestClassificationData(TestIngestData):
task = 'DocumentClassification'
annotation_class = Category
def test_jsonl(self):
filename = 'text_classification/example.jsonl'
@ -65,26 +72,9 @@ class TestIngestClassificationData(TestCase):
self.assert_count(filename, file_format, expected_example=3, expected_label=0, expected_annotation=0)
class TestIngestSequenceLabelingData(TestCase):
def setUp(self):
self.project = prepare_project(task='SequenceLabeling')
self.user = self.project.users[0]
self.data_path = pathlib.Path(__file__).parent / 'data'
def assert_count(self,
filename,
file_format,
kwargs=None,
expected_example=0,
expected_label=0,
expected_annotation=0):
filenames = [str(self.data_path / filename)]
kwargs = kwargs or {}
injest_data(self.user.id, self.project.item.id, filenames, file_format, **kwargs)
self.assertEqual(Example.objects.count(), expected_example)
self.assertEqual(Label.objects.count(), expected_label)
self.assertEqual(Span.objects.count(), expected_annotation)
class TestIngestSequenceLabelingData(TestIngestData):
task = 'SequenceLabeling'
annotation_class = Span
def test_jsonl(self):
filename = 'sequence_labeling/example.jsonl'

Loading…
Cancel
Save