From 8c899abb55dfd082686453f4299a35ab66aef498 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Wed, 18 May 2022 09:55:10 +0900 Subject: [PATCH] Avoid removing overlapping spans with another example uuids --- backend/data_import/pipeline/labels.py | 14 ++++++++------ backend/data_import/tests/test_labels.py | 19 ++++++++++++++++++- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/backend/data_import/pipeline/labels.py b/backend/data_import/pipeline/labels.py index 6bac0446..5824c436 100644 --- a/backend/data_import/pipeline/labels.py +++ b/backend/data_import/pipeline/labels.py @@ -67,13 +67,15 @@ class Spans(Labels): allow_overlapping = getattr(project, "allow_overlapping", False) if allow_overlapping: return - self.labels.sort() - last_offset = -1 spans = [] - for label in self.labels: - if getattr(label, "start_offset") >= last_offset: - last_offset = getattr(label, "end_offset") - spans.append(label) + groups = groupby(self.labels, lambda label: label.example_uuid) + for _, group in groups: + labels = sorted(group) + last_offset = -1 + for label in labels: + if getattr(label, "start_offset") >= last_offset: + last_offset = getattr(label, "end_offset") + spans.append(label) self.labels = spans @property diff --git a/backend/data_import/tests/test_labels.py b/backend/data_import/tests/test_labels.py index 4c71c71b..1f9c82ca 100644 --- a/backend/data_import/tests/test_labels.py +++ b/backend/data_import/tests/test_labels.py @@ -67,9 +67,12 @@ class TestSpans(TestCase): mommy.make("Example", project=self.project.item, uuid=example_uuid) self.spans = Spans(labels, self.types) - def test_clean(self): + def disable_overlapping(self): self.project.item.allow_overlapping = False self.project.item.save() + + def test_clean(self): + self.disable_overlapping() self.spans.clean(self.project.item) self.assertEqual(len(self.spans), 2) @@ -77,6 +80,20 @@ class TestSpans(TestCase): self.spans.clean(self.project.item) self.assertEqual(len(self.spans), 3) + def test_clean_with_multiple_examples(self): + self.disable_overlapping() + example_uuid1 = uuid.uuid4() + example_uuid2 = uuid.uuid4() + labels = [ + SpanLabel(example_uuid=example_uuid1, label="A", start_offset=0, end_offset=1), + SpanLabel(example_uuid=example_uuid2, label="B", start_offset=0, end_offset=3), + ] + mommy.make("Example", project=self.project.item, uuid=example_uuid1) + mommy.make("Example", project=self.project.item, uuid=example_uuid2) + spans = Spans(labels, self.types) + spans.clean(self.project.item) + self.assertEqual(len(spans), 2) + def test_save(self): self.spans.save_types(self.project.item) self.spans.save(self.user)