You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

274 lines
10 KiB

import json
import botocore.exceptions
import requests
from auto_labeling_pipeline.mappings import MappingTemplate
from auto_labeling_pipeline.menu import Options
from auto_labeling_pipeline.models import RequestModelFactory
from auto_labeling_pipeline.pipeline import pipeline
from auto_labeling_pipeline.postprocessing import PostProcessor
from auto_labeling_pipeline.task import TaskFactory
from django.shortcuts import get_object_or_404
from django_drf_filepond.models import TemporaryUpload
from rest_framework import generics, status
from rest_framework.exceptions import ValidationError
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from api.models import Example, Project
from members.permissions import IsInProjectOrAdmin, IsProjectAdmin
from .exceptions import (AutoLabelingException, AutoLabelingPermissionDenied,
AWSTokenError, SampleDataException,
TemplateMappingError, URLConnectionError)
from .models import AutoLabelingConfig
from .serializers import (AutoLabelingConfigSerializer, get_annotation_serializer)
class AutoLabelingTemplateListAPI(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
def get(self, request, *args, **kwargs):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
options = Options.filter_by_task(task_name=project.project_type)
option_names = [o.name for o in options]
return Response(option_names, status=status.HTTP_200_OK)
class AutoLabelingTemplateDetailAPI(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
def get(self, request, *args, **kwargs):
option_name = self.kwargs['option_name']
option = Options.find(option_name=option_name)
return Response(option.to_dict(), status=status.HTTP_200_OK)
class AutoLabelingConfigList(generics.ListCreateAPIView):
serializer_class = AutoLabelingConfigSerializer
pagination_class = None
permission_classes = [IsAuthenticated & IsProjectAdmin]
def get_queryset(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
return project.auto_labeling_config
def perform_create(self, serializer):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
serializer.save(project=project)
class AutoLabelingConfigDetail(generics.RetrieveUpdateDestroyAPIView):
queryset = AutoLabelingConfig.objects.all()
serializer_class = AutoLabelingConfigSerializer
lookup_url_kwarg = 'config_id'
permission_classes = [IsAuthenticated & IsProjectAdmin]
class AutoLabelingConfigTest(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
def post(self, *args, **kwargs):
try:
output = self.pass_config_validation()
output = self.pass_pipeline_call(output)
if not output:
raise SampleDataException()
return Response(
data={'valid': True, 'labels': output},
status=status.HTTP_200_OK
)
except requests.exceptions.ConnectionError:
raise URLConnectionError()
except botocore.exceptions.ClientError:
raise AWSTokenError()
except ValidationError as e:
raise e
except Exception as e:
raise e
def pass_config_validation(self):
config = self.request.data['config']
serializer = AutoLabelingConfigSerializer(data=config)
serializer.is_valid(raise_exception=True)
return serializer
def pass_pipeline_call(self, serializer):
test_input = self.request.data['input']
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
return execute_pipeline(
text=test_input,
project_type=project.project_type,
model_name=serializer.data.get('model_name'),
model_attrs=serializer.data.get('model_attrs'),
template=serializer.data.get('template'),
label_mapping=serializer.data.get('label_mapping')
)
class AutoLabelingConfigParameterTest(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
@property
def project(self):
return get_object_or_404(Project, pk=self.kwargs['project_id'])
def create_model(self):
model_name = self.request.data['model_name']
model_attrs = self.request.data['model_attrs']
try:
model = RequestModelFactory.create(model_name, model_attrs)
return model
except Exception:
model = RequestModelFactory.find(model_name)
schema = model.schema()
required_fields = ', '.join(schema['required']) if 'required' in schema else ''
raise ValidationError(
'The attributes does not match the model.'
'You need to correctly specify the required fields: {}'.format(required_fields)
)
def send_request(self, model, example):
try:
response = model.send(example)
return response
except requests.exceptions.ConnectionError:
raise URLConnectionError
except botocore.exceptions.ClientError:
raise AWSTokenError()
except Exception as e:
raise e
def prepare_example(self):
text = self.request.data['text']
if self.project.is_text_project:
return text
else:
tu = TemporaryUpload.objects.get(upload_id=text)
return tu.get_file_path()
def post(self, *args, **kwargs):
model = self.create_model()
example = self.prepare_example()
response = self.send_request(model=model, example=example)
return Response(response, status=status.HTTP_200_OK)
class AutoLabelingTemplateTest(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
def post(self, *args, **kwargs):
response = self.request.data['response']
template = self.request.data['template']
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
task = TaskFactory.create(project.project_type)
template = MappingTemplate(
label_collection=task.label_collection,
template=template
)
try:
labels = template.render(response)
except json.decoder.JSONDecodeError:
raise TemplateMappingError()
if not labels.dict():
raise SampleDataException()
return Response(labels.dict(), status=status.HTTP_200_OK)
class AutoLabelingMappingTest(APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
def post(self, *args, **kwargs):
response = self.request.data['response']
label_mapping = self.request.data['label_mapping']
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
task = TaskFactory.create(project.project_type)
labels = task.label_collection(response)
post_processor = PostProcessor(label_mapping)
labels = post_processor.transform(labels)
return Response(labels.dict(), status=status.HTTP_200_OK)
class AutoLabelingAnnotation(generics.CreateAPIView):
pagination_class = None
permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
swagger_schema = None
def get_serializer_class(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
self.serializer_class = get_annotation_serializer(task=project.project_type)
return self.serializer_class
def cannot_annotate(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
example = get_object_or_404(Example, pk=self.kwargs['example_id'])
return example.is_labeled(project.collaborative_annotation, self.request.user)
def create(self, request, *args, **kwargs):
if self.cannot_annotate():
raise AutoLabelingException()
labels = self.extract()
labels = self.transform(labels)
serializer = self.get_serializer(data=labels, many=True)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def perform_create(self, serializer):
serializer.save(user=self.request.user)
def get_example(self, project):
example = get_object_or_404(Example, pk=self.kwargs['example_id'])
if project.is_text_project:
return example.text
else:
return str(example.filename)
def extract(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
example = self.get_example(project)
config = project.auto_labeling_config.first()
if not config:
raise AutoLabelingPermissionDenied()
return execute_pipeline(
text=example,
project_type=project.project_type,
model_name=config.model_name,
model_attrs=config.model_attrs,
template=config.template,
label_mapping=config.label_mapping
)
def transform(self, labels):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
for label in labels:
label['example'] = self.kwargs['example_id']
if 'label' in label:
label['label'] = project.labels.get(text=label.pop('label')).id
return labels
def execute_pipeline(text: str,
project_type: str,
model_name: str,
model_attrs: dict,
template: str,
label_mapping: dict):
task = TaskFactory.create(project_type)
model = RequestModelFactory.create(
model_name=model_name,
attributes=model_attrs
)
template = MappingTemplate(
label_collection=task.label_collection,
template=template
)
post_processor = PostProcessor(label_mapping)
labels = pipeline(
text=text,
request_model=model,
mapping_template=template,
post_processing=post_processor
)
return labels.dict()