Browse Source

Add AutomatedTextLabeling API

pull/1650/head
Hironsan 2 years ago
parent
commit
24e04c48b2
3 changed files with 54 additions and 5 deletions
  1. 35
      backend/auto_labeling/tests/test_views.py
  2. 9
      backend/auto_labeling/urls.py
  3. 15
      backend/auto_labeling/views.py

35
backend/auto_labeling/tests/test_views.py

@ -7,8 +7,8 @@ from model_mommy import mommy
from rest_framework import status
from rest_framework.reverse import reverse
from api.models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQUENCE_LABELING
from api.models import Category, Span
from api.models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
from api.models import Category, Span, TextLabel
from api.tests.api.utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image,
prepare_project)
@ -265,3 +265,34 @@ class TestAutomatedSpanLabeling(CRUDMixin):
mommy.make('AutoLabelingConfig', task_type='Span')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Span.objects.count(), 1)
class TestAutomatedTextLabeling(CRUDMixin):
def setUp(self):
self.project = prepare_project(task=SEQ2SEQ)
self.example = make_doc(self.project.item)
self.url = reverse(viewname='automated_text_labeling', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[{'text': 'foo'}])
def test_category_labeling(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(TextLabel.objects.count(), 1)
self.assertEqual(TextLabel.objects.first().text, 'foo')
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'bar'}]])
def test_multiple_configs(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text')
mommy.make('AutoLabelingConfig', task_type='Text')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(TextLabel.objects.count(), 2)
self.assertEqual(TextLabel.objects.first().text, 'foo')
self.assertEqual(TextLabel.objects.last().text, 'bar')
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'foo'}]])
def test_cannot_label_same_category_type(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text')
mommy.make('AutoLabelingConfig', task_type='Text')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(TextLabel.objects.count(), 1)

9
backend/auto_labeling/urls.py

@ -1,8 +1,8 @@
from django.urls import path
from .views import (ConfigDetail, FullPipelineTesting, AutomatedDataLabeling, AutomatedCategoryLabeling,
AutomatedSpanLabeling, LabelMapperTesting, TemplateListAPI, TemplateDetailAPI, ConfigList,
RestAPIRequestTesting, LabelExtractorTesting)
AutomatedSpanLabeling, AutomatedTextLabeling, LabelMapperTesting, TemplateListAPI,
TemplateDetailAPI, ConfigList, RestAPIRequestTesting, LabelExtractorTesting)
urlpatterns = [
path(
@ -59,5 +59,10 @@ urlpatterns = [
route='examples/<int:example_id>/auto-labeling/spans',
view=AutomatedSpanLabeling.as_view(),
name='automated_span_labeling'
),
path(
route='examples/<int:example_id>/auto-labeling/texts',
view=AutomatedTextLabeling.as_view(),
name='automated_text_labeling'
)
]

15
backend/auto_labeling/views.py

@ -18,7 +18,7 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from api.models import Example, Project, Category, CategoryType, Annotation, Span, SpanType
from api.models import Example, Project, Category, CategoryType, Annotation, Span, SpanType, TextLabel
from members.permissions import IsInProjectOrAdmin, IsProjectAdmin
from .pipeline.execution import execute_pipeline
from .exceptions import (AutoLabelingPermissionDenied,
@ -283,3 +283,16 @@ class AutomatedSpanLabeling(AutomatedLabeling):
model = Span
label_type = SpanType
task_type = 'Span'
class AutomatedTextLabeling(AutomatedLabeling):
model = TextLabel
task_type = 'Text'
def transform(self, labels, example: Example, project: Project) -> List[Annotation]:
annotations = []
for label in labels:
label['example'] = example
label['user'] = self.request.user
annotations.append(self.model(**label))
return annotations
Loading…
Cancel
Save