diff --git a/backend/data_export/pipeline/repositories.py b/backend/data_export/pipeline/repositories.py index 1423e235..9734b193 100644 --- a/backend/data_export/pipeline/repositories.py +++ b/backend/data_export/pipeline/repositories.py @@ -1,7 +1,7 @@ import abc import itertools from collections import defaultdict -from typing import Any, Dict, Iterator, List, Tuple, Union +from typing import Any, Dict, Iterator, List, Tuple from .data import Record from examples.models import Example @@ -10,13 +10,12 @@ from projects.models import Project SpanType = Tuple[int, int, str] -class BaseRepository(abc.ABC): +class BaseRepository: def __init__(self, project: Project): self.project = project - @abc.abstractmethod def list(self, export_approved=False) -> Iterator[Record]: - pass + raise NotImplementedError() class FileRepository(BaseRepository): @@ -192,7 +191,7 @@ class IntentDetectionSlotFillingRepository(TextRepository): def label_per_user(self, doc) -> Dict: category_per_user: Dict[str, List[str]] = defaultdict(list) span_per_user: Dict[str, List[SpanType]] = defaultdict(list) - label_per_user: Dict[str, Dict[str, Union[List[str], List[SpanType]]]] = defaultdict(dict) + label_per_user: Dict[str, Dict[str, List]] = defaultdict(dict) for a in doc.categories.all(): category_per_user[a.user.username].append(a.label.text) for a in doc.spans.all():