You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
6.0 KiB

3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
  1. import json
  2. import botocore.exceptions
  3. import requests
  4. from auto_labeling_pipeline.mappings import MappingTemplate
  5. from auto_labeling_pipeline.menu import Options
  6. from auto_labeling_pipeline.models import RequestModelFactory
  7. from auto_labeling_pipeline.postprocessing import PostProcessor
  8. from django.shortcuts import get_object_or_404
  9. from django_drf_filepond.models import TemporaryUpload
  10. from rest_framework import generics, status
  11. from rest_framework.exceptions import ValidationError
  12. from rest_framework.permissions import IsAuthenticated
  13. from rest_framework.request import Request
  14. from rest_framework.response import Response
  15. from rest_framework.views import APIView
  16. from .exceptions import (
  17. AWSTokenError,
  18. ResponseJSONDecodeError,
  19. SampleDataException,
  20. TemplateMappingError,
  21. URLConnectionError,
  22. )
  23. from .models import AutoLabelingConfig
  24. from .pipeline.execution import execute_pipeline, get_label_collection
  25. from .serializers import AutoLabelingConfigSerializer
  26. from projects.models import Project
  27. from projects.permissions import IsProjectAdmin, IsProjectMember
  28. class TemplateListAPI(APIView):
  29. permission_classes = [IsAuthenticated & IsProjectAdmin]
  30. def get(self, request: Request, *args, **kwargs):
  31. task_name = request.query_params.get("task_name")
  32. options = Options.filter_by_task(task_name=task_name)
  33. option_names = [o.name for o in options]
  34. return Response(option_names, status=status.HTTP_200_OK)
  35. class TemplateDetailAPI(APIView):
  36. permission_classes = [IsAuthenticated & IsProjectAdmin]
  37. def get(self, request, *args, **kwargs):
  38. option = Options.find(option_name=self.kwargs["option_name"])
  39. return Response(option.to_dict(), status=status.HTTP_200_OK)
  40. class ConfigList(generics.ListCreateAPIView):
  41. serializer_class = AutoLabelingConfigSerializer
  42. pagination_class = None
  43. permission_classes = [IsAuthenticated & IsProjectAdmin]
  44. def get_queryset(self):
  45. return AutoLabelingConfig.objects.filter(project=self.kwargs["project_id"])
  46. def perform_create(self, serializer):
  47. serializer.save(project_id=self.kwargs["project_id"])
  48. class ConfigDetail(generics.RetrieveUpdateDestroyAPIView):
  49. queryset = AutoLabelingConfig.objects.all()
  50. serializer_class = AutoLabelingConfigSerializer
  51. lookup_url_kwarg = "config_id"
  52. permission_classes = [IsAuthenticated & IsProjectAdmin]
  53. class RestAPIRequestTesting(APIView):
  54. permission_classes = [IsAuthenticated & IsProjectAdmin]
  55. @property
  56. def project(self):
  57. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  58. def create_model(self):
  59. model_name = self.request.data["model_name"]
  60. model_attrs = self.request.data["model_attrs"]
  61. try:
  62. model = RequestModelFactory.create(model_name, model_attrs)
  63. return model
  64. except Exception:
  65. model = RequestModelFactory.find(model_name)
  66. schema = model.schema()
  67. required_fields = ", ".join(schema["required"]) if "required" in schema else ""
  68. raise ValidationError(
  69. "The attributes does not match the model."
  70. "You need to correctly specify the required fields: {}".format(required_fields)
  71. )
  72. def send_request(self, model, example):
  73. try:
  74. return model.send(example)
  75. except requests.exceptions.ConnectionError:
  76. raise URLConnectionError
  77. except botocore.exceptions.ClientError:
  78. raise AWSTokenError()
  79. except json.decoder.JSONDecodeError:
  80. raise ResponseJSONDecodeError()
  81. except Exception as e:
  82. raise e
  83. def prepare_example(self):
  84. text = self.request.data["text"]
  85. if self.project.is_text_project:
  86. return text
  87. else:
  88. tu = TemporaryUpload.objects.get(upload_id=text)
  89. return tu.get_file_path()
  90. def post(self, *args, **kwargs):
  91. model = self.create_model()
  92. example = self.prepare_example()
  93. response = self.send_request(model=model, example=example)
  94. return Response(response, status=status.HTTP_200_OK)
  95. class LabelExtractorTesting(APIView):
  96. permission_classes = [IsAuthenticated & IsProjectAdmin]
  97. def post(self, *args, **kwargs):
  98. response = self.request.data["response"]
  99. template = self.request.data["template"]
  100. task_type = self.request.data["task_type"]
  101. label_collection = get_label_collection(task_type)
  102. template = MappingTemplate(label_collection=label_collection, template=template)
  103. try:
  104. labels = template.render(response)
  105. except json.decoder.JSONDecodeError:
  106. raise TemplateMappingError()
  107. if not labels.dict():
  108. raise SampleDataException()
  109. return Response(labels.dict(), status=status.HTTP_200_OK)
  110. class LabelMapperTesting(APIView):
  111. permission_classes = [IsAuthenticated & IsProjectAdmin]
  112. def post(self, *args, **kwargs):
  113. response = self.request.data["response"]
  114. task_type = self.request.data["task_type"]
  115. label_mapping = self.request.data["label_mapping"]
  116. label_collection = get_label_collection(task_type)
  117. labels = label_collection(response)
  118. post_processor = PostProcessor(label_mapping)
  119. labels = post_processor.transform(labels)
  120. return Response(labels.dict(), status=status.HTTP_200_OK)
  121. class AutomatedLabeling(generics.CreateAPIView):
  122. permission_classes = [IsAuthenticated & IsProjectMember]
  123. swagger_schema = None
  124. def create(self, request, *args, **kwargs):
  125. project = get_object_or_404(Project, pk=self.kwargs["project_id"])
  126. example = project.examples.get(pk=self.request.query_params["example"])
  127. configs = AutoLabelingConfig.objects.filter(project=project)
  128. # Todo: make async calls or celery tasks to reduce waiting time.
  129. for config in configs:
  130. labels = execute_pipeline(example.data, config=config)
  131. labels.save(project, example, self.request.user)
  132. return Response({"ok": True}, status=status.HTTP_201_CREATED)