|
|
@ -6,9 +6,10 @@ from auto_labeling_pipeline.models import RequestModelFactory |
|
|
|
from rest_framework import status |
|
|
|
from rest_framework.reverse import reverse |
|
|
|
|
|
|
|
from ...models import DOCUMENT_CLASSIFICATION |
|
|
|
from ...models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION |
|
|
|
from ...views.auto_labeling import load_data_as_b64 |
|
|
|
from .utils import CRUDMixin, prepare_project |
|
|
|
from .utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image, |
|
|
|
prepare_project) |
|
|
|
|
|
|
|
data_dir = pathlib.Path(__file__).parent / 'data' |
|
|
|
|
|
|
@ -108,3 +109,35 @@ class TestConfigCreation(CRUDMixin): |
|
|
|
|
|
|
|
def test_create_config(self): |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
|
|
|
|
|
|
|
|
class TestAutoLabelingText(CRUDMixin): |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) |
|
|
|
make_auto_labeling_config(self.project.item) |
|
|
|
self.example = make_doc(self.project.item) |
|
|
|
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
@patch('api.views.auto_labeling.execute_pipeline', return_value=[]) |
|
|
|
def test_text_task(self, mock): |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
_, kwargs = mock.call_args |
|
|
|
self.assertEqual(kwargs['text'], self.example.text) |
|
|
|
|
|
|
|
|
|
|
|
class TestAutoLabelingImage(CRUDMixin): |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
self.project = prepare_project(task=IMAGE_CLASSIFICATION) |
|
|
|
make_auto_labeling_config(self.project.item) |
|
|
|
filepath = data_dir / 'images/1500x500.jpeg' |
|
|
|
self.example = make_image(self.project.item, str(filepath)) |
|
|
|
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id]) |
|
|
|
|
|
|
|
@patch('api.views.auto_labeling.execute_pipeline', return_value=[]) |
|
|
|
def test_text_task(self, mock): |
|
|
|
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) |
|
|
|
_, kwargs = mock.call_args |
|
|
|
expected = load_data_as_b64(str(self.example.filename)) |
|
|
|
self.assertEqual(kwargs['text'], expected) |