|
|
@ -3,6 +3,11 @@ import json |
|
|
|
import random |
|
|
|
|
|
|
|
from auto_labeling_pipeline.menu import Options |
|
|
|
from auto_labeling_pipeline.models import RequestModelFactory |
|
|
|
from auto_labeling_pipeline.mappings import MappingTemplate |
|
|
|
from auto_labeling_pipeline.task import TaskFactory |
|
|
|
from auto_labeling_pipeline.postprocessing import PostProcessor |
|
|
|
from auto_labeling_pipeline.pipeline import pipeline |
|
|
|
from django.conf import settings |
|
|
|
from django.contrib.auth.models import User |
|
|
|
from django.db import transaction |
|
|
@ -522,3 +527,47 @@ class AutoLabelingConfigDetail(generics.RetrieveUpdateDestroyAPIView): |
|
|
|
serializer_class = AutoLabelingConfigSerializer |
|
|
|
lookup_url_kwarg = 'config_id' |
|
|
|
permission_classes = [IsAuthenticated & IsProjectAdmin] |
|
|
|
|
|
|
|
|
|
|
|
class AutoLabelingConfigTest(APIView): |
|
|
|
permission_classes = [IsAuthenticated & IsProjectAdmin] |
|
|
|
|
|
|
|
def post(self, *args, **kwargs): |
|
|
|
try: |
|
|
|
output = self.pass_config_validation() |
|
|
|
output = self.pass_pipeline_call(output) |
|
|
|
return Response( |
|
|
|
data={'valid': True, 'labels': output.dict()}, |
|
|
|
status=status.HTTP_200_OK |
|
|
|
) |
|
|
|
except Exception: |
|
|
|
return Response( |
|
|
|
data={'valid': False}, |
|
|
|
status=status.HTTP_400_BAD_REQUEST |
|
|
|
) |
|
|
|
|
|
|
|
def pass_config_validation(self): |
|
|
|
config = self.request.data['config'] |
|
|
|
serializer = AutoLabelingConfigSerializer(data=config) |
|
|
|
serializer.is_valid(raise_exception=True) |
|
|
|
return serializer |
|
|
|
|
|
|
|
def pass_pipeline_call(self, serializer): |
|
|
|
test_input = self.request.data['input'] |
|
|
|
project = get_object_or_404(Project, pk=self.kwargs['project_id']) |
|
|
|
model = RequestModelFactory.create( |
|
|
|
model_name=serializer.data.get('model_name'), |
|
|
|
attributes=serializer.data.get('model_attrs') |
|
|
|
) |
|
|
|
template = MappingTemplate( |
|
|
|
task=TaskFactory.create(project.project_type), |
|
|
|
template=serializer.data.get('template') |
|
|
|
) |
|
|
|
post_processor = PostProcessor(serializer.data.get('label_mapping')) |
|
|
|
labels = pipeline( |
|
|
|
text=test_input, |
|
|
|
request_model=model, |
|
|
|
mapping_template=template, |
|
|
|
post_processing=post_processor |
|
|
|
) |
|
|
|
return labels |