diff --git a/backend/api/managers.py b/backend/api/managers.py index 90001b0a..6c1bfdd3 100644 --- a/backend/api/managers.py +++ b/backend/api/managers.py @@ -47,3 +47,12 @@ class RoleMappingManager(Manager): if mapping.id == mapping_id and rolename != settings.ROLE_PROJECT_ADMIN: return False return True + + +class ExampleManager(Manager): + + def bulk_create(self, objs, batch_size=None, ignore_conflicts=False): + super().bulk_create(objs, batch_size=batch_size, ignore_conflicts=ignore_conflicts) + uuids = [data.uuid for data in objs] + examples = self.in_bulk(uuids, field_name='uuid') + return [examples[uid] for uid in uuids] diff --git a/backend/api/models.py b/backend/api/models.py index d380d5ad..81b6b05d 100644 --- a/backend/api/models.py +++ b/backend/api/models.py @@ -9,7 +9,7 @@ from django.core.exceptions import ValidationError from django.db import models from polymorphic.models import PolymorphicModel -from .managers import (AnnotationManager, RoleMappingManager, +from .managers import (AnnotationManager, ExampleManager, RoleMappingManager, Seq2seqAnnotationManager) DOCUMENT_CLASSIFICATION = 'DocumentClassification' @@ -167,6 +167,8 @@ class Label(models.Model): class Example(models.Model): + objects = ExampleManager() + uuid = models.UUIDField(default=uuid.uuid4, editable=False, db_index=True, unique=True) meta = models.JSONField(default=dict) filename = models.FileField(default='.', max_length=1024) diff --git a/backend/api/tasks.py b/backend/api/tasks.py index 76fdd78a..b3c43a50 100644 --- a/backend/api/tasks.py +++ b/backend/api/tasks.py @@ -80,11 +80,8 @@ class Examples: labels.save(project) def save_data(self, project: Project) -> List[Example]: - dataset = [example.create_data(project) for example in self.buffer] - Example.objects.bulk_create(dataset) - uuids = [data.uuid for data in dataset] - dataset = Example.objects.in_bulk(uuids, field_name='uuid') - return [dataset[uid] for uid in uuids] + examples = [example.create_data(project) for example in self.buffer] + return Example.objects.bulk_create(examples) def save_annotation(self, project, user, examples): mapping = {(label.text, label.task_type): label for label in project.labels.all()}