You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

58 lines
1.4 KiB

  1. """
  2. Represents label collection.
  3. """
  4. import abc
  5. from collections import defaultdict
  6. from typing import Dict, List, Tuple
  7. from django.db.models import QuerySet
  8. from data_export.models import (
  9. ExportedCategory,
  10. ExportedExample,
  11. ExportedLabel,
  12. ExportedRelation,
  13. ExportedSpan,
  14. ExportedText,
  15. )
  16. class Labels(abc.ABC):
  17. label_class = ExportedLabel
  18. column = "labels"
  19. fields: Tuple[str, ...] = ("example", "label") # To boost performance
  20. def __init__(self, examples: QuerySet[ExportedExample], user=None):
  21. self.label_groups = defaultdict(list)
  22. labels = self.label_class.objects.filter(example__in=examples)
  23. if user:
  24. labels = labels.filter(user=user)
  25. for label in labels.select_related(*self.fields):
  26. self.label_groups[label.example.id].append(label)
  27. def find_by(self, example_id: int) -> Dict[str, List[ExportedLabel]]:
  28. return {self.column: self.label_groups[example_id]}
  29. class Categories(Labels):
  30. label_class = ExportedCategory
  31. column = "categories"
  32. fields = ("example", "label")
  33. class Spans(Labels):
  34. label_class = ExportedSpan
  35. column = "entities"
  36. fields = ("example", "label")
  37. class Relations(Labels):
  38. label_class = ExportedRelation
  39. column = "relations"
  40. fields = ("example", "type")
  41. class Texts(Labels):
  42. label_class = ExportedText
  43. column = "labels"
  44. fields = ("example",)