You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
3.9 KiB

3 years ago
3 years ago
  1. import datetime
  2. import itertools
  3. from celery import shared_task
  4. from django.conf import settings
  5. from django.contrib.auth import get_user_model
  6. from django.shortcuts import get_object_or_404
  7. from .models import Document, Label, Project
  8. from .serializers import DocumentSerializer, LabelSerializer
  9. from .views.upload.exception import FileParseException
  10. from .views.upload.factory import (get_data_class, get_dataset_class,
  11. get_label_class)
  12. from .views.upload.utils import append_field
  13. class Buffer:
  14. def __init__(self, buffer_size=settings.IMPORT_BATCH_SIZE):
  15. self.buffer_size = buffer_size
  16. self.buffer = []
  17. def __len__(self):
  18. return len(self.buffer)
  19. @property
  20. def data(self):
  21. return self.buffer
  22. def add(self, data):
  23. self.buffer.append(data)
  24. def clear(self):
  25. self.buffer = []
  26. def is_full(self):
  27. return len(self) >= self.buffer_size
  28. def is_empty(self):
  29. return len(self) == 0
  30. class DataFactory:
  31. def __init__(self, data_class, label_class, annotation_class):
  32. self.data_class = data_class
  33. self.label_class = label_class
  34. self.annotation_class = annotation_class
  35. def create_label(self, examples, project):
  36. flatten = itertools.chain(*[example.label for example in examples])
  37. labels = {
  38. label['text'] for label in flatten
  39. if not project.labels.filter(text=label['text']).exists()
  40. }
  41. labels = [self.label_class(text=text, project=project) for text in labels]
  42. self.label_class.objects.bulk_create(labels)
  43. def create_data(self, examples, project):
  44. dataset = [
  45. self.data_class(project=project, **example.data)
  46. for example in examples
  47. ]
  48. now = datetime.datetime.now()
  49. self.data_class.objects.bulk_create(dataset)
  50. ids = self.data_class.objects.filter(created_at__gte=now)
  51. return list(ids)
  52. def create_annotation(self, examples, ids, user, project):
  53. mapping = {label.text: label.id for label in project.labels.all()}
  54. annotation = [example.annotation(mapping) for example in examples]
  55. for a, id in zip(annotation, ids):
  56. append_field(a, document=id)
  57. annotation = list(itertools.chain(*annotation))
  58. for a in annotation:
  59. if 'label' in a:
  60. a['label_id'] = a.pop('label')
  61. annotation = [self.annotation_class(**a, user=user) for a in annotation]
  62. self.annotation_class.objects.bulk_create(annotation)
  63. def create(self, examples, user, project):
  64. self.create_label(examples, project)
  65. ids = self.create_data(examples, project)
  66. self.create_annotation(examples, ids, user, project)
  67. @shared_task
  68. def injest_data(user_id, project_id, filenames, format: str, **kwargs):
  69. project = get_object_or_404(Project, pk=project_id)
  70. user = get_object_or_404(get_user_model(), pk=user_id)
  71. response = {'error': []}
  72. # Prepare dataset.
  73. dataset_class = get_dataset_class(format)
  74. dataset = dataset_class(
  75. filenames=filenames,
  76. label_class=get_label_class(project.project_type),
  77. data_class=get_data_class(project.project_type),
  78. **kwargs
  79. )
  80. it = iter(dataset)
  81. buffer = Buffer()
  82. factory = DataFactory(
  83. data_class=Document,
  84. label_class=Label,
  85. annotation_class=project.get_annotation_class()
  86. )
  87. while True:
  88. try:
  89. example = next(it)
  90. except StopIteration:
  91. break
  92. except FileParseException as err:
  93. response['error'].append(err.dict())
  94. continue
  95. buffer.add(example)
  96. if buffer.is_full():
  97. factory.create(buffer.data, user, project)
  98. buffer.clear()
  99. if not buffer.is_empty():
  100. factory.create(buffer.data, user, project)
  101. buffer.clear()
  102. return response