diff --git a/backend/data_export/pipeline/repositories.py b/backend/data_export/pipeline/repositories.py index 2d7f8db7..e50eafe3 100644 --- a/backend/data_export/pipeline/repositories.py +++ b/backend/data_export/pipeline/repositories.py @@ -1,7 +1,7 @@ import abc import itertools from collections import defaultdict -from typing import Dict, Iterator, List, Tuple, Union +from typing import Any, Dict, Iterator, List, Tuple, Union from .data import Record from examples.models import Example @@ -54,7 +54,7 @@ class FileRepository(BaseRepository): label_per_user[a.user.username].append(a.label.text) return label_per_user - def reduce_user(self, label_per_user: Dict[str, List]): + def reduce_user(self, label_per_user: Dict[str, Any]): value = list(itertools.chain(*label_per_user.values())) return {"all": value} @@ -80,11 +80,7 @@ class TextRepository(BaseRepository): for doc in docs: label_per_user = self.label_per_user(doc) if self.project.collaborative_annotation: - if getattr(self.project, "use_relation", False): - value_type = "dict" - else: - value_type = "list" - label_per_user = self.reduce_user(label_per_user, value_type) + label_per_user = self.reduce_user(label_per_user) for user, label in label_per_user.items(): yield Record(data_id=doc.id, data=doc.text, label=label, user=user, metadata=doc.meta) # todo: @@ -100,17 +96,9 @@ class TextRepository(BaseRepository): def label_per_user(self, doc) -> Dict: raise NotImplementedError() - def reduce_user(self, label_per_user: Dict, value_type): - if value_type == "list": - value_list = list(itertools.chain(*label_per_user)) - return {"all": value_list} - if value_type == "dict": - value_dict = dict( - (label_type, label_per_user[user][label_type]) - for user in label_per_user - for label_type in label_per_user[user] - ) - return {"all": value_dict} + def reduce_user(self, label_per_user: Dict[str, Any]): + value = list(itertools.chain(*label_per_user.values())) + return {"all": value} class TextClassificationRepository(TextRepository): @@ -173,6 +161,14 @@ class RelationExtractionRepository(TextRepository): label_per_user[user]["entities"] = span return label_per_user + def reduce_user(self, label_per_user: Dict[str, Any]): + entities = [] + relations = [] + for user, label in label_per_user.items(): + entities.extend(label.get("entities", [])) + relations.extend(label.get("relations", [])) + return {"all": {"entities": entities, "relations": relations}} + class Seq2seqRepository(TextRepository): @property diff --git a/backend/data_export/tests/test_repositories.py b/backend/data_export/tests/test_repositories.py index d7654ecc..0ff25287 100644 --- a/backend/data_export/tests/test_repositories.py +++ b/backend/data_export/tests/test_repositories.py @@ -5,70 +5,176 @@ from model_mommy import mommy from ..pipeline.repositories import ( IntentDetectionSlotFillingRepository, RelationExtractionRepository, + SequenceLabelingRepository, ) from projects.models import INTENT_DETECTION_AND_SLOT_FILLING, SEQUENCE_LABELING from projects.tests.utils import prepare_project -class TestCSVWriter(unittest.TestCase): - def setUp(self): - self.project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING) +class TestRepository(unittest.TestCase): + def assert_records(self, repository, expected): + records = list(repository.list()) + self.assertEqual(len(records), len(expected)) + for record, expect in zip(records, expected): + self.assertEqual(record.data, expect["data"]) + self.assertEqual(record.label, expect["label"]) + self.assertEqual(record.user, expect["user"]) + + +class TestIntentDetectionSlotFillingRepository(TestRepository): + def prepare_data(self, project): + self.example = mommy.make("Example", project=project.item, text="example") + self.category1 = mommy.make("Category", example=self.example, user=project.admin) + self.category2 = mommy.make("Category", example=self.example, user=project.annotator) + self.span = mommy.make("Span", example=self.example, user=project.admin, start_offset=0, end_offset=1) def test_list(self): - example = mommy.make("Example", project=self.project.item, text="example") - category = mommy.make("Category", example=example, user=self.project.admin) - span = mommy.make("Span", example=example, user=self.project.admin, start_offset=0, end_offset=1) - repository = IntentDetectionSlotFillingRepository(self.project.item) + project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING) + repository = IntentDetectionSlotFillingRepository(project.item) + self.prepare_data(project) expected = [ { - "data": example.text, + "data": self.example.text, + "label": { + "cats": [self.category1.label.text], + "entities": [(self.span.start_offset, self.span.end_offset, self.span.label.text)], + }, + "user": project.admin.username, + }, + { + "data": self.example.text, + "label": { + "cats": [self.category2.label.text], + "entities": [], + }, + "user": project.annotator.username, + }, + ] + self.assert_records(repository, expected) + + def test_list_on_collaborative_annotation(self): + project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING, collaborative_annotation=True) + repository = IntentDetectionSlotFillingRepository(project.item) + self.prepare_data(project) + expected = [ + { + "data": self.example.text, "label": { - "cats": [category.label.text], - "entities": [(span.start_offset, span.end_offset, span.label.text)], + "cats": [self.category1.label.text, self.category2.label.text], + "entities": [(self.span.start_offset, self.span.end_offset, self.span.label.text)], }, + "user": "all", } ] - records = list(repository.list()) - self.assertEqual(len(records), len(expected)) - for record, expect in zip(records, expected): - self.assertEqual(record.data, expect["data"]) - self.assertEqual(record.label["cats"], expect["label"]["cats"]) - self.assertEqual(record.label["entities"], expect["label"]["entities"]) + self.assert_records(repository, expected) -class TestRelationExtractionRepository(unittest.TestCase): - def setUp(self): - self.project = prepare_project(SEQUENCE_LABELING, use_relation=True) +class TestSequenceLabelingRepository(TestRepository): + def prepare_data(self, project): + self.example = mommy.make("Example", project=project.item, text="example") + self.span1 = mommy.make("Span", example=self.example, user=project.admin, start_offset=0, end_offset=1) + self.span2 = mommy.make("Span", example=self.example, user=project.annotator, start_offset=1, end_offset=2) - def test_label_per_user(self): - from_entity = mommy.make("Span", start_offset=0, end_offset=1, user=self.project.admin) - to_entity = mommy.make( - "Span", start_offset=1, end_offset=2, example=from_entity.example, user=self.project.admin - ) - relation = mommy.make( - "Relation", from_id=from_entity, to_id=to_entity, example=from_entity.example, user=self.project.admin - ) - repository = RelationExtractionRepository(self.project.item) - expected = { - "admin": { - "entities": [ - { - "id": from_entity.id, - "start_offset": from_entity.start_offset, - "end_offset": from_entity.end_offset, - "label": from_entity.label.text, - }, - { - "id": to_entity.id, - "start_offset": to_entity.start_offset, - "end_offset": to_entity.end_offset, - "label": to_entity.label.text, - }, - ], - "relations": [ - {"id": relation.id, "from_id": from_entity.id, "to_id": to_entity.id, "type": relation.type.text} + def test_list(self): + project = prepare_project(SEQUENCE_LABELING) + repository = SequenceLabelingRepository(project) + self.prepare_data(project) + expected = [ + { + "data": self.example.text, + "label": [(self.span1.start_offset, self.span1.end_offset, self.span1.label.text)], + "user": project.admin.username, + }, + { + "data": self.example.text, + "label": [(self.span2.start_offset, self.span2.end_offset, self.span2.label.text)], + "user": project.annotator.username, + }, + ] + self.assert_records(repository, expected) + + def test_list_on_collaborative_annotation(self): + project = prepare_project(SEQUENCE_LABELING, collaborative_annotation=True) + repository = SequenceLabelingRepository(project) + self.prepare_data(project) + expected = [ + { + "data": self.example.text, + "label": [ + (self.span1.start_offset, self.span1.end_offset, self.span1.label.text), + (self.span2.start_offset, self.span2.end_offset, self.span2.label.text), ], + "user": "all", + } + ] + self.assert_records(repository, expected) + + +class TestRelationExtractionRepository(TestRepository): + def test_list(self): + project = prepare_project(SEQUENCE_LABELING, use_relation=True) + example = mommy.make("Example", project=project.item, text="example") + span1 = mommy.make("Span", example=example, user=project.admin, start_offset=0, end_offset=1) + span2 = mommy.make("Span", example=example, user=project.admin, start_offset=1, end_offset=2) + relation = mommy.make("Relation", from_id=span1, to_id=span2, example=example, user=project.admin) + repository = RelationExtractionRepository(project.item) + expected = [ + { + "data": example.text, + "label": { + "entities": [ + { + "id": span1.id, + "start_offset": span1.start_offset, + "end_offset": span1.end_offset, + "label": span1.label.text, + }, + { + "id": span2.id, + "start_offset": span2.start_offset, + "end_offset": span2.end_offset, + "label": span2.label.text, + }, + ], + "relations": [ + {"id": relation.id, "from_id": span1.id, "to_id": span2.id, "type": relation.type.text} + ], + }, + "user": project.admin.username, + } + ] + self.assert_records(repository, expected) + + def test_list_on_collaborative_annotation(self): + project = prepare_project(SEQUENCE_LABELING, collaborative_annotation=True, use_relation=True) + example = mommy.make("Example", project=project.item, text="example") + span1 = mommy.make("Span", example=example, user=project.admin, start_offset=0, end_offset=1) + span2 = mommy.make("Span", example=example, user=project.annotator, start_offset=1, end_offset=2) + relation = mommy.make("Relation", from_id=span1, to_id=span2, example=example, user=project.admin) + repository = RelationExtractionRepository(project.item) + expected = [ + { + "data": example.text, + "label": { + "entities": [ + { + "id": span1.id, + "start_offset": span1.start_offset, + "end_offset": span1.end_offset, + "label": span1.label.text, + }, + { + "id": span2.id, + "start_offset": span2.start_offset, + "end_offset": span2.end_offset, + "label": span2.label.text, + }, + ], + "relations": [ + {"id": relation.id, "from_id": span1.id, "to_id": span2.id, "type": relation.type.text} + ], + }, + "user": "all", } - } - actual = repository.label_per_user(from_entity.example) - self.assertDictEqual(actual, expected) + ] + self.assert_records(repository, expected)