Browse Source

Move auto labeling code to auto labeling app

pull/1646/head
Hironsan 2 years ago
parent
commit
7a3f9b616b
25 changed files with 200 additions and 172 deletions
  1. 2
      Pipfile
  2. 18
      backend/api/admin.py
  3. 30
      backend/api/exceptions.py
  4. 53
      backend/api/serializers.py
  5. 5
      backend/api/tests/api/test_annotation.py
  6. 3
      backend/api/tests/api/test_document.py
  7. 3
      backend/api/tests/api/test_label.py
  8. 7
      backend/api/tests/api/utils.py
  9. 5
      backend/api/tests/test_filters.py
  10. 6
      backend/api/tests/test_models.py
  11. 49
      backend/api/urls.py
  12. 1
      backend/app/settings.py
  13. 1
      backend/app/urls.py
  14. 0
      backend/auto_labeling/__init__.py
  15. 17
      backend/auto_labeling/admin.py
  16. 6
      backend/auto_labeling/apps.py
  17. 29
      backend/auto_labeling/exceptions.py
  18. 0
      backend/auto_labeling/migrations/__init__.py
  19. 0
      backend/auto_labeling/models.py
  20. 54
      backend/auto_labeling/serializers.py
  21. 0
      backend/auto_labeling/tests/__init__.py
  22. 16
      backend/auto_labeling/tests/test_views.py
  23. 53
      backend/auto_labeling/urls.py
  24. 12
      backend/auto_labeling/views.py
  25. 2
      backend/members/tests.py

2
Pipfile

@ -60,6 +60,6 @@ python_version = "3.8"
isort = "isort api -c --skip migrations"
flake8 = "flake8 --filename \"*.py\" --extend-exclude \"api/migrations\""
wait_for_db = "python manage.py wait_for_db"
test = "python manage.py test api.tests roles.tests members.tests metrics.tests users.tests data_import.tests data_export.tests"
test = "python manage.py test --pattern=\"test*.py\""
migrate = "python manage.py migrate"
collectstatic = "python manage.py collectstatic --noinput"

18
backend/api/admin.py

@ -1,8 +1,8 @@
from django.contrib import admin
from .models import (AutoLabelingConfig, Category, CategoryType, Comment,
Example, Project, Seq2seqProject, SequenceLabelingProject,
Span, SpanType, Tag, TextClassificationProject, TextLabel)
from .models import (Category, CategoryType, Comment, Example, Project,
Seq2seqProject, SequenceLabelingProject, Span, SpanType,
Tag, TextClassificationProject, TextLabel)
class LabelAdmin(admin.ModelAdmin):
@ -58,18 +58,6 @@ class CommentAdmin(admin.ModelAdmin):
search_fields = ('user',)
class AutoLabelingConfigAdmin(admin.ModelAdmin):
list_display = ('project', 'model_name', 'model_attrs',)
ordering = ('project',)
def get_readonly_fields(self, request, obj=None):
if obj:
return ["model_name"]
else:
return []
admin.site.register(AutoLabelingConfig, AutoLabelingConfigAdmin)
admin.site.register(Category, CategoryAdmin)
admin.site.register(Span, SpanAdmin)
admin.site.register(TextLabel, TextLabelAdmin)

30
backend/api/exceptions.py

