diff --git a/backend/data_export/tests/test_repositories.py b/backend/data_export/tests/test_repositories.py index 9202b047..a4780de4 100644 --- a/backend/data_export/tests/test_repositories.py +++ b/backend/data_export/tests/test_repositories.py @@ -5,12 +5,14 @@ from model_mommy import mommy from ..pipeline.repositories import ( IntentDetectionSlotFillingRepository, RelationExtractionRepository, + Seq2seqRepository, SequenceLabelingRepository, TextClassificationRepository, ) from projects.models import ( DOCUMENT_CLASSIFICATION, INTENT_DETECTION_AND_SLOT_FILLING, + SEQ2SEQ, SEQUENCE_LABELING, ) from projects.tests.utils import prepare_project @@ -64,6 +66,44 @@ class TestTextClassificationRepository(TestRepository): self.assert_records(repository, expected) +class TestSeq2seqRepository(TestRepository): + def prepare_data(self, project): + self.example = mommy.make("Example", project=project.item, text="example") + self.text1 = mommy.make("TextLabel", example=self.example, user=project.admin) + self.text2 = mommy.make("TextLabel", example=self.example, user=project.annotator) + + def test_list(self): + project = prepare_project(SEQ2SEQ) + repository = Seq2seqRepository(project.item) + self.prepare_data(project) + expected = [ + { + "data": self.example.text, + "label": [self.text1.text], + "user": project.admin.username, + }, + { + "data": self.example.text, + "label": [self.text2.text], + "user": project.annotator.username, + }, + ] + self.assert_records(repository, expected) + + def test_list_on_collaborative_annotation(self): + project = prepare_project(SEQ2SEQ, collaborative_annotation=True) + repository = Seq2seqRepository(project.item) + self.prepare_data(project) + expected = [ + { + "data": self.example.text, + "label": [self.text1.text, self.text2.text], + "user": "all", + } + ] + self.assert_records(repository, expected) + + class TestIntentDetectionSlotFillingRepository(TestRepository): def prepare_data(self, project): self.example = mommy.make("Example", project=project.item, text="example")