diff --git a/app/api/migrations/0010_auto_20210413_0249.py b/app/api/migrations/0010_auto_20210413_0249.py new file mode 100644 index 00000000..ddf9fb23 --- /dev/null +++ b/app/api/migrations/0010_auto_20210413_0249.py @@ -0,0 +1,26 @@ +# Generated by Django 3.1.7 on 2021-04-13 02:49 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('api', '0009_auto_20210411_2330'), + ] + + operations = [ + migrations.AddField( + model_name='document', + name='filename', + field=models.FilePathField(default=''), + ), + migrations.AlterField( + model_name='document', + name='annotations_approved_by', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL), + ), + ] diff --git a/app/api/models.py b/app/api/models.py index c0ea6ff1..a11bda5f 100644 --- a/app/api/models.py +++ b/app/api/models.py @@ -199,9 +199,10 @@ class Document(models.Model): text = models.TextField() project = models.ForeignKey(Project, related_name='documents', on_delete=models.CASCADE) meta = models.JSONField(default=dict) + filename = models.FilePathField(default='') created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) - annotations_approved_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True) + annotations_approved_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True, blank=True) def __str__(self): return self.text[:50] diff --git a/app/api/tasks.py b/app/api/tasks.py index badd2555..b7df380e 100644 --- a/app/api/tasks.py +++ b/app/api/tasks.py @@ -1,3 +1,4 @@ +import datetime import itertools from celery import shared_task @@ -13,6 +14,76 @@ from .views.upload.factory import (get_data_class, get_dataset_class, from .views.upload.utils import append_field +class Buffer: + + def __init__(self, buffer_size=settings.IMPORT_BATCH_SIZE): + self.buffer_size = buffer_size + self.buffer = [] + + def __len__(self): + return len(self.buffer) + + @property + def data(self): + return self.buffer + + def add(self, data): + self.buffer.append(data) + + def clear(self): + self.buffer = [] + + def is_full(self): + return len(self) >= self.buffer_size + + def is_empty(self): + return len(self) == 0 + + +class DataFactory: + + def __init__(self, data_class, label_class, annotation_class): + self.data_class = data_class + self.label_class = label_class + self.annotation_class = annotation_class + + def create_label(self, examples, project): + flatten = itertools.chain(*[example.label for example in examples]) + labels = { + label['text'] for label in flatten + if not project.labels.filter(text=label['text']).exists() + } + labels = [self.label_class(text=text, project=project) for text in labels] + self.label_class.objects.bulk_create(labels) + + def create_data(self, examples, project): + dataset = [ + self.data_class(project=project, **example.data) + for example in examples + ] + now = datetime.datetime.now() + self.data_class.objects.bulk_create(dataset) + ids = self.data_class.objects.filter(created_at__gte=now) + return list(ids) + + def create_annotation(self, examples, ids, user, project): + mapping = {label.text: label.id for label in project.labels.all()} + annotation = [example.annotation(mapping) for example in examples] + for a, id in zip(annotation, ids): + append_field(a, document=id) + annotation = list(itertools.chain(*annotation)) + for a in annotation: + if 'label' in a: + a['label_id'] = a.pop('label') + annotation = [self.annotation_class(**a, user=user) for a in annotation] + self.annotation_class.objects.bulk_create(annotation) + + def create(self, examples, user, project): + self.create_label(examples, project) + ids = self.create_data(examples, project) + self.create_annotation(examples, ids, user, project) + + @shared_task def injest_data(user_id, project_id, filenames, format: str, **kwargs): project = get_object_or_404(Project, pk=project_id) @@ -27,8 +98,13 @@ def injest_data(user_id, project_id, filenames, format: str, **kwargs): data_class=get_data_class(project.project_type), **kwargs ) - annotation_serializer_class = project.get_annotation_serializer() it = iter(dataset) + buffer = Buffer() + factory = DataFactory( + data_class=Document, + label_class=Label, + annotation_class=project.get_annotation_class() + ) while True: try: example = next(it) @@ -38,27 +114,12 @@ def injest_data(user_id, project_id, filenames, format: str, **kwargs): response['error'].append(err.dict()) continue - data_serializer = DocumentSerializer(data=example.data) - if not data_serializer.is_valid(): - continue - data = data_serializer.save(project=project) - - stored_labels = {label.text for label in project.labels.all()} - labels = [label for label in example.label if label['text'] not in stored_labels] - label_serializer = LabelSerializer(data=labels, many=True) - if not label_serializer.is_valid(): - continue - label_serializer.save(project=project) - - mapping = {label.text: label.id for label in project.labels.all()} - annotation = example.annotation(mapping) - append_field(annotation, document=data.id) - annotation_serializer = annotation_serializer_class( - data=annotation, - many=True - ) - if not annotation_serializer.is_valid(): - continue - annotation_serializer.save(user=user) + buffer.add(example) + if buffer.is_full(): + factory.create(buffer.data, user, project) + buffer.clear() + if not buffer.is_empty(): + factory.create(buffer.data, user, project) + buffer.clear() return response diff --git a/app/app/settings.py b/app/app/settings.py index daeb82ef..e301c779 100644 --- a/app/app/settings.py +++ b/app/app/settings.py @@ -310,7 +310,7 @@ ALLOWED_HOSTS = ['*'] # Size of the batch for creating documents # on the import phase -IMPORT_BATCH_SIZE = env.int('IMPORT_BATCH_SIZE', 500) +IMPORT_BATCH_SIZE = env.int('IMPORT_BATCH_SIZE', 1000) GOOGLE_TRACKING_ID = env('GOOGLE_TRACKING_ID', 'UA-125643874-2').strip()