You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
4.9 KiB

  1. import abc
  2. import itertools
  3. from collections import defaultdict
  4. from typing import Dict, Iterator, List
  5. from ...models import Document, Project
  6. from .data import Record
  7. class BaseRepository(abc.ABC):
  8. def __init__(self, project: Project):
  9. self.project = project
  10. @abc.abstractmethod
  11. def list(self, export_approved=False) -> Iterator[Record]:
  12. pass
  13. class FileRepository(BaseRepository):
  14. def list(self, export_approved=False) -> Iterator[Record]:
  15. examples = self.project.examples.all()
  16. if export_approved:
  17. examples = examples.exclude(annotations_approved_by=None)
  18. for example in examples:
  19. label_per_user = self.label_per_user(example)
  20. if self.project.collaborative_annotation:
  21. label_per_user = self.reduce_user(label_per_user)
  22. for user, label in label_per_user.items():
  23. yield Record(
  24. id=example.id,
  25. data=example.filename,
  26. label=label,
  27. user=user,
  28. metadata=example.meta
  29. )
  30. # todo:
  31. # If there is no label, export the doc with `unknown` user.
  32. # This is a quick solution.
  33. # In the future, the doc without label will be exported
  34. # with the user who approved the doc.
  35. # This means I will allow each user to be able to approve the doc.
  36. if len(label_per_user) == 0:
  37. yield Record(
  38. id=example.id,
  39. data=example.text,
  40. label=[],
  41. user='unknown',
  42. metadata={}
  43. )
  44. def label_per_user(self, example) -> Dict:
  45. label_per_user = defaultdict(list)
  46. for a in example.categories.all():
  47. label_per_user[a.user.username].append(a.label.text)
  48. return label_per_user
  49. def reduce_user(self, label_per_user: Dict[str, List]):
  50. value = list(itertools.chain(*label_per_user.values()))
  51. return {'all': value}
  52. class TextRepository(BaseRepository):
  53. @property
  54. def docs(self):
  55. return Document.objects.filter(project=self.project)
  56. def list(self, export_approved=False):
  57. docs = self.docs
  58. if export_approved:
  59. docs = docs.exclude(annotations_approved_by=None)
  60. for doc in docs:
  61. label_per_user = self.label_per_user(doc)
  62. if self.project.collaborative_annotation:
  63. label_per_user = self.reduce_user(label_per_user)
  64. for user, label in label_per_user.items():
  65. yield Record(
  66. id=doc.id,
  67. data=doc.text,
  68. label=label,
  69. user=user,
  70. metadata=doc.meta
  71. )
  72. # todo:
  73. # If there is no label, export the doc with `unknown` user.
  74. # This is a quick solution.
  75. # In the future, the doc without label will be exported
  76. # with the user who approved the doc.
  77. # This means I will allow each user to be able to approve the doc.
  78. if len(label_per_user) == 0:
  79. yield Record(
  80. id=doc.id,
  81. data=doc.text,
  82. label=[],
  83. user='unknown',
  84. metadata={}
  85. )
  86. @abc.abstractmethod
  87. def label_per_user(self, doc) -> Dict:
  88. raise NotImplementedError()
  89. def reduce_user(self, label_per_user: Dict[str, List]):
  90. value = list(itertools.chain(*label_per_user.values()))
  91. return {'all': value}
  92. class TextClassificationRepository(TextRepository):
  93. @property
  94. def docs(self):
  95. return Document.objects.filter(project=self.project).prefetch_related(
  96. 'categories__user', 'categories__label'
  97. )
  98. def label_per_user(self, doc) -> Dict:
  99. label_per_user = defaultdict(list)
  100. for a in doc.categories.all():
  101. label_per_user[a.user.username].append(a.label.text)
  102. return label_per_user
  103. class SequenceLabelingRepository(TextRepository):
  104. @property
  105. def docs(self):
  106. return Document.objects.filter(project=self.project).prefetch_related(
  107. 'spans__user', 'spans__label'
  108. )
  109. def label_per_user(self, doc) -> Dict:
  110. label_per_user = defaultdict(list)
  111. for a in doc.spans.all():
  112. label = (a.start_offset, a.end_offset, a.label.text)
  113. label_per_user[a.user.username].append(label)
  114. return label_per_user
  115. class Seq2seqRepository(TextRepository):
  116. @property
  117. def docs(self):
  118. return Document.objects.filter(project=self.project).prefetch_related(
  119. 'texts__user'
  120. )
  121. def label_per_user(self, doc) -> Dict:
  122. label_per_user = defaultdict(list)
  123. for a in doc.texts.all():
  124. label_per_user[a.user.username].append(a.text)
  125. return label_per_user