diff --git a/app/api/views/download/repositories.py b/app/api/views/download/repositories.py index 82d25ac8..e4ccbef4 100644 --- a/app/api/views/download/repositories.py +++ b/app/api/views/download/repositories.py @@ -1,6 +1,7 @@ import abc +import itertools from collections import defaultdict -from typing import Dict, Iterator +from typing import Dict, Iterator, List from ...models import Project from .data import Record @@ -29,6 +30,8 @@ class TextRepository(BaseRepository): for doc in docs: label_per_user = self.label_per_user(doc) + if self.project.collaborative_annotation: + label_per_user = self.reduce_user(label_per_user) for user, label in label_per_user.items(): yield Record( id=doc.id, @@ -42,6 +45,10 @@ class TextRepository(BaseRepository): def label_per_user(self, doc) -> Dict: raise NotImplementedError() + def reduce_user(self, label_per_user: Dict[str, List]): + value = list(itertools.chain(*label_per_user.values())) + return {'all': value} + class TextClassificationRepository(TextRepository):