mirror of https://github.com/doccano/doccano.git
pythonannotation-tooldatasetsactive-learningtext-annotationdatasetnatural-language-processingdata-labelingmachine-learning
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
11 KiB
250 lines
11 KiB
import pathlib
|
|
from unittest.mock import patch
|
|
|
|
from auto_labeling_pipeline.mappings import AmazonComprehendSentimentTemplate
|
|
from auto_labeling_pipeline.models import RequestModelFactory
|
|
from model_mommy import mommy
|
|
from rest_framework import status
|
|
from rest_framework.reverse import reverse
|
|
|
|
from api.tests.utils import CRUDMixin
|
|
from auto_labeling.pipeline.labels import Categories, Spans, Texts
|
|
from examples.tests.utils import make_doc
|
|
from labels.models import Category, Span, TextLabel
|
|
from projects.models import ProjectType
|
|
from projects.tests.utils import prepare_project
|
|
|
|
data_dir = pathlib.Path(__file__).parent / "data"
|
|
|
|
|
|
class TestTemplateList(CRUDMixin):
|
|
def setUp(self):
|
|
self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
|
|
self.url = reverse(viewname="auto_labeling_templates", args=[self.project.item.id])
|
|
|
|
def test_allow_admin_to_fetch_template_list(self):
|
|
self.url += "?task_name=DocumentClassification"
|
|
response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
|
|
self.assertIn("Custom REST Request", response.data)
|
|
self.assertGreaterEqual(len(response.data), 1)
|
|
|
|
def test_deny_project_staff_to_fetch_template_list(self):
|
|
self.url += "?task_name=DocumentClassification"
|
|
for user in self.project.staffs:
|
|
self.assert_fetch(user, status.HTTP_403_FORBIDDEN)
|
|
|
|
def test_return_only_default_template_with_empty_task_name(self):
|
|
response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
|
|
self.assertEqual(len(response.data), 1)
|
|
self.assertIn("Custom REST Request", response.data)
|
|
|
|
def test_return_only_default_template_with_wrong_task_name(self):
|
|
self.url += "?task_name=foobar"
|
|
response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
|
|
self.assertEqual(len(response.data), 1)
|
|
self.assertIn("Custom REST Request", response.data)
|
|
|
|
|
|
class TestConfigParameter(CRUDMixin):
|
|
def setUp(self):
|
|
self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
|
|
self.data = {
|
|
"model_name": "GCP Entity Analysis",
|
|
"model_attrs": {"key": "hoge", "type": "PLAIN_TEXT", "language": "en"},
|
|
"text": "example",
|
|
}
|
|
self.url = reverse(viewname="auto_labeling_parameter_testing", args=[self.project.item.id])
|
|
|
|
@patch("auto_labeling.views.RestAPIRequestTesting.send_request", return_value={})
|
|
def test_called_with_proper_model(self, mock):
|
|
self.assert_create(self.project.admin, status.HTTP_200_OK)
|
|
_, kwargs = mock.call_args
|
|
expected = RequestModelFactory.create(self.data["model_name"], self.data["model_attrs"])
|
|
self.assertEqual(kwargs["model"], expected)
|
|
|
|
@patch("auto_labeling.views.RestAPIRequestTesting.send_request", return_value={})
|
|
def test_called_with_text(self, mock):
|
|
self.assert_create(self.project.admin, status.HTTP_200_OK)
|
|
_, kwargs = mock.call_args
|
|
self.assertEqual(kwargs["example"], self.data["text"])
|
|
|
|
@patch("auto_labeling.views.RestAPIRequestTesting.send_request", return_value={})
|
|
def test_called_with_image(self, mock):
|
|
self.data["text"] = str(data_dir / "images/1500x500.jpeg")
|
|
self.assert_create(self.project.admin, status.HTTP_200_OK)
|
|
_, kwargs = mock.call_args
|
|
self.assertEqual(kwargs["example"], self.data["text"])
|
|
|
|
|
|
class TestTemplateMapping(CRUDMixin):
|
|
def setUp(self):
|
|
self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
|
|
self.data = {
|
|
"response": {
|
|
"Sentiment": "NEUTRAL",
|
|
"SentimentScore": {
|
|
"Positive": 0.004438233096152544,
|
|
"Negative": 0.0005306027014739811,
|
|
"Neutral": 0.9950305223464966,
|
|
"Mixed": 5.80838445785048e-7,
|
|
},
|
|
},
|
|
"template": AmazonComprehendSentimentTemplate().load(),
|
|
"task_type": "Category",
|
|
}
|
|
self.url = reverse(viewname="auto_labeling_template_test", args=[self.project.item.id])
|
|
|
|
def test_template_mapping(self):
|
|
response = self.assert_create(self.project.admin, status.HTTP_200_OK)
|
|
expected = [{"label": "NEUTRAL"}]
|
|
self.assertEqual(response.json(), expected)
|
|
|
|
def test_json_decode_error(self):
|
|
self.data["template"] = ""
|
|
self.assert_create(self.project.admin, status.HTTP_400_BAD_REQUEST)
|
|
|
|
|
|
class TestLabelMapping(CRUDMixin):
|
|
def setUp(self):
|
|
self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
|
|
self.data = {
|
|
"response": [{"label": "NEGATIVE"}],
|
|
"label_mapping": {"NEGATIVE": "Negative"},
|
|
"task_type": "Category",
|
|
}
|
|
self.url = reverse(viewname="auto_labeling_mapping_test", args=[self.project.item.id])
|
|
|
|
def test_label_mapping(self):
|
|
response = self.assert_create(self.project.admin, status.HTTP_200_OK)
|
|
expected = [{"label": "Negative"}]
|
|
self.assertEqual(response.json(), expected)
|
|
|
|
|
|
class TestConfigCreation(CRUDMixin):
|
|
def setUp(self):
|
|
self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION)
|
|
self.data = {
|
|
"model_name": "Amazon Comprehend Sentiment Analysis",
|
|
"model_attrs": {
|
|
"aws_access_key": "str",
|
|
"aws_secret_access_key": "str",
|
|
"region_name": "us-east-1",
|
|
"language_code": "en",
|
|
},
|
|
"template": AmazonComprehendSentimentTemplate().load(),
|
|
"label_mapping": {"NEGATIVE": "Negative"},
|
|
"task_type": "Category",
|
|
}
|
|
self.url = reverse(viewname="auto_labeling_configs", args=[self.project.item.id])
|
|
|
|
def test_create_config(self):
|
|
response = self.assert_create(self.project.admin, status.HTTP_201_CREATED)
|
|
self.assertEqual(response.data["model_name"], self.data["model_name"])
|
|
|
|
def test_list_config(self):
|
|
mommy.make("AutoLabelingConfig", project=self.project.item)
|
|
response = self.assert_fetch(self.project.admin, status.HTTP_200_OK)
|
|
self.assertEqual(len(response.data), 1)
|
|
|
|
|
|
class TestAutomatedLabeling(CRUDMixin):
|
|
def setUp(self):
|
|
self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION, single_class_classification=False)
|
|
self.example = make_doc(self.project.item)
|
|
self.category_pos = mommy.make("CategoryType", project=self.project.item, text="POS")
|
|
self.category_neg = mommy.make("CategoryType", project=self.project.item, text="NEG")
|
|
self.loc = mommy.make("SpanType", project=self.project.item, text="LOC")
|
|
self.url = reverse(viewname="auto_labeling", args=[self.project.item.id])
|
|
self.url += f"?example={self.example.id}"
|
|
|
|
@patch("auto_labeling.views.execute_pipeline", return_value=Categories([{"label": "POS"}]))
|
|
def test_category_labeling(self, mock):
|
|
mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED)
|
|
self.assertEqual(Category.objects.count(), 1)
|
|
self.assertEqual(Category.objects.first().label, self.category_pos)
|
|
|
|
@patch("auto_labeling.views.execute_pipeline", return_value=Categories([{"label": "NEUTRAL"}]))
|
|
def test_nonexistent_category(self, mock):
|
|
mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED)
|
|
self.assertEqual(Category.objects.count(), 0)
|
|
|
|
@patch(
|
|
"auto_labeling.views.execute_pipeline",
|
|
side_effect=[Categories([{"label": "POS"}]), Categories([{"label": "NEG"}])],
|
|
)
|
|
def test_multiple_configs(self, mock):
|
|
mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
|
|
mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED)
|
|
self.assertEqual(Category.objects.count(), 2)
|
|
self.assertEqual(Category.objects.first().label, self.category_pos)
|
|
self.assertEqual(Category.objects.last().label, self.category_neg)
|
|
|
|
@patch(
|
|
"auto_labeling.views.execute_pipeline",
|
|
side_effect=[Categories([{"label": "POS"}]), Categories([{"label": "POS"}])],
|
|
)
|
|
def test_cannot_label_same_category_type(self, mock):
|
|
mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
|
|
mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED)
|
|
self.assertEqual(Category.objects.count(), 1)
|
|
|
|
@patch(
|
|
"auto_labeling.views.execute_pipeline",
|
|
side_effect=[
|
|
Categories([{"label": "POS"}]),
|
|
Spans([{"label": "LOC", "start_offset": 0, "end_offset": 5}]),
|
|
],
|
|
)
|
|
def test_allow_multi_type_configs(self, mock):
|
|
mommy.make("AutoLabelingConfig", task_type="Category", project=self.project.item)
|
|
mommy.make("AutoLabelingConfig", task_type="Span", project=self.project.item)
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED)
|
|
self.assertEqual(Category.objects.count(), 1)
|
|
self.assertEqual(Span.objects.count(), 1)
|
|
|
|
@patch("auto_labeling.views.execute_pipeline", return_value=Categories([{"label": "POS"}]))
|
|
def test_cannot_use_other_project_config(self, mock):
|
|
mommy.make("AutoLabelingConfig", task_type="Category")
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED)
|
|
self.assertEqual(Category.objects.count(), 0)
|
|
|
|
|
|
class TestAutomatedSpanLabeling(CRUDMixin):
|
|
def setUp(self):
|
|
self.project = prepare_project(task=ProjectType.SEQUENCE_LABELING)
|
|
self.example = make_doc(self.project.item)
|
|
self.loc = mommy.make("SpanType", project=self.project.item, text="LOC")
|
|
self.url = reverse(viewname="auto_labeling", args=[self.project.item.id])
|
|
self.url += f"?example={self.example.id}"
|
|
|
|
@patch(
|
|
"auto_labeling.views.execute_pipeline",
|
|
side_effect=[
|
|
Spans([{"label": "LOC", "start_offset": 0, "end_offset": 5}]),
|
|
Spans([{"label": "LOC", "start_offset": 4, "end_offset": 10}]),
|
|
],
|
|
)
|
|
def test_cannot_label_overlapping_span(self, mock):
|
|
mommy.make("AutoLabelingConfig", task_type="Span", project=self.project.item)
|
|
mommy.make("AutoLabelingConfig", task_type="Span", project=self.project.item)
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED)
|
|
self.assertEqual(Span.objects.count(), 1)
|
|
|
|
|
|
class TestAutomatedTextLabeling(CRUDMixin):
|
|
def setUp(self):
|
|
self.project = prepare_project(task=ProjectType.SEQ2SEQ)
|
|
self.example = make_doc(self.project.item)
|
|
self.url = reverse(viewname="auto_labeling", args=[self.project.item.id])
|
|
self.url += f"?example={self.example.id}"
|
|
|
|
@patch("auto_labeling.views.execute_pipeline", side_effect=[Texts([{"text": "foo"}]), Texts([{"text": "foo"}])])
|
|
def test_cannot_label_same_text(self, mock):
|
|
mommy.make("AutoLabelingConfig", task_type="Text", project=self.project.item)
|
|
mommy.make("AutoLabelingConfig", task_type="Text", project=self.project.item)
|
|
self.assert_create(self.project.admin, status.HTTP_201_CREATED)
|
|
self.assertEqual(TextLabel.objects.count(), 1)
|