You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

239 lines
8.8 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. import abc
  2. import itertools
  3. from collections import defaultdict
  4. from typing import Any, Dict, Iterator, List, Tuple
  5. from .data import Record
  6. from examples.models import Example
  7. from projects.models import Project
  8. SpanType = Tuple[int, int, str]
  9. class BaseRepository:
  10. def __init__(self, project: Project):
  11. self.project = project
  12. def list(self, export_approved=False) -> Iterator[Record]:
  13. raise NotImplementedError()
  14. def create_unlabeled_record(self, example: Example) -> Record:
  15. raise NotImplementedError()
  16. class FileRepository(BaseRepository):
  17. def list(self, export_approved=False) -> Iterator[Record]:
  18. examples = self.project.examples.all()
  19. if export_approved:
  20. examples = examples.exclude(states=None)
  21. for example in examples:
  22. label_per_user = self.label_per_user(example)
  23. if self.project.collaborative_annotation:
  24. label_per_user = self.reduce_user(label_per_user)
  25. for user, label in label_per_user.items():
  26. yield Record(
  27. data_id=example.id,
  28. data=example.upload_name,
  29. label=label,
  30. user=user,
  31. metadata=example.meta,
  32. )
  33. # todo:
  34. # If there is no label, export the doc with `unknown` user.
  35. # This is a quick solution.
  36. # In the future, the doc without label will be exported
  37. # with the user who approved the doc.
  38. # This means I will allow each user to be able to approve the doc.
  39. if len(label_per_user) == 0:
  40. yield self.create_unlabeled_record(example)
  41. def create_unlabeled_record(self, example: Example) -> Record:
  42. return Record(data_id=example.id, data=example.upload_name, label=[], user="unknown", metadata=example.meta)
  43. def label_per_user(self, example) -> Dict:
  44. label_per_user = defaultdict(list)
  45. for a in example.categories.all():
  46. label_per_user[a.user.username].append(a.label.text)
  47. return label_per_user
  48. def reduce_user(self, label_per_user: Dict[str, Any]):
  49. value = list(itertools.chain(*label_per_user.values()))
  50. return {"all": value}
  51. class Speech2TextRepository(FileRepository):
  52. def label_per_user(self, example) -> Dict:
  53. label_per_user = defaultdict(list)
  54. for a in example.texts.all():
  55. label_per_user[a.user.username].append(a.text)
  56. return label_per_user
  57. class TextRepository(BaseRepository):
  58. @property
  59. def docs(self):
  60. return Example.objects.filter(project=self.project)
  61. def list(self, export_approved=False):
  62. docs = self.docs
  63. if export_approved:
  64. docs = docs.exclude(states=None)
  65. for doc in docs:
  66. label_per_user = self.label_per_user(doc)
  67. if self.project.collaborative_annotation:
  68. label_per_user = self.reduce_user(label_per_user)
  69. for user, label in label_per_user.items():
  70. yield Record(data_id=doc.id, data=doc.text, label=label, user=user, metadata=doc.meta)
  71. # todo:
  72. # If there is no label, export the doc with `unknown` user.
  73. # This is a quick solution.
  74. # In the future, the doc without label will be exported
  75. # with the user who approved the doc.
  76. # This means I will allow each user to be able to approve the doc.
  77. if len(label_per_user) == 0:
  78. yield self.create_unlabeled_record(doc)
  79. def create_unlabeled_record(self, example: Example) -> Record:
  80. return Record(data_id=example.id, data=example.text, label=[], user="unknown", metadata=example.meta)
  81. @abc.abstractmethod
  82. def label_per_user(self, doc) -> Dict:
  83. raise NotImplementedError()
  84. def reduce_user(self, label_per_user: Dict[str, Any]):
  85. value = list(itertools.chain(*label_per_user.values()))
  86. return {"all": value}
  87. class TextClassificationRepository(TextRepository):
  88. @property
  89. def docs(self):
  90. return Example.objects.filter(project=self.project).prefetch_related("categories__user", "categories__label")
  91. def label_per_user(self, doc) -> Dict:
  92. label_per_user = defaultdict(list)
  93. for a in doc.categories.all():
  94. label_per_user[a.user.username].append(a.label.text)
  95. return label_per_user
  96. class SequenceLabelingRepository(TextRepository):
  97. @property
  98. def docs(self):
  99. return Example.objects.filter(project=self.project).prefetch_related("spans__user", "spans__label")
  100. def label_per_user(self, doc) -> Dict:
  101. label_per_user = defaultdict(list)
  102. for a in doc.spans.all():
  103. label = (a.start_offset, a.end_offset, a.label.text)
  104. label_per_user[a.user.username].append(label)
  105. return label_per_user
  106. class RelationExtractionRepository(TextRepository):
  107. @property
  108. def docs(self):
  109. return Example.objects.filter(project=self.project).prefetch_related(
  110. "spans__user", "spans__label", "relations__user", "relations__type"
  111. )
  112. def create_unlabeled_record(self, example: Example) -> Record:
  113. return Record(
  114. data_id=example.id,
  115. data=example.text,
  116. label={"entities": [], "relations": []},
  117. user="unknown",
  118. metadata=example.meta,
  119. )
  120. def label_per_user(self, doc) -> Dict:
  121. relation_per_user: Dict = defaultdict(list)
  122. span_per_user: Dict = defaultdict(list)
  123. label_per_user: Dict = defaultdict(dict)
  124. for relation in doc.relations.all():
  125. relation_per_user[relation.user.username].append(
  126. {
  127. "id": relation.id,
  128. "from_id": relation.from_id.id,
  129. "to_id": relation.to_id.id,
  130. "type": relation.type.text,
  131. }
  132. )
  133. for span in doc.spans.all():
  134. span_per_user[span.user.username].append(
  135. {
  136. "id": span.id,
  137. "start_offset": span.start_offset,
  138. "end_offset": span.end_offset,
  139. "label": span.label.text,
  140. }
  141. )
  142. for user, relations in relation_per_user.items():
  143. label_per_user[user]["relations"] = relations
  144. for user, span in span_per_user.items():
  145. label_per_user[user]["entities"] = span
  146. return label_per_user
  147. def reduce_user(self, label_per_user: Dict[str, Any]):
  148. entities = []
  149. relations = []
  150. for user, label in label_per_user.items():
  151. entities.extend(label.get("entities", []))
  152. relations.extend(label.get("relations", []))
  153. return {"all": {"entities": entities, "relations": relations}}
  154. class Seq2seqRepository(TextRepository):
  155. @property
  156. def docs(self):
  157. return Example.objects.filter(project=self.project).prefetch_related("texts__user")
  158. def label_per_user(self, doc) -> Dict:
  159. label_per_user = defaultdict(list)
  160. for a in doc.texts.all():
  161. label_per_user[a.user.username].append(a.text)
  162. return label_per_user
  163. class IntentDetectionSlotFillingRepository(TextRepository):
  164. @property
  165. def docs(self):
  166. return Example.objects.filter(project=self.project).prefetch_related(
  167. "categories__user", "categories__label", "spans__user", "spans__label"
  168. )
  169. def create_unlabeled_record(self, example: Example) -> Record:
  170. return Record(
  171. data_id=example.id,
  172. data=example.text,
  173. label={"entities": [], "cats": []},
  174. user="unknown",
  175. metadata=example.meta,
  176. )
  177. def label_per_user(self, doc) -> Dict:
  178. category_per_user: Dict[str, List[str]] = defaultdict(list)
  179. span_per_user: Dict[str, List[SpanType]] = defaultdict(list)
  180. label_per_user: Dict[str, Dict[str, List]] = defaultdict(dict)
  181. for a in doc.categories.all():
  182. category_per_user[a.user.username].append(a.label.text)
  183. for a in doc.spans.all():
  184. span_per_user[a.user.username].append((a.start_offset, a.end_offset, a.label.text))
  185. for user, cats in category_per_user.items():
  186. label_per_user[user]["cats"] = cats
  187. for user, span in span_per_user.items():
  188. label_per_user[user]["entities"] = span
  189. for label in label_per_user.values():
  190. label.setdefault("cats", [])
  191. label.setdefault("entities", [])
  192. return label_per_user
  193. def reduce_user(self, label_per_user: Dict[str, Any]):
  194. cats = []
  195. entities = []
  196. for user, label in label_per_user.items():
  197. cats.extend(label.get("cats", []))
  198. entities.extend(label.get("entities", []))
  199. return {"all": {"entities": entities, "cats": cats}}