You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

192 lines
6.2 KiB

  1. import abc
  2. import itertools
  3. from collections import defaultdict
  4. from typing import Dict, Iterator, List
  5. from api.models import Project
  6. from examples.models import Example
  7. from .data import Record
  8. class BaseRepository(abc.ABC):
  9. def __init__(self, project: Project):
  10. self.project = project
  11. @abc.abstractmethod
  12. def list(self, export_approved=False) -> Iterator[Record]:
  13. pass
  14. class FileRepository(BaseRepository):
  15. def list(self, export_approved=False) -> Iterator[Record]:
  16. examples = self.project.examples.all()
  17. if export_approved:
  18. examples = examples.exclude(annotations_approved_by=None)
  19. for example in examples:
  20. label_per_user = self.label_per_user(example)
  21. if self.project.collaborative_annotation:
  22. label_per_user = self.reduce_user(label_per_user)
  23. for user, label in label_per_user.items():
  24. yield Record(
  25. id=example.id,
  26. data=str(example.filename).split('/')[-1],
  27. label=label,
  28. user=user,
  29. metadata=example.meta
  30. )
  31. # todo:
  32. # If there is no label, export the doc with `unknown` user.
  33. # This is a quick solution.
  34. # In the future, the doc without label will be exported
  35. # with the user who approved the doc.
  36. # This means I will allow each user to be able to approve the doc.
  37. if len(label_per_user) == 0:
  38. yield Record(
  39. id=example.id,
  40. data=str(example.filename).split('/')[-1],
  41. label=[],
  42. user='unknown',
  43. metadata={}
  44. )
  45. def label_per_user(self, example) -> Dict:
  46. label_per_user = defaultdict(list)
  47. for a in example.categories.all():
  48. label_per_user[a.user.username].append(a.label.text)
  49. return label_per_user
  50. def reduce_user(self, label_per_user: Dict[str, List]):
  51. value = list(itertools.chain(*label_per_user.values()))
  52. return {'all': value}
  53. class Speech2TextRepository(FileRepository):
  54. def label_per_user(self, example) -> Dict:
  55. label_per_user = defaultdict(list)
  56. for a in example.texts.all():
  57. label_per_user[a.user.username].append(a.text)
  58. return label_per_user
  59. class TextRepository(BaseRepository):
  60. @property
  61. def docs(self):
  62. return Example.objects.filter(project=self.project)
  63. def list(self, export_approved=False):
  64. docs = self.docs
  65. if export_approved:
  66. docs = docs.exclude(annotations_approved_by=None)
  67. for doc in docs:
  68. label_per_user = self.label_per_user(doc)
  69. if self.project.collaborative_annotation:
  70. label_per_user = self.reduce_user(label_per_user)
  71. for user, label in label_per_user.items():
  72. yield Record(
  73. id=doc.id,
  74. data=doc.text,
  75. label=label,
  76. user=user,
  77. metadata=doc.meta
  78. )
  79. # todo:
  80. # If there is no label, export the doc with `unknown` user.
  81. # This is a quick solution.
  82. # In the future, the doc without label will be exported
  83. # with the user who approved the doc.
  84. # This means I will allow each user to be able to approve the doc.
  85. if len(label_per_user) == 0:
  86. yield Record(
  87. id=doc.id,
  88. data=doc.text,
  89. label=[],
  90. user='unknown',
  91. metadata={}
  92. )
  93. @abc.abstractmethod
  94. def label_per_user(self, doc) -> Dict:
  95. raise NotImplementedError()
  96. def reduce_user(self, label_per_user: Dict[str, List]):
  97. value = list(itertools.chain(*label_per_user.values()))
  98. return {'all': value}
  99. class TextClassificationRepository(TextRepository):
  100. @property
  101. def docs(self):
  102. return Example.objects.filter(project=self.project).prefetch_related(
  103. 'categories__user', 'categories__label'
  104. )
  105. def label_per_user(self, doc) -> Dict:
  106. label_per_user = defaultdict(list)
  107. for a in doc.categories.all():
  108. label_per_user[a.user.username].append(a.label.text)
  109. return label_per_user
  110. class SequenceLabelingRepository(TextRepository):
  111. @property
  112. def docs(self):
  113. return Example.objects.filter(project=self.project).prefetch_related(
  114. 'spans__user', 'spans__label'
  115. )
  116. def label_per_user(self, doc) -> Dict:
  117. label_per_user = defaultdict(list)
  118. for a in doc.spans.all():
  119. label = (a.start_offset, a.end_offset, a.label.text)
  120. label_per_user[a.user.username].append(label)
  121. return label_per_user
  122. class Seq2seqRepository(TextRepository):
  123. @property
  124. def docs(self):
  125. return Example.objects.filter(project=self.project).prefetch_related(
  126. 'texts__user'
  127. )
  128. def label_per_user(self, doc) -> Dict:
  129. label_per_user = defaultdict(list)
  130. for a in doc.texts.all():
  131. label_per_user[a.user.username].append(a.text)
  132. return label_per_user
  133. class IntentDetectionSlotFillingRepository(TextRepository):
  134. @property
  135. def docs(self):
  136. return Example.objects.filter(project=self.project).prefetch_related(
  137. 'categories__user',
  138. 'categories__label',
  139. 'spans__user',
  140. 'spans__label'
  141. )
  142. def label_per_user(self, doc) -> Dict:
  143. category_per_user = defaultdict(list)
  144. span_per_user = defaultdict(list)
  145. label_per_user = defaultdict(dict)
  146. for a in doc.categories.all():
  147. category_per_user[a.user.username].append(a.label.text)
  148. for a in doc.spans.all():
  149. label = (a.start_offset, a.end_offset, a.label.text)
  150. span_per_user[a.user.username].append(label)
  151. for user, label in category_per_user.items():
  152. label_per_user[user]['cats'] = label
  153. for user, label in span_per_user.items():
  154. label_per_user[user]['entities'] = label
  155. return label_per_user