You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

246 lines
9.4 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. import botocore.exceptions
  2. import requests
  3. from auto_labeling_pipeline.mappings import MappingTemplate
  4. from auto_labeling_pipeline.menu import Options
  5. from auto_labeling_pipeline.models import RequestModelFactory
  6. from auto_labeling_pipeline.pipeline import pipeline
  7. from auto_labeling_pipeline.postprocessing import PostProcessor
  8. from auto_labeling_pipeline.task import TaskFactory
  9. from django.shortcuts import get_object_or_404
  10. from rest_framework import generics, status
  11. from rest_framework.exceptions import ValidationError
  12. from rest_framework.permissions import IsAuthenticated
  13. from rest_framework.response import Response
  14. from rest_framework.views import APIView
  15. from ..exceptions import (AutoLabeliingPermissionDenied, AutoLabelingException,
  16. AWSTokenError, SampleDataException,
  17. URLConnectionError)
  18. from ..models import AutoLabelingConfig, Document, Project
  19. from ..permissions import IsInProjectOrAdmin, IsProjectAdmin
  20. from ..serializers import (AutoLabelingConfigSerializer,
  21. get_annotation_serializer)
  22. class AutoLabelingTemplateListAPI(APIView):
  23. permission_classes = [IsAuthenticated & IsProjectAdmin]
  24. def get(self, request, *args, **kwargs):
  25. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  26. options = Options.filter_by_task(task_name=project.project_type)
  27. option_names = [o.name for o in options]
  28. return Response(option_names, status=status.HTTP_200_OK)
  29. class AutoLabelingTemplateDetailAPI(APIView):
  30. permission_classes = [IsAuthenticated & IsProjectAdmin]
  31. def get(self, request, *args, **kwargs):
  32. option_name = self.kwargs['option_name']
  33. option = Options.find(option_name=option_name)
  34. return Response(option.to_dict(), status=status.HTTP_200_OK)
  35. class AutoLabelingConfigList(generics.ListCreateAPIView):
  36. serializer_class = AutoLabelingConfigSerializer
  37. pagination_class = None
  38. permission_classes = [IsAuthenticated & IsProjectAdmin]
  39. def get_queryset(self):
  40. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  41. return project.auto_labeling_config
  42. def perform_create(self, serializer):
  43. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  44. serializer.save(project=project)
  45. class AutoLabelingConfigDetail(generics.RetrieveUpdateDestroyAPIView):
  46. queryset = AutoLabelingConfig.objects.all()
  47. serializer_class = AutoLabelingConfigSerializer
  48. lookup_url_kwarg = 'config_id'
  49. permission_classes = [IsAuthenticated & IsProjectAdmin]
  50. class AutoLabelingConfigTest(APIView):
  51. permission_classes = [IsAuthenticated & IsProjectAdmin]
  52. def post(self, *args, **kwargs):
  53. try:
  54. output = self.pass_config_validation()
  55. output = self.pass_pipeline_call(output)
  56. if not output:
  57. raise SampleDataException()
  58. return Response(
  59. data={'valid': True, 'labels': output},
  60. status=status.HTTP_200_OK
  61. )
  62. except requests.exceptions.ConnectionError:
  63. raise URLConnectionError()
  64. except botocore.exceptions.ClientError:
  65. raise AWSTokenError()
  66. except ValidationError as e:
  67. raise e
  68. except Exception as e:
  69. raise e
  70. def pass_config_validation(self):
  71. config = self.request.data['config']
  72. serializer = AutoLabelingConfigSerializer(data=config)
  73. serializer.is_valid(raise_exception=True)
  74. return serializer
  75. def pass_pipeline_call(self, serializer):
  76. test_input = self.request.data['input']
  77. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  78. return execute_pipeline(
  79. text=test_input,
  80. project_type=project.project_type,
  81. model_name=serializer.data.get('model_name'),
  82. model_attrs=serializer.data.get('model_attrs'),
  83. template=serializer.data.get('template'),
  84. label_mapping=serializer.data.get('label_mapping')
  85. )
  86. class AutoLabelingConfigParameterTest(APIView):
  87. permission_classes = [IsAuthenticated & IsProjectAdmin]
  88. def post(self, *args, **kwargs):
  89. model_name = self.request.data['model_name']
  90. model_attrs = self.request.data['model_attrs']
  91. sample_text = self.request.data['text']
  92. try:
  93. model = RequestModelFactory.create(model_name, model_attrs)
  94. except Exception:
  95. model = RequestModelFactory.find(model_name)
  96. schema = model.schema()
  97. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  98. raise ValidationError(
  99. 'The attributes does not match the model.'
  100. 'You need to correctly specify the required fields: {}'.format(required_fields)
  101. )
  102. try:
  103. request = model.build()
  104. response = request.send(text=sample_text)
  105. return Response(response, status=status.HTTP_200_OK)
  106. except requests.exceptions.ConnectionError:
  107. raise URLConnectionError
  108. except botocore.exceptions.ClientError:
  109. raise AWSTokenError()
  110. except Exception as e:
  111. raise e
  112. class AutoLabelingTemplateTest(APIView):
  113. permission_classes = [IsAuthenticated & IsProjectAdmin]
  114. def post(self, *args, **kwargs):
  115. response = self.request.data['response']
  116. template = self.request.data['template']
  117. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  118. task = TaskFactory.create(project.project_type)
  119. template = MappingTemplate(
  120. label_collection=task.label_collection,
  121. template=template
  122. )
  123. labels = template.render(response)
  124. if not labels.dict():
  125. raise SampleDataException()
  126. return Response(labels.dict(), status=status.HTTP_200_OK)
  127. class AutoLabelingMappingTest(APIView):
  128. permission_classes = [IsAuthenticated & IsProjectAdmin]
  129. def post(self, *args, **kwargs):
  130. response = self.request.data['response']
  131. label_mapping = self.request.data['label_mapping']
  132. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  133. task = TaskFactory.create(project.project_type)
  134. labels = task.label_collection(response)
  135. post_processor = PostProcessor(label_mapping)
  136. labels = post_processor.transform(labels)
  137. return Response(labels.dict(), status=status.HTTP_200_OK)
  138. class AutoLabelingAnnotation(generics.CreateAPIView):
  139. pagination_class = None
  140. permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
  141. swagger_schema = None
  142. def get_serializer_class(self):
  143. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  144. self.serializer_class = get_annotation_serializer(task=project.project_type)
  145. return self.serializer_class
  146. def get_queryset(self):
  147. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  148. model = project.get_annotation_class()
  149. queryset = model.objects.filter(document=self.kwargs['doc_id'])
  150. if not project.collaborative_annotation:
  151. queryset = queryset.filter(user=self.request.user)
  152. return queryset
  153. def create(self, request, *args, **kwargs):
  154. queryset = self.get_queryset()
  155. if queryset.exists():
  156. raise AutoLabelingException()
  157. labels = self.extract()
  158. labels = self.transform(labels)
  159. serializer = self.get_serializer(data=labels, many=True)
  160. serializer.is_valid(raise_exception=True)
  161. self.perform_create(serializer)
  162. headers = self.get_success_headers(serializer.data)
  163. return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
  164. def perform_create(self, serializer):
  165. serializer.save(user=self.request.user)
  166. def extract(self):
  167. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  168. doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  169. config = project.auto_labeling_config.first()
  170. if not config:
  171. raise AutoLabeliingPermissionDenied()
  172. return execute_pipeline(
  173. text=doc.text,
  174. project_type=project.project_type,
  175. model_name=config.model_name,
  176. model_attrs=config.model_attrs,
  177. template=config.template,
  178. label_mapping=config.label_mapping
  179. )
  180. def transform(self, labels):
  181. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  182. for label in labels:
  183. label['document'] = self.kwargs['doc_id']
  184. if 'label' in label:
  185. label['label'] = project.labels.get(text=label.pop('label')).id
  186. return labels
  187. def execute_pipeline(text: str,
  188. project_type: str,
  189. model_name: str,
  190. model_attrs: dict,
  191. template: str,
  192. label_mapping: dict):
  193. task = TaskFactory.create(project_type)
  194. model = RequestModelFactory.create(
  195. model_name=model_name,
  196. attributes=model_attrs
  197. )
  198. template = MappingTemplate(
  199. label_collection=task.label_collection,
  200. template=template
  201. )
  202. post_processor = PostProcessor(label_mapping)
  203. labels = pipeline(
  204. text=text,
  205. request_model=model,
  206. mapping_template=template,
  207. post_processing=post_processor
  208. )
  209. return labels.dict()