You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
3.0 KiB

2 years ago
2 years ago
2 years ago
  1. import unittest
  2. from model_mommy import mommy
  3. from ..pipeline.repositories import (
  4. IntentDetectionSlotFillingRepository,
  5. RelationExtractionRepository,
  6. )
  7. from projects.models import INTENT_DETECTION_AND_SLOT_FILLING, SEQUENCE_LABELING
  8. from projects.tests.utils import prepare_project
  9. class TestCSVWriter(unittest.TestCase):
  10. def setUp(self):
  11. self.project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING)
  12. def test_list(self):
  13. example = mommy.make("Example", project=self.project.item, text="example")
  14. category = mommy.make("Category", example=example, user=self.project.admin)
  15. span = mommy.make("Span", example=example, user=self.project.admin, start_offset=0, end_offset=1)
  16. repository = IntentDetectionSlotFillingRepository(self.project.item)
  17. expected = [
  18. {
  19. "data": example.text,
  20. "label": {
  21. "cats": [category.label.text],
  22. "entities": [(span.start_offset, span.end_offset, span.label.text)],
  23. },
  24. }
  25. ]
  26. records = list(repository.list())
  27. self.assertEqual(len(records), len(expected))
  28. for record, expect in zip(records, expected):
  29. self.assertEqual(record.data, expect["data"])
  30. self.assertEqual(record.label["cats"], expect["label"]["cats"])
  31. self.assertEqual(record.label["entities"], expect["label"]["entities"])
  32. class TestRelationExtractionRepository(unittest.TestCase):
  33. def setUp(self):
  34. self.project = prepare_project(SEQUENCE_LABELING, use_relation=True)
  35. def test_label_per_user(self):
  36. from_entity = mommy.make("Span", start_offset=0, end_offset=1, user=self.project.admin)
  37. to_entity = mommy.make(
  38. "Span", start_offset=1, end_offset=2, example=from_entity.example, user=self.project.admin
  39. )
  40. relation = mommy.make(
  41. "Relation", from_id=from_entity, to_id=to_entity, example=from_entity.example, user=self.project.admin
  42. )
  43. repository = RelationExtractionRepository(self.project.item)
  44. expected = {
  45. "admin": {
  46. "entities": [
  47. {
  48. "id": from_entity.id,
  49. "start_offset": from_entity.start_offset,
  50. "end_offset": from_entity.end_offset,
  51. "label": from_entity.label.text,
  52. },
  53. {
  54. "id": to_entity.id,
  55. "start_offset": to_entity.start_offset,
  56. "end_offset": to_entity.end_offset,
  57. "label": to_entity.label.text,
  58. },
  59. ],
  60. "relations": [
  61. {"id": relation.id, "from_id": from_entity.id, "to_id": to_entity.id, "type": relation.type.text}
  62. ],
  63. }
  64. }
  65. actual = repository.label_per_user(from_entity.example)
  66. self.assertDictEqual(actual, expected)