@ -1,6 +1,5 @@
from rest_framework import status
from rest_framework.exceptions import (APIException, PermissionDenied,
ValidationError)
from rest_framework.exceptions import APIException
class FileParseException(APIException):
@ -13,33 +12,6 @@ class FileParseException(APIException):
super().__init__(detail, code)
class AutoLabelingException(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Auto labeling not allowed for the document with labels.'
class AutoLabelingPermissionDenied(PermissionDenied):
default_detail = 'You do not have permission to perform auto labeling.' \
'Please ask the project administrators to add you.'
class URLConnectionError(ValidationError):
default_detail = 'Failed to establish a connection. Please check the URL or network.'
class AWSTokenError(ValidationError):
default_detail = 'The security token included in the request is invalid.'
class SampleDataException(ValidationError):
default_detail = 'The response is empty. Maybe the sample data is not appropriate.' \
'Please specify another sample data which returns at least one label.'
class TemplateMappingError(ValidationError):
default_detail = 'The response cannot be mapped. You might need to change the template.'
class LabelValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'You cannot create a label with same name or shortcut key.'

53
backend/api/serializers.py

@ -1,11 +1,8 @@
from auto_labeling_pipeline.models import RequestModelFactory
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_polymorphic.serializers import PolymorphicSerializer
from .models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT, AnnotationRelations,
AutoLabelingConfig, Category, CategoryType, Comment,
from .models import (AnnotationRelations, Category, CategoryType, Comment,
Example, ExampleState, ImageClassificationProject,
IntentDetectionAndSlotFillingProject, Label, Project,
RelationTypes, Seq2seqProject, SequenceLabelingProject,
@ -284,54 +281,6 @@ class TextLabelSerializer(serializers.ModelSerializer):
read_only_fields = ('user',)
class AutoLabelingConfigSerializer(serializers.ModelSerializer):
class Meta:
model = AutoLabelingConfig
fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
read_only_fields = ('created_at', 'updated_at')
def validate_model_name(self, value):
try:
RequestModelFactory.find(value)
except NameError:
raise serializers.ValidationError(f'The specified model name {value} does not exist.')
return value
def valid_label_mapping(self, value):
if isinstance(value, dict):
return value
else:
raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
def validate(self, data):
try:
RequestModelFactory.create(data['model_name'], data['model_attrs'])
except Exception:
model = RequestModelFactory.find(data['model_name'])
schema = model.schema()
required_fields = ', '.join(schema['required']) if 'required' in schema else ''
raise serializers.ValidationError(
'The attributes does not match the model.'
'You need to correctly specify the required fields: {}'.format(required_fields)
)
return data
def get_annotation_serializer(task: str):
mapping = {
DOCUMENT_CLASSIFICATION: CategorySerializer,
SEQUENCE_LABELING: SpanSerializer,
SEQ2SEQ: TextLabelSerializer,
SPEECH2TEXT: TextLabelSerializer,
IMAGE_CLASSIFICATION: CategorySerializer,
}
try:
return mapping[task]
except KeyError:
raise ValueError(f'{task} is not implemented.')
class RelationTypesSerializer(serializers.ModelSerializer):
def validate(self, attrs):

5
backend/api/tests/api/test_annotation.py

@ -1,8 +1,9 @@
from rest_framework import status
from rest_framework.reverse import reverse
from ...models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
Category, Span, TextLabel)
from api.models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
Category, Span, TextLabel)
from .utils import (CRUDMixin, make_annotation, make_doc, make_label,
make_user, prepare_project)

3
backend/api/tests/api/test_document.py

@ -3,7 +3,8 @@ from django.utils.http import urlencode
from rest_framework import status
from rest_framework.reverse import reverse
from ...models import DOCUMENT_CLASSIFICATION
from api.models import DOCUMENT_CLASSIFICATION
from .utils import (CRUDMixin, assign_user_to_role, make_doc,
make_example_state, make_user, prepare_project)

3
backend/api/tests/api/test_label.py

@ -5,7 +5,8 @@ from rest_framework import status
from rest_framework.reverse import reverse
from rest_framework.test import APITestCase
from ...models import DOCUMENT_CLASSIFICATION
from api.models import DOCUMENT_CLASSIFICATION
from .utils import (DATA_DIR, CRUDMixin, make_label, make_project, make_user,
prepare_project)

7
backend/api/tests/api/utils.py

@ -8,13 +8,12 @@ from model_mommy import mommy
from rest_framework import status
from rest_framework.test import APITestCase
from api.models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT)
from members.models import Member
from roles.models import Role
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT)
DATA_DIR = os.path.join(os.path.dirname(__file__), '../../../data_import/tests/data')

5
backend/api/tests/test_filters.py

@ -2,8 +2,9 @@ from unittest.mock import MagicMock
from django.test import TestCase
from ..filters import ExampleFilter
from ..models import Example
from api.filters import ExampleFilter
from api.models import Example
from .api.utils import make_doc, make_example_state, prepare_project

6
backend/api/tests/test_models.py

