You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

51 lines
1.3 KiB

"""
Represents label collection.
"""
import abc
from collections import defaultdict
from typing import Dict, List
from django.db.models import QuerySet
from data_export.models import (
ExportedCategory,
ExportedLabel,
ExportedRelation,
ExportedSpan,
)
from examples.models import Example
class Labels(abc.ABC):
label_class = ExportedLabel
field_name = "labels"
fields = ("example", "label")
def __init__(self, examples: QuerySet[Example], user=None):
self.label_groups = defaultdict(list)
labels = self.label_class.objects.filter(example__in=examples)
if user:
labels = labels.filter(user=user)
for label in labels.select_related(*self.fields):
self.label_groups[label.example.id].append(label)
def find_by(self, example_id: int) -> Dict[str, List[ExportedLabel]]:
return {self.field_name: self.label_groups[example_id]}
class Categories(Labels):
label_class = ExportedCategory
field_name = "categories"
fields = ("example", "label")
class Spans(Labels):
label_class = ExportedSpan
field_name = "entities"
fields = ("example", "label")
class Relations(Labels):
label_class = ExportedRelation
field_name = "relations"
fields = ("example", "type")