You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

196 lines
7.3 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import abc
  2. import itertools
  3. from collections import defaultdict
  4. from typing import Dict, Iterator, List, Tuple, Union
  5. from .data import Record
  6. from examples.models import Example
  7. from projects.models import Project
  8. SpanType = Tuple[int, int, str]
  9. class BaseRepository(abc.ABC):
  10. def __init__(self, project: Project):
  11. self.project = project
  12. @abc.abstractmethod
  13. def list(self, export_approved=False) -> Iterator[Record]:
  14. pass
  15. class FileRepository(BaseRepository):
  16. def list(self, export_approved=False) -> Iterator[Record]:
  17. examples = self.project.examples.all()
  18. if export_approved:
  19. examples = examples.exclude(annotations_approved_by=None)
  20. for example in examples:
  21. label_per_user = self.label_per_user(example)
  22. if self.project.collaborative_annotation:
  23. label_per_user = self.reduce_user(label_per_user)
  24. for user, label in label_per_user.items():
  25. yield Record(
  26. data_id=example.id,
  27. data=str(example.filename).split("/")[-1],
  28. label=label,
  29. user=user,
  30. metadata=example.meta,
  31. )
  32. # todo:
  33. # If there is no label, export the doc with `unknown` user.
  34. # This is a quick solution.
  35. # In the future, the doc without label will be exported
  36. # with the user who approved the doc.
  37. # This means I will allow each user to be able to approve the doc.
  38. if len(label_per_user) == 0:
  39. yield Record(
  40. data_id=example.id, data=str(example.filename).split("/")[-1], label=[], user="unknown", metadata={}
  41. )
  42. def label_per_user(self, example) -> Dict:
  43. label_per_user = defaultdict(list)
  44. for a in example.categories.all():
  45. label_per_user[a.user.username].append(a.label.text)
  46. return label_per_user
  47. def reduce_user(self, label_per_user: Dict[str, List]):
  48. value = list(itertools.chain(*label_per_user.values()))
  49. return {"all": value}
  50. class Speech2TextRepository(FileRepository):
  51. def label_per_user(self, example) -> Dict:
  52. label_per_user = defaultdict(list)
  53. for a in example.texts.all():
  54. label_per_user[a.user.username].append(a.text)
  55. return label_per_user
  56. class TextRepository(BaseRepository):
  57. @property
  58. def docs(self):
  59. return Example.objects.filter(project=self.project)
  60. def list(self, export_approved=False):
  61. docs = self.docs
  62. if export_approved:
  63. docs = docs.exclude(annotations_approved_by=None)
  64. for doc in docs:
  65. label_per_user = self.label_per_user(doc)
  66. if self.project.collaborative_annotation:
  67. label_per_user = self.reduce_user(label_per_user)
  68. for user, label in label_per_user.items():
  69. yield Record(data_id=doc.id, data=doc.text, label=label, user=user, metadata=doc.meta)
  70. # todo:
  71. # If there is no label, export the doc with `unknown` user.
  72. # This is a quick solution.
  73. # In the future, the doc without label will be exported
  74. # with the user who approved the doc.
  75. # This means I will allow each user to be able to approve the doc.
  76. if len(label_per_user) == 0:
  77. yield Record(data_id=doc.id, data=doc.text, label=[], user="unknown", metadata={})
  78. @abc.abstractmethod
  79. def label_per_user(self, doc) -> Dict:
  80. raise NotImplementedError()
  81. def reduce_user(self, label_per_user: Dict[str, List]):
  82. value = list(itertools.chain(*label_per_user.values()))
  83. return {"all": value}
  84. class TextClassificationRepository(TextRepository):
  85. @property
  86. def docs(self):
  87. return Example.objects.filter(project=self.project).prefetch_related("categories__user", "categories__label")
  88. def label_per_user(self, doc) -> Dict:
  89. label_per_user = defaultdict(list)
  90. for a in doc.categories.all():
  91. label_per_user[a.user.username].append(a.label.text)
  92. return label_per_user
  93. class SequenceLabelingRepository(TextRepository):
  94. @property
  95. def docs(self):
  96. return Example.objects.filter(project=self.project).prefetch_related("spans__user", "spans__label")
  97. def label_per_user(self, doc) -> Dict:
  98. label_per_user = defaultdict(list)
  99. for a in doc.spans.all():
  100. label = (a.start_offset, a.end_offset, a.label.text)
  101. label_per_user[a.user.username].append(label)
  102. return label_per_user
  103. class RelationExtractionRepository(TextRepository):
  104. @property
  105. def docs(self):
  106. return Example.objects.filter(project=self.project).prefetch_related(
  107. "spans__user", "spans__label", "relations__user", "relations__type"
  108. )
  109. def label_per_user(self, doc) -> Dict:
  110. relation_per_user: Dict = defaultdict(list)
  111. span_per_user: Dict = defaultdict(list)
  112. label_per_user: Dict = defaultdict(dict)
  113. for relation in doc.relations.all():
  114. relation_per_user[relation.user.username].append(
  115. {
  116. "id": relation.id,
  117. "from_id": relation.from_id.id,
  118. "to_id": relation.to_id.id,
  119. "type": relation.type.text,
  120. }
  121. )
  122. for span in doc.spans.all():
  123. span_per_user[span.user.username].append(
  124. {
  125. "id": span.id,
  126. "start_offset": span.start_offset,
  127. "end_offset": span.end_offset,
  128. "label": span.label.text,
  129. }
  130. )
  131. for user, relations in relation_per_user.items():
  132. label_per_user[user]["relations"] = relations
  133. for user, span in span_per_user.items():
  134. label_per_user[user]["entities"] = span
  135. return label_per_user
  136. class Seq2seqRepository(TextRepository):
  137. @property
  138. def docs(self):
  139. return Example.objects.filter(project=self.project).prefetch_related("texts__user")
  140. def label_per_user(self, doc) -> Dict:
  141. label_per_user = defaultdict(list)
  142. for a in doc.texts.all():
  143. label_per_user[a.user.username].append(a.text)
  144. return label_per_user
  145. class IntentDetectionSlotFillingRepository(TextRepository):
  146. @property
  147. def docs(self):
  148. return Example.objects.filter(project=self.project).prefetch_related(
  149. "categories__user", "categories__label", "spans__user", "spans__label"
  150. )
  151. def label_per_user(self, doc) -> Dict:
  152. category_per_user: Dict[str, List[str]] = defaultdict(list)
  153. span_per_user: Dict[str, List[SpanType]] = defaultdict(list)
  154. label_per_user: Dict[str, Dict[str, Union[List[str], List[SpanType]]]] = defaultdict(dict)
  155. for a in doc.categories.all():
  156. category_per_user[a.user.username].append(a.label.text)
  157. for a in doc.spans.all():
  158. span_per_user[a.user.username].append((a.start_offset, a.end_offset, a.label.text))
  159. for user, cats in category_per_user.items():
  160. label_per_user[user]["cats"] = cats
  161. for user, span in span_per_user.items():
  162. label_per_user[user]["entities"] = span
  163. return label_per_user