Browse Source

Add reduce_user to RelationExtractionRepository

pull/1764/head
Hironsan 3 years ago
parent
commit
5d39a7004e
2 changed files with 169 additions and 67 deletions
  1. 32
      backend/data_export/pipeline/repositories.py
  2. 204
      backend/data_export/tests/test_repositories.py

32
backend/data_export/pipeline/repositories.py

@ -1,7 +1,7 @@
import abc
import itertools
from collections import defaultdict
from typing import Dict, Iterator, List, Tuple, Union
from typing import Any, Dict, Iterator, List, Tuple, Union
from .data import Record
from examples.models import Example
@ -54,7 +54,7 @@ class FileRepository(BaseRepository):
label_per_user[a.user.username].append(a.label.text)
return label_per_user
def reduce_user(self, label_per_user: Dict[str, List]):
def reduce_user(self, label_per_user: Dict[str, Any]):
value = list(itertools.chain(*label_per_user.values()))
return {"all": value}
@ -80,11 +80,7 @@ class TextRepository(BaseRepository):
for doc in docs:
label_per_user = self.label_per_user(doc)
if self.project.collaborative_annotation:
if getattr(self.project, "use_relation", False):
value_type = "dict"
else:
value_type = "list"
label_per_user = self.reduce_user(label_per_user, value_type)
label_per_user = self.reduce_user(label_per_user)
for user, label in label_per_user.items():
yield Record(data_id=doc.id, data=doc.text, label=label, user=user, metadata=doc.meta)
# todo:
@ -100,17 +96,9 @@ class TextRepository(BaseRepository):
def label_per_user(self, doc) -> Dict:
raise NotImplementedError()
def reduce_user(self, label_per_user: Dict, value_type):
if value_type == "list":
value_list = list(itertools.chain(*label_per_user))
return {"all": value_list}
if value_type == "dict":
value_dict = dict(
(label_type, label_per_user[user][label_type])
for user in label_per_user
for label_type in label_per_user[user]
)
return {"all": value_dict}
def reduce_user(self, label_per_user: Dict[str, Any]):
value = list(itertools.chain(*label_per_user.values()))
return {"all": value}
class TextClassificationRepository(TextRepository):
@ -173,6 +161,14 @@ class RelationExtractionRepository(TextRepository):
label_per_user[user]["entities"] = span
return label_per_user
def reduce_user(self, label_per_user: Dict[str, Any]):
entities = []
relations = []
for user, label in label_per_user.items():
entities.extend(label.get("entities", []))
relations.extend(label.get("relations", []))
return {"all": {"entities": entities, "relations": relations}}
class Seq2seqRepository(TextRepository):
@property

204
backend/data_export/tests/test_repositories.py

