You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

72 lines
1.7 KiB

"""
Represents label collection.
"""
import abc
from collections import defaultdict
from typing import Dict, List, Tuple
from django.db.models import QuerySet
from data_export.models import (
ExportedBoundingBox,
ExportedCategory,
ExportedExample,
ExportedLabel,
ExportedRelation,
ExportedSegmentation,
ExportedSpan,
ExportedText,
)
class Labels(abc.ABC):
label_class = ExportedLabel
column = "labels"
fields: Tuple[str, ...] = ("example", "label") # To boost performance
def __init__(self, examples: QuerySet[ExportedExample], user=None):
self.label_groups = defaultdict(list)
labels = self.label_class.objects.filter(example__in=examples)
if user:
labels = labels.filter(user=user)
for label in labels.select_related(*self.fields):
self.label_groups[label.example.id].append(label)
def find_by(self, example_id: int) -> Dict[str, List[ExportedLabel]]:
return {self.column: self.label_groups[example_id]}
class Categories(Labels):
label_class = ExportedCategory
column = "categories"
fields = ("example", "label")
class Spans(Labels):
label_class = ExportedSpan
column = "entities"
fields = ("example", "label")
class Relations(Labels):
label_class = ExportedRelation
column = "relations"
fields = ("example", "type")
class Texts(Labels):
label_class = ExportedText
column = "labels"
fields = ("example",)
class BoundingBoxes(Labels):
label_class = ExportedBoundingBox
column = "labels"
fields = ("example", "label")
class Segments(Labels):
label_class = ExportedSegmentation
column = "labels"
fields = ("example", "label")