Browse Source

Remove AutoLabeling APIs

pull/1650/head
Hironsan 2 years ago
parent
commit
0325201405
4 changed files with 135 additions and 216 deletions
  1. 67
      backend/auto_labeling/pipeline/execution.py
  2. 159
      backend/auto_labeling/tests/test_views.py
  3. 25
      backend/auto_labeling/urls.py
  4. 100
      backend/auto_labeling/views.py

67
backend/auto_labeling/pipeline/execution.py

@ -1,4 +1,5 @@
from typing import Type
import abc
from typing import List, Type
from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels
from auto_labeling_pipeline.mappings import MappingTemplate from auto_labeling_pipeline.mappings import MappingTemplate
@ -6,6 +7,10 @@ from auto_labeling_pipeline.models import RequestModelFactory
from auto_labeling_pipeline.pipeline import pipeline from auto_labeling_pipeline.pipeline import pipeline
from auto_labeling_pipeline.postprocessing import PostProcessor from auto_labeling_pipeline.postprocessing import PostProcessor
from api.models import Example, Project, User
from api.models import CategoryType, SpanType
from api.models import Annotation, Category, Span, TextLabel
def get_label_collection(task_type: str) -> Type[Labels]: def get_label_collection(task_type: str) -> Type[Labels]:
return { return {
@ -15,6 +20,63 @@ def get_label_collection(task_type: str) -> Type[Labels]:
}[task_type] }[task_type]
class LabelCollection(abc.ABC):
label_type = None
model = None
def __init__(self, labels):
self.labels = labels
def transform(self, project: Project, example: Example, user: User) -> List[Annotation]:
mapping = {
c.text: c for c in self.label_type.objects.filter(project=project)
}
annotations = []
for label in self.labels:
if label['label'] not in mapping:
continue
label['example'] = example
label['label'] = mapping[label['label']]
label['user'] = user
annotations.append(self.model(**label))
return annotations
def save(self, project: Project, example: Example, user: User):
labels = self.transform(project, example, user)
labels = self.model.objects.filter_annotatable_labels(labels, project)
self.model.objects.bulk_create(labels)
class Categories(LabelCollection):
label_type = CategoryType
model = Category
class Spans(LabelCollection):
label_type = SpanType
model = Span
class Texts(LabelCollection):
model = TextLabel
def transform(self, project: Project, example: Example, user: User) -> List[Annotation]:
annotations = []
for label in self.labels:
label['example'] = example
label['user'] = user
annotations.append(self.model(**label))
return annotations
def create_labels(task_type: str, labels: Labels) -> LabelCollection:
return {
'Category': Categories,
'Span': Spans,
'Text': Texts
}[task_type](labels.dict())
def execute_pipeline(text: str, def execute_pipeline(text: str,
task_type: str, task_type: str,
model_name: str, model_name: str,
@ -37,4 +99,5 @@ def execute_pipeline(text: str,
mapping_template=template, mapping_template=template,
post_processing=post_processor post_processing=post_processor
) )
return labels.dict()
labels = create_labels(task_type, labels)
return labels

159
backend/auto_labeling/tests/test_views.py

@ -7,10 +7,10 @@ from model_mommy import mommy
from rest_framework import status from rest_framework import status
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from api.models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
from api.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
from api.models import Category, Span, TextLabel from api.models import Category, Span, TextLabel
from api.tests.api.utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image,
prepare_project)
from api.tests.api.utils import CRUDMixin, make_doc, prepare_project
from auto_labeling.pipeline.execution import Categories, Spans, Texts
data_dir = pathlib.Path(__file__).parent / 'data' data_dir = pathlib.Path(__file__).parent / 'data'
@ -148,38 +148,6 @@ class TestConfigCreation(CRUDMixin):
self.assertEqual(len(response.data), 1) self.assertEqual(len(response.data), 1)
class TestAutoLabelingText(CRUDMixin):
def setUp(self):
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
make_auto_labeling_config(self.project.item)
self.example = make_doc(self.project.item)
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[])
def test_text_task(self, mock):
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
_, kwargs = mock.call_args
self.assertEqual(kwargs['text'], self.example.text)
class TestAutoLabelingImage(CRUDMixin):
def setUp(self):
self.project = prepare_project(task=IMAGE_CLASSIFICATION)
make_auto_labeling_config(self.project.item)
filepath = data_dir / 'images/1500x500.jpeg'
self.example = make_image(self.project.item, str(filepath))
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[])
def test_text_task(self, mock):
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
_, kwargs = mock.call_args
expected = str(self.example.filename)
self.assertEqual(kwargs['text'], expected)
class TestAutomatedCategoryLabeling(CRUDMixin): class TestAutomatedCategoryLabeling(CRUDMixin):
def setUp(self): def setUp(self):
@ -191,30 +159,63 @@ class TestAutomatedCategoryLabeling(CRUDMixin):
self.category_neg = mommy.make( self.category_neg = mommy.make(
'CategoryType', project=self.project.item, text='NEG' 'CategoryType', project=self.project.item, text='NEG'
) )
self.url = reverse(viewname='automated_category_labeling', args=[self.project.item.id, self.example.id])
self.loc = mommy.make('SpanType', project=self.project.item, text='LOC')
self.url = reverse(viewname='automated_labeling', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[{'label': 'POS'}])
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}]))
def test_category_labeling(self, mock): def test_category_labeling(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 1) self.assertEqual(Category.objects.count(), 1)
self.assertEqual(Category.objects.first().label, self.category_pos) self.assertEqual(Category.objects.first().label, self.category_pos)
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'label': 'POS'}], [{'label': 'NEG'}]])
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
Categories([{'label': 'POS'}]),
Categories([{'label': 'NEG'}])
]
)
def test_multiple_configs(self, mock): def test_multiple_configs(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 2) self.assertEqual(Category.objects.count(), 2)
self.assertEqual(Category.objects.first().label, self.category_pos) self.assertEqual(Category.objects.first().label, self.category_pos)
self.assertEqual(Category.objects.last().label, self.category_neg) self.assertEqual(Category.objects.last().label, self.category_neg)
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'label': 'POS'}], [{'label': 'POS'}]])
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
Categories([{'label': 'POS'}]),
Categories([{'label': 'POS'}])
]
)
def test_cannot_label_same_category_type(self, mock): def test_cannot_label_same_category_type(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 1)
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
Categories([{'label': 'POS'}]),
Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]),
]
)
def test_allow_multi_type_configs(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 1) self.assertEqual(Category.objects.count(), 1)
self.assertEqual(Span.objects.count(), 1)
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}]))
def test_cannot_use_other_project_config(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 0)
class TestAutomatedSpanLabeling(CRUDMixin): class TestAutomatedSpanLabeling(CRUDMixin):
@ -223,46 +224,18 @@ class TestAutomatedSpanLabeling(CRUDMixin):
self.project = prepare_project(task=SEQUENCE_LABELING) self.project = prepare_project(task=SEQUENCE_LABELING)
self.example = make_doc(self.project.item) self.example = make_doc(self.project.item)
self.loc = mommy.make('SpanType', project=self.project.item, text='LOC') self.loc = mommy.make('SpanType', project=self.project.item, text='LOC')
self.url = reverse(viewname='automated_span_labeling', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}])
def test_span_labeling(self, mock):
mommy.make('AutoLabelingConfig', task_type='Span')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Span.objects.count(), 1)
self.assertEqual(Span.objects.first().label, self.loc)
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}],
[{'label': 'LOC', 'start_offset': 5, 'end_offset': 10}]
]
)
def test_multiple_configs(self, mock):
mommy.make('AutoLabelingConfig', task_type='Span')
mommy.make('AutoLabelingConfig', task_type='Span')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
expected_spans = [
{'label': 'LOC', 'start_offset': 0, 'end_offset': 5},
{'label': 'LOC', 'start_offset': 5, 'end_offset': 10}
]
self.assertEqual(Span.objects.count(), len(expected_spans))
for actual, expected in zip(Span.objects.all(), expected_spans):
self.assertEqual(actual.label, self.loc)
self.assertEqual(actual.start_offset, expected['start_offset'])
self.assertEqual(actual.end_offset, expected['end_offset'])
self.url = reverse(viewname='automated_labeling', args=[self.project.item.id, self.example.id])
@patch( @patch(
'auto_labeling.views.execute_pipeline', 'auto_labeling.views.execute_pipeline',
side_effect=[ side_effect=[
[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}],
[{'label': 'LOC', 'start_offset': 4, 'end_offset': 10}]
Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]),
Spans([{'label': 'LOC', 'start_offset': 4, 'end_offset': 10}])
] ]
) )
def test_cannot_label_overlapping_span(self, mock): def test_cannot_label_overlapping_span(self, mock):
mommy.make('AutoLabelingConfig', task_type='Span')
mommy.make('AutoLabelingConfig', task_type='Span')
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Span.objects.count(), 1) self.assertEqual(Span.objects.count(), 1)
@ -272,27 +245,17 @@ class TestAutomatedTextLabeling(CRUDMixin):
def setUp(self): def setUp(self):
self.project = prepare_project(task=SEQ2SEQ) self.project = prepare_project(task=SEQ2SEQ)
self.example = make_doc(self.project.item) self.example = make_doc(self.project.item)
self.url = reverse(viewname='automated_text_labeling', args=[self.project.item.id, self.example.id])
self.url = reverse(viewname='automated_labeling', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[{'text': 'foo'}])
def test_category_labeling(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(TextLabel.objects.count(), 1)
self.assertEqual(TextLabel.objects.first().text, 'foo')
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'bar'}]])
def test_multiple_configs(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text')
mommy.make('AutoLabelingConfig', task_type='Text')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(TextLabel.objects.count(), 2)
self.assertEqual(TextLabel.objects.first().text, 'foo')
self.assertEqual(TextLabel.objects.last().text, 'bar')
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'foo'}]])
def test_cannot_label_same_category_type(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text')
mommy.make('AutoLabelingConfig', task_type='Text')
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
Texts([{'text': 'foo'}]),
Texts([{'text': 'foo'}])
]
)
def test_cannot_label_same_text(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED) self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(TextLabel.objects.count(), 1) self.assertEqual(TextLabel.objects.count(), 1)

