You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
1.7 KiB

  1. """
  2. Represents label collection.
  3. """
  4. import abc
  5. from collections import defaultdict
  6. from typing import Dict, List, Tuple
  7. from django.db.models import QuerySet
  8. from data_export.models import (
  9. ExportedBoundingBox,
  10. ExportedCategory,
  11. ExportedExample,
  12. ExportedLabel,
  13. ExportedRelation,
  14. ExportedSegmentation,
  15. ExportedSpan,
  16. ExportedText,
  17. )
  18. class Labels(abc.ABC):
  19. label_class = ExportedLabel
  20. column = "labels"
  21. fields: Tuple[str, ...] = ("example", "label") # To boost performance
  22. def __init__(self, examples: QuerySet[ExportedExample], user=None):
  23. self.label_groups = defaultdict(list)
  24. labels = self.label_class.objects.filter(example__in=examples)
  25. if user:
  26. labels = labels.filter(user=user)
  27. for label in labels.select_related(*self.fields):
  28. self.label_groups[label.example.id].append(label)
  29. def find_by(self, example_id: int) -> Dict[str, List[ExportedLabel]]:
  30. return {self.column: self.label_groups[example_id]}
  31. class Categories(Labels):
  32. label_class = ExportedCategory
  33. column = "categories"
  34. fields = ("example", "label")
  35. class Spans(Labels):
  36. label_class = ExportedSpan
  37. column = "entities"
  38. fields = ("example", "label")
  39. class Relations(Labels):
  40. label_class = ExportedRelation
  41. column = "relations"
  42. fields = ("example", "type")
  43. class Texts(Labels):
  44. label_class = ExportedText
  45. column = "labels"
  46. fields = ("example",)
  47. class BoundingBoxes(Labels):
  48. label_class = ExportedBoundingBox
  49. column = "labels"
  50. fields = ("example", "label")
  51. class Segments(Labels):
  52. label_class = ExportedSegmentation
  53. column = "labels"
  54. fields = ("example", "label")