Browse Source

Merge pull request #2099 from doccano/fix/2039

Update the way to import relational dataset
pull/2112/head
Hiroki Nakayama 1 year ago
committed by GitHub
parent
commit
32319fe91f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 13 deletions
  1. 4
      backend/data_import/pipeline/label.py
  2. 10
      backend/data_import/pipeline/labels.py
  3. 8
      backend/data_import/tests/test_label.py
  4. 4
      backend/data_import/tests/test_labels.py

4
backend/data_import/pipeline/label.py

@ -142,6 +142,6 @@ class RelationLabel(Label):
user=user, user=user,
example=example, example=example,
type=types[self.type], type=types[self.type],
from_id=kwargs["id_to_span"][self.from_id],
to_id=kwargs["id_to_span"][self.to_id],
from_id=kwargs["id_to_span"][(self.from_id, str(self.example_uuid))],
to_id=kwargs["id_to_span"][(self.to_id, str(self.example_uuid))],
) )

10
backend/data_import/pipeline/labels.py

@ -1,6 +1,6 @@
import abc import abc
from itertools import groupby from itertools import groupby
from typing import Dict, List
from typing import Dict, List, Tuple
from .examples import Examples from .examples import Examples
from .label import Label from .label import Label
@ -70,11 +70,11 @@ class Spans(Labels):
self.labels = spans self.labels = spans
@property @property
def id_to_span(self) -> Dict[int, SpanModel]:
span_uuids = [str(label.uuid) for label in self.labels]
spans = SpanModel.objects.filter(uuid__in=span_uuids)
def id_to_span(self) -> Dict[Tuple[int, str], SpanModel]:
uuids = [str(span.uuid) for span in self.labels]
spans = SpanModel.objects.filter(uuid__in=uuids)
uuid_to_span = {span.uuid: span for span in spans} uuid_to_span = {span.uuid: span for span in spans}
return {span.id: uuid_to_span[span.uuid] for span in self.labels}
return {(span.id, str(span.example_uuid)): uuid_to_span[span.uuid] for span in self.labels}
class Texts(Labels): class Texts(Labels):

8
backend/data_import/tests/test_label.py

@ -25,7 +25,7 @@ class TestLabel(TestCase):
def setUp(self): def setUp(self):
self.project = prepare_project(self.task) self.project = prepare_project(self.task)
self.user = self.project.admin self.user = self.project.admin
self.example = mommy.make("Example", project=self.project.item)
self.example = mommy.make("Example", project=self.project.item, text="hello world")
class TestCategoryLabel(TestLabel): class TestCategoryLabel(TestLabel):
@ -166,12 +166,12 @@ class TestRelationLabel(TestLabel):
self.assertEqual(relation_type.text, "A") self.assertEqual(relation_type.text, "A")
def test_create(self): def test_create(self):
relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4())
relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=self.example.uuid)
types = MagicMock() types = MagicMock()
types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item) types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item)
id_to_span = { id_to_span = {
0: mommy.make(SpanModel, start_offset=0, end_offset=1),
1: mommy.make(SpanModel, start_offset=2, end_offset=3),
(0, str(self.example.uuid)): mommy.make(SpanModel, start_offset=0, end_offset=1, example=self.example),
(1, str(self.example.uuid)): mommy.make(SpanModel, start_offset=2, end_offset=3, example=self.example),
} }
relation_model = relation.create(self.user, self.example, types, id_to_span=id_to_span) relation_model = relation.create(self.user, self.example, types, id_to_span=id_to_span)
self.assertIsInstance(relation_model, RelationModel) self.assertIsInstance(relation_model, RelationModel)

4
backend/data_import/tests/test_labels.py

@ -146,7 +146,7 @@ class TestRelations(TestCase):
self.project = prepare_project(SEQUENCE_LABELING, use_relation=True) self.project = prepare_project(SEQUENCE_LABELING, use_relation=True)
self.user = self.project.admin self.user = self.project.admin
example_uuid = uuid.uuid4() example_uuid = uuid.uuid4()
example = mommy.make("Example", project=self.project.item, uuid=example_uuid)
example = mommy.make("Example", project=self.project.item, uuid=example_uuid, text="hello world")
from_span = mommy.make("Span", example=example, start_offset=0, end_offset=1) from_span = mommy.make("Span", example=example, start_offset=0, end_offset=1)
to_span = mommy.make("Span", example=example, start_offset=2, end_offset=3) to_span = mommy.make("Span", example=example, start_offset=2, end_offset=3)
labels = [ labels = [
@ -154,7 +154,7 @@ class TestRelations(TestCase):
] ]
self.relations = Relations(labels, self.types) self.relations = Relations(labels, self.types)
self.spans = MagicMock() self.spans = MagicMock()
self.spans.id_to_span = {from_span.id: from_span, to_span.id: to_span}
self.spans.id_to_span = {(from_span.id, str(example_uuid)): from_span, (to_span.id, str(example_uuid)): to_span}
self.examples = MagicMock() self.examples = MagicMock()
self.examples.__getitem__.return_value = example self.examples.__getitem__.return_value = example
self.examples.__contains__.return_value = True self.examples.__contains__.return_value = True

Loading…
Cancel
Save