25
backend/auto_labeling/urls.py

@ -1,7 +1,7 @@
from django.urls import path from django.urls import path
from .views import (ConfigDetail, FullPipelineTesting, AutomatedDataLabeling, AutomatedCategoryLabeling,
AutomatedSpanLabeling, AutomatedTextLabeling, LabelMapperTesting, TemplateListAPI,
from .views import (ConfigDetail, FullPipelineTesting, AutomatedLabeling,
LabelMapperTesting, TemplateListAPI,
TemplateDetailAPI, ConfigList, RestAPIRequestTesting, LabelExtractorTesting) TemplateDetailAPI, ConfigList, RestAPIRequestTesting, LabelExtractorTesting)
urlpatterns = [ urlpatterns = [
@ -30,11 +30,6 @@ urlpatterns = [
view=FullPipelineTesting.as_view(), view=FullPipelineTesting.as_view(),
name='auto_labeling_config_test' name='auto_labeling_config_test'
), ),
path(
route='examples/<int:example_id>/auto-labeling',
view=AutomatedDataLabeling.as_view(),
name='auto_labeling_annotation'
),
path( path(
route='auto-labeling-parameter-testing', route='auto-labeling-parameter-testing',
view=RestAPIRequestTesting.as_view(), view=RestAPIRequestTesting.as_view(),
@ -51,18 +46,8 @@ urlpatterns = [
name='auto_labeling_mapping_test' name='auto_labeling_mapping_test'
), ),
path( path(
route='examples/<int:example_id>/auto-labeling/categories',
view=AutomatedCategoryLabeling.as_view(),
name='automated_category_labeling'
),
path(
route='examples/<int:example_id>/auto-labeling/spans',
view=AutomatedSpanLabeling.as_view(),
name='automated_span_labeling'
route='examples/<int:example_id>/auto-labeling',
view=AutomatedLabeling.as_view(),
name='automated_labeling'
), ),
path(
route='examples/<int:example_id>/auto-labeling/texts',
view=AutomatedTextLabeling.as_view(),
name='automated_text_labeling'
)
] ]

