diff --git a/app/api/tasks.py b/app/api/tasks.py index 7e9e3235..e1a4f711 100644 --- a/app/api/tasks.py +++ b/app/api/tasks.py @@ -1,28 +1,40 @@ from celery import shared_task from django.conf import settings -from django_drf_filepond.api import store_upload -from django_drf_filepond.models import TemporaryUpload +from django.shortcuts import get_object_or_404 -from .models import Document, Label +from .models import Document, Label, Project +from .serializers import LabelSerializer, DocumentSerializer +from .views.upload.factory import get_data_class, get_dataset_class, get_label_class +from .views.upload.utils import append_field @shared_task -def add(x, y): - return x + y - - -@shared_task -def mul(x, y): - return x * y - - -@shared_task -def xsum(numbers): - return sum(numbers) - - -@shared_task -def parse(upload_id): - tu = TemporaryUpload.objects.get(upload_id=upload_id) - su = store_upload(upload_id, destination_file_path=tu.upload_name) - return su.file.path +def injest_data(project_id, filenames, format: str, **kwargs): + project = get_object_or_404(Project, pk=project_id) + data_class = get_data_class(project.project_type) + dataset_class = get_dataset_class(format) + label_class = get_label_class(project.project_type) + dataset = dataset_class( + filenames=filenames, + label_class=label_class, + data_class=data_class + ) + annotation_serializer_class = project.get_annotation_serializer() + for batch in dataset.batch(settings.IMPORT_BATCH_SIZE): + data_serializer = DocumentSerializer(data=batch.data(), many=True) + data_serializer.is_valid() + data = data_serializer.save(project=project) + + label_serializer = LabelSerializer(data=batch.label(), many=True) + label_serializer.is_valid() + label_serializer.save(project=project) + + mapping = {label['text']: label['id'] for label in project.labels.values()} + annotation = batch.annotation(mapping) + for a, d in zip(annotation, data): + append_field(a, document=d.id) + annotation_serializer = annotation_serializer_class(data=annotation, many=True) + annotation_serializer.is_valid() + annotation_serializer.save() + + return {'error': []}