You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

278 lines
10 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. import json
  2. import botocore.exceptions
  3. import requests
  4. from auto_labeling_pipeline.mappings import MappingTemplate
  5. from auto_labeling_pipeline.menu import Options
  6. from auto_labeling_pipeline.models import RequestModelFactory
  7. from auto_labeling_pipeline.pipeline import pipeline
  8. from auto_labeling_pipeline.postprocessing import PostProcessor
  9. from auto_labeling_pipeline.task import TaskFactory
  10. from django.shortcuts import get_object_or_404
  11. from django_drf_filepond.models import TemporaryUpload
  12. from rest_framework import generics, status
  13. from rest_framework.exceptions import ValidationError
  14. from rest_framework.permissions import IsAuthenticated
  15. from rest_framework.response import Response
  16. from rest_framework.views import APIView
  17. from ..exceptions import (AutoLabelingException, AutoLabelingPermissionDenied,
  18. AWSTokenError, SampleDataException,
  19. TemplateMappingError, URLConnectionError)
  20. from ..models import AutoLabelingConfig, Example, Project
  21. from ..permissions import IsInProjectOrAdmin, IsProjectAdmin
  22. from ..serializers import (AutoLabelingConfigSerializer,
  23. get_annotation_serializer)
  24. class AutoLabelingTemplateListAPI(APIView):
  25. permission_classes = [IsAuthenticated & IsProjectAdmin]
  26. def get(self, request, *args, **kwargs):
  27. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  28. options = Options.filter_by_task(task_name=project.project_type)
  29. option_names = [o.name for o in options]
  30. return Response(option_names, status=status.HTTP_200_OK)
  31. class AutoLabelingTemplateDetailAPI(APIView):
  32. permission_classes = [IsAuthenticated & IsProjectAdmin]
  33. def get(self, request, *args, **kwargs):
  34. option_name = self.kwargs['option_name']
  35. option = Options.find(option_name=option_name)
  36. return Response(option.to_dict(), status=status.HTTP_200_OK)
  37. class AutoLabelingConfigList(generics.ListCreateAPIView):
  38. serializer_class = AutoLabelingConfigSerializer
  39. pagination_class = None
  40. permission_classes = [IsAuthenticated & IsProjectAdmin]
  41. def get_queryset(self):
  42. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  43. return project.auto_labeling_config
  44. def perform_create(self, serializer):
  45. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  46. serializer.save(project=project)
  47. class AutoLabelingConfigDetail(generics.RetrieveUpdateDestroyAPIView):
  48. queryset = AutoLabelingConfig.objects.all()
  49. serializer_class = AutoLabelingConfigSerializer
  50. lookup_url_kwarg = 'config_id'
  51. permission_classes = [IsAuthenticated & IsProjectAdmin]
  52. class AutoLabelingConfigTest(APIView):
  53. permission_classes = [IsAuthenticated & IsProjectAdmin]
  54. def post(self, *args, **kwargs):
  55. try:
  56. output = self.pass_config_validation()
  57. output = self.pass_pipeline_call(output)
  58. if not output:
  59. raise SampleDataException()
  60. return Response(
  61. data={'valid': True, 'labels': output},
  62. status=status.HTTP_200_OK
  63. )
  64. except requests.exceptions.ConnectionError:
  65. raise URLConnectionError()
  66. except botocore.exceptions.ClientError:
  67. raise AWSTokenError()
  68. except ValidationError as e:
  69. raise e
  70. except Exception as e:
  71. raise e
  72. def pass_config_validation(self):
  73. config = self.request.data['config']
  74. serializer = AutoLabelingConfigSerializer(data=config)
  75. serializer.is_valid(raise_exception=True)
  76. return serializer
  77. def pass_pipeline_call(self, serializer):
  78. test_input = self.request.data['input']
  79. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  80. return execute_pipeline(
  81. text=test_input,
  82. project_type=project.project_type,
  83. model_name=serializer.data.get('model_name'),
  84. model_attrs=serializer.data.get('model_attrs'),
  85. template=serializer.data.get('template'),
  86. label_mapping=serializer.data.get('label_mapping')
  87. )
  88. class AutoLabelingConfigParameterTest(APIView):
  89. permission_classes = [IsAuthenticated & IsProjectAdmin]
  90. @property
  91. def project(self):
  92. return get_object_or_404(Project, pk=self.kwargs['project_id'])
  93. def create_model(self):
  94. model_name = self.request.data['model_name']
  95. model_attrs = self.request.data['model_attrs']
  96. try:
  97. model = RequestModelFactory.create(model_name, model_attrs)
  98. return model
  99. except Exception:
  100. model = RequestModelFactory.find(model_name)
  101. schema = model.schema()
  102. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  103. raise ValidationError(
  104. 'The attributes does not match the model.'
  105. 'You need to correctly specify the required fields: {}'.format(required_fields)
  106. )
  107. def send_request(self, model, example):
  108. try:
  109. response = model.send(example)
  110. return response
  111. except requests.exceptions.ConnectionError:
  112. raise URLConnectionError
  113. except botocore.exceptions.ClientError:
  114. raise AWSTokenError()
  115. except Exception as e:
  116. raise e
  117. def prepare_example(self):
  118. text = self.request.data['text']
  119. if self.project.is_task_of('text'):
  120. return text
  121. else:
  122. tu = TemporaryUpload.objects.get(upload_id=text)
  123. return tu.get_file_path()
  124. def post(self, *args, **kwargs):
  125. model = self.create_model()
  126. example = self.prepare_example()
  127. response = self.send_request(model=model, example=example)
  128. return Response(response, status=status.HTTP_200_OK)
  129. class AutoLabelingTemplateTest(APIView):
  130. permission_classes = [IsAuthenticated & IsProjectAdmin]
  131. def post(self, *args, **kwargs):
  132. response = self.request.data['response']
  133. template = self.request.data['template']
  134. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  135. task = TaskFactory.create(project.project_type)
  136. template = MappingTemplate(
  137. label_collection=task.label_collection,
  138. template=template
  139. )
  140. try:
  141. labels = template.render(response)
  142. except json.decoder.JSONDecodeError:
  143. raise TemplateMappingError()
  144. if not labels.dict():
  145. raise SampleDataException()
  146. return Response(labels.dict(), status=status.HTTP_200_OK)
  147. class AutoLabelingMappingTest(APIView):
  148. permission_classes = [IsAuthenticated & IsProjectAdmin]
  149. def post(self, *args, **kwargs):
  150. response = self.request.data['response']
  151. label_mapping = self.request.data['label_mapping']
  152. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  153. task = TaskFactory.create(project.project_type)
  154. labels = task.label_collection(response)
  155. post_processor = PostProcessor(label_mapping)
  156. labels = post_processor.transform(labels)
  157. return Response(labels.dict(), status=status.HTTP_200_OK)
  158. class AutoLabelingAnnotation(generics.CreateAPIView):
  159. pagination_class = None
  160. permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
  161. swagger_schema = None
  162. def get_serializer_class(self):
  163. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  164. self.serializer_class = get_annotation_serializer(task=project.project_type)
  165. return self.serializer_class
  166. def get_queryset(self):
  167. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  168. model = project.get_annotation_class()
  169. queryset = model.objects.filter(example=self.kwargs['example_id'])
  170. if not project.collaborative_annotation:
  171. queryset = queryset.filter(user=self.request.user)
  172. return queryset
  173. def create(self, request, *args, **kwargs):
  174. queryset = self.get_queryset()
  175. if queryset.exists():
  176. raise AutoLabelingException()
  177. labels = self.extract()
  178. labels = self.transform(labels)
  179. serializer = self.get_serializer(data=labels, many=True)
  180. serializer.is_valid(raise_exception=True)
  181. self.perform_create(serializer)
  182. headers = self.get_success_headers(serializer.data)
  183. return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
  184. def perform_create(self, serializer):
  185. serializer.save(user=self.request.user)
  186. def get_example(self, project):
  187. example = get_object_or_404(Example, pk=self.kwargs['example_id'])
  188. if project.is_task_of('text'):
  189. return example.text
  190. else:
  191. return str(example.filename)
  192. def extract(self):
  193. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  194. example = self.get_example(project)
  195. config = project.auto_labeling_config.first()
  196. if not config:
  197. raise AutoLabelingPermissionDenied()
  198. return execute_pipeline(
  199. text=example,
  200. project_type=project.project_type,
  201. model_name=config.model_name,
  202. model_attrs=config.model_attrs,
  203. template=config.template,
  204. label_mapping=config.label_mapping
  205. )
  206. def transform(self, labels):
  207. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  208. for label in labels:
  209. label['example'] = self.kwargs['example_id']
  210. if 'label' in label:
  211. label['label'] = project.labels.get(text=label.pop('label')).id
  212. return labels
  213. def execute_pipeline(text: str,
  214. project_type: str,
  215. model_name: str,
  216. model_attrs: dict,
  217. template: str,
  218. label_mapping: dict):
  219. task = TaskFactory.create(project_type)
  220. model = RequestModelFactory.create(
  221. model_name=model_name,
  222. attributes=model_attrs
  223. )
  224. template = MappingTemplate(
  225. label_collection=task.label_collection,
  226. template=template
  227. )
  228. post_processor = PostProcessor(label_mapping)
  229. labels = pipeline(
  230. text=text,
  231. request_model=model,
  232. mapping_template=template,
  233. post_processing=post_processor
  234. )
  235. return labels.dict()