100
backend/auto_labeling/views.py

@ -1,6 +1,4 @@
import abc
import json import json
from typing import List
import botocore.exceptions import botocore.exceptions
import requests import requests
@ -18,7 +16,7 @@ from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from api.models import Example, Project, Category, CategoryType, Annotation, Span, SpanType, TextLabel
from api.models import Example, Project
from members.permissions import IsInProjectOrAdmin, IsProjectAdmin from members.permissions import IsInProjectOrAdmin, IsProjectAdmin
from .pipeline.execution import execute_pipeline from .pipeline.execution import execute_pipeline
from .exceptions import (AutoLabelingPermissionDenied, from .exceptions import (AutoLabelingPermissionDenied,
@ -187,63 +185,14 @@ class LabelMapperTesting(APIView):
return Response(labels.dict(), status=status.HTTP_200_OK) return Response(labels.dict(), status=status.HTTP_200_OK)
class AutomatedDataLabeling(generics.CreateAPIView):
pagination_class = None
permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
swagger_schema = None
def get_serializer_class(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
self.serializer_class = get_annotation_serializer(task=project.project_type)
return self.serializer_class
def create(self, request, *args, **kwargs):
labels = self.extract()
labels = self.transform(labels)
serializer = self.get_serializer(data=labels, many=True)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def perform_create(self, serializer):
serializer.save(user=self.request.user)
def extract(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
example = get_object_or_404(Example, pk=self.kwargs['example_id'])
config = project.auto_labeling_config.first()
if not config:
raise AutoLabelingPermissionDenied()
return execute_pipeline(
text=example.data,
task_type=project.project_type,
model_name=config.model_name,
model_attrs=config.model_attrs,
template=config.template,
label_mapping=config.label_mapping
)
def transform(self, labels):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
for label in labels:
label['example'] = self.kwargs['example_id']
if 'label' in label:
label['label'] = project.labels.get(text=label.pop('label')).id
return labels
class AutomatedLabeling(abc.ABC, generics.CreateAPIView):
class AutomatedLabeling(generics.CreateAPIView):
permission_classes = [IsAuthenticated & IsInProjectOrAdmin] permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
swagger_schema = None swagger_schema = None
model = None
label_type = None
task_type = None
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
project = get_object_or_404(Project, pk=self.kwargs['project_id']) project = get_object_or_404(Project, pk=self.kwargs['project_id'])
example = get_object_or_404(Example, pk=self.kwargs['example_id']) example = get_object_or_404(Example, pk=self.kwargs['example_id'])
configs = AutoLabelingConfig.objects.filter(task_type=self.task_type)
configs = AutoLabelingConfig.objects.filter(project=project)
for config in configs: for config in configs:
labels = execute_pipeline( labels = execute_pipeline(
text=example.data, text=example.data,
@ -253,46 +202,5 @@ class AutomatedLabeling(abc.ABC, generics.CreateAPIView):
template=config.template, template=config.template,
label_mapping=config.label_mapping label_mapping=config.label_mapping
) )
labels = self.transform(labels, example, project)
labels = self.model.objects.filter_annotatable_labels(labels, project)
self.model.objects.bulk_create(labels)
labels.save(project, example, self.request.user)
return Response({'ok': True}, status=status.HTTP_201_CREATED) return Response({'ok': True}, status=status.HTTP_201_CREATED)
def transform(self, labels, example: Example, project: Project) -> List[Annotation]:
mapping = {
c.text: c for c in self.label_type.objects.filter(project=project)
}
annotations = []
for label in labels:
if label['label'] not in mapping:
continue
label['example'] = example
label['label'] = mapping[label['label']]
label['user'] = self.request.user
annotations.append(self.model(**label))
return annotations
class AutomatedCategoryLabeling(AutomatedLabeling):
model = Category
label_type = CategoryType
task_type = 'Category'
class AutomatedSpanLabeling(AutomatedLabeling):
model = Span
label_type = SpanType
task_type = 'Span'
class AutomatedTextLabeling(AutomatedLabeling):
model = TextLabel
task_type = 'Text'
def transform(self, labels, example: Example, project: Project) -> List[Annotation]:
annotations = []
for label in labels:
label['example'] = example
label['user'] = self.request.user
annotations.append(self.model(**label))
return annotations
Loading…
Cancel
Save