You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

164 lines
5.2 KiB

  1. import abc
  2. import itertools
  3. from collections import defaultdict
  4. from typing import Dict, Iterator, List
  5. from ...models import Example, Project
  6. from .data import Record
  7. class BaseRepository(abc.ABC):
  8. def __init__(self, project: Project):
  9. self.project = project
  10. @abc.abstractmethod
  11. def list(self, export_approved=False) -> Iterator[Record]:
  12. pass
  13. class FileRepository(BaseRepository):
  14. def list(self, export_approved=False) -> Iterator[Record]:
  15. examples = self.project.examples.all()
  16. if export_approved:
  17. examples = examples.exclude(annotations_approved_by=None)
  18. for example in examples:
  19. label_per_user = self.label_per_user(example)
  20. if self.project.collaborative_annotation:
  21. label_per_user = self.reduce_user(label_per_user)
  22. for user, label in label_per_user.items():
  23. yield Record(
  24. id=example.id,
  25. data=str(example.filename).split('/')[-1],
  26. label=label,
  27. user=user,
  28. metadata=example.meta
  29. )
  30. # todo:
  31. # If there is no label, export the doc with `unknown` user.
  32. # This is a quick solution.
  33. # In the future, the doc without label will be exported
  34. # with the user who approved the doc.
  35. # This means I will allow each user to be able to approve the doc.
  36. if len(label_per_user) == 0:
  37. yield Record(
  38. id=example.id,
  39. data=str(example.filename).split('/')[-1],
  40. label=[],
  41. user='unknown',
  42. metadata={}
  43. )
  44. def label_per_user(self, example) -> Dict:
  45. label_per_user = defaultdict(list)
  46. for a in example.categories.all():
  47. label_per_user[a.user.username].append(a.label.text)
  48. return label_per_user
  49. def reduce_user(self, label_per_user: Dict[str, List]):
  50. value = list(itertools.chain(*label_per_user.values()))
  51. return {'all': value}
  52. class Speech2TextRepository(FileRepository):
  53. def label_per_user(self, example) -> Dict:
  54. label_per_user = defaultdict(list)
  55. for a in example.texts.all():
  56. label_per_user[a.user.username].append(a.text)
  57. return label_per_user
  58. class TextRepository(BaseRepository):
  59. @property
  60. def docs(self):
  61. return Example.objects.filter(project=self.project)
  62. def list(self, export_approved=False):
  63. docs = self.docs
  64. if export_approved:
  65. docs = docs.exclude(annotations_approved_by=None)
  66. for doc in docs:
  67. label_per_user = self.label_per_user(doc)
  68. if self.project.collaborative_annotation:
  69. label_per_user = self.reduce_user(label_per_user)
  70. for user, label in label_per_user.items():
  71. yield Record(
  72. id=doc.id,
  73. data=doc.text,
  74. label=label,
  75. user=user,
  76. metadata=doc.meta
  77. )
  78. # todo:
  79. # If there is no label, export the doc with `unknown` user.
  80. # This is a quick solution.
  81. # In the future, the doc without label will be exported
  82. # with the user who approved the doc.
  83. # This means I will allow each user to be able to approve the doc.
  84. if len(label_per_user) == 0:
  85. yield Record(
  86. id=doc.id,
  87. data=doc.text,
  88. label=[],
  89. user='unknown',
  90. metadata={}
  91. )
  92. @abc.abstractmethod
  93. def label_per_user(self, doc) -> Dict:
  94. raise NotImplementedError()
  95. def reduce_user(self, label_per_user: Dict[str, List]):
  96. value = list(itertools.chain(*label_per_user.values()))
  97. return {'all': value}
  98. class TextClassificationRepository(TextRepository):
  99. @property
  100. def docs(self):
  101. return Example.objects.filter(project=self.project).prefetch_related(
  102. 'categories__user', 'categories__label'
  103. )
  104. def label_per_user(self, doc) -> Dict:
  105. label_per_user = defaultdict(list)
  106. for a in doc.categories.all():
  107. label_per_user[a.user.username].append(a.label.text)
  108. return label_per_user
  109. class SequenceLabelingRepository(TextRepository):
  110. @property
  111. def docs(self):
  112. return Example.objects.filter(project=self.project).prefetch_related(
  113. 'spans__user', 'spans__label'
  114. )
  115. def label_per_user(self, doc) -> Dict:
  116. label_per_user = defaultdict(list)
  117. for a in doc.spans.all():
  118. label = (a.start_offset, a.end_offset, a.label.text)
  119. label_per_user[a.user.username].append(label)
  120. return label_per_user
  121. class Seq2seqRepository(TextRepository):
  122. @property
  123. def docs(self):
  124. return Example.objects.filter(project=self.project).prefetch_related(
  125. 'texts__user'
  126. )
  127. def label_per_user(self, doc) -> Dict:
  128. label_per_user = defaultdict(list)
  129. for a in doc.texts.all():
  130. label_per_user[a.user.username].append(a.text)
  131. return label_per_user