You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

255 lines
8.8 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. import csv
  2. import io
  3. import json
  4. from typing import Dict, Iterator, List, Optional, Type
  5. import pydantic.error_wrappers
  6. import pyexcel
  7. from chardet.universaldetector import UniversalDetector
  8. from seqeval.scheme import BILOU, IOB2, IOBES, IOE2, Tokens
  9. from .data import BaseData
  10. from .exception import FileParseException
  11. from .label import Label
  12. from .labels import Labels
  13. class Record:
  14. def __init__(self,
  15. data: Type[BaseData],
  16. label: List[Label] = None):
  17. if label is None:
  18. label = []
  19. self._data = data
  20. self._label = label
  21. def __str__(self):
  22. return f'{self._data}\t{self._label}'
  23. @property
  24. def data(self):
  25. return self._data.dict()
  26. def annotation(self, mapping: Dict[str, int]):
  27. labels = Labels(self._label)
  28. labels = labels.replace_label(mapping)
  29. return labels.dict()
  30. @property
  31. def label(self):
  32. return [
  33. {
  34. 'text': label.name
  35. } for label in self._label
  36. if label.has_name() and label.name
  37. ]
  38. class Dataset:
  39. def __init__(self,
  40. filenames: List[str],
  41. data_class: Type[BaseData],
  42. label_class: Type[Label],
  43. encoding: Optional[str] = None,
  44. **kwargs):
  45. self.filenames = filenames
  46. self.data_class = data_class
  47. self.label_class = label_class
  48. self.encoding = encoding
  49. self.kwargs = kwargs
  50. def __iter__(self) -> Iterator[Record]:
  51. for filename in self.filenames:
  52. try:
  53. yield from self.load(filename)
  54. except UnicodeDecodeError as err:
  55. message = str(err)
  56. raise FileParseException(filename, line_num=-1, message=message)
  57. def load(self, filename: str) -> Iterator[Record]:
  58. """Loads a file content."""
  59. encoding = self.detect_encoding(filename)
  60. with open(filename, encoding=encoding) as f:
  61. data = self.data_class.parse(filename=filename, text=f.read())
  62. record = Record(data=data)
  63. yield record
  64. def detect_encoding(self, filename: str, buffer_size=io.DEFAULT_BUFFER_SIZE):
  65. if self.encoding != 'Auto':
  66. return self.encoding
  67. with open(filename, 'rb') as f:
  68. detector = UniversalDetector()
  69. while True:
  70. read = f.read(buffer_size)
  71. detector.feed(read)
  72. if detector.done:
  73. break
  74. if detector.done:
  75. return detector.result['encoding']
  76. else:
  77. return 'utf-8'
  78. def from_row(self, filename: str, row: Dict, line_num: int) -> Record:
  79. column_data = self.kwargs.get('column_data', 'text')
  80. if column_data not in row:
  81. message = f'{column_data} does not exist.'
  82. raise FileParseException(filename, line_num, message)
  83. text = row.pop(column_data)
  84. label = row.pop(self.kwargs.get('column_label', 'label'), [])
  85. label = [label] if isinstance(label, str) else label
  86. try:
  87. label = [self.label_class.parse(o) for o in label]
  88. except pydantic.error_wrappers.ValidationError:
  89. label = []
  90. data = self.data_class.parse(text=text, filename=filename, meta=row)
  91. record = Record(data=data, label=label)
  92. return record
  93. class FileBaseDataset(Dataset):
  94. def load(self, filename: str) -> Iterator[Record]:
  95. data = self.data_class.parse(filename=filename)
  96. record = Record(data=data)
  97. yield record
  98. class TextFileDataset(Dataset):
  99. def load(self, filename: str) -> Iterator[Record]:
  100. encoding = self.detect_encoding(filename)
  101. with open(filename, encoding=encoding) as f:
  102. data = self.data_class.parse(filename=filename, text=f.read())
  103. record = Record(data=data)
  104. yield record
  105. class TextLineDataset(Dataset):
  106. def load(self, filename: str) -> Iterator[Record]:
  107. encoding = self.detect_encoding(filename)
  108. with open(filename, encoding=encoding) as f:
  109. for line in f:
  110. data = self.data_class.parse(filename=filename, text=line.rstrip())
  111. record = Record(data=data)
  112. yield record
  113. class CsvDataset(Dataset):
  114. def load(self, filename: str) -> Iterator[Record]:
  115. encoding = self.detect_encoding(filename)
  116. with open(filename, encoding=encoding) as f:
  117. delimiter = self.kwargs.get('delimiter', ',')
  118. reader = csv.reader(f, delimiter=delimiter)
  119. header = next(reader)
  120. column_data = self.kwargs.get('column_data', 'text')
  121. if column_data not in header:
  122. message = f'Column `{column_data}` does not exist in the header: {header}'
  123. raise FileParseException(filename, 1, message)
  124. for line_num, row in enumerate(reader, start=2):
  125. row = dict(zip(header, row))
  126. yield self.from_row(filename, row, line_num)
  127. class JSONDataset(Dataset):
  128. def load(self, filename: str) -> Iterator[Record]:
  129. encoding = self.detect_encoding(filename)
  130. with open(filename, encoding=encoding) as f:
  131. dataset = json.load(f)
  132. for line_num, row in enumerate(dataset, start=1):
  133. yield self.from_row(filename, row, line_num)
  134. class JSONLDataset(Dataset):
  135. def load(self, filename: str) -> Iterator[Record]:
  136. encoding = self.detect_encoding(filename)
  137. with open(filename, encoding=encoding) as f:
  138. for line_num, line in enumerate(f, start=1):
  139. row = json.loads(line)
  140. yield self.from_row(filename, row, line_num)
  141. class ExcelDataset(Dataset):
  142. def load(self, filename: str) -> Iterator[Record]:
  143. records = pyexcel.iget_records(file_name=filename)
  144. for line_num, row in enumerate(records, start=1):
  145. yield self.from_row(filename, row, line_num)
  146. class FastTextDataset(Dataset):
  147. def load(self, filename: str) -> Iterator[Record]:
  148. encoding = self.detect_encoding(filename)
  149. with open(filename, encoding=encoding) as f:
  150. for line_num, line in enumerate(f, start=1):
  151. labels = []
  152. tokens = []
  153. for token in line.rstrip().split(' '):
  154. if token.startswith('__label__'):
  155. if token == '__label__':
  156. message = 'Label name is empty.'
  157. raise FileParseException(filename, line_num, message)
  158. label_name = token[len('__label__'):]
  159. labels.append(self.label_class.parse(label_name))
  160. else:
  161. tokens.append(token)
  162. text = ' '.join(tokens)
  163. data = self.data_class.parse(filename=filename, text=text)
  164. record = Record(data=data, label=labels)
  165. yield record
  166. class CoNLLDataset(Dataset):
  167. def load(self, filename: str) -> Iterator[Record]:
  168. encoding = self.detect_encoding(filename)
  169. with open(filename, encoding=encoding) as f:
  170. words, tags = [], []
  171. delimiter = self.kwargs.get('delimiter', ' ')
  172. for line_num, line in enumerate(f, start=1):
  173. line = line.rstrip()
  174. if line:
  175. tokens = line.split('\t')
  176. if len(tokens) != 2:
  177. message = 'A line must be separated by tab and has two columns.'
  178. raise FileParseException(filename, line_num, message)
  179. word, tag = tokens
  180. words.append(word)
  181. tags.append(tag)
  182. else:
  183. text = delimiter.join(words)
  184. data = self.data_class.parse(filename=filename, text=text)
  185. labels = self.get_label(words, tags, delimiter)
  186. record = Record(data=data, label=labels)
  187. yield record
  188. words, tags = [], []
  189. def get_scheme(self, scheme: str):
  190. mapping = {
  191. 'IOB2': IOB2,
  192. 'IOE2': IOE2,
  193. 'IOBES': IOBES,
  194. 'BILOU': BILOU
  195. }
  196. return mapping[scheme]
  197. def get_label(self, words: List[str], tags: List[str], delimiter: str) -> List[Label]:
  198. scheme = self.get_scheme(self.kwargs.get('scheme', 'IOB2'))
  199. tokens = Tokens(tags, scheme)
  200. labels = []
  201. for entity in tokens.entities:
  202. text = delimiter.join(words[:entity.start])
  203. start = len(text) + len(delimiter) if text else len(text)
  204. chunk = words[entity.start: entity.end]
  205. text = delimiter.join(chunk)
  206. end = start + len(text)
  207. labels.append(self.label_class.parse((start, end, entity.tag)))
  208. return labels