""" Represents label collection. """ import abc from collections import defaultdict from typing import Dict, List from django.db.models import QuerySet from data_export.models import ExportedCategory, ExportedLabel from examples.models import Example class Labels(abc.ABC): label_class = ExportedLabel field_name = "labels" fields = ("example", "label") def __init__(self, examples: QuerySet[Example], user): self.label_groups = defaultdict(list) labels = self.label_class.objects.filter(example__in=examples, user=user).select_related(*self.fields) for label in labels: self.label_groups[label.example.id].append(label.dict()) def find_by(self, example_id: int) -> Dict[str, List]: return {self.field_name: self.label_groups[example_id]} class Categories(Labels): label_class = ExportedCategory field_name = "categories" fields = ("example", "label")