From 25cde888e2638ecc955f3b713df6fd23f3603af8 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Tue, 6 Apr 2021 11:02:16 +0900 Subject: [PATCH] Update dataset to use data class --- app/api/views/upload/dataset.py | 37 +++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/app/api/views/upload/dataset.py b/app/api/views/upload/dataset.py index c62b3fe1..bed273db 100644 --- a/app/api/views/upload/dataset.py +++ b/app/api/views/upload/dataset.py @@ -4,22 +4,19 @@ from typing import Dict, Iterator, List, Optional, Type import pyexcel +from .data import BaseData from .label import Label class Record: def __init__(self, - filename: str, - data: str = '', - label: List[Label] = None, - metadata: Optional[Dict] = None): - if metadata is None: - metadata = {} - self.filename = filename + data: Type[BaseData], + label: List[Label] = None): + if label is None: + label = [] self.data = data self.label = label - self.metadata = metadata def __str__(self): return f'{self.data}\t{self.label}' @@ -29,12 +26,14 @@ class Dataset: def __init__(self, filenames: List[str], + data_class: Type[BaseData], label_class: Type[Label], encoding: Optional[str] = None, column_data: str = 'text', column_label: str = 'label', **kwargs): self.filenames = filenames + self.data_class = data_class self.label_class = label_class self.encoding = encoding self.column_data = column_data @@ -48,22 +47,25 @@ class Dataset: def load(self, filename: str) -> Iterator[Record]: """Loads a file content.""" with open(filename, encoding=self.encoding) as f: - record = Record(filename=filename, data=f.read()) + data = self.data_class.parse(filename=filename, text=f.read()) + record = Record(data=data) yield record def from_row(self, filename: str, row: Dict) -> Record: - data = row.pop(self.column_data) + text = row.pop(self.column_data) label = row.pop(self.column_label, []) label = [label] if isinstance(label, str) else label label = [self.label_class.parse(o) for o in label] - record = Record(filename=filename, data=data, label=label, metadata=row) + data = self.data_class.parse(text=text, filename=filename, metadata=row) + record = Record(data=data, label=label) return record class FileBaseDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: - record = Record(filename=filename, data=filename) + data = self.data_class.parse(filename=filename) + record = Record(data=data) yield record @@ -71,7 +73,8 @@ class TextFileDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: with open(filename, encoding=self.encoding) as f: - record = Record(filename=filename, data=f.read()) + data = self.data_class.parse(filename=filename, text=f.read()) + record = Record(data=data) yield record @@ -80,7 +83,8 @@ class TextLineDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: with open(filename, encoding=self.encoding) as f: for line in f: - record = Record(filename=filename, data=line.rstrip()) + data = self.data_class.parse(filename=filename, text=line.rstrip()) + record = Record(data=data) yield record @@ -135,8 +139,9 @@ class FastTextDataset(Dataset): labels.append(self.label_class.parse(label_name)) else: tokens.append(token) - data = ' '.join(tokens) - record = Record(filename=filename, data=data, label=labels) + text = ' '.join(tokens) + data = self.data_class.parse(filename=filename, text=text) + record = Record(data=data, label=labels) yield record