You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

285 lines
11 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. import abc
  2. import json
  3. from typing import List
  4. import botocore.exceptions
  5. import requests
  6. from auto_labeling_pipeline.mappings import MappingTemplate
  7. from auto_labeling_pipeline.menu import Options
  8. from auto_labeling_pipeline.models import RequestModelFactory
  9. from auto_labeling_pipeline.postprocessing import PostProcessor
  10. from auto_labeling_pipeline.task import TaskFactory
  11. from django.shortcuts import get_object_or_404
  12. from django_drf_filepond.models import TemporaryUpload
  13. from rest_framework import generics, status
  14. from rest_framework.exceptions import ValidationError
  15. from rest_framework.permissions import IsAuthenticated
  16. from rest_framework.request import Request
  17. from rest_framework.response import Response
  18. from rest_framework.views import APIView
  19. from api.models import Example, Project, Category, CategoryType, Annotation, Span, SpanType
  20. from members.permissions import IsInProjectOrAdmin, IsProjectAdmin
  21. from .pipeline.execution import execute_pipeline
  22. from .exceptions import (AutoLabelingPermissionDenied,
  23. AWSTokenError, SampleDataException,
  24. TemplateMappingError, URLConnectionError)
  25. from .models import AutoLabelingConfig
  26. from .serializers import (AutoLabelingConfigSerializer, get_annotation_serializer)
  27. class TemplateListAPI(APIView):
  28. permission_classes = [IsAuthenticated & IsProjectAdmin]
  29. def get(self, request: Request, *args, **kwargs):
  30. task_name = request.query_params.get('task_name')
  31. options = Options.filter_by_task(task_name=task_name)
  32. option_names = [o.name for o in options]
  33. return Response(option_names, status=status.HTTP_200_OK)
  34. class TemplateDetailAPI(APIView):
  35. permission_classes = [IsAuthenticated & IsProjectAdmin]
  36. def get(self, request, *args, **kwargs):
  37. option = Options.find(option_name=self.kwargs['option_name'])
  38. return Response(option.to_dict(), status=status.HTTP_200_OK)
  39. class ConfigList(generics.ListCreateAPIView):
  40. serializer_class = AutoLabelingConfigSerializer
  41. pagination_class = None
  42. permission_classes = [IsAuthenticated & IsProjectAdmin]
  43. def get_queryset(self):
  44. return AutoLabelingConfig.objects.filter(project=self.kwargs['project_id'])
  45. def perform_create(self, serializer):
  46. serializer.save(project_id=self.kwargs['project_id'])
  47. class ConfigDetail(generics.RetrieveUpdateDestroyAPIView):
  48. queryset = AutoLabelingConfig.objects.all()
  49. serializer_class = AutoLabelingConfigSerializer
  50. lookup_url_kwarg = 'config_id'
  51. permission_classes = [IsAuthenticated & IsProjectAdmin]
  52. class FullPipelineTesting(APIView):
  53. permission_classes = [IsAuthenticated & IsProjectAdmin]
  54. def post(self, *args, **kwargs):
  55. try:
  56. output = self.pass_config_validation()
  57. output = self.pass_pipeline_call(output)
  58. if not output:
  59. raise SampleDataException()
  60. return Response(
  61. data={'valid': True, 'labels': output},
  62. status=status.HTTP_200_OK
  63. )
  64. except requests.exceptions.ConnectionError:
  65. raise URLConnectionError()
  66. except botocore.exceptions.ClientError:
  67. raise AWSTokenError()
  68. except ValidationError as e:
  69. raise e
  70. except Exception as e:
  71. raise e
  72. def pass_config_validation(self):
  73. config = self.request.data['config']
  74. serializer = AutoLabelingConfigSerializer(data=config)
  75. serializer.is_valid(raise_exception=True)
  76. return serializer
  77. def pass_pipeline_call(self, serializer):
  78. test_input = self.request.data['input']
  79. return execute_pipeline(
  80. text=test_input,
  81. task_type=serializer.data.get('task_type'),
  82. model_name=serializer.data.get('model_name'),
  83. model_attrs=serializer.data.get('model_attrs'),
  84. template=serializer.data.get('template'),
  85. label_mapping=serializer.data.get('label_mapping')
  86. )
  87. class RestAPIRequestTesting(APIView):
  88. permission_classes = [IsAuthenticated & IsProjectAdmin]
  89. @property
  90. def project(self):
  91. return get_object_or_404(Project, pk=self.kwargs['project_id'])
  92. def create_model(self):
  93. model_name = self.request.data['model_name']
  94. model_attrs = self.request.data['model_attrs']
  95. try:
  96. model = RequestModelFactory.create(model_name, model_attrs)
  97. return model
  98. except Exception:
  99. model = RequestModelFactory.find(model_name)
  100. schema = model.schema()
  101. required_fields = ', '.join(schema['required']) if 'required' in schema else ''
  102. raise ValidationError(
  103. 'The attributes does not match the model.'
  104. 'You need to correctly specify the required fields: {}'.format(required_fields)
  105. )
  106. def send_request(self, model, example):
  107. try:
  108. return model.send(example)
  109. except requests.exceptions.ConnectionError:
  110. raise URLConnectionError
  111. except botocore.exceptions.ClientError:
  112. raise AWSTokenError()
  113. except Exception as e:
  114. raise e
  115. def prepare_example(self):
  116. text = self.request.data['text']
  117. if self.project.is_text_project:
  118. return text
  119. else:
  120. tu = TemporaryUpload.objects.get(upload_id=text)
  121. return tu.get_file_path()
  122. def post(self, *args, **kwargs):
  123. model = self.create_model()
  124. example = self.prepare_example()
  125. response = self.send_request(model=model, example=example)
  126. return Response(response, status=status.HTTP_200_OK)
  127. class LabelExtractorTesting(APIView):
  128. permission_classes = [IsAuthenticated & IsProjectAdmin]
  129. def post(self, *args, **kwargs):
  130. response = self.request.data['response']
  131. template = self.request.data['template']
  132. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  133. task = TaskFactory.create(project.project_type)
  134. template = MappingTemplate(
  135. label_collection=task.label_collection,
  136. template=template
  137. )
  138. try:
  139. labels = template.render(response)
  140. except json.decoder.JSONDecodeError:
  141. raise TemplateMappingError()
  142. if not labels.dict():
  143. raise SampleDataException()
  144. return Response(labels.dict(), status=status.HTTP_200_OK)
  145. class LabelMapperTesting(APIView):
  146. permission_classes = [IsAuthenticated & IsProjectAdmin]
  147. def post(self, *args, **kwargs):
  148. response = self.request.data['response']
  149. label_mapping = self.request.data['label_mapping']
  150. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  151. task = TaskFactory.create(project.project_type)
  152. labels = task.label_collection(response)
  153. post_processor = PostProcessor(label_mapping)
  154. labels = post_processor.transform(labels)
  155. return Response(labels.dict(), status=status.HTTP_200_OK)
  156. class AutomatedDataLabeling(generics.CreateAPIView):
  157. pagination_class = None
  158. permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
  159. swagger_schema = None
  160. def get_serializer_class(self):
  161. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  162. self.serializer_class = get_annotation_serializer(task=project.project_type)
  163. return self.serializer_class
  164. def create(self, request, *args, **kwargs):
  165. labels = self.extract()
  166. labels = self.transform(labels)
  167. serializer = self.get_serializer(data=labels, many=True)
  168. serializer.is_valid(raise_exception=True)
  169. self.perform_create(serializer)
  170. headers = self.get_success_headers(serializer.data)
  171. return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
  172. def perform_create(self, serializer):
  173. serializer.save(user=self.request.user)
  174. def extract(self):
  175. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  176. example = get_object_or_404(Example, pk=self.kwargs['example_id'])
  177. config = project.auto_labeling_config.first()
  178. if not config:
  179. raise AutoLabelingPermissionDenied()
  180. return execute_pipeline(
  181. text=example.data,
  182. task_type=project.project_type,
  183. model_name=config.model_name,
  184. model_attrs=config.model_attrs,
  185. template=config.template,
  186. label_mapping=config.label_mapping
  187. )
  188. def transform(self, labels):
  189. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  190. for label in labels:
  191. label['example'] = self.kwargs['example_id']
  192. if 'label' in label:
  193. label['label'] = project.labels.get(text=label.pop('label')).id
  194. return labels
  195. class AutomatedLabeling(abc.ABC, generics.CreateAPIView):
  196. permission_classes = [IsAuthenticated & IsInProjectOrAdmin]
  197. swagger_schema = None
  198. model = None
  199. label_type = None
  200. task_type = None
  201. def create(self, request, *args, **kwargs):
  202. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  203. example = get_object_or_404(Example, pk=self.kwargs['example_id'])
  204. configs = AutoLabelingConfig.objects.filter(task_type=self.task_type)
  205. for config in configs:
  206. labels = execute_pipeline(
  207. text=example.data,
  208. task_type=config.task_type,
  209. model_name=config.model_name,
  210. model_attrs=config.model_attrs,
  211. template=config.template,
  212. label_mapping=config.label_mapping
  213. )
  214. labels = self.transform(labels, example, project)
  215. labels = self.model.objects.filter_annotatable_labels(labels, project)
  216. self.model.objects.bulk_create(labels)
  217. return Response({'ok': True}, status=status.HTTP_201_CREATED)
  218. def transform(self, labels, example: Example, project: Project) -> List[Annotation]:
  219. mapping = {
  220. c.text: c for c in self.label_type.objects.filter(project=project)
  221. }
  222. annotations = []
  223. for label in labels:
  224. if label['label'] not in mapping:
  225. continue
  226. label['example'] = example
  227. label['label'] = mapping[label['label']]
  228. label['user'] = self.request.user
  229. annotations.append(self.model(**label))
  230. return annotations
  231. class AutomatedCategoryLabeling(AutomatedLabeling):
  232. model = Category
  233. label_type = CategoryType
  234. task_type = 'Category'
  235. class AutomatedSpanLabeling(AutomatedLabeling):
  236. model = Span
  237. label_type = SpanType
  238. task_type = 'Span'