diff --git a/backend/data_export/pipeline/repositories.py b/backend/data_export/pipeline/repositories.py index a45fc5b5..2d7f8db7 100644 --- a/backend/data_export/pipeline/repositories.py +++ b/backend/data_export/pipeline/repositories.py @@ -80,7 +80,11 @@ class TextRepository(BaseRepository): for doc in docs: label_per_user = self.label_per_user(doc) if self.project.collaborative_annotation: - label_per_user = self.reduce_user(label_per_user) + if getattr(self.project, "use_relation", False): + value_type = "dict" + else: + value_type = "list" + label_per_user = self.reduce_user(label_per_user, value_type) for user, label in label_per_user.items(): yield Record(data_id=doc.id, data=doc.text, label=label, user=user, metadata=doc.meta) # todo: @@ -96,9 +100,17 @@ class TextRepository(BaseRepository): def label_per_user(self, doc) -> Dict: raise NotImplementedError() - def reduce_user(self, label_per_user: Dict[str, List]): - value = list(itertools.chain(*label_per_user.values())) - return {"all": value} + def reduce_user(self, label_per_user: Dict, value_type): + if value_type == "list": + value_list = list(itertools.chain(*label_per_user)) + return {"all": value_list} + if value_type == "dict": + value_dict = dict( + (label_type, label_per_user[user][label_type]) + for user in label_per_user + for label_type in label_per_user[user] + ) + return {"all": value_dict} class TextClassificationRepository(TextRepository):