Browse Source

Add RelationExtractionRepository

pull/1703/head
Hironsan 3 years ago
parent
commit
eaad9e41e7
2 changed files with 79 additions and 2 deletions
  1. 36
      backend/data_export/pipeline/repositories.py
  2. 45
      backend/data_export/tests/test_repositories.py

36
backend/data_export/pipeline/repositories.py

@ -126,6 +126,42 @@ class SequenceLabelingRepository(TextRepository):
return label_per_user
class RelationExtractionRepository(TextRepository):
@property
def docs(self):
return Example.objects.filter(project=self.project).prefetch_related(
"spans__user", "spans__label", "relations__user", "relations__type"
)
def label_per_user(self, doc) -> Dict:
relation_per_user: Dict = defaultdict(list)
span_per_user: Dict = defaultdict(list)
label_per_user: Dict = defaultdict(dict)
for relation in doc.relations.all():
relation_per_user[relation.user.username].append(
{
"id": relation.id,
"from_id": relation.from_id.id,
"to_id": relation.to_id.id,
"type": relation.type.text,
}
)
for span in doc.spans.all():
span_per_user[span.user.username].append(
{
"id": span.id,
"start_offset": span.start_offset,
"end_offset": span.end_offset,
"label": span.label.text,
}
)
for user, relations in relation_per_user.items():
label_per_user[user]["relations"] = relations
for user, span in span_per_user.items():
label_per_user[user]["entities"] = span
return label_per_user
class Seq2seqRepository(TextRepository):
@property
def docs(self):

45
backend/data_export/tests/test_repositories.py

@ -2,8 +2,11 @@ import unittest
from model_mommy import mommy
from ..pipeline.repositories import IntentDetectionSlotFillingRepository
from projects.models import INTENT_DETECTION_AND_SLOT_FILLING
from ..pipeline.repositories import (
IntentDetectionSlotFillingRepository,
RelationExtractionRepository,
)
from projects.models import INTENT_DETECTION_AND_SLOT_FILLING, SEQUENCE_LABELING
from projects.tests.utils import prepare_project
@ -31,3 +34,41 @@ class TestCSVWriter(unittest.TestCase):
self.assertEqual(record.data, expect["data"])
self.assertEqual(record.label["cats"], expect["label"]["cats"])
self.assertEqual(record.label["entities"], expect["label"]["entities"])
class TestRelationExtractionRepository(unittest.TestCase):
def setUp(self):
self.project = prepare_project(SEQUENCE_LABELING, use_relation=True)
def test_label_per_user(self):
from_entity = mommy.make("Span", start_offset=0, end_offset=1, user=self.project.admin)
to_entity = mommy.make(
"Span", start_offset=1, end_offset=2, example=from_entity.example, user=self.project.admin
)
relation = mommy.make(
"Relation", from_id=from_entity, to_id=to_entity, example=from_entity.example, user=self.project.admin
)
repository = RelationExtractionRepository(self.project.item)
expected = {
"admin": {
"entities": [
{
"id": from_entity.id,
"start_offset": from_entity.start_offset,
"end_offset": from_entity.end_offset,
"label": from_entity.label.text,
},
{
"id": to_entity.id,
"start_offset": to_entity.start_offset,
"end_offset": to_entity.end_offset,
"label": to_entity.label.text,
},
],
"relations": [
{"id": relation.id, "from_id": from_entity.id, "to_id": to_entity.id, "type": relation.type.text}
],
}
}
actual = repository.label_per_user(from_entity.example)
self.assertDictEqual(actual, expected)
Loading…
Cancel
Save