""" Represents label collection. """ import abc from collections import defaultdict from typing import Dict, List from django.db.models import QuerySet from data_export.models import ( ExportedCategory, ExportedLabel, ExportedRelation, ExportedSpan, ) from examples.models import Example class Labels(abc.ABC): label_class = ExportedLabel field_name = "labels" fields = ("example", "label") def __init__(self, examples: QuerySet[Example], user=None): self.label_groups = defaultdict(list) labels = self.label_class.objects.filter(example__in=examples) if user: labels = labels.filter(user=user) for label in labels.select_related(*self.fields): self.label_groups[label.example.id].append(label) def find_by(self, example_id: int) -> Dict[str, List[ExportedLabel]]: return {self.field_name: self.label_groups[example_id]} class Categories(Labels): label_class = ExportedCategory field_name = "categories" fields = ("example", "label") class Spans(Labels): label_class = ExportedSpan field_name = "entities" fields = ("example", "label") class Relations(Labels): label_class = ExportedRelation field_name = "relations" fields = ("example", "type")