You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

157 lines
5.9 KiB

3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. import json
  2. import botocore.exceptions
  3. import requests
  4. from auto_labeling_pipeline.mappings import MappingTemplate
  5. from auto_labeling_pipeline.menu import Options
  6. from auto_labeling_pipeline.models import RequestModelFactory
  7. from auto_labeling_pipeline.postprocessing import PostProcessor
  8. from django.shortcuts import get_object_or_404
  9. from django_drf_filepond.models import TemporaryUpload
  10. from rest_framework import generics, status
  11. from rest_framework.exceptions import ValidationError
  12. from rest_framework.permissions import IsAuthenticated
  13. from rest_framework.request import Request
  14. from rest_framework.response import Response
  15. from rest_framework.views import APIView
  16. from api.models import Project
  17. from members.permissions import IsProjectMember, IsProjectAdmin
  18. from .pipeline.execution import execute_pipeline, get_label_collection
  19. from .exceptions import AWSTokenError, SampleDataException, TemplateMappingError, URLConnectionError
  20. from .models import AutoLabelingConfig
  21. from .serializers import AutoLabelingConfigSerializer
  22. class TemplateListAPI(APIView):
  23. permission_classes = [IsAuthenticated & IsProjectAdmin]
  24. def get(self, request: Request, *args, **kwargs):
  25. task_name = request.query_params.get('task_name')
  26. options = Options.filter_by_task(task_name=task_name)
  27. option_names = [o.name for o in options]
  28. return Response(option_names, status=status.HTTP_200_OK)
  29. class TemplateDetailAPI(APIView):
  30. permission_classes = [IsAuthenticated & IsProjectAdmin]
  31. def get(self, request, *args, **kwargs):
  32. option = Options.find(option_name=self.kwargs['option_name'])
  33. return Response(option.to_dict(), status=status.HTTP_200_OK)
  34. class ConfigList(generics.ListCreateAPIView):
  35. serializer_class = AutoLabelingConfigSerializer
  36. pagination_class = None
  37. permission_classes = [IsAuthenticated & IsProjectAdmin]
  38. def get_queryset(self):
  39. return AutoLabelingConfig.objects.filter(project=self.kwargs['project_id'])
  40. def perform_create(self, serializer):
  41. serializer.save(project_id=self.kwargs['project_id'])
  42. class ConfigDetail(generics.RetrieveUpdateDestroyAPIView):
  43. queryset = AutoLabelingConfig.objects.all()
  44. serializer_class = AutoLabelingConfigSerializer
  45. lookup_url_kwarg = 'config_id'
  46. permission_classes = [IsAuthenticated & IsProjectAdmin]
  47. class RestAPIRequestTesting(APIView):
  48. permission_classes = [IsAuthenticated & IsProjectAdmin]
  49. @property
  50. def project(self):
  51. return get_object_or_404(Project, pk=self.kwargs['project_id'])
  52. def create_model(self):
  53. model_name = self.request.data['model_name']
  54. model_attrs = self.request.data['model_attrs']
  55. try:
  56. model = RequestModelFactory.create(model_name, model_attrs)
  57. return model
  58. except Exception:
  59. model = RequestModelFactory.find(model_name)
  60. schema = model.schema()
  61. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  62. raise ValidationError(
  63. 'The attributes does not match the model.'
  64. 'You need to correctly specify the required fields: {}'.format(required_fields)
  65. )
  66. def send_request(self, model, example):
  67. try:
  68. return model.send(example)
  69. except requests.exceptions.ConnectionError:
  70. raise URLConnectionError
  71. except botocore.exceptions.ClientError:
  72. raise AWSTokenError()
  73. except Exception as e:
  74. raise e
  75. def prepare_example(self):
  76. text = self.request.data['text']
  77. if self.project.is_text_project:
  78. return text
  79. else:
  80. tu = TemporaryUpload.objects.get(upload_id=text)
  81. return tu.get_file_path()
  82. def post(self, *args, **kwargs):
  83. model = self.create_model()
  84. example = self.prepare_example()
  85. response = self.send_request(model=model, example=example)
  86. return Response(response, status=status.HTTP_200_OK)
  87. class LabelExtractorTesting(APIView):
  88. permission_classes = [IsAuthenticated & IsProjectAdmin]
  89. def post(self, *args, **kwargs):
  90. response = self.request.data['response']
  91. template = self.request.data['template']
  92. task_type = self.request.data['task_type']
  93. label_collection = get_label_collection(task_type)
  94. template = MappingTemplate(
  95. label_collection=label_collection,
  96. template=template
  97. )
  98. try:
  99. labels = template.render(response)
  100. except json.decoder.JSONDecodeError:
  101. raise TemplateMappingError()
  102. if not labels.dict():
  103. raise SampleDataException()
  104. return Response(labels.dict(), status=status.HTTP_200_OK)
  105. class LabelMapperTesting(APIView):
  106. permission_classes = [IsAuthenticated & IsProjectAdmin]
  107. def post(self, *args, **kwargs):
  108. response = self.request.data['response']
  109. task_type = self.request.data['task_type']
  110. label_mapping = self.request.data['label_mapping']
  111. label_collection = get_label_collection(task_type)
  112. labels = label_collection(response)
  113. post_processor = PostProcessor(label_mapping)
  114. labels = post_processor.transform(labels)
  115. return Response(labels.dict(), status=status.HTTP_200_OK)
  116. class AutomatedLabeling(generics.CreateAPIView):
  117. permission_classes = [IsAuthenticated & IsProjectMember]
  118. swagger_schema = None
  119. def create(self, request, *args, **kwargs):
  120. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  121. example = project.examples.get(pk=self.request.query_params['example'])
  122. configs = AutoLabelingConfig.objects.filter(project=project)
  123. # Todo: make async calls or celery tasks to reduce waiting time.
  124. for config in configs:
  125. labels = execute_pipeline(example.data, config=config)
  126. labels.save(project, example, self.request.user)
  127. return Response({'ok': True}, status=status.HTTP_201_CREATED)