|
|
@ -1,11 +1,11 @@ |
|
|
|
import csv |
|
|
|
import json |
|
|
|
from itertools import chain |
|
|
|
from typing import Dict, Iterator, List, Optional, Type |
|
|
|
|
|
|
|
import pyexcel |
|
|
|
|
|
|
|
from .data import BaseData |
|
|
|
from .exception import FileParseException |
|
|
|
from .label import Label |
|
|
|
from .labels import Labels |
|
|
|
|
|
|
@ -25,32 +25,20 @@ class Record: |
|
|
|
|
|
|
|
@property |
|
|
|
def data(self): |
|
|
|
return self._data |
|
|
|
|
|
|
|
@property |
|
|
|
def annotation(self): |
|
|
|
return Labels(self._label) |
|
|
|
|
|
|
|
@property |
|
|
|
def label(self): |
|
|
|
return [label.name for label in self._label if label.has_name() and label.name] |
|
|
|
|
|
|
|
|
|
|
|
class Records: |
|
|
|
|
|
|
|
def __init__(self, records: List[Record]): |
|
|
|
self.records = records |
|
|
|
|
|
|
|
def data(self): |
|
|
|
return [r.data.dict() for r in self.records] |
|
|
|
return self._data.dict() |
|
|
|
|
|
|
|
def annotation(self, mapping: Dict[str, int]): |
|
|
|
return [r.annotation.replace_label(mapping).dict() for r in self.records] |
|
|
|
labels = Labels(self._label) |
|
|
|
labels = labels.replace_label(mapping) |
|
|
|
return labels.dict() |
|
|
|
|
|
|
|
@property |
|
|
|
def label(self): |
|
|
|
labels = set(chain(*[r.label for r in self.records])) |
|
|
|
return [ |
|
|
|
{'text': label} for label in labels |
|
|
|
{ |
|
|
|
'text': label.name |
|
|
|
} for label in self._label |
|
|
|
if label.has_name() and label.name |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
@ -72,15 +60,6 @@ class Dataset: |
|
|
|
for filename in self.filenames: |
|
|
|
yield from self.load(filename) |
|
|
|
|
|
|
|
def batch(self, batch_size) -> Records: |
|
|
|
records = [] |
|
|
|
for record in self: |
|
|
|
records.append(record) |
|
|
|
if len(records) == batch_size: |
|
|
|
yield Records(records) |
|
|
|
records = [] |
|
|
|
yield Records(records) |
|
|
|
|
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
"""Loads a file content.""" |
|
|
|
with open(filename, encoding=self.encoding) as f: |
|
|
@ -88,7 +67,11 @@ class Dataset: |
|
|
|
record = Record(data=data) |
|
|
|
yield record |
|
|
|
|
|
|
|
def from_row(self, filename: str, row: Dict) -> Record: |
|
|
|
def from_row(self, filename: str, row: Dict, line_num: int) -> Record: |
|
|
|
column_data = self.kwargs.get('column_data', 'text') |
|
|
|
if column_data not in row: |
|
|
|
message = f'{column_data} does not exist.' |
|
|
|
raise FileParseException(filename, line_num, message) |
|
|
|
text = row.pop(self.kwargs.get('column_data', 'text')) |
|
|
|
label = row.pop(self.kwargs.get('column_label', 'label'), []) |
|
|
|
label = [label] if isinstance(label, str) else label |
|
|
@ -132,9 +115,15 @@ class CsvDataset(Dataset): |
|
|
|
delimiter = self.kwargs.get('delimiter', ',') |
|
|
|
reader = csv.reader(f, delimiter=delimiter) |
|
|
|
header = next(reader) |
|
|
|
for row in reader: |
|
|
|
|
|
|
|
column_data = self.kwargs.get('column_data', 'text') |
|
|
|
if column_data not in header: |
|
|
|
message = f'{column_data} does not exist in the header: {header}' |
|
|
|
raise FileParseException(filename, 1, message) |
|
|
|
|
|
|
|
for line_num, row in enumerate(reader, start=2): |
|
|
|
row = dict(zip(header, row)) |
|
|
|
yield self.from_row(filename, row) |
|
|
|
yield self.from_row(filename, row, line_num) |
|
|
|
|
|
|
|
|
|
|
|
class JSONDataset(Dataset): |
|
|
@ -142,36 +131,39 @@ class JSONDataset(Dataset): |
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
with open(filename, encoding=self.encoding) as f: |
|
|
|
dataset = json.load(f) |
|
|
|
for row in dataset: |
|
|
|
yield self.from_row(filename, row) |
|
|
|
for line_num, row in enumerate(dataset, start=1): |
|
|
|
yield self.from_row(filename, row, line_num) |
|
|
|
|
|
|
|
|
|
|
|
class JSONLDataset(Dataset): |
|
|
|
|
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
with open(filename, encoding=self.encoding) as f: |
|
|
|
for line in f: |
|
|
|
for line_num, line in enumerate(f, start=1): |
|
|
|
row = json.loads(line) |
|
|
|
yield self.from_row(filename, row) |
|
|
|
yield self.from_row(filename, row, line_num) |
|
|
|
|
|
|
|
|
|
|
|
class ExcelDataset(Dataset): |
|
|
|
|
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
records = pyexcel.iget_records(filename) |
|
|
|
for row in records: |
|
|
|
yield self.from_row(filename, row) |
|
|
|
for line_num, row in enumerate(records, start=1): |
|
|
|
yield self.from_row(filename, row, line_num) |
|
|
|
|
|
|
|
|
|
|
|
class FastTextDataset(Dataset): |
|
|
|
|
|
|
|
def load(self, filename: str) -> Iterator[Record]: |
|
|
|
with open(filename, encoding=self.encoding) as f: |
|
|
|
for i, line in enumerate(f, start=1): |
|
|
|
for line_num, line in enumerate(f, start=1): |
|
|
|
labels = [] |
|
|
|
tokens = [] |
|
|
|
for token in line.rstrip().split(' '): |
|
|
|
if token.startswith('__label__'): |
|
|
|
if token == '__label__': |
|
|
|
message = 'Label name is empty.' |
|
|
|
raise FileParseException(filename, line_num, message) |
|
|
|
label_name = token[len('__label__'):] |
|
|
|
labels.append(self.label_class.parse(label_name)) |
|
|
|
else: |
|
|
|