Browse Source

Support bulk create

pull/1310/head
Hironsan 3 years ago
parent
commit
9066a62ab9
4 changed files with 113 additions and 25 deletions
  1. 26
      app/api/migrations/0010_auto_20210413_0249.py
  2. 3
      app/api/models.py
  3. 107
      app/api/tasks.py
  4. 2
      app/app/settings.py

26
app/api/migrations/0010_auto_20210413_0249.py

@ -0,0 +1,26 @@
# Generated by Django 3.1.7 on 2021-04-13 02:49
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('api', '0009_auto_20210411_2330'),
]
operations = [
migrations.AddField(
model_name='document',
name='filename',
field=models.FilePathField(default=''),
),
migrations.AlterField(
model_name='document',
name='annotations_approved_by',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL),
),
]

3
app/api/models.py

@ -199,9 +199,10 @@ class Document(models.Model):
text = models.TextField() text = models.TextField()
project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE) project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE)
meta = models.JSONField(default=dict) meta = models.JSONField(default=dict)
filename = models.FilePathField(default='')
created_at = models.DateTimeField(auto_now_add=True) created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True) updated_at = models.DateTimeField(auto_now=True)
annotations_approved_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True)
annotations_approved_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True, blank=True)
def __str__(self): def __str__(self):
return self.text[:50] return self.text[:50]

107
app/api/tasks.py

@ -1,3 +1,4 @@
import datetime
import itertools import itertools
from celery import shared_task from celery import shared_task
@ -13,6 +14,76 @@ from .views.upload.factory import (get_data_class, get_dataset_class,
from .views.upload.utils import append_field from .views.upload.utils import append_field
class Buffer:
def __init__(self, buffer_size=settings.IMPORT_BATCH_SIZE):
self.buffer_size = buffer_size
self.buffer = []
def __len__(self):
return len(self.buffer)
@property
def data(self):
return self.buffer
def add(self, data):
self.buffer.append(data)
def clear(self):
self.buffer = []
def is_full(self):
return len(self) >= self.buffer_size
def is_empty(self):
return len(self) == 0
class DataFactory:
def __init__(self, data_class, label_class, annotation_class):
self.data_class = data_class
self.label_class = label_class
self.annotation_class = annotation_class
def create_label(self, examples, project):
flatten = itertools.chain(*[example.label for example in examples])
labels = {
label['text'] for label in flatten
if not project.labels.filter(text=label['text']).exists()
}
labels = [self.label_class(text=text, project=project) for text in labels]
self.label_class.objects.bulk_create(labels)
def create_data(self, examples, project):
dataset = [
self.data_class(project=project, **example.data)
for example in examples
]
now = datetime.datetime.now()
self.data_class.objects.bulk_create(dataset)
ids = self.data_class.objects.filter(created_at__gte=now)
return list(ids)
def create_annotation(self, examples, ids, user, project):
mapping = {label.text: label.id for label in project.labels.all()}
annotation = [example.annotation(mapping) for example in examples]
for a, id in zip(annotation, ids):
append_field(a, document=id)
annotation = list(itertools.chain(*annotation))
for a in annotation:
if 'label' in a:
a['label_id'] = a.pop('label')
annotation = [self.annotation_class(**a, user=user) for a in annotation]
self.annotation_class.objects.bulk_create(annotation)
def create(self, examples, user, project):
self.create_label(examples, project)
ids = self.create_data(examples, project)
self.create_annotation(examples, ids, user, project)
@shared_task @shared_task
def injest_data(user_id, project_id, filenames, format: str, **kwargs): def injest_data(user_id, project_id, filenames, format: str, **kwargs):
project = get_object_or_404(Project, pk=project_id) project = get_object_or_404(Project, pk=project_id)
@ -27,8 +98,13 @@ def injest_data(user_id, project_id, filenames, format: str, **kwargs):
data_class=get_data_class(project.project_type), data_class=get_data_class(project.project_type),
**kwargs **kwargs
) )
annotation_serializer_class = project.get_annotation_serializer()
it = iter(dataset) it = iter(dataset)
buffer = Buffer()
factory = DataFactory(
data_class=Document,
label_class=Label,
annotation_class=project.get_annotation_class()
)
while True: while True:
try: try:
example = next(it) example = next(it)
@ -38,27 +114,12 @@ def injest_data(user_id, project_id, filenames, format: str, **kwargs):
response['error'].append(err.dict()) response['error'].append(err.dict())
continue continue
data_serializer = DocumentSerializer(data=example.data)
if not data_serializer.is_valid():
continue
data = data_serializer.save(project=project)
stored_labels = {label.text for label in project.labels.all()}
labels = [label for label in example.label if label['text'] not in stored_labels]
label_serializer = LabelSerializer(data=labels, many=True)
if not label_serializer.is_valid():
continue
label_serializer.save(project=project)
mapping = {label.text: label.id for label in project.labels.all()}
annotation = example.annotation(mapping)
append_field(annotation, document=data.id)
annotation_serializer = annotation_serializer_class(
data=annotation,
many=True
)
if not annotation_serializer.is_valid():
continue
annotation_serializer.save(user=user)
buffer.add(example)
if buffer.is_full():
factory.create(buffer.data, user, project)
buffer.clear()
if not buffer.is_empty():
factory.create(buffer.data, user, project)
buffer.clear()
return response return response

2
app/app/settings.py

@ -310,7 +310,7 @@ ALLOWED_HOSTS = ['*']
# Size of the batch for creating documents # Size of the batch for creating documents
# on the import phase # on the import phase
IMPORT_BATCH_SIZE = env.int('IMPORT_BATCH_SIZE', 500)
IMPORT_BATCH_SIZE = env.int('IMPORT_BATCH_SIZE', 1000)
GOOGLE_TRACKING_ID = env('GOOGLE_TRACKING_ID', 'UA-125643874-2').strip() GOOGLE_TRACKING_ID = env('GOOGLE_TRACKING_ID', 'UA-125643874-2').strip()

Loading…
Cancel
Save