diff --git a/backend/api/tasks.py b/backend/api/tasks.py index b3c43a50..3ddb5caf 100644 --- a/backend/api/tasks.py +++ b/backend/api/tasks.py @@ -17,28 +17,6 @@ from .views.upload.factory import (create_cleaner, get_data_class, logger = get_task_logger(__name__) -class Labels: - - def __init__(self): - self.items = [] - - def add(self, label: Label): - self.items.append(label) - - def dedupe(self, project: Project): - labels = [] - existing_labels = {(label.text, label.task_type) for label in project.labels.all()} - for label in self.items: - if label and label.text and (label.text, label.task_type) not in existing_labels: - labels.append(label) - existing_labels.add((label.text, label.task_type)) - self.items = labels - - def save(self, project: Project): - self.dedupe(project) - Label.objects.bulk_create(self.items) - - def group_by_class(instances): from collections import defaultdict groups = defaultdict(list) @@ -73,11 +51,9 @@ class Examples: return len(self) == 0 def save_label(self, project: Project): - labels = Labels() - for example in self.buffer: - for label in example.create_label(project): - labels.add(label) - labels.save(project) + labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.buffer])) + labels = list(filter(None, labels)) + Label.objects.bulk_create(labels, ignore_conflicts=True) def save_data(self, project: Project) -> List[Example]: examples = [example.create_data(project) for example in self.buffer]