From eaad9e41e738712c3af3400d5a08fe94d14dcda5 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Fri, 25 Feb 2022 11:01:00 +0900 Subject: [PATCH] Add RelationExtractionRepository --- backend/data_export/pipeline/repositories.py | 36 +++++++++++++++ .../data_export/tests/test_repositories.py | 45 ++++++++++++++++++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/backend/data_export/pipeline/repositories.py b/backend/data_export/pipeline/repositories.py index 10174061..227e82f4 100644 --- a/backend/data_export/pipeline/repositories.py +++ b/backend/data_export/pipeline/repositories.py @@ -126,6 +126,42 @@ class SequenceLabelingRepository(TextRepository): return label_per_user +class RelationExtractionRepository(TextRepository): + @property + def docs(self): + return Example.objects.filter(project=self.project).prefetch_related( + "spans__user", "spans__label", "relations__user", "relations__type" + ) + + def label_per_user(self, doc) -> Dict: + relation_per_user: Dict = defaultdict(list) + span_per_user: Dict = defaultdict(list) + label_per_user: Dict = defaultdict(dict) + for relation in doc.relations.all(): + relation_per_user[relation.user.username].append( + { + "id": relation.id, + "from_id": relation.from_id.id, + "to_id": relation.to_id.id, + "type": relation.type.text, + } + ) + for span in doc.spans.all(): + span_per_user[span.user.username].append( + { + "id": span.id, + "start_offset": span.start_offset, + "end_offset": span.end_offset, + "label": span.label.text, + } + ) + for user, relations in relation_per_user.items(): + label_per_user[user]["relations"] = relations + for user, span in span_per_user.items(): + label_per_user[user]["entities"] = span + return label_per_user + + class Seq2seqRepository(TextRepository): @property def docs(self): diff --git a/backend/data_export/tests/test_repositories.py b/backend/data_export/tests/test_repositories.py index fa32b7ec..d7654ecc 100644 --- a/backend/data_export/tests/test_repositories.py +++ b/backend/data_export/tests/test_repositories.py @@ -2,8 +2,11 @@ import unittest from model_mommy import mommy -from ..pipeline.repositories import IntentDetectionSlotFillingRepository -from projects.models import INTENT_DETECTION_AND_SLOT_FILLING +from ..pipeline.repositories import ( + IntentDetectionSlotFillingRepository, + RelationExtractionRepository, +) +from projects.models import INTENT_DETECTION_AND_SLOT_FILLING, SEQUENCE_LABELING from projects.tests.utils import prepare_project @@ -31,3 +34,41 @@ class TestCSVWriter(unittest.TestCase): self.assertEqual(record.data, expect["data"]) self.assertEqual(record.label["cats"], expect["label"]["cats"]) self.assertEqual(record.label["entities"], expect["label"]["entities"]) + + +class TestRelationExtractionRepository(unittest.TestCase): + def setUp(self): + self.project = prepare_project(SEQUENCE_LABELING, use_relation=True) + + def test_label_per_user(self): + from_entity = mommy.make("Span", start_offset=0, end_offset=1, user=self.project.admin) + to_entity = mommy.make( + "Span", start_offset=1, end_offset=2, example=from_entity.example, user=self.project.admin + ) + relation = mommy.make( + "Relation", from_id=from_entity, to_id=to_entity, example=from_entity.example, user=self.project.admin + ) + repository = RelationExtractionRepository(self.project.item) + expected = { + "admin": { + "entities": [ + { + "id": from_entity.id, + "start_offset": from_entity.start_offset, + "end_offset": from_entity.end_offset, + "label": from_entity.label.text, + }, + { + "id": to_entity.id, + "start_offset": to_entity.start_offset, + "end_offset": to_entity.end_offset, + "label": to_entity.label.text, + }, + ], + "relations": [ + {"id": relation.id, "from_id": from_entity.id, "to_id": to_entity.id, "type": relation.type.text} + ], + } + } + actual = repository.label_per_user(from_entity.example) + self.assertDictEqual(actual, expected)