You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

245 lines
9.3 KiB

3 years ago
3 years ago
3 years ago
  1. import botocore.exceptions
  2. import requests
  3. from auto_labeling_pipeline.mappings import MappingTemplate
  4. from auto_labeling_pipeline.menu import Options
  5. from auto_labeling_pipeline.models import RequestModelFactory
  6. from auto_labeling_pipeline.pipeline import pipeline
  7. from auto_labeling_pipeline.postprocessing import PostProcessor
  8. from auto_labeling_pipeline.task import TaskFactory
  9. from django.shortcuts import get_object_or_404
  10. from rest_framework import generics, status
  11. from rest_framework.exceptions import ValidationError
  12. from rest_framework.permissions import IsAuthenticated
  13. from rest_framework.response import Response
  14. from rest_framework.views import APIView
  15. from ..exceptions import (AutoLabeliingPermissionDenied, AutoLabelingException,
  16. AWSTokenError, SampleDataException,
  17. URLConnectionError)
  18. from ..models import AutoLabelingConfig, Document, Project
  19. from ..permissions import IsInProjectOrAdmin, IsProjectAdmin
  20. from ..serializers import AutoLabelingConfigSerializer
  21. class AutoLabelingTemplateListAPI(APIView):
  22. permission_classes = [IsAuthenticated & IsProjectAdmin]
  23. def get(self, request, *args, **kwargs):
  24. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  25. options = Options.filter_by_task(task_name=project.project_type)
  26. option_names = [o.name for o in options]
  27. return Response(option_names, status=status.HTTP_200_OK)
  28. class AutoLabelingTemplateDetailAPI(APIView):
  29. permission_classes = [IsAuthenticated & IsProjectAdmin]
  30. def get(self, request, *args, **kwargs):
  31. option_name = self.kwargs['option_name']
  32. option = Options.find(option_name=option_name)
  33. return Response(option.to_dict(), status=status.HTTP_200_OK)
  34. class AutoLabelingConfigList(generics.ListCreateAPIView):
  35. serializer_class = AutoLabelingConfigSerializer
  36. pagination_class = None
  37. permission_classes = [IsAuthenticated & IsProjectAdmin]
  38. def get_queryset(self):
  39. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  40. return project.auto_labeling_config
  41. def perform_create(self, serializer):
  42. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  43. serializer.save(project=project)
  44. class AutoLabelingConfigDetail(generics.RetrieveUpdateDestroyAPIView):
  45. queryset = AutoLabelingConfig.objects.all()
  46. serializer_class = AutoLabelingConfigSerializer
  47. lookup_url_kwarg = 'config_id'
  48. permission_classes = [IsAuthenticated & IsProjectAdmin]
  49. class AutoLabelingConfigTest(APIView):
  50. permission_classes = [IsAuthenticated & IsProjectAdmin]
  51. def post(self, *args, **kwargs):
  52. try:
  53. output = self.pass_config_validation()
  54. output = self.pass_pipeline_call(output)
  55. if not output:
  56. raise SampleDataException()
  57. return Response(
  58. data={'valid': True, 'labels': output},
  59. status=status.HTTP_200_OK
  60. )
  61. except requests.exceptions.ConnectionError:
  62. raise URLConnectionError()
  63. except botocore.exceptions.ClientError:
  64. raise AWSTokenError()
  65. except ValidationError as e:
  66. raise e
  67. except Exception as e:
  68. raise e
  69. def pass_config_validation(self):
  70. config = self.request.data['config']
  71. serializer = AutoLabelingConfigSerializer(data=config)
  72. serializer.is_valid(raise_exception=True)
  73. return serializer
  74. def pass_pipeline_call(self, serializer):
  75. test_input = self.request.data['input']
  76. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  77. return execute_pipeline(
  78. text=test_input,
  79. project_type=project.project_type,
  80. model_name=serializer.data.get('model_name'),
  81. model_attrs=serializer.data.get('model_attrs'),
  82. template=serializer.data.get('template'),
  83. label_mapping=serializer.data.get('label_mapping')
  84. )
  85. class AutoLabelingConfigParameterTest(APIView):
  86. permission_classes = [IsAuthenticated & IsProjectAdmin]
  87. def post(self, *args, **kwargs):
  88. model_name = self.request.data['model_name']
  89. model_attrs = self.request.data['model_attrs']
  90. sample_text = self.request.data['text']
  91. try:
  92. model = RequestModelFactory.create(model_name, model_attrs)
  93. except Exception:
  94. model = RequestModelFactory.find(model_name)
  95. schema = model.schema()
  96. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  97. raise ValidationError(
  98. 'The attributes does not match the model.'
  99. 'You need to correctly specify the required fields: {}'.format(required_fields)
  100. )
  101. try:
  102. request = model.build()
  103. response = request.send(text=sample_text)
  104. return Response(response, status=status.HTTP_200_OK)
  105. except requests.exceptions.ConnectionError:
  106. raise URLConnectionError
  107. except botocore.exceptions.ClientError:
  108. raise AWSTokenError()
  109. except Exception as e:
  110. raise e
  111. class AutoLabelingTemplateTest(APIView):
  112. permission_classes = [IsAuthenticated & IsProjectAdmin]
  113. def post(self, *args, **kwargs):
  114. response = self.request.data['response']
  115. template = self.request.data['template']
  116. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  117. task = TaskFactory.create(project.project_type)
  118. template = MappingTemplate(
  119. label_collection=task.label_collection,
  120. template=template
  121. )
  122. labels = template.render(response)
  123. if not labels.dict():
  124. raise SampleDataException()
  125. return Response(labels.dict(), status=status.HTTP_200_OK)
  126. class AutoLabelingMappingTest(APIView):
  127. permission_classes = [IsAuthenticated & IsProjectAdmin]
  128. def post(self, *args, **kwargs):
  129. response = self.request.data['response']
  130. label_mapping = self.request.data['label_mapping']
  131. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  132. task = TaskFactory.create(project.project_type)
  133. labels = task.label_collection(response)
  134. post_processor = PostProcessor(label_mapping)
  135. labels = post_processor.transform(labels)
  136. return Response(labels.dict(), status=status.HTTP_200_OK)
  137. class AutoLabelingAnnotation(generics.CreateAPIView):
  138. pagination_class = None
  139. permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
  140. swagger_schema = None
  141. def get_serializer_class(self):
  142. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  143. self.serializer_class = project.get_annotation_serializer()
  144. return self.serializer_class
  145. def get_queryset(self):
  146. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  147. model = project.get_annotation_class()
  148. queryset = model.objects.filter(document=self.kwargs['doc_id'])
  149. if not project.collaborative_annotation:
  150. queryset = queryset.filter(user=self.request.user)
  151. return queryset
  152. def create(self, request, *args, **kwargs):
  153. queryset = self.get_queryset()
  154. if queryset.exists():
  155. raise AutoLabelingException()
  156. labels = self.extract()
  157. labels = self.transform(labels)
  158. serializer = self.get_serializer(data=labels, many=True)
  159. serializer.is_valid(raise_exception=True)
  160. self.perform_create(serializer)
  161. headers = self.get_success_headers(serializer.data)
  162. return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
  163. def perform_create(self, serializer):
  164. serializer.save(user=self.request.user)
  165. def extract(self):
  166. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  167. doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  168. config = project.auto_labeling_config.first()
  169. if not config:
  170. raise AutoLabeliingPermissionDenied()
  171. return execute_pipeline(
  172. text=doc.text,
  173. project_type=project.project_type,
  174. model_name=config.model_name,
  175. model_attrs=config.model_attrs,
  176. template=config.template,
  177. label_mapping=config.label_mapping
  178. )
  179. def transform(self, labels):
  180. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  181. for label in labels:
  182. label['document'] = self.kwargs['doc_id']
  183. if 'label' in label:
  184. label['label'] = project.labels.get(text=label.pop('label')).id
  185. return labels
  186. def execute_pipeline(text: str,
  187. project_type: str,
  188. model_name: str,
  189. model_attrs: dict,
  190. template: str,
  191. label_mapping: dict):
  192. task = TaskFactory.create(project_type)
  193. model = RequestModelFactory.create(
  194. model_name=model_name,
  195. attributes=model_attrs
  196. )
  197. template = MappingTemplate(
  198. label_collection=task.label_collection,
  199. template=template
  200. )
  201. post_processor = PostProcessor(label_mapping)
  202. labels = pipeline(
  203. text=text,
  204. request_model=model,
  205. mapping_template=template,
  206. post_processing=post_processor
  207. )
  208. return labels.dict()