@ -3,8 +3,10 @@ from django.db.utils import IntegrityError
from django.test import TestCase
from model_mommy import mommy
from ..models import (SEQUENCE_LABELING, Category, CategoryType, ExampleState,
Span, SpanType, TextLabel, generate_random_hex_color)
from api.models import (SEQUENCE_LABELING, Category, CategoryType,
ExampleState, Span, SpanType, TextLabel,
generate_random_hex_color)
from .api.utils import prepare_project

49
backend/api/urls.py

@ -1,7 +1,7 @@
from django.urls import include, path
from .views import (annotation, auto_labeling, comment, example, example_state,
health, label, project, tag, task)
from .views import (annotation, comment, example, example_state, health, label,
project, tag, task)
from .views.tasks import category, relation, span, text
urlpatterns_project = [
@ -140,51 +140,6 @@ urlpatterns_project = [
view=example_state.ExampleStateList.as_view(),
name='example_state_list'
),
path(
route='auto-labeling-templates',
view=auto_labeling.AutoLabelingTemplateListAPI.as_view(),
name='auto_labeling_templates'
),
path(
route='auto-labeling-templates/<str:option_name>',
view=auto_labeling.AutoLabelingTemplateDetailAPI.as_view(),
name='auto_labeling_template'
),
path(
route='auto-labeling-configs',
view=auto_labeling.AutoLabelingConfigList.as_view(),
name='auto_labeling_configs'
),
path(
route='auto-labeling-configs/<int:config_id>',
view=auto_labeling.AutoLabelingConfigDetail.as_view(),
name='auto_labeling_config'
),
path(
route='auto-labeling-config-testing',
view=auto_labeling.AutoLabelingConfigTest.as_view(),
name='auto_labeling_config_test'
),
path(
route='examples/<int:example_id>/auto-labeling',
view=auto_labeling.AutoLabelingAnnotation.as_view(),
name='auto_labeling_annotation'
),
path(
route='auto-labeling-parameter-testing',
view=auto_labeling.AutoLabelingConfigParameterTest.as_view(),
name='auto_labeling_parameter_testing'
),
path(
route='auto-labeling-template-testing',
view=auto_labeling.AutoLabelingTemplateTest.as_view(),
name='auto_labeling_template_test'
),
path(
route='auto-labeling-mapping-testing',
view=auto_labeling.AutoLabelingMappingTest.as_view(),
name='auto_labeling_mapping_test'
)
]
urlpatterns = [

1
backend/app/settings.py

@ -58,6 +58,7 @@ INSTALLED_APPS = [
'users.apps.UsersConfig',
'data_import.apps.DataImportConfig',
'data_export.apps.DataExportConfig',
'auto_labeling.apps.AutoLabelingConfig',
'rest_framework',
'rest_framework.authtoken',
'django_filters',

1
backend/app/urls.py

@ -47,6 +47,7 @@ urlpatterns += [
path('v1/', include('data_export.urls')),
path('v1/projects/<int:project_id>/', include('members.urls')),
path('v1/projects/<int:project_id>/metrics/', include('metrics.urls')),
path('v1/projects/<int:project_id>/', include('auto_labeling.urls')),
path('swagger/', schema_view.with_ui('swagger', cache_timeout=0), name='schema-swagger-ui'),
re_path('', TemplateView.as_view(template_name='index.html')),
]

0
backend/auto_labeling/__init__.py

17
backend/auto_labeling/admin.py

@ -0,0 +1,17 @@
from django.contrib import admin
from api.models import AutoLabelingConfig
class AutoLabelingConfigAdmin(admin.ModelAdmin):
list_display = ('project', 'model_name', 'model_attrs',)
ordering = ('project',)
def get_readonly_fields(self, request, obj=None):
if obj:
return ["model_name"]
else:
return []
admin.site.register(AutoLabelingConfig, AutoLabelingConfigAdmin)

6
backend/auto_labeling/apps.py

@ -0,0 +1,6 @@
from django.apps import AppConfig
class AutoLabelingConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'auto_labeling'

29
backend/auto_labeling/exceptions.py

@ -0,0 +1,29 @@
from rest_framework import status
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
class AutoLabelingException(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Auto labeling not allowed for the document with labels.'
class AutoLabelingPermissionDenied(PermissionDenied):
default_detail = 'You do not have permission to perform auto labeling.' \
'Please ask the project administrators to add you.'
class URLConnectionError(ValidationError):
default_detail = 'Failed to establish a connection. Please check the URL or network.'
class AWSTokenError(ValidationError):
default_detail = 'The security token included in the request is invalid.'
class SampleDataException(ValidationError):
default_detail = 'The response is empty. Maybe the sample data is not appropriate.' \
'Please specify another sample data which returns at least one label.'
class TemplateMappingError(ValidationError):
default_detail = 'The response cannot be mapped. You might need to change the template.'

0
backend/auto_labeling/migrations/__init__.py

0
backend/auto_labeling/models.py

54
backend/auto_labeling/serializers.py

@ -0,0 +1,54 @@
from auto_labeling_pipeline.models import RequestModelFactory
from rest_framework import serializers
from api.models import AutoLabelingConfig, DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ, SPEECH2TEXT, \
IMAGE_CLASSIFICATION
from api.serializers import CategorySerializer, SpanSerializer, TextLabelSerializer
class AutoLabelingConfigSerializer(serializers.ModelSerializer):
class Meta:
model = AutoLabelingConfig
fields = ('id', 'model_name', 'model_attrs', 'template', 'label_mapping', 'default')
read_only_fields = ('created_at', 'updated_at')
def validate_model_name(self, value):
try:
RequestModelFactory.find(value)
except NameError:
raise serializers.ValidationError(f'The specified model name {value} does not exist.')
return value
def valid_label_mapping(self, value):
if isinstance(value, dict):
return value
else:
raise serializers.ValidationError(f'The {value} is not a dictionary. Please specify it as a dictionary.')
def validate(self, data):
try:
RequestModelFactory.create(data['model_name'], data['model_attrs'])
except Exception:
model = RequestModelFactory.find(data['model_name'])
schema = model.schema()
required_fields = ', '.join(schema['required']) if 'required' in schema else ''
raise serializers.ValidationError(
'The attributes does not match the model.'
'You need to correctly specify the required fields: {}'.format(required_fields)
)
return data
def get_annotation_serializer(task: str):
mapping = {
DOCUMENT_CLASSIFICATION: CategorySerializer,
SEQUENCE_LABELING: SpanSerializer,
SEQ2SEQ: TextLabelSerializer,
SPEECH2TEXT: TextLabelSerializer,
IMAGE_CLASSIFICATION: CategorySerializer,
}
try:
return mapping[task]
except KeyError:
raise ValueError(f'{task} is not implemented.')

0
backend/auto_labeling/tests/__init__.py

backend/api/tests/api/test_auto_labeling.py → backend/auto_labeling/tests/test_views.py

@ -6,9 +6,9 @@ from auto_labeling_pipeline.models import RequestModelFactory
from rest_framework import status
from rest_framework.reverse import reverse
from ...models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION
from .utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image,
prepare_project)
from api.models import DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION
from api.tests.api.utils import (CRUDMixin, make_auto_labeling_config, make_doc, make_image,
prepare_project)
data_dir = pathlib.Path(__file__).parent / 'data'
@ -24,20 +24,20 @@ class TestConfigParameter(CRUDMixin):
}
self.url = reverse(viewname='auto_labeling_parameter_testing', args=[self.project.item.id])
@patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={})
@patch('auto_labeling.views.AutoLabelingConfigParameterTest.send_request', return_value={})
def test_called_with_proper_model(self, mock):
self.assert_create(self.project.users[0], status.HTTP_200_OK)
_, kwargs = mock.call_args
expected = RequestModelFactory.create(self.data['model_name'], self.data['model_attrs'])
self.assertEqual(kwargs['model'], expected)
@patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={})
@patch('auto_labeling.views.AutoLabelingConfigParameterTest.send_request', return_value={})
def test_called_with_text(self, mock):
self.assert_create(self.project.users[0], status.HTTP_200_OK)
_, kwargs = mock.call_args
self.assertEqual(kwargs['example'], self.data['text'])
@patch('api.views.auto_labeling.AutoLabelingConfigParameterTest.send_request', return_value={})
@patch('auto_labeling.views.AutoLabelingConfigParameterTest.send_request', return_value={})
def test_called_with_image(self, mock):
self.data['text'] = str(data_dir / 'images/1500x500.jpeg')
self.assert_create(self.project.users[0], status.HTTP_200_OK)
@ -118,7 +118,7 @@ class TestAutoLabelingText(CRUDMixin):
self.example = make_doc(self.project.item)
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
@patch('api.views.auto_labeling.execute_pipeline', return_value=[])
@patch('auto_labeling.views.execute_pipeline', return_value=[])
def test_text_task(self, mock):
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
_, kwargs = mock.call_args
@ -134,7 +134,7 @@ class TestAutoLabelingImage(CRUDMixin):
self.example = make_image(self.project.item, str(filepath))
self.url = reverse(viewname='auto_labeling_annotation', args=[self.project.item.id, self.example.id])
@patch('api.views.auto_labeling.execute_pipeline', return_value=[])
@patch('auto_labeling.views.execute_pipeline', return_value=[])
def test_text_task(self, mock):
self.assert_create(self.project.users[0], status.HTTP_201_CREATED)
_, kwargs = mock.call_args

