diff --git a/backend/data_export/pipeline/factories.py b/backend/data_export/pipeline/factories.py index 9c930bca..44a38235 100644 --- a/backend/data_export/pipeline/factories.py +++ b/backend/data_export/pipeline/factories.py @@ -12,6 +12,8 @@ from projects.models import ( def create_repository(project): + if getattr(project, "use_relation", False): + return repositories.RelationExtractionRepository(project) mapping = { DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository, SEQUENCE_LABELING: repositories.SequenceLabelingRepository, @@ -33,6 +35,7 @@ def create_writer(file_format: str) -> Type[writers.BaseWriter]: catalog.JSONL.name: writers.JSONLWriter, catalog.FastText.name: writers.FastTextWriter, catalog.IntentAndSlot.name: writers.IntentAndSlotWriter, + catalog.JSONLRelation.name: writers.EntityAndRelationWriter, } if file_format not in mapping: ValueError(f"Invalid format: {file_format}")