You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

509 lines
16 KiB

5 years ago
5 years ago
5 years ago
  1. import csv
  2. import io
  3. import itertools
  4. import json
  5. import re
  6. from collections import defaultdict
  7. import conllu
  8. from chardet import UniversalDetector
  9. from django.db import transaction
  10. from django.conf import settings
  11. from colour import Color
  12. import pyexcel
  13. from rest_framework.renderers import JSONRenderer
  14. from seqeval.metrics.sequence_labeling import get_entities
  15. from .exceptions import FileParseException
  16. from .models import Label
  17. from .serializers import DocumentSerializer, LabelSerializer
  18. def extract_label(tag):
  19. ptn = re.compile(r'(B|I|E|S)-(.+)')
  20. m = ptn.match(tag)
  21. if m:
  22. return m.groups()[1]
  23. else:
  24. return tag
  25. class BaseStorage(object):
  26. def __init__(self, data, project):
  27. self.data = data
  28. self.project = project
  29. @transaction.atomic
  30. def save(self, user):
  31. raise NotImplementedError()
  32. def save_doc(self, data):
  33. serializer = DocumentSerializer(data=data, many=True)
  34. serializer.is_valid(raise_exception=True)
  35. doc = serializer.save(project=self.project)
  36. return doc
  37. def save_label(self, data):
  38. serializer = LabelSerializer(data=data, many=True)
  39. serializer.is_valid(raise_exception=True)
  40. label = serializer.save(project=self.project)
  41. return label
  42. def save_annotation(self, data, user):
  43. annotation_serializer = self.project.get_annotation_serializer()
  44. serializer = annotation_serializer(data=data, many=True)
  45. serializer.is_valid(raise_exception=True)
  46. annotation = serializer.save(user=user)
  47. return annotation
  48. @classmethod
  49. def extract_label(cls, data):
  50. return [d.get('labels', []) for d in data]
  51. @classmethod
  52. def exclude_created_labels(cls, labels, created):
  53. return [label for label in labels if label not in created]
  54. @classmethod
  55. def to_serializer_format(cls, labels, created):
  56. existing_shortkeys = {(label.suffix_key, label.prefix_key)
  57. for label in created.values()}
  58. serializer_labels = []
  59. for label in sorted(labels):
  60. serializer_label = {'text': label}
  61. shortkey = cls.get_shortkey(label, existing_shortkeys)
  62. if shortkey:
  63. serializer_label['suffix_key'] = shortkey[0]
  64. serializer_label['prefix_key'] = shortkey[1]
  65. existing_shortkeys.add(shortkey)
  66. background_color = Color(pick_for=label)
  67. text_color = Color('white') if background_color.get_luminance() < 0.5 else Color('black')
  68. serializer_label['background_color'] = background_color.hex
  69. serializer_label['text_color'] = text_color.hex
  70. serializer_labels.append(serializer_label)
  71. return serializer_labels
  72. @classmethod
  73. def get_shortkey(cls, label, existing_shortkeys):
  74. model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS]
  75. prefix_keys = [None] + model_prefix_keys
  76. model_suffix_keys = {key for (key, _) in Label.SUFFIX_KEYS}
  77. suffix_keys = [key for key in label.lower() if key in model_suffix_keys]
  78. for shortkey in itertools.product(suffix_keys, prefix_keys):
  79. if shortkey not in existing_shortkeys:
  80. return shortkey
  81. return None
  82. @classmethod
  83. def update_saved_labels(cls, saved, new):
  84. for label in new:
  85. saved[label.text] = label
  86. return saved
  87. class PlainStorage(BaseStorage):
  88. @transaction.atomic
  89. def save(self, user):
  90. for text in self.data:
  91. self.save_doc(text)
  92. class ClassificationStorage(BaseStorage):
  93. """Store json for text classification.
  94. The format is as follows:
  95. {"text": "Python is awesome!", "labels": ["positive"]}
  96. ...
  97. """
  98. @transaction.atomic
  99. def save(self, user):
  100. saved_labels = {label.text: label for label in self.project.labels.all()}
  101. for data in self.data:
  102. docs = self.save_doc(data)
  103. labels = self.extract_label(data)
  104. unique_labels = self.extract_unique_labels(labels)
  105. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  106. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  107. new_labels = self.save_label(unique_labels)
  108. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  109. annotations = self.make_annotations(docs, labels, saved_labels)
  110. self.save_annotation(annotations, user)
  111. @classmethod
  112. def extract_unique_labels(cls, labels):
  113. return set(itertools.chain(*labels))
  114. @classmethod
  115. def make_annotations(cls, docs, labels, saved_labels):
  116. annotations = []
  117. for doc, label in zip(docs, labels):
  118. for name in label:
  119. label = saved_labels[name]
  120. annotations.append({'document': doc.id, 'label': label.id})
  121. return annotations
  122. class SequenceLabelingStorage(BaseStorage):
  123. """Upload jsonl for sequence labeling.
  124. The format is as follows:
  125. {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
  126. ...
  127. """
  128. @transaction.atomic
  129. def save(self, user):
  130. saved_labels = {label.text: label for label in self.project.labels.all()}
  131. for data in self.data:
  132. docs = self.save_doc(data)
  133. labels = self.extract_label(data)
  134. unique_labels = self.extract_unique_labels(labels)
  135. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  136. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  137. new_labels = self.save_label(unique_labels)
  138. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  139. annotations = self.make_annotations(docs, labels, saved_labels)
  140. self.save_annotation(annotations, user)
  141. @classmethod
  142. def extract_unique_labels(cls, labels):
  143. return set([label for _, _, label in itertools.chain(*labels)])
  144. @classmethod
  145. def make_annotations(cls, docs, labels, saved_labels):
  146. annotations = []
  147. for doc, spans in zip(docs, labels):
  148. for span in spans:
  149. start_offset, end_offset, name = span
  150. label = saved_labels[name]
  151. annotations.append({'document': doc.id,
  152. 'label': label.id,
  153. 'start_offset': start_offset,
  154. 'end_offset': end_offset})
  155. return annotations
  156. class Seq2seqStorage(BaseStorage):
  157. """Store json for seq2seq.
  158. The format is as follows:
  159. {"text": "Hello, World!", "labels": ["こんにちは、世界!"]}
  160. ...
  161. """
  162. @transaction.atomic
  163. def save(self, user):
  164. for data in self.data:
  165. doc = self.save_doc(data)
  166. labels = self.extract_label(data)
  167. annotations = self.make_annotations(doc, labels)
  168. self.save_annotation(annotations, user)
  169. @classmethod
  170. def make_annotations(cls, docs, labels):
  171. annotations = []
  172. for doc, texts in zip(docs, labels):
  173. for text in texts:
  174. annotations.append({'document': doc.id, 'text': text})
  175. return annotations
  176. class FileParser(object):
  177. def parse(self, file):
  178. raise NotImplementedError()
  179. @staticmethod
  180. def encode_metadata(data):
  181. return json.dumps(data, ensure_ascii=False)
  182. class CoNLLParser(FileParser):
  183. """Uploads CoNLL format file.
  184. The file format is tab-separated values.
  185. A blank line is required at the end of a sentence.
  186. For example:
  187. ```
  188. EU B-ORG
  189. rejects O
  190. German B-MISC
  191. call O
  192. to O
  193. boycott O
  194. British B-MISC
  195. lamb O
  196. . O
  197. Peter B-PER
  198. Blackburn I-PER
  199. ...
  200. ```
  201. """
  202. def parse(self, file):
  203. data = []
  204. file = EncodedIO(file)
  205. file = io.TextIOWrapper(file, encoding=file.encoding)
  206. # Add check exception
  207. field_parsers = {
  208. "ne": lambda line, i: conllu.parser.parse_nullable_value(line[i]),
  209. }
  210. gen_parser = conllu.parse_incr(
  211. file,
  212. fields=("form", "ne"),
  213. field_parsers=field_parsers
  214. )
  215. try:
  216. for sentence in gen_parser:
  217. if not sentence:
  218. continue
  219. if len(data) >= settings.IMPORT_BATCH_SIZE:
  220. yield data
  221. data = []
  222. words, labels = [], []
  223. for item in sentence:
  224. word = item.get("form")
  225. tag = item.get("ne")
  226. if tag is not None:
  227. char_left = sum(map(len, words)) + len(words)
  228. char_right = char_left + len(word)
  229. span = [char_left, char_right, tag]
  230. labels.append(span)
  231. words.append(word)
  232. # Create and add JSONL
  233. data.append({'text': ' '.join(words), 'labels': labels})
  234. except conllu.parser.ParseException as e:
  235. raise FileParseException(line_num=-1, line=str(e))
  236. if data:
  237. yield data
  238. class PlainTextParser(FileParser):
  239. """Uploads plain text.
  240. The file format is as follows:
  241. ```
  242. EU rejects German call to boycott British lamb.
  243. President Obama is speaking at the White House.
  244. ...
  245. ```
  246. """
  247. def parse(self, file):
  248. file = EncodedIO(file)
  249. file = io.TextIOWrapper(file, encoding=file.encoding)
  250. while True:
  251. batch = list(itertools.islice(file, settings.IMPORT_BATCH_SIZE))
  252. if not batch:
  253. break
  254. yield [{'text': line.strip()} for line in batch]
  255. class CSVParser(FileParser):
  256. """Uploads csv file.
  257. The file format is comma separated values.
  258. Column names are required at the top of a file.
  259. For example:
  260. ```
  261. text, label
  262. "EU rejects German call to boycott British lamb.",Politics
  263. "President Obama is speaking at the White House.",Politics
  264. "He lives in Newark, Ohio.",Other
  265. ...
  266. ```
  267. """
  268. def parse(self, file):
  269. file = EncodedIO(file)
  270. file = io.TextIOWrapper(file, encoding=file.encoding)
  271. reader = csv.reader(file)
  272. yield from ExcelParser.parse_excel_csv_reader(reader)
  273. class ExcelParser(FileParser):
  274. def parse(self, file):
  275. excel_book = pyexcel.iget_book(file_type="xlsx", file_content=file.read())
  276. # Handle multiple sheets
  277. for sheet_name in excel_book.sheet_names():
  278. reader = excel_book[sheet_name].to_array()
  279. yield from self.parse_excel_csv_reader(reader)
  280. @staticmethod
  281. def parse_excel_csv_reader(reader):
  282. columns = next(reader)
  283. data = []
  284. if len(columns) == 1 and columns[0] != 'text':
  285. data.append({'text': columns[0]})
  286. for i, row in enumerate(reader, start=2):
  287. if len(data) >= settings.IMPORT_BATCH_SIZE:
  288. yield data
  289. data = []
  290. # Only text column
  291. if len(row) == len(columns) and len(row) == 1:
  292. data.append({'text': row[0]})
  293. # Text, labels and metadata columns
  294. elif len(row) == len(columns) and len(row) >= 2:
  295. datum = dict(zip(columns, row))
  296. text, label = datum.pop('text'), datum.pop('label')
  297. meta = FileParser.encode_metadata(datum)
  298. j = {'text': text, 'labels': [label], 'meta': meta}
  299. data.append(j)
  300. else:
  301. raise FileParseException(line_num=i, line=row)
  302. if data:
  303. yield data
  304. class JSONParser(FileParser):
  305. def parse(self, file):
  306. file = EncodedIO(file)
  307. file = io.TextIOWrapper(file, encoding=file.encoding)
  308. data = []
  309. for i, line in enumerate(file, start=1):
  310. if len(data) >= settings.IMPORT_BATCH_SIZE:
  311. yield data
  312. data = []
  313. try:
  314. j = json.loads(line)
  315. j['meta'] = FileParser.encode_metadata(j.get('meta', {}))
  316. data.append(j)
  317. except json.decoder.JSONDecodeError:
  318. raise FileParseException(line_num=i, line=line)
  319. if data:
  320. yield data
  321. class JSONLRenderer(JSONRenderer):
  322. def render(self, data, accepted_media_type=None, renderer_context=None):
  323. """
  324. Render `data` into JSON, returning a bytestring.
  325. """
  326. if data is None:
  327. return bytes()
  328. if not isinstance(data, list):
  329. data = [data]
  330. for d in data:
  331. yield json.dumps(d,
  332. cls=self.encoder_class,
  333. ensure_ascii=self.ensure_ascii,
  334. allow_nan=not self.strict) + '\n'
  335. class JSONPainter(object):
  336. def paint(self, documents):
  337. serializer = DocumentSerializer(documents, many=True)
  338. data = []
  339. for d in serializer.data:
  340. d['meta'] = json.loads(d['meta'])
  341. for a in d['annotations']:
  342. a.pop('id')
  343. a.pop('prob')
  344. a.pop('document')
  345. data.append(d)
  346. return data
  347. @staticmethod
  348. def paint_labels(documents, labels):
  349. serializer_labels = LabelSerializer(labels, many=True)
  350. serializer = DocumentSerializer(documents, many=True)
  351. data = []
  352. for d in serializer.data:
  353. labels = []
  354. for a in d['annotations']:
  355. label_obj = [x for x in serializer_labels.data if x['id'] == a['label']][0]
  356. label_text = label_obj['text']
  357. label_start = a['start_offset']
  358. label_end = a['end_offset']
  359. labels.append([label_start, label_end, label_text])
  360. d.pop('annotations')
  361. d['labels'] = labels
  362. d['meta'] = json.loads(d['meta'])
  363. data.append(d)
  364. return data
  365. class CSVPainter(JSONPainter):
  366. def paint(self, documents):
  367. data = super().paint(documents)
  368. res = []
  369. for d in data:
  370. annotations = d.pop('annotations')
  371. for a in annotations:
  372. res.append({**d, **a})
  373. return res
  374. def iterable_to_io(iterable, buffer_size=io.DEFAULT_BUFFER_SIZE):
  375. """See https://stackoverflow.com/a/20260030/3817588."""
  376. class IterStream(io.RawIOBase):
  377. def __init__(self):
  378. self.leftover = None
  379. def readable(self):
  380. return True
  381. def readinto(self, b):
  382. try:
  383. l = len(b) # We're supposed to return at most this much
  384. chunk = self.leftover or next(iterable)
  385. output, self.leftover = chunk[:l], chunk[l:]
  386. b[:len(output)] = output
  387. return len(output)
  388. except StopIteration:
  389. return 0 # indicate EOF
  390. return io.BufferedReader(IterStream(), buffer_size=buffer_size)
  391. class EncodedIO(io.RawIOBase):
  392. def __init__(self, fobj, buffer_size=io.DEFAULT_BUFFER_SIZE, default_encoding='utf-8'):
  393. buffer = b''
  394. detector = UniversalDetector()
  395. while True:
  396. read = fobj.read(buffer_size)
  397. detector.feed(read)
  398. buffer += read
  399. if detector.done or len(read) < buffer_size:
  400. break
  401. if detector.done:
  402. self.encoding = detector.result['encoding']
  403. else:
  404. self.encoding = default_encoding
  405. self._fobj = fobj
  406. self._buffer = buffer
  407. def readable(self):
  408. return self._fobj.readable()
  409. def readinto(self, b):
  410. l = len(b)
  411. chunk = self._buffer or self._fobj.read(l)
  412. output, self._buffer = chunk[:l], chunk[l:]
  413. b[:len(output)] = output
  414. return len(output)