You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

285 lines
10 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. import csv
  2. import io
  3. import json
  4. import os
  5. from typing import Dict, Iterator, List, Optional, Type
  6. import chardet
  7. import pydantic.error_wrappers
  8. import pyexcel
  9. import pyexcel.exceptions
  10. from chardet.universaldetector import UniversalDetector
  11. from seqeval.scheme import BILOU, IOB2, IOBES, IOE2, Tokens
  12. from .data import BaseData
  13. from .exception import FileParseException
  14. from .label import Label
  15. from .labels import Labels
  16. class Record:
  17. def __init__(self,
  18. data: Type[BaseData],
  19. label: List[Label] = None):
  20. if label is None:
  21. label = []
  22. self._data = data
  23. self._label = label
  24. def __str__(self):
  25. return f'{self._data}\t{self._label}'
  26. @property
  27. def data(self):
  28. return self._data.dict()
  29. def annotation(self, mapping: Dict[str, int]):
  30. labels = Labels(self._label)
  31. labels = labels.replace_label(mapping)
  32. return labels.dict()
  33. @property
  34. def label(self):
  35. return [
  36. {
  37. 'text': label.name
  38. } for label in self._label
  39. if label.has_name() and label.name
  40. ]
  41. class Dataset:
  42. def __init__(self,
  43. filenames: List[str],
  44. data_class: Type[BaseData],
  45. label_class: Type[Label],
  46. encoding: Optional[str] = None,
  47. **kwargs):
  48. self.filenames = filenames
  49. self.data_class = data_class
  50. self.label_class = label_class
  51. self.encoding = encoding
  52. self.kwargs = kwargs
  53. def __iter__(self) -> Iterator[Record]:
  54. for filename in self.filenames:
  55. try:
  56. yield from self.load(filename)
  57. except UnicodeDecodeError as err:
  58. message = str(err)
  59. raise FileParseException(filename, line_num=-1, message=message)
  60. def load(self, filename: str) -> Iterator[Record]:
  61. """Loads a file content."""
  62. encoding = self.detect_encoding(filename)
  63. with open(filename, encoding=encoding) as f:
  64. data = self.data_class.parse(filename=filename, text=f.read())
  65. record = Record(data=data)
  66. yield record
  67. def detect_encoding(self, filename: str, buffer_size=io.DEFAULT_BUFFER_SIZE):
  68. if self.encoding != 'Auto':
  69. return self.encoding
  70. # For a small file.
  71. if os.path.getsize(filename) < buffer_size:
  72. detected = chardet.detect(open(filename, 'rb').read())
  73. return detected.get('encoding', 'utf-8')
  74. # For a large file.
  75. with open(filename, 'rb') as f:
  76. detector = UniversalDetector()
  77. while True:
  78. binary = f.read(buffer_size)
  79. detector.feed(binary)
  80. if binary == b'':
  81. break
  82. if detector.done:
  83. break
  84. if detector.done:
  85. return detector.result['encoding']
  86. else:
  87. return 'utf-8'
  88. def from_row(self, filename: str, row: Dict, line_num: int) -> Record:
  89. column_data = self.kwargs.get('column_data', 'text')
  90. if column_data not in row:
  91. message = f'{column_data} does not exist.'
  92. raise FileParseException(filename, line_num, message)
  93. text = row.pop(column_data)
  94. label = row.pop(self.kwargs.get('column_label', 'label'), [])
  95. label = [label] if isinstance(label, str) else label
  96. try:
  97. label = [self.label_class.parse(o) for o in label]
  98. except (pydantic.error_wrappers.ValidationError, TypeError):
  99. label = []
  100. data = self.data_class.parse(text=text, filename=filename, meta=row)
  101. record = Record(data=data, label=label)
  102. return record
  103. class FileBaseDataset(Dataset):
  104. def load(self, filename: str) -> Iterator[Record]:
  105. data = self.data_class.parse(filename=filename)
  106. record = Record(data=data)
  107. yield record
  108. class TextFileDataset(Dataset):
  109. def load(self, filename: str) -> Iterator[Record]:
  110. encoding = self.detect_encoding(filename)
  111. with open(filename, encoding=encoding) as f:
  112. data = self.data_class.parse(filename=filename, text=f.read())
  113. record = Record(data=data)
  114. yield record
  115. class TextLineDataset(Dataset):
  116. def load(self, filename: str) -> Iterator[Record]:
  117. encoding = self.detect_encoding(filename)
  118. with open(filename, encoding=encoding) as f:
  119. for line in f:
  120. data = self.data_class.parse(filename=filename, text=line.rstrip())
  121. record = Record(data=data)
  122. yield record
  123. class CsvDataset(Dataset):
  124. def load(self, filename: str) -> Iterator[Record]:
  125. encoding = self.detect_encoding(filename)
  126. with open(filename, encoding=encoding) as f:
  127. delimiter = self.kwargs.get('delimiter', ',')
  128. reader = csv.reader(f, delimiter=delimiter)
  129. header = next(reader)
  130. column_data = self.kwargs.get('column_data', 'text')
  131. if column_data not in header:
  132. message = f'Column `{column_data}` does not exist in the header: {header}'
  133. raise FileParseException(filename, 1, message)
  134. for line_num, row in enumerate(reader, start=2):
  135. row = dict(zip(header, row))
  136. yield self.from_row(filename, row, line_num)
  137. class JSONDataset(Dataset):
  138. def load(self, filename: str) -> Iterator[Record]:
  139. encoding = self.detect_encoding(filename)
  140. with open(filename, encoding=encoding) as f:
  141. try:
  142. dataset = json.load(f)
  143. for line_num, row in enumerate(dataset, start=1):
  144. yield self.from_row(filename, row, line_num)
  145. except json.decoder.JSONDecodeError:
  146. message = 'Failed to decode the json file.'
  147. raise FileParseException(filename, line_num=-1, message=message)
  148. class JSONLDataset(Dataset):
  149. def load(self, filename: str) -> Iterator[Record]:
  150. encoding = self.detect_encoding(filename)
  151. with open(filename, encoding=encoding) as f:
  152. for line_num, line in enumerate(f, start=1):
  153. try:
  154. row = json.loads(line)
  155. yield self.from_row(filename, row, line_num)
  156. except json.decoder.JSONDecodeError:
  157. message = 'Failed to decode the line.'
  158. raise FileParseException(filename, line_num, message)
  159. class ExcelDataset(Dataset):
  160. def load(self, filename: str) -> Iterator[Record]:
  161. records = pyexcel.iget_records(file_name=filename)
  162. try:
  163. for line_num, row in enumerate(records, start=1):
  164. yield self.from_row(filename, row, line_num)
  165. except pyexcel.exceptions.FileTypeNotSupported:
  166. message = 'This file type is not supported.'
  167. raise FileParseException(filename, line_num=-1, message=message)
  168. class FastTextDataset(Dataset):
  169. def load(self, filename: str) -> Iterator[Record]:
  170. encoding = self.detect_encoding(filename)
  171. with open(filename, encoding=encoding) as f:
  172. for line_num, line in enumerate(f, start=1):
  173. labels = []
  174. tokens = []
  175. for token in line.rstrip().split(' '):
  176. if token.startswith('__label__'):
  177. if token == '__label__':
  178. message = 'Label name is empty.'
  179. raise FileParseException(filename, line_num, message)
  180. label_name = token[len('__label__'):]
  181. labels.append(self.label_class.parse(label_name))
  182. else:
  183. tokens.append(token)
  184. text = ' '.join(tokens)
  185. data = self.data_class.parse(filename=filename, text=text)
  186. record = Record(data=data, label=labels)
  187. yield record
  188. class CoNLLDataset(Dataset):
  189. def load(self, filename: str) -> Iterator[Record]:
  190. encoding = self.detect_encoding(filename)
  191. with open(filename, encoding=encoding) as f:
  192. words, tags = [], []
  193. delimiter = self.kwargs.get('delimiter', ' ')
  194. for line_num, line in enumerate(f, start=1):
  195. line = line.rstrip()
  196. if line:
  197. tokens = line.split('\t')
  198. if len(tokens) != 2:
  199. message = 'A line must be separated by tab and has two columns.'
  200. raise FileParseException(filename, line_num, message)
  201. word, tag = tokens
  202. words.append(word)
  203. tags.append(tag)
  204. else:
  205. text = delimiter.join(words)
  206. data = self.data_class.parse(filename=filename, text=text)
  207. labels = self.get_label(words, tags, delimiter)
  208. record = Record(data=data, label=labels)
  209. yield record
  210. words, tags = [], []
  211. if words:
  212. text = delimiter.join(words)
  213. data = self.data_class.parse(filename=filename, text=text)
  214. labels = self.get_label(words, tags, delimiter)
  215. record = Record(data=data, label=labels)
  216. yield record
  217. def get_scheme(self, scheme: str):
  218. mapping = {
  219. 'IOB2': IOB2,
  220. 'IOE2': IOE2,
  221. 'IOBES': IOBES,
  222. 'BILOU': BILOU
  223. }
  224. return mapping[scheme]
  225. def get_label(self, words: List[str], tags: List[str], delimiter: str) -> List[Label]:
  226. scheme = self.get_scheme(self.kwargs.get('scheme', 'IOB2'))
  227. tokens = Tokens(tags, scheme)
  228. labels = []
  229. for entity in tokens.entities:
  230. text = delimiter.join(words[:entity.start])
  231. start = len(text) + len(delimiter) if text else len(text)
  232. chunk = words[entity.start: entity.end]
  233. text = delimiter.join(chunk)
  234. end = start + len(text)
  235. labels.append(self.label_class.parse((start, end, entity.tag)))
  236. return labels