Browse Source

Update dataset classes to use label class

pull/1310/head
Hironsan 3 years ago
parent
commit
10acca1704
1 changed files with 9 additions and 3 deletions
  1. 12
      app/api/views/upload/dataset.py

12
app/api/views/upload/dataset.py

@ -1,16 +1,18 @@
import csv
import json
from typing import Any, Dict, Iterator, List, Optional
from typing import Dict, Iterator, List, Optional, Type
import pyexcel
from .label import Label
class Record:
def __init__(self,
filename: str,
data: str = '',
label: Any = None,
label: List[Label] = None,
metadata: Optional[Dict] = None):
if metadata is None:
metadata = {}
@ -27,11 +29,13 @@ class Dataset:
def __init__(self,
filenames: List[str],
label_class: Type[Label],
encoding: Optional[str] = None,
column_data: str = 'text',
column_label: str = 'label',
**kwargs):
self.filenames = filenames
self.label_class = label_class
self.encoding = encoding
self.column_data = column_data
self.column_label = column_label
@ -51,6 +55,7 @@ class Dataset:
data = row.pop(self.column_data)
label = row.pop(self.column_label, [])
label = [label] if isinstance(label, str) else label
label = [self.label_class.parse(o) for o in label]
record = Record(filename=filename, data=data, label=label, metadata=row)
return record
@ -126,7 +131,8 @@ class FastTextDataset(Dataset):
tokens = []
for token in line.rstrip().split(' '):
if token.startswith('__label__'):
labels.append(token[len('__label__'):])
label_name = token[len('__label__'):]
labels.append(self.label_class.parse(label_name))
else:
tokens.append(token)
data = ' '.join(tokens)

Loading…
Cancel
Save