You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

83 lines
2.7 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from django.db.models import Count, Manager
  2. class LabelManager(Manager):
  3. label_type_field = "label"
  4. def calc_label_distribution(self, examples, members, labels):
  5. """Calculate label distribution.
  6. Args:
  7. examples: example queryset.
  8. members: user queryset.
  9. labels: label queryset.
  10. Returns:
  11. label distribution per user.
  12. Examples:
  13. >>> self.calc_label_distribution(examples, members, labels)
  14. {'admin': {'positive': 10, 'negative': 5}}
  15. """
  16. distribution = {member.username: {label.text: 0 for label in labels} for member in members}
  17. items = (
  18. self.filter(example_id__in=examples)
  19. .values("user__username", f"{self.label_type_field}__text")
  20. .annotate(count=Count(f"{self.label_type_field}__text"))
  21. )
  22. for item in items:
  23. username = item["user__username"]
  24. label = item[f"{self.label_type_field}__text"]
  25. count = item["count"]
  26. distribution[username][label] = count
  27. return distribution
  28. def get_labels(self, label, project):
  29. if project.collaborative_annotation:
  30. return self.filter(example=label.example)
  31. else:
  32. return self.filter(example=label.example, user=label.user)
  33. def can_annotate(self, label, project) -> bool:
  34. raise NotImplementedError("Please implement this method in the subclass")
  35. def filter_annotatable_labels(self, labels, project):
  36. return [label for label in labels if self.can_annotate(label, project)]
  37. class CategoryManager(LabelManager):
  38. def can_annotate(self, label, project) -> bool:
  39. is_exclusive = project.single_class_classification
  40. categories = self.get_labels(label, project)
  41. if is_exclusive:
  42. return not categories.exists()
  43. else:
  44. return not categories.filter(label=label.label).exists()
  45. class SpanManager(LabelManager):
  46. def can_annotate(self, label, project) -> bool:
  47. overlapping = getattr(project, "allow_overlapping", False)
  48. spans = self.get_labels(label, project)
  49. if overlapping:
  50. return True
  51. for span in spans:
  52. if span.is_overlapping(label):
  53. return False
  54. return True
  55. class TextLabelManager(LabelManager):
  56. def can_annotate(self, label, project) -> bool:
  57. texts = self.get_labels(label, project)
  58. for text in texts:
  59. if text.is_same_text(label):
  60. return False
  61. return True
  62. class RelationManager(LabelManager):
  63. label_type_field = "type"
  64. def can_annotate(self, label, project) -> bool:
  65. return True