You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.5 KiB

2 years ago
2 years ago
2 years ago
2 years ago
  1. from django.db.models import Manager, Count
  2. class LabelManager(Manager):
  3. def calc_label_distribution(self, examples, members, labels):
  4. """Calculate label distribution.
  5. Args:
  6. examples: example queryset.
  7. members: user queryset.
  8. labels: label queryset.
  9. Returns:
  10. label distribution per user.
  11. Examples:
  12. >>> self.calc_label_distribution(examples, members, labels)
  13. {'admin': {'positive': 10, 'negative': 5}}
  14. """
  15. distribution = {member.username: {label.text: 0 for label in labels} for member in members}
  16. items = (
  17. self.filter(example_id__in=examples)
  18. .values("user__username", "label__text")
  19. .annotate(count=Count("label__text"))
  20. )
  21. for item in items:
  22. username = item["user__username"]
  23. label = item["label__text"]
  24. count = item["count"]
  25. distribution[username][label] = count
  26. return distribution
  27. def get_labels(self, label, project):
  28. if project.collaborative_annotation:
  29. return self.filter(example=label.example)
  30. else:
  31. return self.filter(example=label.example, user=label.user)
  32. def can_annotate(self, label, project) -> bool:
  33. raise NotImplementedError("Please implement this method in the subclass")
  34. def filter_annotatable_labels(self, labels, project):
  35. return [label for label in labels if self.can_annotate(label, project)]
  36. class CategoryManager(LabelManager):
  37. def can_annotate(self, label, project) -> bool:
  38. is_exclusive = project.single_class_classification
  39. categories = self.get_labels(label, project)
  40. if is_exclusive:
  41. return not categories.exists()
  42. else:
  43. return not categories.filter(label=label.label).exists()
  44. class SpanManager(LabelManager):
  45. def can_annotate(self, label, project) -> bool:
  46. overlapping = getattr(project, "allow_overlapping", False)
  47. spans = self.get_labels(label, project)
  48. if overlapping:
  49. return True
  50. for span in spans:
  51. if span.is_overlapping(label):
  52. return False
  53. return True
  54. class TextLabelManager(LabelManager):
  55. def can_annotate(self, label, project) -> bool:
  56. texts = self.get_labels(label, project)
  57. for text in texts:
  58. if text.is_same_text(label):
  59. return False
  60. return True