|
|
@ -1,11 +1,13 @@ |
|
|
|
import csv |
|
|
|
import json |
|
|
|
from itertools import chain |
|
|
|
from typing import Dict, Iterator, List, Optional, Type |
|
|
|
|
|
|
|
import pyexcel |
|
|
|
|
|
|
|
from .data import BaseData |
|
|
|
from .label import Label |
|
|
|
from .labels import Labels |
|
|
|
|
|
|
|
|
|
|
|
class Record: |
|
|
@ -15,23 +17,41 @@ class Record: |
|
|
|
label: List[Label] = None): |
|
|
|
if label is None: |
|
|
|
label = [] |
|
|
|
self.data = data |
|
|
|
self.label = label |
|
|
|
self._data = data |
|
|
|
self._label = label |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
return f'{self.data}\t{self.label}' |
|
|
|
return f'{self._data}\t{self._label}' |
|
|
|
|
|
|
|
def dict(self): |
|
|
|
label_names = [ |
|
|
|
{ |
|
|
|
'text': label.name |
|
|
|
} for label in self.label if label.has_name() |
|
|
|
@property |
|
|
|
def data(self): |
|
|
|
return self._data |
|
|
|
|
|
|
|
@property |
|
|
|
def annotation(self): |
|
|
|
return Labels(self._label) |
|
|
|
|
|
|
|
@property |
|
|
|
def label(self): |
|
|
|
return [label.name for label in self._label if label.has_name()] |
|
|
|
|
|
|
|
|
|
|
|
class Records: |
|
|
|
|
|
|
|
def __init__(self, records: List[Record]): |
|
|
|
self.records = records |
|
|
|
|
|
|
|
def data(self): |
|
|
|
return [r.data.dict() for r in self.records] |
|
|
|
|
|
|
|
def annotation(self, mapping: Dict[str, int]): |
|
|
|
return [r.annotation.replace_label(mapping).dict() for r in self.records] |
|
|
|
|
|
|
|
def label(self): |
|
|
|
labels = set(chain(*[r.label for r in self.records])) |
|
|
|
return [ |
|
|
|
{'text': label} for label in labels |
|
|
|
] |
|
|
|
return { |
|
|
|
'data': self.data.dict(), |
|
|
|
'annotation': [label.dict() for label in self.label], |
|
|
|
'label': label_names |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class Dataset: |
|
|
@ -41,21 +61,26 @@ class Dataset: |
|
|
|
data_class: Type[BaseData], |
|
|
|
label_class: Type[Label], |
|
|
|
encoding: Optional[str] = None, |
|
|
|
column_data: str = 'text', |
|
|
|
column_label: str = 'label', |
|
|
|
**kwargs): |
|
|
|
self.filenames = filenames |
|
|
|
self.data_class = data_class |
|
|
|
self.label_class = label_class |
|
|
|
self.encoding = encoding |
|
|
|
self.column_data = column_data |
|
|
|
self.column_label = column_label |
|
|
|
self.kwargs = kwargs |
|
|
|
|
|
|
|
def __iter__(self) -> Iterator[Record]: |
|
|
|
for filename in self.filenames: |
|
|
|
yield from self.load(filename) |
|
|
|
|
|
|
|
def batch(self, batch_size) -> Records: |
|
|
|
records = [] |
|
|
|
for record in self: |
|
|
|
records.append(record) |
|
|
|
if len(records) == batch_size: |
|
|
|
yield Records(records) |
|
|
|
records = [] |
|
|
|
yield Records(records) |
|
|
|
|
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
"""Loads a file content.""" |
|
|
|
with open(filename, encoding=self.encoding) as f: |
|
|
@ -64,8 +89,8 @@ class Dataset: |
|
|
|
yield record |
|
|
|
|
|
|
|
def from_row(self, filename: str, row: Dict) -> Record: |
|
|
|
text = row.pop(self.column_data) |
|
|
|
label = row.pop(self.column_label, []) |
|
|
|
text = row.pop(self.kwargs.get('column_data', 'text')) |
|
|
|
label = row.pop(self.kwargs.get('column_label', 'label'), []) |
|
|
|
label = [label] if isinstance(label, str) else label |
|
|
|
label = [self.label_class.parse(o) for o in label] |
|
|
|
data = self.data_class.parse(text=text, filename=filename, metadata=row) |
|
|
|