You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

266 lines
9.1 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. import csv
  2. import io
  3. import json
  4. import os
  5. from typing import Dict, Iterator, List, Optional, Type
  6. import chardet
  7. import pydantic.error_wrappers
  8. import pyexcel
  9. from chardet.universaldetector import UniversalDetector
  10. from seqeval.scheme import BILOU, IOB2, IOBES, IOE2, Tokens
  11. from .data import BaseData
  12. from .exception import FileParseException
  13. from .label import Label
  14. from .labels import Labels
  15. class Record:
  16. def __init__(self,
  17. data: Type[BaseData],
  18. label: List[Label] = None):
  19. if label is None:
  20. label = []
  21. self._data = data
  22. self._label = label
  23. def __str__(self):
  24. return f'{self._data}\t{self._label}'
  25. @property
  26. def data(self):
  27. return self._data.dict()
  28. def annotation(self, mapping: Dict[str, int]):
  29. labels = Labels(self._label)
  30. labels = labels.replace_label(mapping)
  31. return labels.dict()
  32. @property
  33. def label(self):
  34. return [
  35. {
  36. 'text': label.name
  37. } for label in self._label
  38. if label.has_name() and label.name
  39. ]
  40. class Dataset:
  41. def __init__(self,
  42. filenames: List[str],
  43. data_class: Type[BaseData],
  44. label_class: Type[Label],
  45. encoding: Optional[str] = None,
  46. **kwargs):
  47. self.filenames = filenames
  48. self.data_class = data_class
  49. self.label_class = label_class
  50. self.encoding = encoding
  51. self.kwargs = kwargs
  52. def __iter__(self) -> Iterator[Record]:
  53. for filename in self.filenames:
  54. try:
  55. yield from self.load(filename)
  56. except UnicodeDecodeError as err:
  57. message = str(err)
  58. raise FileParseException(filename, line_num=-1, message=message)
  59. def load(self, filename: str) -> Iterator[Record]:
  60. """Loads a file content."""
  61. encoding = self.detect_encoding(filename)
  62. with open(filename, encoding=encoding) as f:
  63. data = self.data_class.parse(filename=filename, text=f.read())
  64. record = Record(data=data)
  65. yield record
  66. def detect_encoding(self, filename: str, buffer_size=io.DEFAULT_BUFFER_SIZE):
  67. if self.encoding != 'Auto':
  68. return self.encoding
  69. # For a small file.
  70. if os.path.getsize(filename) < buffer_size:
  71. detected = chardet.detect(open(filename, 'rb').read())
  72. return detected.get('encoding', 'utf-8')
  73. # For a large file.
  74. with open(filename, 'rb') as f:
  75. detector = UniversalDetector()
  76. while True:
  77. binary = f.read(buffer_size)
  78. detector.feed(binary)
  79. if binary == b'':
  80. break
  81. if detector.done:
  82. break
  83. if detector.done:
  84. return detector.result['encoding']
  85. else:
  86. return 'utf-8'
  87. def from_row(self, filename: str, row: Dict, line_num: int) -> Record:
  88. column_data = self.kwargs.get('column_data', 'text')
  89. if column_data not in row:
  90. message = f'{column_data} does not exist.'
  91. raise FileParseException(filename, line_num, message)
  92. text = row.pop(column_data)
  93. label = row.pop(self.kwargs.get('column_label', 'label'), [])
  94. label = [label] if isinstance(label, str) else label
  95. try:
  96. label = [self.label_class.parse(o) for o in label]
  97. except pydantic.error_wrappers.ValidationError:
  98. label = []
  99. data = self.data_class.parse(text=text, filename=filename, meta=row)
  100. record = Record(data=data, label=label)
  101. return record
  102. class FileBaseDataset(Dataset):
  103. def load(self, filename: str) -> Iterator[Record]:
  104. data = self.data_class.parse(filename=filename)
  105. record = Record(data=data)
  106. yield record
  107. class TextFileDataset(Dataset):
  108. def load(self, filename: str) -> Iterator[Record]:
  109. encoding = self.detect_encoding(filename)
  110. with open(filename, encoding=encoding) as f:
  111. data = self.data_class.parse(filename=filename, text=f.read())
  112. record = Record(data=data)
  113. yield record
  114. class TextLineDataset(Dataset):
  115. def load(self, filename: str) -> Iterator[Record]:
  116. encoding = self.detect_encoding(filename)
  117. with open(filename, encoding=encoding) as f:
  118. for line in f:
  119. data = self.data_class.parse(filename=filename, text=line.rstrip())
  120. record = Record(data=data)
  121. yield record
  122. class CsvDataset(Dataset):
  123. def load(self, filename: str) -> Iterator[Record]:
  124. encoding = self.detect_encoding(filename)
  125. with open(filename, encoding=encoding) as f:
  126. delimiter = self.kwargs.get('delimiter', ',')
  127. reader = csv.reader(f, delimiter=delimiter)
  128. header = next(reader)
  129. column_data = self.kwargs.get('column_data', 'text')
  130. if column_data not in header:
  131. message = f'Column `{column_data}` does not exist in the header: {header}'
  132. raise FileParseException(filename, 1, message)
  133. for line_num, row in enumerate(reader, start=2):
  134. row = dict(zip(header, row))
  135. yield self.from_row(filename, row, line_num)
  136. class JSONDataset(Dataset):
  137. def load(self, filename: str) -> Iterator[Record]:
  138. encoding = self.detect_encoding(filename)
  139. with open(filename, encoding=encoding) as f:
  140. dataset = json.load(f)
  141. for line_num, row in enumerate(dataset, start=1):
  142. yield self.from_row(filename, row, line_num)
  143. class JSONLDataset(Dataset):
  144. def load(self, filename: str) -> Iterator[Record]:
  145. encoding = self.detect_encoding(filename)
  146. with open(filename, encoding=encoding) as f:
  147. for line_num, line in enumerate(f, start=1):
  148. row = json.loads(line)
  149. yield self.from_row(filename, row, line_num)
  150. class ExcelDataset(Dataset):
  151. def load(self, filename: str) -> Iterator[Record]:
  152. records = pyexcel.iget_records(file_name=filename)
  153. for line_num, row in enumerate(records, start=1):
  154. yield self.from_row(filename, row, line_num)
  155. class FastTextDataset(Dataset):
  156. def load(self, filename: str) -> Iterator[Record]:
  157. encoding = self.detect_encoding(filename)
  158. with open(filename, encoding=encoding) as f:
  159. for line_num, line in enumerate(f, start=1):
  160. labels = []
  161. tokens = []
  162. for token in line.rstrip().split(' '):
  163. if token.startswith('__label__'):
  164. if token == '__label__':
  165. message = 'Label name is empty.'
  166. raise FileParseException(filename, line_num, message)
  167. label_name = token[len('__label__'):]
  168. labels.append(self.label_class.parse(label_name))
  169. else:
  170. tokens.append(token)
  171. text = ' '.join(tokens)
  172. data = self.data_class.parse(filename=filename, text=text)
  173. record = Record(data=data, label=labels)
  174. yield record
  175. class CoNLLDataset(Dataset):
  176. def load(self, filename: str) -> Iterator[Record]:
  177. encoding = self.detect_encoding(filename)
  178. with open(filename, encoding=encoding) as f:
  179. words, tags = [], []
  180. delimiter = self.kwargs.get('delimiter', ' ')
  181. for line_num, line in enumerate(f, start=1):
  182. line = line.rstrip()
  183. if line:
  184. tokens = line.split('\t')
  185. if len(tokens) != 2:
  186. message = 'A line must be separated by tab and has two columns.'
  187. raise FileParseException(filename, line_num, message)
  188. word, tag = tokens
  189. words.append(word)
  190. tags.append(tag)
  191. else:
  192. text = delimiter.join(words)
  193. data = self.data_class.parse(filename=filename, text=text)
  194. labels = self.get_label(words, tags, delimiter)
  195. record = Record(data=data, label=labels)
  196. yield record
  197. words, tags = [], []
  198. def get_scheme(self, scheme: str):
  199. mapping = {
  200. 'IOB2': IOB2,
  201. 'IOE2': IOE2,
  202. 'IOBES': IOBES,
  203. 'BILOU': BILOU
  204. }
  205. return mapping[scheme]
  206. def get_label(self, words: List[str], tags: List[str], delimiter: str) -> List[Label]:
  207. scheme = self.get_scheme(self.kwargs.get('scheme', 'IOB2'))
  208. tokens = Tokens(tags, scheme)
  209. labels = []
  210. for entity in tokens.entities:
  211. text = delimiter.join(words[:entity.start])
  212. start = len(text) + len(delimiter) if text else len(text)
  213. chunk = words[entity.start: entity.end]
  214. text = delimiter.join(chunk)
  215. end = start + len(text)
  216. labels.append(self.label_class.parse((start, end, entity.tag)))
  217. return labels