You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
4.5 KiB

3 years ago
3 years ago
  1. import itertools
  2. import uuid
  3. from celery import shared_task
  4. from celery.utils.log import get_task_logger
  5. from django.conf import settings
  6. from django.contrib.auth import get_user_model
  7. from django.shortcuts import get_object_or_404
  8. from .models import Example, Label, Project
  9. from .views.download.factory import create_repository, create_writer
  10. from .views.download.service import ExportApplicationService
  11. from .views.upload.exception import FileParseException
  12. from .views.upload.factory import (get_data_class, get_dataset_class,
  13. get_label_class)
  14. from .views.upload.utils import append_field
  15. logger = get_task_logger(__name__)
  16. class Buffer:
  17. def __init__(self, buffer_size=settings.IMPORT_BATCH_SIZE):
  18. self.buffer_size = buffer_size
  19. self.buffer = []
  20. def __len__(self):
  21. return len(self.buffer)
  22. @property
  23. def data(self):
  24. return self.buffer
  25. def add(self, data):
  26. self.buffer.append(data)
  27. def clear(self):
  28. self.buffer = []
  29. def is_full(self):
  30. return len(self) >= self.buffer_size
  31. def is_empty(self):
  32. return len(self) == 0
  33. class DataFactory:
  34. def __init__(self, data_class, label_class, annotation_class):
  35. self.data_class = data_class
  36. self.label_class = label_class
  37. self.annotation_class = annotation_class
  38. def create_label(self, examples, project):
  39. flatten = itertools.chain(*[example.label for example in examples])
  40. labels = {
  41. label['text'] for label in flatten
  42. if not project.labels.filter(text=label['text']).exists()
  43. }
  44. labels = [self.label_class(text=text, project=project) for text in labels]
  45. self.label_class.objects.bulk_create(labels)
  46. def create_data(self, examples, project):
  47. uuids = sorted(uuid.uuid4() for _ in range(len(examples)))
  48. dataset = [
  49. self.data_class(uuid=uid, project=project, **example.data)
  50. for uid, example in zip(uuids, examples)
  51. ]
  52. self.data_class.objects.bulk_create(dataset)
  53. data = self.data_class.objects.in_bulk(uuids, field_name='uuid')
  54. return [data[uid] for uid in uuids]
  55. def create_annotation(self, examples, ids, user, project):
  56. mapping = {label.text: label.id for label in project.labels.all()}
  57. annotation = [example.annotation(mapping) for example in examples]
  58. for a, id in zip(annotation, ids):
  59. append_field(a, example=id)
  60. annotation = list(itertools.chain(*annotation))
  61. for a in annotation:
  62. if 'label' in a:
  63. a['label_id'] = a.pop('label')
  64. annotation = [self.annotation_class(**a, user=user) for a in annotation]
  65. self.annotation_class.objects.bulk_create(annotation)
  66. def create(self, examples, user, project):
  67. self.create_label(examples, project)
  68. ids = self.create_data(examples, project)
  69. self.create_annotation(examples, ids, user, project)
  70. @shared_task
  71. def injest_data(user_id, project_id, filenames, format: str, **kwargs):
  72. project = get_object_or_404(Project, pk=project_id)
  73. user = get_object_or_404(get_user_model(), pk=user_id)
  74. response = {'error': []}
  75. # Prepare dataset.
  76. dataset_class = get_dataset_class(format)
  77. dataset = dataset_class(
  78. filenames=filenames,
  79. label_class=get_label_class(project.project_type),
  80. data_class=get_data_class(project.project_type),
  81. **kwargs
  82. )
  83. it = iter(dataset)
  84. buffer = Buffer()
  85. factory = DataFactory(
  86. data_class=Example,
  87. label_class=Label,
  88. annotation_class=project.get_annotation_class()
  89. )
  90. while True:
  91. try:
  92. example = next(it)
  93. except StopIteration:
  94. break
  95. except FileParseException as err:
  96. response['error'].append(err.dict())
  97. continue
  98. buffer.add(example)
  99. if buffer.is_full():
  100. factory.create(buffer.data, user, project)
  101. buffer.clear()
  102. if not buffer.is_empty():
  103. logger.debug(f'BUFFER LEN {len(buffer)}')
  104. factory.create(buffer.data, user, project)
  105. buffer.clear()
  106. return response
  107. @shared_task
  108. def export_dataset(project_id, format: str, export_approved=False):
  109. project = get_object_or_404(Project, pk=project_id)
  110. repository = create_repository(project)
  111. writer = create_writer(format)(settings.MEDIA_ROOT)
  112. service = ExportApplicationService(repository, writer)
  113. filepath = service.export(export_approved)
  114. return filepath