You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

119 lines
3.8 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import abc
  2. import itertools
  3. from collections import defaultdict
  4. from typing import Any, Dict, List
  5. from django.conf import settings
  6. from projects.models import Project
  7. from examples.models import Example
  8. from label_types.models import CategoryType, SpanType
  9. from .exceptions import FileParseException
  10. from .readers import BaseReader, Record
  11. class Writer(abc.ABC):
  12. @abc.abstractmethod
  13. def save(self, reader: BaseReader, project: Project, user, cleaner):
  14. """Save the read contents to DB."""
  15. raise NotImplementedError("Please implement this method in the subclass.")
  16. def errors(self) -> List[Dict[Any, Any]]:
  17. """Return errors."""
  18. raise NotImplementedError("Please implement this method in the subclass.")
  19. def group_by_class(instances):
  20. groups = defaultdict(list)
  21. for instance in instances:
  22. groups[instance.__class__].append(instance)
  23. return groups
  24. class Examples:
  25. def __init__(self, buffer_size: int = settings.IMPORT_BATCH_SIZE):
  26. self.buffer_size = buffer_size
  27. self.buffer: List[Record] = []
  28. def __len__(self):
  29. return len(self.buffer)
  30. @property
  31. def data(self):
  32. return self.buffer
  33. def add(self, data):
  34. self.buffer.append(data)
  35. def clear(self):
  36. self.buffer = []
  37. def is_full(self):
  38. return len(self) >= self.buffer_size
  39. def is_empty(self):
  40. return len(self) == 0
  41. def save_label(self, project: Project):
  42. labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.buffer]))
  43. labels = list(filter(None, labels))
  44. groups = group_by_class(labels)
  45. for klass, instances in groups.items():
  46. klass.objects.bulk_create(instances, ignore_conflicts=True)
  47. def save_data(self, project: Project) -> List[Example]:
  48. examples = [example.create_data(project) for example in self.buffer]
  49. return Example.objects.bulk_create(examples)
  50. def save_annotation(self, project: Project, user, examples):
  51. # mapping = {label.text: label for label in project.labels.all()}
  52. # Todo: move annotation class
  53. mapping = {}
  54. for model in [CategoryType, SpanType]:
  55. for label in model.objects.all():
  56. mapping[label.text] = label
  57. annotations = list(
  58. itertools.chain.from_iterable(
  59. [data.create_annotation(user, example, mapping) for data, example in zip(self.buffer, examples)]
  60. )
  61. )
  62. groups = group_by_class(annotations)
  63. for klass, instances in groups.items():
  64. klass.objects.bulk_create(instances)
  65. class BulkWriter(Writer):
  66. def __init__(self, batch_size: int):
  67. self.examples = Examples(batch_size)
  68. self._errors: List[FileParseException] = []
  69. def save(self, reader: BaseReader, project: Project, user, cleaner):
  70. it = iter(reader)
  71. while True:
  72. try:
  73. example = next(it)
  74. except StopIteration:
  75. break
  76. try:
  77. example.clean(cleaner)
  78. except FileParseException as err:
  79. self._errors.append(err)
  80. self.examples.add(example)
  81. if self.examples.is_full():
  82. self.create(project, user)
  83. self.examples.clear()
  84. if not self.examples.is_empty():
  85. self.create(project, user)
  86. self.examples.clear()
  87. self._errors.extend(reader.errors)
  88. @property
  89. def errors(self) -> List[Dict[Any, Any]]:
  90. self._errors.sort(key=lambda e: e.line_num)
  91. return [error.dict() for error in self._errors]
  92. def create(self, project: Project, user):
  93. self.examples.save_label(project)
  94. ids = self.examples.save_data(project)
  95. self.examples.save_annotation(project, user, ids)