You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
5.9 KiB

3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
  1. import json
  2. import botocore.exceptions
  3. import requests
  4. from auto_labeling_pipeline.mappings import MappingTemplate
  5. from auto_labeling_pipeline.menu import Options
  6. from auto_labeling_pipeline.models import RequestModelFactory
  7. from auto_labeling_pipeline.postprocessing import PostProcessor
  8. from django.shortcuts import get_object_or_404
  9. from django_drf_filepond.models import TemporaryUpload
  10. from rest_framework import generics, status
  11. from rest_framework.exceptions import ValidationError
  12. from rest_framework.permissions import IsAuthenticated
  13. from rest_framework.request import Request
  14. from rest_framework.response import Response
  15. from rest_framework.views import APIView
  16. from .exceptions import (
  17. AWSTokenError,
  18. SampleDataException,
  19. TemplateMappingError,
  20. URLConnectionError,
  21. )
  22. from .models import AutoLabelingConfig
  23. from .pipeline.execution import execute_pipeline, get_label_collection
  24. from .serializers import AutoLabelingConfigSerializer
  25. from projects.models import Project
  26. from projects.permissions import IsProjectAdmin, IsProjectMember
  27. class TemplateListAPI(APIView):
  28. permission_classes = [IsAuthenticated & IsProjectAdmin]
  29. def get(self, request: Request, *args, **kwargs):
  30. task_name = request.query_params.get("task_name")
  31. options = Options.filter_by_task(task_name=task_name)
  32. option_names = [o.name for o in options]
  33. return Response(option_names, status=status.HTTP_200_OK)
  34. class TemplateDetailAPI(APIView):
  35. permission_classes = [IsAuthenticated & IsProjectAdmin]
  36. def get(self, request, *args, **kwargs):
  37. option = Options.find(option_name=self.kwargs["option_name"])
  38. return Response(option.to_dict(), status=status.HTTP_200_OK)
  39. class ConfigList(generics.ListCreateAPIView):
  40. serializer_class = AutoLabelingConfigSerializer
  41. pagination_class = None
  42. permission_classes = [IsAuthenticated & IsProjectAdmin]
  43. def get_queryset(self):
  44. return AutoLabelingConfig.objects.filter(project=self.kwargs["project_id"])
  45. def perform_create(self, serializer):
  46. serializer.save(project_id=self.kwargs["project_id"])
  47. class ConfigDetail(generics.RetrieveUpdateDestroyAPIView):
  48. queryset = AutoLabelingConfig.objects.all()
  49. serializer_class = AutoLabelingConfigSerializer
  50. lookup_url_kwarg = "config_id"
  51. permission_classes = [IsAuthenticated & IsProjectAdmin]
  52. class RestAPIRequestTesting(APIView):
  53. permission_classes = [IsAuthenticated & IsProjectAdmin]
  54. @property
  55. def project(self):
  56. return get_object_or_404(Project, pk=self.kwargs["project_id"])
  57. def create_model(self):
  58. model_name = self.request.data["model_name"]
  59. model_attrs = self.request.data["model_attrs"]
  60. try:
  61. model = RequestModelFactory.create(model_name, model_attrs)
  62. return model
  63. except Exception:
  64. model = RequestModelFactory.find(model_name)
  65. schema = model.schema()
  66. required_fields = ", ".join(schema["required"]) if "required" in schema else ""
  67. raise ValidationError(
  68. "The attributes does not match the model."
  69. "You need to correctly specify the required fields: {}".format(required_fields)
  70. )
  71. def send_request(self, model, example):
  72. try:
  73. return model.send(example)
  74. except requests.exceptions.ConnectionError:
  75. raise URLConnectionError
  76. except botocore.exceptions.ClientError:
  77. raise AWSTokenError()
  78. except Exception as e:
  79. raise e
  80. def prepare_example(self):
  81. text = self.request.data["text"]
  82. if self.project.is_text_project:
  83. return text
  84. else:
  85. tu = TemporaryUpload.objects.get(upload_id=text)
  86. return tu.get_file_path()
  87. def post(self, *args, **kwargs):
  88. model = self.create_model()
  89. example = self.prepare_example()
  90. response = self.send_request(model=model, example=example)
  91. return Response(response, status=status.HTTP_200_OK)
  92. class LabelExtractorTesting(APIView):
  93. permission_classes = [IsAuthenticated & IsProjectAdmin]
  94. def post(self, *args, **kwargs):
  95. response = self.request.data["response"]
  96. template = self.request.data["template"]
  97. task_type = self.request.data["task_type"]
  98. label_collection = get_label_collection(task_type)
  99. template = MappingTemplate(label_collection=label_collection, template=template)
  100. try:
  101. labels = template.render(response)
  102. except json.decoder.JSONDecodeError:
  103. raise TemplateMappingError()
  104. if not labels.dict():
  105. raise SampleDataException()
  106. return Response(labels.dict(), status=status.HTTP_200_OK)
  107. class LabelMapperTesting(APIView):
  108. permission_classes = [IsAuthenticated & IsProjectAdmin]
  109. def post(self, *args, **kwargs):
  110. response = self.request.data["response"]
  111. task_type = self.request.data["task_type"]
  112. label_mapping = self.request.data["label_mapping"]
  113. label_collection = get_label_collection(task_type)
  114. labels = label_collection(response)
  115. post_processor = PostProcessor(label_mapping)
  116. labels = post_processor.transform(labels)
  117. return Response(labels.dict(), status=status.HTTP_200_OK)
  118. class AutomatedLabeling(generics.CreateAPIView):
  119. permission_classes = [IsAuthenticated & IsProjectMember]
  120. swagger_schema = None
  121. def create(self, request, *args, **kwargs):
  122. project = get_object_or_404(Project, pk=self.kwargs["project_id"])
  123. example = project.examples.get(pk=self.request.query_params["example"])
  124. configs = AutoLabelingConfig.objects.filter(project=project)
  125. # Todo: make async calls or celery tasks to reduce waiting time.
  126. for config in configs:
  127. labels = execute_pipeline(example.data, config=config)
  128. labels.save(project, example, self.request.user)
  129. return Response({"ok": True}, status=status.HTTP_201_CREATED)