|
|
@ -1,6 +1,7 @@ |
|
|
|
import abc |
|
|
|
import itertools |
|
|
|
from collections import defaultdict |
|
|
|
from typing import Dict, Iterator |
|
|
|
from typing import Dict, Iterator, List |
|
|
|
|
|
|
|
from ...models import Project |
|
|
|
from .data import Record |
|
|
@ -29,6 +30,8 @@ class TextRepository(BaseRepository): |
|
|
|
|
|
|
|
for doc in docs: |
|
|
|
label_per_user = self.label_per_user(doc) |
|
|
|
if self.project.collaborative_annotation: |
|
|
|
label_per_user = self.reduce_user(label_per_user) |
|
|
|
for user, label in label_per_user.items(): |
|
|
|
yield Record( |
|
|
|
id=doc.id, |
|
|
@ -42,6 +45,10 @@ class TextRepository(BaseRepository): |
|
|
|
def label_per_user(self, doc) -> Dict: |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
def reduce_user(self, label_per_user: Dict[str, List]): |
|
|
|
value = list(itertools.chain(*label_per_user.values())) |
|
|
|
return {'all': value} |
|
|
|
|
|
|
|
|
|
|
|
class TextClassificationRepository(TextRepository): |
|
|
|
|
|
|
|