You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

64 lines
2.2 KiB

3 years ago
3 years ago
  1. import itertools
  2. from celery import shared_task
  3. from django.conf import settings
  4. from django.contrib.auth import get_user_model
  5. from django.shortcuts import get_object_or_404
  6. from .models import Document, Label, Project
  7. from .serializers import DocumentSerializer, LabelSerializer
  8. from .views.upload.exception import FileParseException
  9. from .views.upload.factory import (get_data_class, get_dataset_class,
  10. get_label_class)
  11. from .views.upload.utils import append_field
  12. @shared_task
  13. def injest_data(user_id, project_id, filenames, format: str, **kwargs):
  14. project = get_object_or_404(Project, pk=project_id)
  15. user = get_object_or_404(get_user_model(), pk=user_id)
  16. response = {'error': []}
  17. # Prepare dataset.
  18. dataset_class = get_dataset_class(format)
  19. dataset = dataset_class(
  20. filenames=filenames,
  21. label_class=get_label_class(project.project_type),
  22. data_class=get_data_class(project.project_type),
  23. **kwargs
  24. )
  25. annotation_serializer_class = project.get_annotation_serializer()
  26. it = iter(dataset)
  27. while True:
  28. try:
  29. example = next(it)
  30. except StopIteration:
  31. break
  32. except FileParseException as err:
  33. response['error'].append(err.dict())
  34. continue
  35. data_serializer = DocumentSerializer(data=example.data)
  36. if not data_serializer.is_valid():
  37. continue
  38. data = data_serializer.save(project=project)
  39. stored_labels = {label.text for label in project.labels.all()}
  40. labels = [label for label in example.label if label['text'] not in stored_labels]
  41. label_serializer = LabelSerializer(data=labels, many=True)
  42. if not label_serializer.is_valid():
  43. continue
  44. label_serializer.save(project=project)
  45. mapping = {label.text: label.id for label in project.labels.all()}
  46. annotation = example.annotation(mapping)
  47. append_field(annotation, document=data.id)
  48. annotation_serializer = annotation_serializer_class(
  49. data=annotation,
  50. many=True
  51. )
  52. if not annotation_serializer.is_valid():
  53. continue
  54. annotation_serializer.save(user=user)
  55. return response