|
|
@ -2,8 +2,11 @@ import unittest |
|
|
|
|
|
|
|
from model_mommy import mommy |
|
|
|
|
|
|
|
from ..pipeline.repositories import IntentDetectionSlotFillingRepository |
|
|
|
from projects.models import INTENT_DETECTION_AND_SLOT_FILLING |
|
|
|
from ..pipeline.repositories import ( |
|
|
|
IntentDetectionSlotFillingRepository, |
|
|
|
RelationExtractionRepository, |
|
|
|
) |
|
|
|
from projects.models import INTENT_DETECTION_AND_SLOT_FILLING, SEQUENCE_LABELING |
|
|
|
from projects.tests.utils import prepare_project |
|
|
|
|
|
|
|
|
|
|
@ -31,3 +34,41 @@ class TestCSVWriter(unittest.TestCase): |
|
|
|
self.assertEqual(record.data, expect["data"]) |
|
|
|
self.assertEqual(record.label["cats"], expect["label"]["cats"]) |
|
|
|
self.assertEqual(record.label["entities"], expect["label"]["entities"]) |
|
|
|
|
|
|
|
|
|
|
|
class TestRelationExtractionRepository(unittest.TestCase): |
|
|
|
def setUp(self): |
|
|
|
self.project = prepare_project(SEQUENCE_LABELING, use_relation=True) |
|
|
|
|
|
|
|
def test_label_per_user(self): |
|
|
|
from_entity = mommy.make("Span", start_offset=0, end_offset=1, user=self.project.admin) |
|
|
|
to_entity = mommy.make( |
|
|
|
"Span", start_offset=1, end_offset=2, example=from_entity.example, user=self.project.admin |
|
|
|
) |
|
|
|
relation = mommy.make( |
|
|
|
"Relation", from_id=from_entity, to_id=to_entity, example=from_entity.example, user=self.project.admin |
|
|
|
) |
|
|
|
repository = RelationExtractionRepository(self.project.item) |
|
|
|
expected = { |
|
|
|
"admin": { |
|
|
|
"entities": [ |
|
|
|
{ |
|
|
|
"id": from_entity.id, |
|
|
|
"start_offset": from_entity.start_offset, |
|
|
|
"end_offset": from_entity.end_offset, |
|
|
|
"label": from_entity.label.text, |
|
|
|
}, |
|
|
|
{ |
|
|
|
"id": to_entity.id, |
|
|
|
"start_offset": to_entity.start_offset, |
|
|
|
"end_offset": to_entity.end_offset, |
|
|
|
"label": to_entity.label.text, |
|
|
|
}, |
|
|
|
], |
|
|
|
"relations": [ |
|
|
|
{"id": relation.id, "from_id": from_entity.id, "to_id": to_entity.id, "type": relation.type.text} |
|
|
|
], |
|
|
|
} |
|
|
|
} |
|
|
|
actual = repository.label_per_user(from_entity.example) |
|
|
|
self.assertDictEqual(actual, expected) |