Browse Source

Update dataset to use data class

pull/1310/head
Hironsan 3 years ago
parent
commit
25cde888e2
1 changed files with 21 additions and 16 deletions
  1. 37
      app/api/views/upload/dataset.py

37
app/api/views/upload/dataset.py

@ -4,22 +4,19 @@ from typing import Dict, Iterator, List, Optional, Type
import pyexcel
from .data import BaseData
from .label import Label
class Record:
def __init__(self,
filename: str,
data: str = '',
label: List[Label] = None,
metadata: Optional[Dict] = None):
if metadata is None:
metadata = {}
self.filename = filename
data: Type[BaseData],
label: List[Label] = None):
if label is None:
label = []
self.data = data
self.label = label
self.metadata = metadata
def __str__(self):
return f'{self.data}\t{self.label}'
@ -29,12 +26,14 @@ class Dataset:
def __init__(self,
filenames: List[str],
data_class: Type[BaseData],
label_class: Type[Label],
encoding: Optional[str] = None,
column_data: str = 'text',
column_label: str = 'label',
**kwargs):
self.filenames = filenames
self.data_class = data_class
self.label_class = label_class
self.encoding = encoding
self.column_data = column_data
@ -48,22 +47,25 @@ class Dataset:
def load(self, filename: str) -> Iterator[Record]:
"""Loads a file content."""
with open(filename, encoding=self.encoding) as f:
record = Record(filename=filename, data=f.read())
data = self.data_class.parse(filename=filename, text=f.read())
record = Record(data=data)
yield record
def from_row(self, filename: str, row: Dict) -> Record:
data = row.pop(self.column_data)
text = row.pop(self.column_data)
label = row.pop(self.column_label, [])
label = [label] if isinstance(label, str) else label
label = [self.label_class.parse(o) for o in label]
record = Record(filename=filename, data=data, label=label, metadata=row)
data = self.data_class.parse(text=text, filename=filename, metadata=row)
record = Record(data=data, label=label)
return record
class FileBaseDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
record = Record(filename=filename, data=filename)
data = self.data_class.parse(filename=filename)
record = Record(data=data)
yield record
@ -71,7 +73,8 @@ class TextFileDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
with open(filename, encoding=self.encoding) as f:
record = Record(filename=filename, data=f.read())
data = self.data_class.parse(filename=filename, text=f.read())
record = Record(data=data)
yield record
@ -80,7 +83,8 @@ class TextLineDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
with open(filename, encoding=self.encoding) as f:
for line in f:
record = Record(filename=filename, data=line.rstrip())
data = self.data_class.parse(filename=filename, text=line.rstrip())
record = Record(data=data)
yield record
@ -135,8 +139,9 @@ class FastTextDataset(Dataset):
labels.append(self.label_class.parse(label_name))
else:
tokens.append(token)
data = ' '.join(tokens)
record = Record(filename=filename, data=data, label=labels)
text = ' '.join(tokens)
data = self.data_class.parse(filename=filename, text=text)
record = Record(data=data, label=labels)
yield record

Loading…
Cancel
Save