diff --git a/app/api/tasks.py b/app/api/tasks.py index e1a4f711..49bf490a 100644 --- a/app/api/tasks.py +++ b/app/api/tasks.py @@ -1,5 +1,8 @@ +import itertools + from celery import shared_task from django.conf import settings +from django.contrib.auth import get_user_model from django.shortcuts import get_object_or_404 from .models import Document, Label, Project @@ -9,15 +12,17 @@ from .views.upload.utils import append_field @shared_task -def injest_data(project_id, filenames, format: str, **kwargs): +def injest_data(user_id, project_id, filenames, format: str, **kwargs): project = get_object_or_404(Project, pk=project_id) + user = get_object_or_404(get_user_model(), pk=user_id) data_class = get_data_class(project.project_type) dataset_class = get_dataset_class(format) label_class = get_label_class(project.project_type) dataset = dataset_class( - filenames=filenames, - label_class=label_class, - data_class=data_class + filenames=filenames, + label_class=label_class, + data_class=data_class, + **kwargs ) annotation_serializer_class = project.get_annotation_serializer() for batch in dataset.batch(settings.IMPORT_BATCH_SIZE): @@ -33,8 +38,11 @@ def injest_data(project_id, filenames, format: str, **kwargs): annotation = batch.annotation(mapping) for a, d in zip(annotation, data): append_field(a, document=d.id) - annotation_serializer = annotation_serializer_class(data=annotation, many=True) + annotation_serializer = annotation_serializer_class( + data=list(itertools.chain(*annotation)), + many=True + ) annotation_serializer.is_valid() - annotation_serializer.save() + annotation_serializer.save(user=user) return {'error': []}