|
|
@ -4,22 +4,19 @@ from typing import Dict, Iterator, List, Optional, Type |
|
|
|
|
|
|
|
import pyexcel |
|
|
|
|
|
|
|
from .data import BaseData |
|
|
|
from .label import Label |
|
|
|
|
|
|
|
|
|
|
|
class Record: |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
filename: str, |
|
|
|
data: str = '', |
|
|
|
label: List[Label] = None, |
|
|
|
metadata: Optional[Dict] = None): |
|
|
|
if metadata is None: |
|
|
|
metadata = {} |
|
|
|
self.filename = filename |
|
|
|
data: Type[BaseData], |
|
|
|
label: List[Label] = None): |
|
|
|
if label is None: |
|
|
|
label = [] |
|
|
|
self.data = data |
|
|
|
self.label = label |
|
|
|
self.metadata = metadata |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
return f'{self.data}\t{self.label}' |
|
|
@ -29,12 +26,14 @@ class Dataset: |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
filenames: List[str], |
|
|
|
data_class: Type[BaseData], |
|
|
|
label_class: Type[Label], |
|
|
|
encoding: Optional[str] = None, |
|
|
|
column_data: str = 'text', |
|
|
|
column_label: str = 'label', |
|
|
|
**kwargs): |
|
|
|
self.filenames = filenames |
|
|
|
self.data_class = data_class |
|
|
|
self.label_class = label_class |
|
|
|
self.encoding = encoding |
|
|
|
self.column_data = column_data |
|
|
@ -48,22 +47,25 @@ class Dataset: |
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
"""Loads a file content.""" |
|
|
|
with open(filename, encoding=self.encoding) as f: |
|
|
|
record = Record(filename=filename, data=f.read()) |
|
|
|
data = self.data_class.parse(filename=filename, text=f.read()) |
|
|
|
record = Record(data=data) |
|
|
|
yield record |
|
|
|
|
|
|
|
def from_row(self, filename: str, row: Dict) -> Record: |
|
|
|
data = row.pop(self.column_data) |
|
|
|
text = row.pop(self.column_data) |
|
|
|
label = row.pop(self.column_label, []) |
|
|
|
label = [label] if isinstance(label, str) else label |
|
|
|
label = [self.label_class.parse(o) for o in label] |
|
|
|
record = Record(filename=filename, data=data, label=label, metadata=row) |
|
|
|
data = self.data_class.parse(text=text, filename=filename, metadata=row) |
|
|
|
record = Record(data=data, label=label) |
|
|
|
return record |
|
|
|
|
|
|
|
|
|
|
|
class FileBaseDataset(Dataset): |
|
|
|
|
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
record = Record(filename=filename, data=filename) |
|
|
|
data = self.data_class.parse(filename=filename) |
|
|
|
record = Record(data=data) |
|
|
|
yield record |
|
|
|
|
|
|
|
|
|
|
@ -71,7 +73,8 @@ class TextFileDataset(Dataset): |
|
|
|
|
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
with open(filename, encoding=self.encoding) as f: |
|
|
|
record = Record(filename=filename, data=f.read()) |
|
|
|
data = self.data_class.parse(filename=filename, text=f.read()) |
|
|
|
record = Record(data=data) |
|
|
|
yield record |
|
|
|
|
|
|
|
|
|
|
@ -80,7 +83,8 @@ class TextLineDataset(Dataset): |
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
with open(filename, encoding=self.encoding) as f: |
|
|
|
for line in f: |
|
|
|
record = Record(filename=filename, data=line.rstrip()) |
|
|
|
data = self.data_class.parse(filename=filename, text=line.rstrip()) |
|
|
|
record = Record(data=data) |
|
|
|
yield record |
|
|
|
|
|
|
|
|
|
|
@ -135,8 +139,9 @@ class FastTextDataset(Dataset): |
|
|
|
labels.append(self.label_class.parse(label_name)) |
|
|
|
else: |
|
|
|
tokens.append(token) |
|
|
|
data = ' '.join(tokens) |
|
|
|
record = Record(filename=filename, data=data, label=labels) |
|
|
|
text = ' '.join(tokens) |
|
|
|
data = self.data_class.parse(filename=filename, text=text) |
|
|
|
record = Record(data=data, label=labels) |
|
|
|
yield record |
|
|
|
|
|
|
|
|
|
|
|