From 10acca17041a492596088abd1f0af7fd5a90a852 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Tue, 6 Apr 2021 10:07:15 +0900 Subject: [PATCH] Update dataset classes to use label class --- app/api/views/upload/dataset.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/app/api/views/upload/dataset.py b/app/api/views/upload/dataset.py index ead819b9..c62b3fe1 100644 --- a/app/api/views/upload/dataset.py +++ b/app/api/views/upload/dataset.py @@ -1,16 +1,18 @@ import csv import json -from typing import Any, Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional, Type import pyexcel +from .label import Label + class Record: def __init__(self, filename: str, data: str = '', - label: Any = None, + label: List[Label] = None, metadata: Optional[Dict] = None): if metadata is None: metadata = {} @@ -27,11 +29,13 @@ class Dataset: def __init__(self, filenames: List[str], + label_class: Type[Label], encoding: Optional[str] = None, column_data: str = 'text', column_label: str = 'label', **kwargs): self.filenames = filenames + self.label_class = label_class self.encoding = encoding self.column_data = column_data self.column_label = column_label @@ -51,6 +55,7 @@ class Dataset: data = row.pop(self.column_data) label = row.pop(self.column_label, []) label = [label] if isinstance(label, str) else label + label = [self.label_class.parse(o) for o in label] record = Record(filename=filename, data=data, label=label, metadata=row) return record @@ -126,7 +131,8 @@ class FastTextDataset(Dataset): tokens = [] for token in line.rstrip().split(' '): if token.startswith('__label__'): - labels.append(token[len('__label__'):]) + label_name = token[len('__label__'):] + labels.append(self.label_class.parse(label_name)) else: tokens.append(token) data = ' '.join(tokens)