@ -5,70 +5,176 @@ from model_mommy import mommy
from ..pipeline.repositories import (
IntentDetectionSlotFillingRepository,
RelationExtractionRepository,
SequenceLabelingRepository,
)
from projects.models import INTENT_DETECTION_AND_SLOT_FILLING, SEQUENCE_LABELING
from projects.tests.utils import prepare_project
class TestCSVWriter(unittest.TestCase):
def setUp(self):
self.project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING)
class TestRepository(unittest.TestCase):
def assert_records(self, repository, expected):
records = list(repository.list())
self.assertEqual(len(records), len(expected))
for record, expect in zip(records, expected):
self.assertEqual(record.data, expect["data"])
self.assertEqual(record.label, expect["label"])
self.assertEqual(record.user, expect["user"])
class TestIntentDetectionSlotFillingRepository(TestRepository):
def prepare_data(self, project):
self.example = mommy.make("Example", project=project.item, text="example")
self.category1 = mommy.make("Category", example=self.example, user=project.admin)
self.category2 = mommy.make("Category", example=self.example, user=project.annotator)
self.span = mommy.make("Span", example=self.example, user=project.admin, start_offset=0, end_offset=1)
def test_list(self):
example = mommy.make("Example", project=self.project.item, text="example")
category = mommy.make("Category", example=example, user=self.project.admin)
span = mommy.make("Span", example=example, user=self.project.admin, start_offset=0, end_offset=1)
repository = IntentDetectionSlotFillingRepository(self.project.item)
project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING)
repository = IntentDetectionSlotFillingRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": example.text,
"data": self.example.text,
"label": {
"cats": [self.category1.label.text],
"entities": [(self.span.start_offset, self.span.end_offset, self.span.label.text)],
},
"user": project.admin.username,
},
{
"data": self.example.text,
"label": {
"cats": [self.category2.label.text],
"entities": [],
},
"user": project.annotator.username,
},
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING, collaborative_annotation=True)
repository = IntentDetectionSlotFillingRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": {
"cats": [category.label.text],
"entities": [(span.start_offset, span.end_offset, span.label.text)],
"cats": [self.category1.label.text, self.category2.label.text],
"entities": [(self.span.start_offset, self.span.end_offset, self.span.label.text)],
},
"user": "all",
}
]
records = list(repository.list())
self.assertEqual(len(records), len(expected))
for record, expect in zip(records, expected):
self.assertEqual(record.data, expect["data"])
self.assertEqual(record.label["cats"], expect["label"]["cats"])
self.assertEqual(record.label["entities"], expect["label"]["entities"])
self.assert_records(repository, expected)
class TestRelationExtractionRepository(unittest.TestCase):
def setUp(self):
self.project = prepare_project(SEQUENCE_LABELING, use_relation=True)
class TestSequenceLabelingRepository(TestRepository):
def prepare_data(self, project):
self.example = mommy.make("Example", project=project.item, text="example")
self.span1 = mommy.make("Span", example=self.example, user=project.admin, start_offset=0, end_offset=1)
self.span2 = mommy.make("Span", example=self.example, user=project.annotator, start_offset=1, end_offset=2)
def test_label_per_user(self):
from_entity = mommy.make("Span", start_offset=0, end_offset=1, user=self.project.admin)
to_entity = mommy.make(
"Span", start_offset=1, end_offset=2, example=from_entity.example, user=self.project.admin
)
relation = mommy.make(
"Relation", from_id=from_entity, to_id=to_entity, example=from_entity.example, user=self.project.admin
)
repository = RelationExtractionRepository(self.project.item)
expected = {
"admin": {
"entities": [
{
"id": from_entity.id,
"start_offset": from_entity.start_offset,
"end_offset": from_entity.end_offset,
"label": from_entity.label.text,
},
{
"id": to_entity.id,
"start_offset": to_entity.start_offset,
"end_offset": to_entity.end_offset,
"label": to_entity.label.text,
},
],
"relations": [
{"id": relation.id, "from_id": from_entity.id, "to_id": to_entity.id, "type": relation.type.text}
def test_list(self):
project = prepare_project(SEQUENCE_LABELING)
repository = SequenceLabelingRepository(project)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": [(self.span1.start_offset, self.span1.end_offset, self.span1.label.text)],
"user": project.admin.username,
},
{
"data": self.example.text,
"label": [(self.span2.start_offset, self.span2.end_offset, self.span2.label.text)],
"user": project.annotator.username,
},
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(SEQUENCE_LABELING, collaborative_annotation=True)
repository = SequenceLabelingRepository(project)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": [
(self.span1.start_offset, self.span1.end_offset, self.span1.label.text),
(self.span2.start_offset, self.span2.end_offset, self.span2.label.text),
],
"user": "all",
}
]
self.assert_records(repository, expected)
class TestRelationExtractionRepository(TestRepository):
def test_list(self):
project = prepare_project(SEQUENCE_LABELING, use_relation=True)
example = mommy.make("Example", project=project.item, text="example")
span1 = mommy.make("Span", example=example, user=project.admin, start_offset=0, end_offset=1)
span2 = mommy.make("Span", example=example, user=project.admin, start_offset=1, end_offset=2)
relation = mommy.make("Relation", from_id=span1, to_id=span2, example=example, user=project.admin)
repository = RelationExtractionRepository(project.item)
expected = [
{
"data": example.text,
"label": {
"entities": [
{
"id": span1.id,
"start_offset": span1.start_offset,
"end_offset": span1.end_offset,
"label": span1.label.text,
},
{
"id": span2.id,
"start_offset": span2.start_offset,
"end_offset": span2.end_offset,
"label": span2.label.text,
},
],
"relations": [
{"id": relation.id, "from_id": span1.id, "to_id": span2.id, "type": relation.type.text}
],
},
"user": project.admin.username,
}
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(SEQUENCE_LABELING, collaborative_annotation=True, use_relation=True)
example = mommy.make("Example", project=project.item, text="example")
span1 = mommy.make("Span", example=example, user=project.admin, start_offset=0, end_offset=1)
span2 = mommy.make("Span", example=example, user=project.annotator, start_offset=1, end_offset=2)
relation = mommy.make("Relation", from_id=span1, to_id=span2, example=example, user=project.admin)
repository = RelationExtractionRepository(project.item)
expected = [
{
"data": example.text,
"label": {
"entities": [
{
"id": span1.id,
"start_offset": span1.start_offset,
"end_offset": span1.end_offset,
"label": span1.label.text,
},
{
"id": span2.id,
"start_offset": span2.start_offset,
"end_offset": span2.end_offset,
"label": span2.label.text,
},
],
"relations": [
{"id": relation.id, "from_id": span1.id, "to_id": span2.id, "type": relation.type.text}
],
},
"user": "all",
}
}
actual = repository.label_per_user(from_entity.example)
self.assertDictEqual(actual, expected)
]
self.assert_records(repository, expected)
Loading…
Cancel
Save