diff --git a/backend/data_import/pipeline/labeled_examples.py b/backend/data_import/pipeline/labeled_examples.py index ac53755d..a6257072 100644 --- a/backend/data_import/pipeline/labeled_examples.py +++ b/backend/data_import/pipeline/labeled_examples.py @@ -116,15 +116,16 @@ class RelationExamples(LabeledExamples): [data.create_label(user, example, mapping, SpanLabel) for data, example in zip(self.records, examples)] ) ) + uuids = [label.uuid for label in labels] Span.objects.bulk_create(labels) # filter spans by uuid original_spans = list( itertools.chain.from_iterable(example.select_label(SpanLabel) for example in self.records) ) - spans = Span.objects.filter(uuid__in=[span.uuid for span in original_spans]) + uuid_to_span = {span.uuid: span for span in Span.objects.filter(uuid__in=uuids)} # create mapping from id to span # this is needed to create the relation - span_mapping = {original_span.id: saved_span for saved_span, original_span in zip(spans, original_spans)} + span_mapping = {span.id: uuid_to_span[span.uuid] for span in original_spans} # then, replace from_id and to_id with the span id relations = itertools.chain.from_iterable( [ diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index be1275a4..cecb10ad 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -224,6 +224,7 @@ class TestImportRelationExtractionData(TestImportData): spans = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()] self.assertEqual(spans, expected_spans) self.assertEqual(example.relations.count(), 3) + print(example.relations.all()) def assert_parse_error(self, response): self.assertGreaterEqual(len(response["error"]), 1) diff --git a/backend/labels/models.py b/backend/labels/models.py index c07de8ff..e5336d56 100644 --- a/backend/labels/models.py +++ b/backend/labels/models.py @@ -45,6 +45,10 @@ class Span(Label): start_offset = models.IntegerField() end_offset = models.IntegerField() + def __str__(self): + text = self.example.text[self.start_offset:self.end_offset] + return f"({text}, {self.start_offset}, {self.end_offset}, {self.label.text})" + def validate_unique(self, exclude=None): allow_overlapping = getattr(self.example.project, "allow_overlapping", False) is_collaborative = self.example.project.collaborative_annotation @@ -107,7 +111,11 @@ class Relation(Label): example = models.ForeignKey(to=Example, on_delete=models.CASCADE, related_name="relations") def __str__(self): - return self.__dict__.__str__() + text = self.example.text + from_span = text[self.from_id.start_offset: self.from_id.end_offset] + to_span = text[self.to_id.start_offset: self.to_id.end_offset] + type_text = self.type.text + return f"{from_span} - ({type_text}) -> {to_span}" def save(self, force_insert=False, force_update=False, using=None, update_fields=None): self.full_clean()