You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

343 lines
13 KiB

3 years ago
3 years ago
  1. import unittest
  2. from model_mommy import mommy
  3. from ..pipeline.repositories import (
  4. FileRepository,
  5. IntentDetectionSlotFillingRepository,
  6. RelationExtractionRepository,
  7. Seq2seqRepository,
  8. SequenceLabelingRepository,
  9. Speech2TextRepository,
  10. TextClassificationRepository,
  11. )
  12. from projects.models import (
  13. DOCUMENT_CLASSIFICATION,
  14. IMAGE_CLASSIFICATION,
  15. INTENT_DETECTION_AND_SLOT_FILLING,
  16. SEQ2SEQ,
  17. SEQUENCE_LABELING,
  18. SPEECH2TEXT,
  19. )
  20. from projects.tests.utils import prepare_project
  21. class TestRepository(unittest.TestCase):
  22. def assert_records(self, repository, expected):
  23. records = list(repository.list())
  24. self.assertEqual(len(records), len(expected))
  25. for record, expect in zip(records, expected):
  26. self.assertEqual(record.data, expect["data"])
  27. self.assertEqual(record.label, expect["label"])
  28. self.assertEqual(record.user, expect["user"])
  29. class TestTextClassificationRepository(TestRepository):
  30. def prepare_data(self, project):
  31. self.example = mommy.make("Example", project=project.item, text="example")
  32. self.category1 = mommy.make("Category", example=self.example, user=project.admin)
  33. self.category2 = mommy.make("Category", example=self.example, user=project.annotator)
  34. def test_list(self):
  35. project = prepare_project(DOCUMENT_CLASSIFICATION)
  36. repository = TextClassificationRepository(project.item)
  37. self.prepare_data(project)
  38. expected = [
  39. {
  40. "data": self.example.text,
  41. "label": [self.category1.label.text],
  42. "user": project.admin.username,
  43. },
  44. {
  45. "data": self.example.text,
  46. "label": [self.category2.label.text],
  47. "user": project.annotator.username,
  48. },
  49. ]
  50. self.assert_records(repository, expected)
  51. def test_list_on_collaborative_annotation(self):
  52. project = prepare_project(DOCUMENT_CLASSIFICATION, collaborative_annotation=True)
  53. repository = TextClassificationRepository(project.item)
  54. self.prepare_data(project)
  55. expected = [
  56. {
  57. "data": self.example.text,
  58. "label": [self.category1.label.text, self.category2.label.text],
  59. "user": "all",
  60. }
  61. ]
  62. self.assert_records(repository, expected)
  63. class TestSeq2seqRepository(TestRepository):
  64. def prepare_data(self, project):
  65. self.example = mommy.make("Example", project=project.item, text="example")
  66. self.text1 = mommy.make("TextLabel", example=self.example, user=project.admin)
  67. self.text2 = mommy.make("TextLabel", example=self.example, user=project.annotator)
  68. def test_list(self):
  69. project = prepare_project(SEQ2SEQ)
  70. repository = Seq2seqRepository(project.item)
  71. self.prepare_data(project)
  72. expected = [
  73. {
  74. "data": self.example.text,
  75. "label": [self.text1.text],
  76. "user": project.admin.username,
  77. },
  78. {
  79. "data": self.example.text,
  80. "label": [self.text2.text],
  81. "user": project.annotator.username,
  82. },
  83. ]
  84. self.assert_records(repository, expected)
  85. def test_list_on_collaborative_annotation(self):
  86. project = prepare_project(SEQ2SEQ, collaborative_annotation=True)
  87. repository = Seq2seqRepository(project.item)
  88. self.prepare_data(project)
  89. expected = [
  90. {
  91. "data": self.example.text,
  92. "label": [self.text1.text, self.text2.text],
  93. "user": "all",
  94. }
  95. ]
  96. self.assert_records(repository, expected)
  97. class TestIntentDetectionSlotFillingRepository(TestRepository):
  98. def prepare_data(self, project):
  99. self.example = mommy.make("Example", project=project.item, text="example")
  100. self.category1 = mommy.make("Category", example=self.example, user=project.admin)
  101. self.category2 = mommy.make("Category", example=self.example, user=project.annotator)
  102. self.span = mommy.make("Span", example=self.example, user=project.admin, start_offset=0, end_offset=1)
  103. def test_list(self):
  104. project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING)
  105. repository = IntentDetectionSlotFillingRepository(project.item)
  106. self.prepare_data(project)
  107. expected = [
  108. {
  109. "data": self.example.text,
  110. "label": {
  111. "cats": [self.category1.label.text],
  112. "entities": [(self.span.start_offset, self.span.end_offset, self.span.label.text)],
  113. },
  114. "user": project.admin.username,
  115. },
  116. {
  117. "data": self.example.text,
  118. "label": {
  119. "cats": [self.category2.label.text],
  120. "entities": [],
  121. },
  122. "user": project.annotator.username,
  123. },
  124. ]
  125. self.assert_records(repository, expected)
  126. def test_list_on_collaborative_annotation(self):
  127. project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING, collaborative_annotation=True)
  128. repository = IntentDetectionSlotFillingRepository(project.item)
  129. self.prepare_data(project)
  130. expected = [
  131. {
  132. "data": self.example.text,
  133. "label": {
  134. "cats": [self.category1.label.text, self.category2.label.text],
  135. "entities": [(self.span.start_offset, self.span.end_offset, self.span.label.text)],
  136. },
  137. "user": "all",
  138. }
  139. ]
  140. self.assert_records(repository, expected)
  141. class TestSequenceLabelingRepository(TestRepository):
  142. def prepare_data(self, project):
  143. self.example = mommy.make("Example", project=project.item, text="example")
  144. self.span1 = mommy.make("Span", example=self.example, user=project.admin, start_offset=0, end_offset=1)
  145. self.span2 = mommy.make("Span", example=self.example, user=project.annotator, start_offset=1, end_offset=2)
  146. def test_list(self):
  147. project = prepare_project(SEQUENCE_LABELING)
  148. repository = SequenceLabelingRepository(project.item)
  149. self.prepare_data(project)
  150. expected = [
  151. {
  152. "data": self.example.text,
  153. "label": [(self.span1.start_offset, self.span1.end_offset, self.span1.label.text)],
  154. "user": project.admin.username,
  155. },
  156. {
  157. "data": self.example.text,
  158. "label": [(self.span2.start_offset, self.span2.end_offset, self.span2.label.text)],
  159. "user": project.annotator.username,
  160. },
  161. ]
  162. self.assert_records(repository, expected)
  163. def test_list_on_collaborative_annotation(self):
  164. project = prepare_project(SEQUENCE_LABELING, collaborative_annotation=True)
  165. repository = SequenceLabelingRepository(project.item)
  166. self.prepare_data(project)
  167. expected = [
  168. {
  169. "data": self.example.text,
  170. "label": [
  171. (self.span1.start_offset, self.span1.end_offset, self.span1.label.text),
  172. (self.span2.start_offset, self.span2.end_offset, self.span2.label.text),
  173. ],
  174. "user": "all",
  175. }
  176. ]
  177. self.assert_records(repository, expected)
  178. class TestRelationExtractionRepository(TestRepository):
  179. def test_list(self):
  180. project = prepare_project(SEQUENCE_LABELING, use_relation=True)
  181. example = mommy.make("Example", project=project.item, text="example")
  182. span1 = mommy.make("Span", example=example, user=project.admin, start_offset=0, end_offset=1)
  183. span2 = mommy.make("Span", example=example, user=project.admin, start_offset=1, end_offset=2)
  184. relation = mommy.make("Relation", from_id=span1, to_id=span2, example=example, user=project.admin)
  185. repository = RelationExtractionRepository(project.item)
  186. expected = [
  187. {
  188. "data": example.text,
  189. "label": {
  190. "entities": [
  191. {
  192. "id": span1.id,
  193. "start_offset": span1.start_offset,
  194. "end_offset": span1.end_offset,
  195. "label": span1.label.text,
  196. },
  197. {
  198. "id": span2.id,
  199. "start_offset": span2.start_offset,
  200. "end_offset": span2.end_offset,
  201. "label": span2.label.text,
  202. },
  203. ],
  204. "relations": [
  205. {"id": relation.id, "from_id": span1.id, "to_id": span2.id, "type": relation.type.text}
  206. ],
  207. },
  208. "user": project.admin.username,
  209. }
  210. ]
  211. self.assert_records(repository, expected)
  212. def test_list_on_collaborative_annotation(self):
  213. project = prepare_project(SEQUENCE_LABELING, collaborative_annotation=True, use_relation=True)
  214. example = mommy.make("Example", project=project.item, text="example")
  215. span1 = mommy.make("Span", example=example, user=project.admin, start_offset=0, end_offset=1)
  216. span2 = mommy.make("Span", example=example, user=project.annotator, start_offset=1, end_offset=2)
  217. relation = mommy.make("Relation", from_id=span1, to_id=span2, example=example, user=project.admin)
  218. repository = RelationExtractionRepository(project.item)
  219. expected = [
  220. {
  221. "data": example.text,
  222. "label": {
  223. "entities": [
  224. {
  225. "id": span1.id,
  226. "start_offset": span1.start_offset,
  227. "end_offset": span1.end_offset,
  228. "label": span1.label.text,
  229. },
  230. {
  231. "id": span2.id,
  232. "start_offset": span2.start_offset,
  233. "end_offset": span2.end_offset,
  234. "label": span2.label.text,
  235. },
  236. ],
  237. "relations": [
  238. {"id": relation.id, "from_id": span1.id, "to_id": span2.id, "type": relation.type.text}
  239. ],
  240. },
  241. "user": "all",
  242. }
  243. ]
  244. self.assert_records(repository, expected)
  245. class TestSpeech2TextRepository(TestRepository):
  246. def prepare_data(self, project):
  247. self.example = mommy.make("Example", project=project.item, text="example")
  248. self.text1 = mommy.make("TextLabel", example=self.example, user=project.admin)
  249. self.text2 = mommy.make("TextLabel", example=self.example, user=project.annotator)
  250. def test_list(self):
  251. project = prepare_project(SPEECH2TEXT)
  252. repository = Speech2TextRepository(project.item)
  253. self.prepare_data(project)
  254. expected = [
  255. {
  256. "data": self.example.upload_name,
  257. "label": [self.text1.text],
  258. "user": project.admin.username,
  259. },
  260. {
  261. "data": self.example.upload_name,
  262. "label": [self.text2.text],
  263. "user": project.annotator.username,
  264. },
  265. ]
  266. self.assert_records(repository, expected)
  267. def test_list_on_collaborative_annotation(self):
  268. project = prepare_project(SPEECH2TEXT, collaborative_annotation=True)
  269. repository = Speech2TextRepository(project.item)
  270. self.prepare_data(project)
  271. expected = [
  272. {
  273. "data": self.example.upload_name,
  274. "label": [self.text1.text, self.text2.text],
  275. "user": "all",
  276. }
  277. ]
  278. self.assert_records(repository, expected)
  279. class TestFileRepository(TestRepository):
  280. def prepare_data(self, project):
  281. self.example = mommy.make("Example", project=project.item, text="example")
  282. self.category1 = mommy.make("Category", example=self.example, user=project.admin)
  283. self.category2 = mommy.make("Category", example=self.example, user=project.annotator)
  284. def test_list(self):
  285. project = prepare_project(IMAGE_CLASSIFICATION)
  286. repository = FileRepository(project.item)
  287. self.prepare_data(project)
  288. expected = [
  289. {
  290. "data": self.example.upload_name,
  291. "label": [self.category1.label.text],
  292. "user": project.admin.username,
  293. },
  294. {
  295. "data": self.example.upload_name,
  296. "label": [self.category2.label.text],
  297. "user": project.annotator.username,
  298. },
  299. ]
  300. self.assert_records(repository, expected)
  301. def test_list_on_collaborative_annotation(self):
  302. project = prepare_project(IMAGE_CLASSIFICATION, collaborative_annotation=True)
  303. repository = FileRepository(project.item)
  304. self.prepare_data(project)
  305. expected = [
  306. {
  307. "data": self.example.upload_name,
  308. "label": [self.category1.label.text, self.category2.label.text],
  309. "user": "all",
  310. }
  311. ]
  312. self.assert_records(repository, expected)