You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
5.9 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
  1. import json
  2. import botocore.exceptions
  3. import requests
  4. from auto_labeling_pipeline.mappings import MappingTemplate
  5. from auto_labeling_pipeline.menu import Options
  6. from auto_labeling_pipeline.models import RequestModelFactory
  7. from auto_labeling_pipeline.postprocessing import PostProcessor
  8. from django.shortcuts import get_object_or_404
  9. from django_drf_filepond.models import TemporaryUpload
  10. from rest_framework import generics, status
  11. from rest_framework.exceptions import ValidationError
  12. from rest_framework.permissions import IsAuthenticated
  13. from rest_framework.request import Request
  14. from rest_framework.response import Response
  15. from rest_framework.views import APIView
  16. from projects.models import Project
  17. from projects.permissions import IsProjectMember, IsProjectAdmin
  18. from .pipeline.execution import execute_pipeline, get_label_collection
  19. from .exceptions import AWSTokenError, SampleDataException, TemplateMappingError, URLConnectionError
  20. from .models import AutoLabelingConfig
  21. from .serializers import AutoLabelingConfigSerializer
  22. class TemplateListAPI(APIView):
  23. permission_classes = [IsAuthenticated & IsProjectAdmin]
  24. def get(self, request: Request, *args, **kwargs):
  25. task_name = request.query_params.get("task_name")
  26. options = Options.filter_by_task(task_name=task_name)
  27. option_names = [o.name for o in options]
  28. return Response(option_names, status=status.HTTP_200_OK)
  29. class TemplateDetailAPI(APIView):
  30. permission_classes = [IsAuthenticated & IsProjectAdmin]
  31. def get(self, request, *args, **kwargs):
  32. option = Options.find(option_name=self.kwargs["option_name"])
  33. return Response(option.to_dict(), status=status.HTTP_200_OK)
  34. class ConfigList(generics.ListCreateAPIView):
  35. serializer_class = AutoLabelingConfigSerializer
  36. pagination_class = None
  37. permission_classes = [IsAuthenticated & IsProjectAdmin]
  38. def get_queryset(self):
  39. return AutoLabelingConfig.objects.filter(project=self.kwargs["project_id"])
  40. def perform_create(self, serializer):
  41. serializer.save(project_id=self.kwargs["project_id"])
  42. class ConfigDetail(generics.RetrieveUpdateDestroyAPIView):
  43. queryset = AutoLabelingConfig.objects.all()
  44. serializer_class = AutoLabelingConfigSerializer
  45. lookup_url_kwarg = "config_id"
  46. permission_classes = [IsAuthenticated & IsProjectAdmin]
  47. class RestAPIRequestTesting(APIView):
  48. permission_classes = [IsAuthenticated & IsProjectAdmin]
  49. @property
  50. def project(self):
  51. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  52. def create_model(self):
  53. model_name = self.request.data["model_name"]
  54. model_attrs = self.request.data["model_attrs"]
  55. try:
  56. model = RequestModelFactory.create(model_name, model_attrs)
  57. return model
  58. except Exception:
  59. model = RequestModelFactory.find(model_name)
  60. schema = model.schema()
  61. required_fields = ", ".join(schema["required"]) if "required" in schema else ""
  62. raise ValidationError(
  63. "The attributes does not match the model."
  64. "You need to correctly specify the required fields: {}".format(required_fields)
  65. )
  66. def send_request(self, model, example):
  67. try:
  68. return model.send(example)
  69. except requests.exceptions.ConnectionError:
  70. raise URLConnectionError
  71. except botocore.exceptions.ClientError:
  72. raise AWSTokenError()
  73. except Exception as e:
  74. raise e
  75. def prepare_example(self):
  76. text = self.request.data["text"]
  77. if self.project.is_text_project:
  78. return text
  79. else:
  80. tu = TemporaryUpload.objects.get(upload_id=text)
  81. return tu.get_file_path()
  82. def post(self, *args, **kwargs):
  83. model = self.create_model()
  84. example = self.prepare_example()
  85. response = self.send_request(model=model, example=example)
  86. return Response(response, status=status.HTTP_200_OK)
  87. class LabelExtractorTesting(APIView):
  88. permission_classes = [IsAuthenticated & IsProjectAdmin]
  89. def post(self, *args, **kwargs):
  90. response = self.request.data["response"]
  91. template = self.request.data["template"]
  92. task_type = self.request.data["task_type"]
  93. label_collection = get_label_collection(task_type)
  94. template = MappingTemplate(label_collection=label_collection, template=template)
  95. try:
  96. labels = template.render(response)
  97. except json.decoder.JSONDecodeError:
  98. raise TemplateMappingError()
  99. if not labels.dict():
  100. raise SampleDataException()
  101. return Response(labels.dict(), status=status.HTTP_200_OK)
  102. class LabelMapperTesting(APIView):
  103. permission_classes = [IsAuthenticated & IsProjectAdmin]
  104. def post(self, *args, **kwargs):
  105. response = self.request.data["response"]
  106. task_type = self.request.data["task_type"]
  107. label_mapping = self.request.data["label_mapping"]
  108. label_collection = get_label_collection(task_type)
  109. labels = label_collection(response)
  110. post_processor = PostProcessor(label_mapping)
  111. labels = post_processor.transform(labels)
  112. return Response(labels.dict(), status=status.HTTP_200_OK)
  113. class AutomatedLabeling(generics.CreateAPIView):
  114. permission_classes = [IsAuthenticated & IsProjectMember]
  115. swagger_schema = None
  116. def create(self, request, *args, **kwargs):
  117. project = get_object_or_404(Project, pk=self.kwargs["project_id"])
  118. example = project.examples.get(pk=self.request.query_params["example"])
  119. configs = AutoLabelingConfig.objects.filter(project=project)
  120. # Todo: make async calls or celery tasks to reduce waiting time.
  121. for config in configs:
  122. labels = execute_pipeline(example.data, config=config)
  123. labels.save(project, example, self.request.user)
  124. return Response({"ok": True}, status=status.HTTP_201_CREATED)