53
backend/auto_labeling/urls.py

@ -0,0 +1,53 @@
from django.urls import path
from .views import (AutoLabelingConfigDetail, AutoLabelingConfigTest, AutoLabelingAnnotation, AutoLabelingMappingTest,
AutoLabelingTemplateListAPI, AutoLabelingTemplateDetailAPI, AutoLabelingConfigList,
AutoLabelingConfigParameterTest, AutoLabelingTemplateTest)
urlpatterns = [
path(
route='auto-labeling-templates',
view=AutoLabelingTemplateListAPI.as_view(),
name='auto_labeling_templates'
),
path(
route='auto-labeling-templates/<str:option_name>',
view=AutoLabelingTemplateDetailAPI.as_view(),
name='auto_labeling_template'
),
path(
route='auto-labeling-configs',
view=AutoLabelingConfigList.as_view(),
name='auto_labeling_configs'
),
path(
route='auto-labeling-configs/<int:config_id>',
view=AutoLabelingConfigDetail.as_view(),
name='auto_labeling_config'
),
path(
route='auto-labeling-config-testing',
view=AutoLabelingConfigTest.as_view(),
name='auto_labeling_config_test'
),
path(
route='examples/<int:example_id>/auto-labeling',
view=AutoLabelingAnnotation.as_view(),
name='auto_labeling_annotation'
),
path(
route='auto-labeling-parameter-testing',
view=AutoLabelingConfigParameterTest.as_view(),
name='auto_labeling_parameter_testing'
),
path(
route='auto-labeling-template-testing',
view=AutoLabelingTemplateTest.as_view(),
name='auto_labeling_template_test'
),
path(
route='auto-labeling-mapping-testing',
view=AutoLabelingMappingTest.as_view(),
name='auto_labeling_mapping_test'
)
]

backend/api/views/auto_labeling.py → backend/auto_labeling/views.py

@ -16,14 +16,12 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from api.models import AutoLabelingConfig, Example, Project
from members.permissions import IsInProjectOrAdmin, IsProjectAdmin
from ..exceptions import (AutoLabelingException, AutoLabelingPermissionDenied,
AWSTokenError, SampleDataException,
TemplateMappingError, URLConnectionError)
from ..models import AutoLabelingConfig, Example, Project
from ..serializers import (AutoLabelingConfigSerializer,
get_annotation_serializer)
from .exceptions import (AutoLabelingException, AutoLabelingPermissionDenied,
AWSTokenError, SampleDataException,
TemplateMappingError, URLConnectionError)
from .serializers import (AutoLabelingConfigSerializer, get_annotation_serializer)
class AutoLabelingTemplateListAPI(APIView):

2
backend/members/tests.py

@ -3,7 +3,7 @@ from rest_framework import status
from rest_framework.reverse import reverse
from roles.models import Role
from .models import Member
from members.models import Member
from api.tests.api.utils import (CRUDMixin, prepare_project, make_user)

Loading…
Cancel
Save