From 875c1fa0a736afe684a9f6c4b78e06254aa73ded Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 15 Dec 2022 16:02:11 +0900 Subject: [PATCH] Update the way to import relational dataset --- backend/data_import/pipeline/label.py | 4 ++-- backend/data_import/pipeline/labels.py | 10 +++++----- backend/data_import/tests/test_label.py | 8 ++++---- backend/data_import/tests/test_labels.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/backend/data_import/pipeline/label.py b/backend/data_import/pipeline/label.py index bdff2d43..501a96e4 100644 --- a/backend/data_import/pipeline/label.py +++ b/backend/data_import/pipeline/label.py @@ -142,6 +142,6 @@ class RelationLabel(Label): user=user, example=example, type=types[self.type], - from_id=kwargs["id_to_span"][self.from_id], - to_id=kwargs["id_to_span"][self.to_id], + from_id=kwargs["id_to_span"][(self.from_id, str(self.example_uuid))], + to_id=kwargs["id_to_span"][(self.to_id, str(self.example_uuid))], ) diff --git a/backend/data_import/pipeline/labels.py b/backend/data_import/pipeline/labels.py index 97045526..524bf36f 100644 --- a/backend/data_import/pipeline/labels.py +++ b/backend/data_import/pipeline/labels.py @@ -1,6 +1,6 @@ import abc from itertools import groupby -from typing import Dict, List +from typing import Dict, List, Tuple from .examples import Examples from .label import Label @@ -70,11 +70,11 @@ class Spans(Labels): self.labels = spans @property - def id_to_span(self) -> Dict[int, SpanModel]: - span_uuids = [str(label.uuid) for label in self.labels] - spans = SpanModel.objects.filter(uuid__in=span_uuids) + def id_to_span(self) -> Dict[Tuple[int, str], SpanModel]: + uuids = [str(span.uuid) for span in self.labels] + spans = SpanModel.objects.filter(uuid__in=uuids) uuid_to_span = {span.uuid: span for span in spans} - return {span.id: uuid_to_span[span.uuid] for span in self.labels} + return {(span.id, str(span.example_uuid)): uuid_to_span[span.uuid] for span in self.labels} class Texts(Labels): diff --git a/backend/data_import/tests/test_label.py b/backend/data_import/tests/test_label.py index da79ca49..362c2230 100644 --- a/backend/data_import/tests/test_label.py +++ b/backend/data_import/tests/test_label.py @@ -25,7 +25,7 @@ class TestLabel(TestCase): def setUp(self): self.project = prepare_project(self.task) self.user = self.project.admin - self.example = mommy.make("Example", project=self.project.item) + self.example = mommy.make("Example", project=self.project.item, text="hello world") class TestCategoryLabel(TestLabel): @@ -166,12 +166,12 @@ class TestRelationLabel(TestLabel): self.assertEqual(relation_type.text, "A") def test_create(self): - relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) + relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=self.example.uuid) types = MagicMock() types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item) id_to_span = { - 0: mommy.make(SpanModel, start_offset=0, end_offset=1), - 1: mommy.make(SpanModel, start_offset=2, end_offset=3), + (0, str(self.example.uuid)): mommy.make(SpanModel, start_offset=0, end_offset=1, example=self.example), + (1, str(self.example.uuid)): mommy.make(SpanModel, start_offset=2, end_offset=3, example=self.example), } relation_model = relation.create(self.user, self.example, types, id_to_span=id_to_span) self.assertIsInstance(relation_model, RelationModel) diff --git a/backend/data_import/tests/test_labels.py b/backend/data_import/tests/test_labels.py index 6b3a27ea..3ed395ea 100644 --- a/backend/data_import/tests/test_labels.py +++ b/backend/data_import/tests/test_labels.py @@ -146,7 +146,7 @@ class TestRelations(TestCase): self.project = prepare_project(SEQUENCE_LABELING, use_relation=True) self.user = self.project.admin example_uuid = uuid.uuid4() - example = mommy.make("Example", project=self.project.item, uuid=example_uuid) + example = mommy.make("Example", project=self.project.item, uuid=example_uuid, text="hello world") from_span = mommy.make("Span", example=example, start_offset=0, end_offset=1) to_span = mommy.make("Span", example=example, start_offset=2, end_offset=3) labels = [ @@ -154,7 +154,7 @@ class TestRelations(TestCase): ] self.relations = Relations(labels, self.types) self.spans = MagicMock() - self.spans.id_to_span = {from_span.id: from_span, to_span.id: to_span} + self.spans.id_to_span = {(from_span.id, str(example_uuid)): from_span, (to_span.id, str(example_uuid)): to_span} self.examples = MagicMock() self.examples.__getitem__.return_value = example self.examples.__contains__.return_value = True