|
|
@ -1,16 +1,18 @@ |
|
|
|
import csv |
|
|
|
import json |
|
|
|
from typing import Any, Dict, Iterator, List, Optional |
|
|
|
from typing import Dict, Iterator, List, Optional, Type |
|
|
|
|
|
|
|
import pyexcel |
|
|
|
|
|
|
|
from .label import Label |
|
|
|
|
|
|
|
|
|
|
|
class Record: |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
filename: str, |
|
|
|
data: str = '', |
|
|
|
label: Any = None, |
|
|
|
label: List[Label] = None, |
|
|
|
metadata: Optional[Dict] = None): |
|
|
|
if metadata is None: |
|
|
|
metadata = {} |
|
|
@ -27,11 +29,13 @@ class Dataset: |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
filenames: List[str], |
|
|
|
label_class: Type[Label], |
|
|
|
encoding: Optional[str] = None, |
|
|
|
column_data: str = 'text', |
|
|
|
column_label: str = 'label', |
|
|
|
**kwargs): |
|
|
|
self.filenames = filenames |
|
|
|
self.label_class = label_class |
|
|
|
self.encoding = encoding |
|
|
|
self.column_data = column_data |
|
|
|
self.column_label = column_label |
|
|
@ -51,6 +55,7 @@ class Dataset: |
|
|
|
data = row.pop(self.column_data) |
|
|
|
label = row.pop(self.column_label, []) |
|
|
|
label = [label] if isinstance(label, str) else label |
|
|
|
label = [self.label_class.parse(o) for o in label] |
|
|
|
record = Record(filename=filename, data=data, label=label, metadata=row) |
|
|
|
return record |
|
|
|
|
|
|
@ -126,7 +131,8 @@ class FastTextDataset(Dataset): |
|
|
|
tokens = [] |
|
|
|
for token in line.rstrip().split(' '): |
|
|
|
if token.startswith('__label__'): |
|
|
|
labels.append(token[len('__label__'):]) |
|
|
|
label_name = token[len('__label__'):] |
|
|
|
labels.append(self.label_class.parse(label_name)) |
|
|
|
else: |
|
|
|
tokens.append(token) |
|
|
|
data = ' '.join(tokens) |
|
|
|