You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

634 lines
20 KiB

5 years ago
5 years ago
5 years ago
  1. import base64
  2. import csv
  3. import io
  4. import itertools
  5. import json
  6. import mimetypes
  7. import re
  8. from collections import defaultdict
  9. import conllu
  10. import pyexcel
  11. from chardet import UniversalDetector
  12. from colour import Color
  13. from django.conf import settings
  14. from django.db import transaction
  15. from rest_framework.renderers import BaseRenderer, JSONRenderer
  16. from seqeval.metrics.sequence_labeling import get_entities
  17. from .exceptions import FileParseException
  18. from .models import Label
  19. from .serializers import DocumentSerializer, LabelSerializer
  20. def extract_label(tag):
  21. ptn = re.compile(r'(B|I|E|S)-(.+)')
  22. m = ptn.match(tag)
  23. if m:
  24. return m.groups()[1]
  25. else:
  26. return tag
  27. class BaseStorage(object):
  28. def __init__(self, data, project):
  29. self.data = data
  30. self.project = project
  31. @transaction.atomic
  32. def save(self, user):
  33. raise NotImplementedError()
  34. def save_doc(self, data):
  35. serializer = DocumentSerializer(data=data, many=True)
  36. serializer.is_valid(raise_exception=True)
  37. doc = serializer.save(project=self.project)
  38. return doc
  39. def save_label(self, data):
  40. serializer = LabelSerializer(data=data, many=True)
  41. serializer.is_valid(raise_exception=True)
  42. label = serializer.save(project=self.project)
  43. return label
  44. def save_annotation(self, data, user):
  45. annotation_serializer = self.project.get_annotation_serializer()
  46. serializer = annotation_serializer(data=data, many=True)
  47. serializer.is_valid(raise_exception=True)
  48. annotation = serializer.save(user=user)
  49. return annotation
  50. @classmethod
  51. def extract_label(cls, data):
  52. return [d.get('labels', []) for d in data]
  53. @classmethod
  54. def exclude_created_labels(cls, labels, created):
  55. return [label for label in labels if label not in created]
  56. @classmethod
  57. def to_serializer_format(cls, labels, created):
  58. existing_shortkeys = {(label.suffix_key, label.prefix_key)
  59. for label in created.values()}
  60. serializer_labels = []
  61. for label in sorted(labels):
  62. serializer_label = {'text': label}
  63. shortkey = cls.get_shortkey(label, existing_shortkeys)
  64. if shortkey:
  65. serializer_label['suffix_key'] = shortkey[0]
  66. serializer_label['prefix_key'] = shortkey[1]
  67. existing_shortkeys.add(shortkey)
  68. background_color = Color(pick_for=label)
  69. text_color = Color('white') if background_color.get_luminance() < 0.5 else Color('black')
  70. serializer_label['background_color'] = background_color.hex
  71. serializer_label['text_color'] = text_color.hex
  72. serializer_labels.append(serializer_label)
  73. return serializer_labels
  74. @classmethod
  75. def get_shortkey(cls, label, existing_shortkeys):
  76. model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS]
  77. prefix_keys = [None] + model_prefix_keys
  78. model_suffix_keys = {key for (key, _) in Label.SUFFIX_KEYS}
  79. suffix_keys = [key for key in label.lower() if key in model_suffix_keys]
  80. for shortkey in itertools.product(suffix_keys, prefix_keys):
  81. if shortkey not in existing_shortkeys:
  82. return shortkey
  83. return None
  84. @classmethod
  85. def update_saved_labels(cls, saved, new):
  86. for label in new:
  87. saved[label.text] = label
  88. return saved
  89. class PlainStorage(BaseStorage):
  90. @transaction.atomic
  91. def save(self, user):
  92. for text in self.data:
  93. self.save_doc(text)
  94. class ClassificationStorage(BaseStorage):
  95. """Store json for text classification.
  96. The format is as follows:
  97. {"text": "Python is awesome!", "labels": ["positive"]}
  98. ...
  99. """
  100. @transaction.atomic
  101. def save(self, user):
  102. saved_labels = {label.text: label for label in self.project.labels.all()}
  103. for data in self.data:
  104. docs = self.save_doc(data)
  105. labels = self.extract_label(data)
  106. unique_labels = self.extract_unique_labels(labels)
  107. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  108. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  109. new_labels = self.save_label(unique_labels)
  110. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  111. annotations = self.make_annotations(docs, labels, saved_labels)
  112. self.save_annotation(annotations, user)
  113. @classmethod
  114. def extract_unique_labels(cls, labels):
  115. return set(itertools.chain(*labels))
  116. @classmethod
  117. def make_annotations(cls, docs, labels, saved_labels):
  118. annotations = []
  119. for doc, label in zip(docs, labels):
  120. for name in label:
  121. label = saved_labels[name]
  122. annotations.append({'document': doc.id, 'label': label.id})
  123. return annotations
  124. class SequenceLabelingStorage(BaseStorage):
  125. """Upload jsonl for sequence labeling.
  126. The format is as follows:
  127. {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
  128. ...
  129. """
  130. @transaction.atomic
  131. def save(self, user):
  132. saved_labels = {label.text: label for label in self.project.labels.all()}
  133. for data in self.data:
  134. docs = self.save_doc(data)
  135. labels = self.extract_label(data)
  136. unique_labels = self.extract_unique_labels(labels)
  137. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  138. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  139. new_labels = self.save_label(unique_labels)
  140. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  141. annotations = self.make_annotations(docs, labels, saved_labels)
  142. self.save_annotation(annotations, user)
  143. @classmethod
  144. def extract_unique_labels(cls, labels):
  145. return set([label for _, _, label in itertools.chain(*labels)])
  146. @classmethod
  147. def make_annotations(cls, docs, labels, saved_labels):
  148. annotations = []
  149. for doc, spans in zip(docs, labels):
  150. for span in spans:
  151. start_offset, end_offset, name = span
  152. label = saved_labels[name]
  153. annotations.append({'document': doc.id,
  154. 'label': label.id,
  155. 'start_offset': start_offset,
  156. 'end_offset': end_offset})
  157. return annotations
  158. class Seq2seqStorage(BaseStorage):
  159. """Store json for seq2seq.
  160. The format is as follows:
  161. {"text": "Hello, World!", "labels": ["こんにちは、世界!"]}
  162. ...
  163. """
  164. @transaction.atomic
  165. def save(self, user):
  166. for data in self.data:
  167. doc = self.save_doc(data)
  168. labels = self.extract_label(data)
  169. annotations = self.make_annotations(doc, labels)
  170. self.save_annotation(annotations, user)
  171. @classmethod
  172. def make_annotations(cls, docs, labels):
  173. annotations = []
  174. for doc, texts in zip(docs, labels):
  175. for text in texts:
  176. annotations.append({'document': doc.id, 'text': text})
  177. return annotations
  178. class Speech2textStorage(BaseStorage):
  179. """Store json for speech2text.
  180. The format is as follows:
  181. {"audio": "data:audio/mpeg;base64,...", "transcription": "こんにちは、世界!"}
  182. ...
  183. """
  184. @transaction.atomic
  185. def save(self, user):
  186. for data in self.data:
  187. for audio in data:
  188. audio['text'] = audio.pop('audio')
  189. doc = self.save_doc(data)
  190. annotations = self.make_annotations(doc, data)
  191. self.save_annotation(annotations, user)
  192. @classmethod
  193. def make_annotations(cls, docs, data):
  194. annotations = []
  195. for doc, datum in zip(docs, data):
  196. try:
  197. annotations.append({'document': doc.id, 'text': datum['transcription']})
  198. except KeyError:
  199. continue
  200. return annotations
  201. class FileParser(object):
  202. def parse(self, file):
  203. raise NotImplementedError()
  204. @staticmethod
  205. def encode_metadata(data):
  206. return json.dumps(data, ensure_ascii=False)
  207. class CoNLLParser(FileParser):
  208. """Uploads CoNLL format file.
  209. The file format is tab-separated values.
  210. A blank line is required at the end of a sentence.
  211. For example:
  212. ```
  213. EU B-ORG
  214. rejects O
  215. German B-MISC
  216. call O
  217. to O
  218. boycott O
  219. British B-MISC
  220. lamb O
  221. . O
  222. Peter B-PER
  223. Blackburn I-PER
  224. ...
  225. ```
  226. """
  227. def parse(self, file):
  228. data = []
  229. file = EncodedIO(file)
  230. file = io.TextIOWrapper(file, encoding=file.encoding)
  231. # Add check exception
  232. field_parsers = {
  233. "ne": lambda line, i: conllu.parser.parse_nullable_value(line[i]),
  234. }
  235. gen_parser = conllu.parse_incr(
  236. file,
  237. fields=("form", "ne"),
  238. field_parsers=field_parsers
  239. )
  240. try:
  241. for sentence in gen_parser:
  242. if not sentence:
  243. continue
  244. if len(data) >= settings.IMPORT_BATCH_SIZE:
  245. yield data
  246. data = []
  247. words, labels = [], []
  248. for item in sentence:
  249. word = item.get("form")
  250. tag = item.get("ne")
  251. if tag is not None:
  252. char_left = sum(map(len, words)) + len(words)
  253. char_right = char_left + len(word)
  254. span = [char_left, char_right, tag]
  255. labels.append(span)
  256. words.append(word)
  257. # Create and add JSONL
  258. data.append({'text': ' '.join(words), 'labels': labels})
  259. except conllu.parser.ParseException as e:
  260. raise FileParseException(line_num=-1, line=str(e))
  261. if data:
  262. yield data
  263. class PlainTextParser(FileParser):
  264. """Uploads plain text.
  265. The file format is as follows:
  266. ```
  267. EU rejects German call to boycott British lamb.
  268. President Obama is speaking at the White House.
  269. ...
  270. ```
  271. """
  272. def parse(self, file):
  273. file = EncodedIO(file)
  274. file = io.TextIOWrapper(file, encoding=file.encoding)
  275. while True:
  276. batch = list(itertools.islice(file, settings.IMPORT_BATCH_SIZE))
  277. if not batch:
  278. break
  279. yield [{'text': line.strip()} for line in batch]
  280. class CSVParser(FileParser):
  281. """Uploads csv file.
  282. The file format is comma separated values.
  283. Column names are required at the top of a file.
  284. For example:
  285. ```
  286. text, label
  287. "EU rejects German call to boycott British lamb.",Politics
  288. "President Obama is speaking at the White House.",Politics
  289. "He lives in Newark, Ohio.",Other
  290. ...
  291. ```
  292. """
  293. def parse(self, file):
  294. file = EncodedIO(file)
  295. file = io.TextIOWrapper(file, encoding=file.encoding)
  296. reader = csv.reader(file)
  297. yield from ExcelParser.parse_excel_csv_reader(reader)
  298. class ExcelParser(FileParser):
  299. def parse(self, file):
  300. excel_book = pyexcel.iget_book(file_type="xlsx", file_content=file.read())
  301. # Handle multiple sheets
  302. for sheet_name in excel_book.sheet_names():
  303. reader = excel_book[sheet_name].to_array()
  304. yield from self.parse_excel_csv_reader(reader)
  305. @staticmethod
  306. def parse_excel_csv_reader(reader):
  307. columns = next(reader)
  308. data = []
  309. if len(columns) == 1 and columns[0] != 'text':
  310. data.append({'text': columns[0]})
  311. for i, row in enumerate(reader, start=2):
  312. if len(data) >= settings.IMPORT_BATCH_SIZE:
  313. yield data
  314. data = []
  315. # Only text column
  316. if len(row) <= len(columns) and len(row) == 1:
  317. data.append({'text': row[0]})
  318. # Text, labels and metadata columns
  319. elif 2 <= len(row) <= len(columns):
  320. datum = dict(zip(columns, row))
  321. text, label = datum.pop('text'), datum.pop('label')
  322. meta = FileParser.encode_metadata(datum)
  323. if label != '':
  324. j = {'text': text, 'labels': [label], 'meta': meta}
  325. else:
  326. j = {'text': text, 'meta': meta}
  327. data.append(j)
  328. else:
  329. raise FileParseException(line_num=i, line=row)
  330. if data:
  331. yield data
  332. class JSONParser(FileParser):
  333. def parse(self, file):
  334. file = EncodedIO(file)
  335. file = io.TextIOWrapper(file, encoding=file.encoding)
  336. data = []
  337. for i, line in enumerate(file, start=1):
  338. if len(data) >= settings.IMPORT_BATCH_SIZE:
  339. yield data
  340. data = []
  341. try:
  342. j = json.loads(line)
  343. j['meta'] = FileParser.encode_metadata(j.get('meta', {}))
  344. data.append(j)
  345. except json.decoder.JSONDecodeError:
  346. raise FileParseException(line_num=i, line=line)
  347. if data:
  348. yield data
  349. class FastTextParser(FileParser):
  350. """
  351. Parse files in fastText format.
  352. Labels are marked with the __label__ prefix
  353. and the corresponding text comes afterwards in the same line
  354. For example:
  355. ```
  356. __label__dog poodle
  357. __label__house mansion
  358. ```
  359. """
  360. def parse(self, file):
  361. file = EncodedIO(file)
  362. file = io.TextIOWrapper(file, encoding=file.encoding)
  363. data = []
  364. for i, line in enumerate(file, start=0):
  365. if len(data) >= settings.IMPORT_BATCH_SIZE:
  366. yield data
  367. data = []
  368. # Search labels and text, check correct syntax and append
  369. labels = []
  370. text = []
  371. for token in line.rstrip().split(" "):
  372. if token.startswith('__label__'):
  373. if token == '__label__':
  374. raise FileParseException(line_num=i, line=line)
  375. labels.append(token[len('__label__'):])
  376. else:
  377. text.append(token)
  378. # Check if text for labels is given
  379. if not text:
  380. raise FileParseException(line_num=i, line=line)
  381. data.append({'text': " ".join(text), 'labels': labels})
  382. if data:
  383. yield data
  384. class AudioParser(FileParser):
  385. def parse(self, file):
  386. file_type, _ = mimetypes.guess_type(file.name, strict=False)
  387. if not file_type:
  388. raise FileParseException(line_num=1, line='Unable to guess file type')
  389. audio = base64.b64encode(file.read())
  390. yield [{
  391. 'audio': f'data:{file_type};base64,{audio.decode("ascii")}',
  392. 'meta': json.dumps({'filename': file.name}),
  393. }]
  394. class JSONLRenderer(JSONRenderer):
  395. def render(self, data, accepted_media_type=None, renderer_context=None):
  396. """
  397. Render `data` into JSON, returning a bytestring.
  398. """
  399. if data is None:
  400. return bytes()
  401. if not isinstance(data, list):
  402. data = [data]
  403. for d in data:
  404. yield json.dumps(d,
  405. cls=self.encoder_class,
  406. ensure_ascii=self.ensure_ascii,
  407. allow_nan=not self.strict) + '\n'
  408. class FastTextPainter(object):
  409. @staticmethod
  410. def paint_labels(documents, labels):
  411. serializer = DocumentSerializer(documents, many=True)
  412. serializer_labels = LabelSerializer(labels, many=True)
  413. data = []
  414. for d in serializer.data:
  415. labels = []
  416. for a in d['annotations']:
  417. label_obj = [x for x in serializer_labels.data if x['id'] == a['label']][0]
  418. labels.append('__label__{}'.format(label_obj['text'].replace(' ', '_')))
  419. text = d['text'].replace('\n', ' ')
  420. if labels:
  421. data.append('{} {}'.format(' '.join(labels), text))
  422. else:
  423. data.append(text)
  424. return data
  425. class PlainTextRenderer(BaseRenderer):
  426. media_type = 'text/plain'
  427. format = 'txt'
  428. charset = 'utf-8'
  429. def render(self, data, accepted_media_type=None, renderer_context=None):
  430. if data is None:
  431. return bytes()
  432. if not isinstance(data, list):
  433. data = [data]
  434. buffer = io.BytesIO()
  435. for d in data:
  436. buffer.write((d + '\n').encode(self.charset))
  437. return buffer.getvalue()
  438. class JSONPainter(object):
  439. def paint(self, documents):
  440. serializer = DocumentSerializer(documents, many=True)
  441. data = []
  442. for d in serializer.data:
  443. d['meta'] = json.loads(d['meta'])
  444. for a in d['annotations']:
  445. a.pop('id')
  446. a.pop('prob')
  447. a.pop('document')
  448. data.append(d)
  449. return data
  450. @staticmethod
  451. def paint_labels(documents, labels):
  452. serializer_labels = LabelSerializer(labels, many=True)
  453. serializer = DocumentSerializer(documents, many=True)
  454. data = []
  455. for d in serializer.data:
  456. labels = []
  457. for a in d['annotations']:
  458. label_obj = [x for x in serializer_labels.data if x['id'] == a['label']][0]
  459. label_text = label_obj['text']
  460. label_start = a['start_offset']
  461. label_end = a['end_offset']
  462. labels.append([label_start, label_end, label_text])
  463. d.pop('annotations')
  464. d['labels'] = labels
  465. d['meta'] = json.loads(d['meta'])
  466. data.append(d)
  467. return data
  468. class CSVPainter(JSONPainter):
  469. def paint(self, documents):
  470. data = super().paint(documents)
  471. res = []
  472. for d in data:
  473. annotations = d.pop('annotations')
  474. for a in annotations:
  475. res.append({**d, **a})
  476. return res
  477. def iterable_to_io(iterable, buffer_size=io.DEFAULT_BUFFER_SIZE):
  478. """See https://stackoverflow.com/a/20260030/3817588."""
  479. class IterStream(io.RawIOBase):
  480. def __init__(self):
  481. self.leftover = None
  482. def readable(self):
  483. return True
  484. def readinto(self, b):
  485. try:
  486. l = len(b) # We're supposed to return at most this much
  487. chunk = self.leftover or next(iterable)
  488. output, self.leftover = chunk[:l], chunk[l:]
  489. b[:len(output)] = output
  490. return len(output)
  491. except StopIteration:
  492. return 0 # indicate EOF
  493. return io.BufferedReader(IterStream(), buffer_size=buffer_size)
  494. class EncodedIO(io.RawIOBase):
  495. def __init__(self, fobj, buffer_size=io.DEFAULT_BUFFER_SIZE, default_encoding='utf-8'):
  496. buffer = b''
  497. detector = UniversalDetector()
  498. while True:
  499. read = fobj.read(buffer_size)
  500. detector.feed(read)
  501. buffer += read
  502. if detector.done or len(read) < buffer_size:
  503. break
  504. if detector.done:
  505. self.encoding = detector.result['encoding']
  506. else:
  507. self.encoding = default_encoding
  508. self._fobj = fobj
  509. self._buffer = buffer
  510. def readable(self):
  511. return self._fobj.readable()
  512. def readinto(self, b):
  513. l = len(b)
  514. chunk = self._buffer or self._fobj.read(l)
  515. output, self._buffer = chunk[:l], chunk[l:]
  516. b[:len(output)] = output
  517. return len(output)