You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
3.4 KiB

  1. from collections import Counter
  2. from django.conf import settings
  3. from django.db.models import Count, Manager
  4. class AnnotationManager(Manager):
  5. def calc_label_frequency(self, examples):
  6. """Calculate label frequencies.
  7. Args:
  8. examples: example queryset.
  9. Returns:
  10. label frequency.
  11. Examples:
  12. >>> {'positive': 3, 'negative': 4}
  13. """
  14. freq = Counter()
  15. annotations = self.filter(example_id__in=examples)
  16. for d in annotations.values('label__text').annotate(Count('label')):
  17. freq[d['label__text']] += d['label__count']
  18. return freq
  19. def calc_user_frequency(self, examples):
  20. """Calculate user frequencies.
  21. Args:
  22. examples: example queryset.
  23. Returns:
  24. user frequency.
  25. Examples:
  26. >>> {'mary': 3, 'john': 4}
  27. """
  28. freq = Counter()
  29. annotations = self.filter(example_id__in=examples)
  30. for d in annotations.values('user__username').annotate(Count('user')):
  31. freq[d['user__username']] += d['user__count']
  32. return freq
  33. def get_label_per_data(self, project):
  34. label_count = Counter()
  35. user_count = Counter()
  36. docs = project.examples.all()
  37. annotations = self.filter(example_id__in=docs.all())
  38. for d in annotations.values('label__text', 'user__username').annotate(Count('label'), Count('user')):
  39. label_count[d['label__text']] += d['label__count']
  40. user_count[d['user__username']] += d['user__count']
  41. return label_count, user_count
  42. class Seq2seqAnnotationManager(Manager):
  43. def get_label_per_data(self, project):
  44. label_count = Counter()
  45. user_count = Counter()
  46. docs = project.examples.all()
  47. annotations = self.filter(example_id__in=docs.all())
  48. for d in annotations.values('text', 'user__username').annotate(Count('text'), Count('user')):
  49. label_count[d['text']] += d['text__count']
  50. user_count[d['user__username']] += d['user__count']
  51. return label_count, user_count
  52. class RoleMappingManager(Manager):
  53. def can_update(self, project: int, mapping_id: int, rolename: str):
  54. queryset = self.filter(
  55. project=project, role__name=settings.ROLE_PROJECT_ADMIN
  56. )
  57. if queryset.count() > 1:
  58. return True
  59. else:
  60. mapping = queryset.first()
  61. if mapping.id == mapping_id and rolename != settings.ROLE_PROJECT_ADMIN:
  62. return False
  63. return True
  64. class ExampleManager(Manager):
  65. def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
  66. super().bulk_create(objs, batch_size=batch_size, ignore_conflicts=ignore_conflicts)
  67. uuids = [data.uuid for data in objs]
  68. examples = self.in_bulk(uuids, field_name='uuid')
  69. return [examples[uid] for uid in uuids]
  70. class ExampleStateManager(Manager):
  71. def count_done(self, examples):
  72. return self.filter(example_id__in=examples).distinct().values('example').count()
  73. def count_user(self, examples):
  74. done_count = self.filter(example_id__in=examples)\
  75. .values('confirmed_by__username')\
  76. .annotate(total=Count('confirmed_by'))
  77. return {
  78. obj['confirmed_by__username']: obj['total']
  79. for obj in done_count
  80. }