diff --git a/app/api/views/upload/dataset.py b/app/api/views/upload/dataset.py index 26d6735f..097b4078 100644 --- a/app/api/views/upload/dataset.py +++ b/app/api/views/upload/dataset.py @@ -1,11 +1,13 @@ import csv import json +from itertools import chain from typing import Dict, Iterator, List, Optional, Type import pyexcel from .data import BaseData from .label import Label +from .labels import Labels class Record: @@ -15,23 +17,41 @@ class Record: label: List[Label] = None): if label is None: label = [] - self.data = data - self.label = label + self._data = data + self._label = label def __str__(self): - return f'{self.data}\t{self.label}' + return f'{self._data}\t{self._label}' - def dict(self): - label_names = [ - { - 'text': label.name - } for label in self.label if label.has_name() + @property + def data(self): + return self._data + + @property + def annotation(self): + return Labels(self._label) + + @property + def label(self): + return [label.name for label in self._label if label.has_name()] + + +class Records: + + def __init__(self, records: List[Record]): + self.records = records + + def data(self): + return [r.data.dict() for r in self.records] + + def annotation(self, mapping: Dict[str, int]): + return [r.annotation.replace_label(mapping).dict() for r in self.records] + + def label(self): + labels = set(chain(*[r.label for r in self.records])) + return [ + {'text': label} for label in labels ] - return { - 'data': self.data.dict(), - 'annotation': [label.dict() for label in self.label], - 'label': label_names - } class Dataset: @@ -41,21 +61,26 @@ class Dataset: data_class: Type[BaseData], label_class: Type[Label], encoding: Optional[str] = None, - column_data: str = 'text', - column_label: str = 'label', **kwargs): self.filenames = filenames self.data_class = data_class self.label_class = label_class self.encoding = encoding - self.column_data = column_data - self.column_label = column_label self.kwargs = kwargs def __iter__(self) -> Iterator[Record]: for filename in self.filenames: yield from self.load(filename) + def batch(self, batch_size) -> Records: + records = [] + for record in self: + records.append(record) + if len(records) == batch_size: + yield Records(records) + records = [] + yield Records(records) + def load(self, filename: str) -> Iterator[Record]: """Loads a file content.""" with open(filename, encoding=self.encoding) as f: @@ -64,8 +89,8 @@ class Dataset: yield record def from_row(self, filename: str, row: Dict) -> Record: - text = row.pop(self.column_data) - label = row.pop(self.column_label, []) + text = row.pop(self.kwargs.get('column_data', 'text')) + label = row.pop(self.kwargs.get('column_label', 'label'), []) label = [label] if isinstance(label, str) else label label = [self.label_class.parse(o) for o in label] data = self.data_class.parse(text=text, filename=filename, metadata=row)