From 9bf6c523470739d0e43c16b758fefac96f1fceb6 Mon Sep 17 00:00:00 2001 From: zanussbaum Date: Thu, 1 Jul 2021 15:57:04 -0400 Subject: [PATCH] Fixing Data Annotation Issues When uploading datasets, the code uses a `bulk_create` to upload Examples and Labels. It then filters the data from the database based on when it was created. However, [Django doesn't enforce the list order when calling filter](https://stackoverflow.com/questions/7163640/what-is-the-default-order-of-a-list-returned-from-a-django-filter-call) unless ordering is specified. The previous behavior mismatched labels and examples. When this was shown in the UI, the data would show labels for incorrect examples (i.e. a label for message #2 would be shown on message #1). This fix enforces that the data is returned in the order it was inserted so that the data, label pair is as expected. --- backend/api/models.py | 16 +++++++++++----- backend/api/tasks.py | 7 +++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/backend/api/models.py b/backend/api/models.py index 18d52f45..5eb028f6 100644 --- a/backend/api/models.py +++ b/backend/api/models.py @@ -1,6 +1,8 @@ import string from typing import Literal +from django import db + from auto_labeling_pipeline.models import RequestModelFactory from django.contrib.auth.models import User from django.core.exceptions import ValidationError @@ -92,7 +94,7 @@ class ImageClassificationProject(Project): class Label(models.Model): - text = models.CharField(max_length=100) + text = models.CharField(max_length=100, db_index=True) prefix_key = models.CharField( max_length=10, blank=True, @@ -118,7 +120,7 @@ class Label(models.Model): ) background_color = models.CharField(max_length=7, default='#209cee') text_color = models.CharField(max_length=7, default='#ffffff') - created_at = models.DateTimeField(auto_now_add=True) + created_at = models.DateTimeField(auto_now_add=True, db_index=True) updated_at = models.DateTimeField(auto_now=True) def __str__(self): @@ -143,6 +145,7 @@ class Label(models.Model): unique_together = ( ('project', 'text'), ) + ordering = ['created_at'] class Example(models.Model): @@ -160,13 +163,16 @@ class Example(models.Model): blank=True ) text = models.TextField(null=True, blank=True) - created_at = models.DateTimeField(auto_now_add=True) + created_at = models.DateTimeField(auto_now_add=True, db_index=True) updated_at = models.DateTimeField(auto_now=True) @property def comment_count(self): return Comment.objects.filter(example=self.id).count() + class Meta: + ordering = ['created_at'] + class ExampleState(models.Model): example = models.ForeignKey( @@ -196,7 +202,7 @@ class Comment(models.Model): on_delete=models.CASCADE, null=True ) - created_at = models.DateTimeField(auto_now_add=True) + created_at = models.DateTimeField(auto_now_add=True, db_index=True) updated_at = models.DateTimeField(auto_now=True) @property @@ -204,7 +210,7 @@ class Comment(models.Model): return self.user.username class Meta: - ordering = ('-created_at', ) + ordering = ['created_at'] class Tag(models.Model): diff --git a/backend/api/tasks.py b/backend/api/tasks.py index f97f9525..963ed7a6 100644 --- a/backend/api/tasks.py +++ b/backend/api/tasks.py @@ -1,7 +1,8 @@ import datetime import itertools -from celery import shared_task +from celery import shared_task +from celery.utils.log import get_task_logger from django.conf import settings from django.contrib.auth import get_user_model from django.shortcuts import get_object_or_404 @@ -14,7 +15,7 @@ from .views.upload.factory import (get_data_class, get_dataset_class, get_label_class) from .views.upload.utils import append_field - +logger = get_task_logger(__name__) class Buffer: def __init__(self, buffer_size=settings.IMPORT_BATCH_SIZE): @@ -82,6 +83,7 @@ class DataFactory: def create(self, examples, user, project): self.create_label(examples, project) ids = self.create_data(examples, project) + logger.debug(f'IDS {[ids[i].text for i in range(15)]}') self.create_annotation(examples, ids, user, project) @@ -120,6 +122,7 @@ def injest_data(user_id, project_id, filenames, format: str, **kwargs): factory.create(buffer.data, user, project) buffer.clear() if not buffer.is_empty(): + logger.debug(f'BUFFER LEN {len(buffer)}') factory.create(buffer.data, user, project) buffer.clear()