Browse Source

Remove AutoLabeling APIs

pull/1650/head
Hironsan 2 years ago
parent
commit
0325201405
4 changed files with 135 additions and 216 deletions
  1. 67
      backend/auto_labeling/pipeline/execution.py
  2. 159
      backend/auto_labeling/tests/test_views.py
  3. 25
      backend/auto_labeling/urls.py
  4. 100
      backend/auto_labeling/views.py

67
backend/auto_labeling/pipeline/execution.py

@ -1,4 +1,5 @@
from typing import Type
import abc
from typing import List, Type
from auto_labeling_pipeline.labels import SequenceLabels, Seq2seqLabels, ClassificationLabels, Labels
from auto_labeling_pipeline.mappings import MappingTemplate
@ -6,6 +7,10 @@ from auto_labeling_pipeline.models import RequestModelFactory
from auto_labeling_pipeline.pipeline import pipeline
from auto_labeling_pipeline.postprocessing import PostProcessor
from api.models import Example, Project, User
from api.models import CategoryType, SpanType
from api.models import Annotation, Category, Span, TextLabel
def get_label_collection(task_type: str) -> Type[Labels]:
return {
@ -15,6 +20,63 @@ def get_label_collection(task_type: str) -> Type[Labels]:
}[task_type]
class LabelCollection(abc.ABC):
label_type = None
model = None
def __init__(self, labels):
self.labels = labels
def transform(self, project: Project, example: Example, user: User) -> List[Annotation]:
mapping = {
c.text: c for c in self.label_type.objects.filter(project=project)
}
annotations = []
for label in self.labels:
if label['label'] not in mapping:
continue
label['example'] = example
label['label'] = mapping[label['label']]
label['user'] = user
annotations.append(self.model(**label))
return annotations
def save(self, project: Project, example: Example, user: User):
labels = self.transform(project, example, user)
labels = self.model.objects.filter_annotatable_labels(labels, project)
self.model.objects.bulk_create(labels)
class Categories(LabelCollection):
label_type = CategoryType
model = Category
class Spans(LabelCollection):
label_type = SpanType
model = Span
class Texts(LabelCollection):
model = TextLabel
def transform(self, project: Project, example: Example, user: User) -> List[Annotation]:
annotations = []
for label in self.labels:
label['example'] = example
label['user'] = user
annotations.append(self.model(**label))
return annotations
def create_labels(task_type: str, labels: Labels) -> LabelCollection:
return {
'Category': Categories,
'Span': Spans,
'Text': Texts
}[task_type](labels.dict())
def execute_pipeline(text: str,
task_type: str,
model_name: str,
@ -37,4 +99,5 @@ def execute_pipeline(text: str,
mapping_template=template,
post_processing=post_processor
)
return labels.dict()
labels = create_labels(task_type, labels)
return labels

159
backend/auto_labeling/tests/test_views.py

@ -7,10 +7,10 @@ from model_mommy import mommy
from rest_framework import status
from rest_framework.reverse import reverse
from api.models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
from api.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ
from api.models import Category, Span, TextLabel
from api.tests.api.utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image,
prepare_project)
from api.tests.api.utils import CRUDMixin, make_doc, prepare_project
from auto_labeling.pipeline.execution import Categories, Spans, Texts
data_dir = pathlib.Path(__file__).parent / 'data'
@ -148,38 +148,6 @@ class TestConfigCreation(CRUDMixin):
self.assertEqual(len(response.data), 1)
class TestAutoLabelingText(CRUDMixin):
def setUp(self):
self.project = prepare_project(task=DOCUMENT_CLASSIFICATION)
make_auto_labeling_config(self.project.item)
self.example = make_doc(self.project.item)
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[])
def test_text_task(self, mock):
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
_, kwargs = mock.call_args
self.assertEqual(kwargs['text'], self.example.text)
class TestAutoLabelingImage(CRUDMixin):
def setUp(self):
self.project = prepare_project(task=IMAGE_CLASSIFICATION)
make_auto_labeling_config(self.project.item)
filepath = data_dir / 'images/1500x500.jpeg'
self.example = make_image(self.project.item, str(filepath))
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[])
def test_text_task(self, mock):
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
_, kwargs = mock.call_args
expected = str(self.example.filename)
self.assertEqual(kwargs['text'], expected)
class TestAutomatedCategoryLabeling(CRUDMixin):
def setUp(self):
@ -191,30 +159,63 @@ class TestAutomatedCategoryLabeling(CRUDMixin):
self.category_neg = mommy.make(
'CategoryType', project=self.project.item, text='NEG'
)
self.url = reverse(viewname='automated_category_labeling', args=[self.project.item.id, self.example.id])
self.loc = mommy.make('SpanType', project=self.project.item, text='LOC')
self.url = reverse(viewname='automated_labeling', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[{'label': 'POS'}])
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}]))
def test_category_labeling(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 1)
self.assertEqual(Category.objects.first().label, self.category_pos)
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'label': 'POS'}], [{'label': 'NEG'}]])
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
Categories([{'label': 'POS'}]),
Categories([{'label': 'NEG'}])
]
)
def test_multiple_configs(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 2)
self.assertEqual(Category.objects.first().label, self.category_pos)
self.assertEqual(Category.objects.last().label, self.category_neg)
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'label': 'POS'}], [{'label': 'POS'}]])
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
Categories([{'label': 'POS'}]),
Categories([{'label': 'POS'}])
]
)
def test_cannot_label_same_category_type(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category')
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 1)
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
Categories([{'label': 'POS'}]),
Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]),
]
)
def test_allow_multi_type_configs(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 1)
self.assertEqual(Span.objects.count(), 1)
@patch('auto_labeling.views.execute_pipeline', return_value=Categories([{'label': 'POS'}]))
def test_cannot_use_other_project_config(self, mock):
mommy.make('AutoLabelingConfig', task_type='Category')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Category.objects.count(), 0)
class TestAutomatedSpanLabeling(CRUDMixin):
@ -223,46 +224,18 @@ class TestAutomatedSpanLabeling(CRUDMixin):
self.project = prepare_project(task=SEQUENCE_LABELING)
self.example = make_doc(self.project.item)
self.loc = mommy.make('SpanType', project=self.project.item, text='LOC')
self.url = reverse(viewname='automated_span_labeling', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}])
def test_span_labeling(self, mock):
mommy.make('AutoLabelingConfig', task_type='Span')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Span.objects.count(), 1)
self.assertEqual(Span.objects.first().label, self.loc)
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}],
[{'label': 'LOC', 'start_offset': 5, 'end_offset': 10}]
]
)
def test_multiple_configs(self, mock):
mommy.make('AutoLabelingConfig', task_type='Span')
mommy.make('AutoLabelingConfig', task_type='Span')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
expected_spans = [
{'label': 'LOC', 'start_offset': 0, 'end_offset': 5},
{'label': 'LOC', 'start_offset': 5, 'end_offset': 10}
]
self.assertEqual(Span.objects.count(), len(expected_spans))
for actual, expected in zip(Span.objects.all(), expected_spans):
self.assertEqual(actual.label, self.loc)
self.assertEqual(actual.start_offset, expected['start_offset'])
self.assertEqual(actual.end_offset, expected['end_offset'])
self.url = reverse(viewname='automated_labeling', args=[self.project.item.id, self.example.id])
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
[{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}],
[{'label': 'LOC', 'start_offset': 4, 'end_offset': 10}]
Spans([{'label': 'LOC', 'start_offset': 0, 'end_offset': 5}]),
Spans([{'label': 'LOC', 'start_offset': 4, 'end_offset': 10}])
]
)
def test_cannot_label_overlapping_span(self, mock):
mommy.make('AutoLabelingConfig', task_type='Span')
mommy.make('AutoLabelingConfig', task_type='Span')
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Span', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(Span.objects.count(), 1)
@ -272,27 +245,17 @@ class TestAutomatedTextLabeling(CRUDMixin):
def setUp(self):
self.project = prepare_project(task=SEQ2SEQ)
self.example = make_doc(self.project.item)
self.url = reverse(viewname='automated_text_labeling', args=[self.project.item.id, self.example.id])
self.url = reverse(viewname='automated_labeling', args=[self.project.item.id, self.example.id])
@patch('auto_labeling.views.execute_pipeline', return_value=[{'text': 'foo'}])
def test_category_labeling(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(TextLabel.objects.count(), 1)
self.assertEqual(TextLabel.objects.first().text, 'foo')
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'bar'}]])
def test_multiple_configs(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text')
mommy.make('AutoLabelingConfig', task_type='Text')
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(TextLabel.objects.count(), 2)
self.assertEqual(TextLabel.objects.first().text, 'foo')
self.assertEqual(TextLabel.objects.last().text, 'bar')
@patch('auto_labeling.views.execute_pipeline', side_effect=[[{'text': 'foo'}], [{'text': 'foo'}]])
def test_cannot_label_same_category_type(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text')
mommy.make('AutoLabelingConfig', task_type='Text')
@patch(
'auto_labeling.views.execute_pipeline',
side_effect=[
Texts([{'text': 'foo'}]),
Texts([{'text': 'foo'}])
]
)
def test_cannot_label_same_text(self, mock):
mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item)
mommy.make('AutoLabelingConfig', task_type='Text', project=self.project.item)
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
self.assertEqual(TextLabel.objects.count(), 1)

25
backend/auto_labeling/urls.py

@ -1,7 +1,7 @@
from django.urls import path
from .views import (ConfigDetail, FullPipelineTesting, AutomatedDataLabeling, AutomatedCategoryLabeling,
AutomatedSpanLabeling, AutomatedTextLabeling, LabelMapperTesting, TemplateListAPI,
from .views import (ConfigDetail, FullPipelineTesting, AutomatedLabeling,
LabelMapperTesting, TemplateListAPI,
TemplateDetailAPI, ConfigList, RestAPIRequestTesting, LabelExtractorTesting)
urlpatterns = [
@ -30,11 +30,6 @@ urlpatterns = [
view=FullPipelineTesting.as_view(),
name='auto_labeling_config_test'
),
path(
route='examples/<int:example_id>/auto-labeling',
view=AutomatedDataLabeling.as_view(),
name='auto_labeling_annotation'
),
path(
route='auto-labeling-parameter-testing',
view=RestAPIRequestTesting.as_view(),
@ -51,18 +46,8 @@ urlpatterns = [
name='auto_labeling_mapping_test'
),
path(
route='examples/<int:example_id>/auto-labeling/categories',
view=AutomatedCategoryLabeling.as_view(),
name='automated_category_labeling'
),
path(
route='examples/<int:example_id>/auto-labeling/spans',
view=AutomatedSpanLabeling.as_view(),
name='automated_span_labeling'
route='examples/<int:example_id>/auto-labeling',
view=AutomatedLabeling.as_view(),
name='automated_labeling'
),
path(
route='examples/<int:example_id>/auto-labeling/texts',
view=AutomatedTextLabeling.as_view(),
name='automated_text_labeling'
)
]

100
backend/auto_labeling/views.py

@ -1,6 +1,4 @@
import abc
import json
from typing import List
import botocore.exceptions
import requests
@ -18,7 +16,7 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from api.models import Example, Project, Category, CategoryType, Annotation, Span, SpanType, TextLabel
from api.models import Example, Project
from members.permissions import IsInProjectOrAdmin, IsProjectAdmin
from .pipeline.execution import execute_pipeline
from .exceptions import (AutoLabelingPermissionDenied,
@ -187,63 +185,14 @@ class LabelMapperTesting(APIView):
return Response(labels.dict(), status=status.HTTP_200_OK)
class AutomatedDataLabeling(generics.CreateAPIView):
pagination_class = None
permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
swagger_schema = None
def get_serializer_class(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
self.serializer_class = get_annotation_serializer(task=project.project_type)
return self.serializer_class
def create(self, request, *args, **kwargs):
labels = self.extract()
labels = self.transform(labels)
serializer = self.get_serializer(data=labels, many=True)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def perform_create(self, serializer):
serializer.save(user=self.request.user)
def extract(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
example = get_object_or_404(Example, pk=self.kwargs['example_id'])
config = project.auto_labeling_config.first()
if not config:
raise AutoLabelingPermissionDenied()
return execute_pipeline(
text=example.data,
task_type=project.project_type,
model_name=config.model_name,
model_attrs=config.model_attrs,
template=config.template,
label_mapping=config.label_mapping
)
def transform(self, labels):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
for label in labels:
label['example'] = self.kwargs['example_id']
if 'label' in label:
label['label'] = project.labels.get(text=label.pop('label')).id
return labels
class AutomatedLabeling(abc.ABC, generics.CreateAPIView):
class AutomatedLabeling(generics.CreateAPIView):
permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
swagger_schema = None
model = None
label_type = None
task_type = None
def create(self, request, *args, **kwargs):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
example = get_object_or_404(Example, pk=self.kwargs['example_id'])
configs = AutoLabelingConfig.objects.filter(task_type=self.task_type)
configs = AutoLabelingConfig.objects.filter(project=project)
for config in configs:
labels = execute_pipeline(
text=example.data,
@ -253,46 +202,5 @@ class AutomatedLabeling(abc.ABC, generics.CreateAPIView):
template=config.template,
label_mapping=config.label_mapping
)
labels = self.transform(labels, example, project)
labels = self.model.objects.filter_annotatable_labels(labels, project)
self.model.objects.bulk_create(labels)
labels.save(project, example, self.request.user)
return Response({'ok': True}, status=status.HTTP_201_CREATED)
def transform(self, labels, example: Example, project: Project) -> List[Annotation]:
mapping = {
c.text: c for c in self.label_type.objects.filter(project=project)
}
annotations = []
for label in labels:
if label['label'] not in mapping:
continue
label['example'] = example
label['label'] = mapping[label['label']]
label['user'] = self.request.user
annotations.append(self.model(**label))
return annotations
class AutomatedCategoryLabeling(AutomatedLabeling):
model = Category
label_type = CategoryType
task_type = 'Category'
class AutomatedSpanLabeling(AutomatedLabeling):
model = Span
label_type = SpanType
task_type = 'Span'
class AutomatedTextLabeling(AutomatedLabeling):
model = TextLabel
task_type = 'Text'
def transform(self, labels, example: Example, project: Project) -> List[Annotation]:
annotations = []
for label in labels:
label['example'] = example
label['user'] = self.request.user
annotations.append(self.model(**label))
return annotations
Loading…
Cancel
Save