You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

339 lines
12 KiB

import csv
import io
import json
import os
from typing import Dict, Iterator, List, Optional, Type
import chardet
import pyexcel
import pyexcel.exceptions
from chardet.universaldetector import UniversalDetector
from pydantic import ValidationError
from seqeval.scheme import BILOU, IOB2, IOBES, IOE2, Tokens
from .cleaners import Cleaner
from .data import BaseData
from .exception import FileParseException, FileParseExceptions
from .label import Label
from .labels import Labels
class Record:
def __init__(self,
data: Type[BaseData],
label: List[Label] = None,
line_num: int = -1):
if label is None:
label = []
self._data = data
self._label = label
self._line_num = line_num
def __str__(self):
return f'{self._data}\t{self._label}'
def clean(self, cleaner: Cleaner):
label = cleaner.clean(self._label)
changed = len(label) != len(self.label)
self._label = label
if changed:
raise FileParseException(
filename=self._data.filename,
line_num=self._line_num,
message=cleaner.message
)
@property
def data(self):
return self._data.dict()
def annotation(self, mapping: Dict[str, int]):
labels = Labels(self._label)
labels = labels.replace_label(mapping)
return labels.dict()
@property
def label(self):
return [
{
'text': label.name
} for label in self._label
if label.has_name() and label.name
]
class Dataset:
def __init__(self,
filenames: List[str],
data_class: Type[BaseData],
label_class: Type[Label],
encoding: Optional[str] = None,
**kwargs):
self.filenames = filenames
self.data_class = data_class
self.label_class = label_class
self.encoding = encoding
self.kwargs = kwargs
def __iter__(self) -> Iterator[Record]:
errors = []
for filename in self.filenames:
try:
yield from self.load(filename)
except (UnicodeDecodeError, FileParseException) as err:
message = str(err)
raise FileParseException(filename, line_num=-1, message=message)
except FileParseExceptions as err:
errors.extend(err.exceptions)
if errors:
raise FileParseExceptions(errors)
def load(self, filename: str) -> Iterator[Record]:
"""Loads a file content."""
encoding = self.detect_encoding(filename)
with open(filename, encoding=encoding) as f:
data = self.data_class.parse(filename=filename, text=f.read())
record = Record(data=data)
yield record
def detect_encoding(self, filename: str, buffer_size=io.DEFAULT_BUFFER_SIZE):
if self.encoding != 'Auto':
return self.encoding
# For a small file.
if os.path.getsize(filename) < buffer_size:
detected = chardet.detect(open(filename, 'rb').read())
return detected.get('encoding', 'utf-8')
# For a large file.
with open(filename, 'rb') as f:
detector = UniversalDetector()
while True:
binary = f.read(buffer_size)
detector.feed(binary)
if binary == b'':
break
if detector.done:
break
if detector.done:
return detector.result['encoding']
else:
return 'utf-8'
def from_row(self, filename: str, row: Dict, line_num: int) -> Record:
column_data = self.kwargs.get('column_data', 'text')
if column_data not in row:
message = f'{column_data} does not exist.'
raise FileParseException(filename, line_num, message)
text = row.pop(column_data)
label = row.pop(self.kwargs.get('column_label', 'label'), [])
label = [label] if isinstance(label, str) else label
try:
label = [self.label_class.parse(o) for o in label]
except (ValidationError, TypeError):
label = []
try:
data = self.data_class.parse(text=text, filename=filename, meta=row)
except ValidationError:
message = 'The empty text is not allowed.'
raise FileParseException(filename, line_num, message)
record = Record(data=data, label=label, line_num=line_num)
return record
class FileBaseDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
data = self.data_class.parse(filename=filename)
record = Record(data=data)
yield record
class TextFileDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
encoding = self.detect_encoding(filename)
with open(filename, encoding=encoding) as f:
data = self.data_class.parse(filename=filename, text=f.read())
record = Record(data=data)
yield record
class TextLineDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
encoding = self.detect_encoding(filename)
errors = []
with open(filename, encoding=encoding) as f:
for line_num, line in enumerate(f, start=1):
try:
data = self.data_class.parse(filename=filename, text=line.rstrip())
record = Record(data=data, line_num=line_num)
yield record
except ValidationError:
message = 'The empty text is not allowed.'
errors.append(FileParseException(filename, line_num, message))
if errors:
raise FileParseExceptions(errors)
class CsvDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
encoding = self.detect_encoding(filename)
errors = []
with open(filename, encoding=encoding) as f:
delimiter = self.kwargs.get('delimiter', ',')
reader = csv.reader(f, delimiter=delimiter)
header = next(reader)
column_data = self.kwargs.get('column_data', 'text')
if column_data not in header:
message = f'Column `{column_data}` does not exist in the header: {header}'
raise FileParseException(filename, 1, message)
for line_num, row in enumerate(reader, start=2):
row = dict(zip(header, row))
try:
yield self.from_row(filename, row, line_num)
except FileParseException as err:
errors.append(err)
if errors:
raise FileParseExceptions(errors)
class JSONDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
encoding = self.detect_encoding(filename)
with open(filename, encoding=encoding) as f:
try:
dataset = json.load(f)
for line_num, row in enumerate(dataset, start=1):
yield self.from_row(filename, row, line_num)
except json.decoder.JSONDecodeError:
message = 'Failed to decode the json file.'
raise FileParseException(filename, line_num=-1, message=message)
class JSONLDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
encoding = self.detect_encoding(filename)
errors = []
with open(filename, encoding=encoding) as f:
for line_num, line in enumerate(f, start=1):
try:
row = json.loads(line)
yield self.from_row(filename, row, line_num)
except json.decoder.JSONDecodeError:
message = 'Failed to decode the line.'
errors.append(FileParseException(filename, line_num, message))
if errors:
raise FileParseExceptions(errors)
class ExcelDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
records = pyexcel.iget_records(file_name=filename)
errors = []
try:
for line_num, row in enumerate(records, start=1):
try:
yield self.from_row(filename, row, line_num)
except FileParseException as err:
errors.append(err)
except pyexcel.exceptions.FileTypeNotSupported:
message = 'This file type is not supported.'
raise FileParseException(filename, line_num=-1, message=message)
if errors:
raise FileParseExceptions(errors)
class FastTextDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
encoding = self.detect_encoding(filename)
errors = []
with open(filename, encoding=encoding) as f:
for line_num, line in enumerate(f, start=1):
labels = []
tokens = []
for token in line.rstrip().split(' '):
if token.startswith('__label__'):
if token == '__label__':
message = 'Label name is empty.'
errors.append(FileParseException(filename, line_num, message))
break
label_name = token[len('__label__'):]
labels.append(self.label_class.parse(label_name))
else:
tokens.append(token)
text = ' '.join(tokens)
try:
data = self.data_class.parse(filename=filename, text=text)
record = Record(data=data, label=labels, line_num=line_num)
yield record
except ValidationError:
message = 'The empty text is not allowed.'
errors.append(FileParseException(filename, line_num, message))
if errors:
raise FileParseExceptions(errors)
class CoNLLDataset(Dataset):
def load(self, filename: str) -> Iterator[Record]:
encoding = self.detect_encoding(filename)
with open(filename, encoding=encoding) as f:
words, tags = [], []
for line_num, line in enumerate(f, start=1):
line = line.rstrip()
if line:
tokens = line.split('\t')
if len(tokens) != 2:
message = 'A line must be separated by tab and has two columns.'
raise FileParseException(filename, line_num, message)
word, tag = tokens
words.append(word)
tags.append(tag)
else:
yield self.create_record(filename, tags, words)
words, tags = [], []
if words:
yield self.create_record(filename, tags, words)
def create_record(self, filename, tags, words):
delimiter = self.kwargs.get('delimiter', ' ')
text = delimiter.join(words)
data = self.data_class.parse(filename=filename, text=text)
labels = self.get_label(words, tags, delimiter)
record = Record(data=data, label=labels)
return record
def get_scheme(self, scheme: str):
mapping = {
'IOB2': IOB2,
'IOE2': IOE2,
'IOBES': IOBES,
'BILOU': BILOU
}
return mapping[scheme]
def get_label(self, words: List[str], tags: List[str], delimiter: str) -> List[Label]:
scheme = self.get_scheme(self.kwargs.get('scheme', 'IOB2'))
tokens = Tokens(tags, scheme)
labels = []
for entity in tokens.entities:
text = delimiter.join(words[:entity.start])
start = len(text) + len(delimiter) if text else len(text)
chunk = words[entity.start: entity.end]
text = delimiter.join(chunk)
end = start + len(text)
labels.append(self.label_class.parse((start, end, entity.tag)))
return labels