You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

554 lines
17 KiB

5 years ago
5 years ago
5 years ago
  1. import base64
  2. import csv
  3. import io
  4. import itertools
  5. import json
  6. import mimetypes
  7. import re
  8. from collections import defaultdict
  9. import conllu
  10. from chardet import UniversalDetector
  11. from django.db import transaction
  12. from django.conf import settings
  13. from colour import Color
  14. import pyexcel
  15. from rest_framework.renderers import JSONRenderer
  16. from seqeval.metrics.sequence_labeling import get_entities
  17. from .exceptions import FileParseException
  18. from .models import Label
  19. from .serializers import DocumentSerializer, LabelSerializer
  20. def extract_label(tag):
  21. ptn = re.compile(r'(B|I|E|S)-(.+)')
  22. m = ptn.match(tag)
  23. if m:
  24. return m.groups()[1]
  25. else:
  26. return tag
  27. class BaseStorage(object):
  28. def __init__(self, data, project):
  29. self.data = data
  30. self.project = project
  31. @transaction.atomic
  32. def save(self, user):
  33. raise NotImplementedError()
  34. def save_doc(self, data):
  35. serializer = DocumentSerializer(data=data, many=True)
  36. serializer.is_valid(raise_exception=True)
  37. doc = serializer.save(project=self.project)
  38. return doc
  39. def save_label(self, data):
  40. serializer = LabelSerializer(data=data, many=True)
  41. serializer.is_valid(raise_exception=True)
  42. label = serializer.save(project=self.project)
  43. return label
  44. def save_annotation(self, data, user):
  45. annotation_serializer = self.project.get_annotation_serializer()
  46. serializer = annotation_serializer(data=data, many=True)
  47. serializer.is_valid(raise_exception=True)
  48. annotation = serializer.save(user=user)
  49. return annotation
  50. @classmethod
  51. def extract_label(cls, data):
  52. return [d.get('labels', []) for d in data]
  53. @classmethod
  54. def exclude_created_labels(cls, labels, created):
  55. return [label for label in labels if label not in created]
  56. @classmethod
  57. def to_serializer_format(cls, labels, created):
  58. existing_shortkeys = {(label.suffix_key, label.prefix_key)
  59. for label in created.values()}
  60. serializer_labels = []
  61. for label in sorted(labels):
  62. serializer_label = {'text': label}
  63. shortkey = cls.get_shortkey(label, existing_shortkeys)
  64. if shortkey:
  65. serializer_label['suffix_key'] = shortkey[0]
  66. serializer_label['prefix_key'] = shortkey[1]
  67. existing_shortkeys.add(shortkey)
  68. background_color = Color(pick_for=label)
  69. text_color = Color('white') if background_color.get_luminance() < 0.5 else Color('black')
  70. serializer_label['background_color'] = background_color.hex
  71. serializer_label['text_color'] = text_color.hex
  72. serializer_labels.append(serializer_label)
  73. return serializer_labels
  74. @classmethod
  75. def get_shortkey(cls, label, existing_shortkeys):
  76. model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS]
  77. prefix_keys = [None] + model_prefix_keys
  78. model_suffix_keys = {key for (key, _) in Label.SUFFIX_KEYS}
  79. suffix_keys = [key for key in label.lower() if key in model_suffix_keys]
  80. for shortkey in itertools.product(suffix_keys, prefix_keys):
  81. if shortkey not in existing_shortkeys:
  82. return shortkey
  83. return None
  84. @classmethod
  85. def update_saved_labels(cls, saved, new):
  86. for label in new:
  87. saved[label.text] = label
  88. return saved
  89. class PlainStorage(BaseStorage):
  90. @transaction.atomic
  91. def save(self, user):
  92. for text in self.data:
  93. self.save_doc(text)
  94. class ClassificationStorage(BaseStorage):
  95. """Store json for text classification.
  96. The format is as follows:
  97. {"text": "Python is awesome!", "labels": ["positive"]}
  98. ...
  99. """
  100. @transaction.atomic
  101. def save(self, user):
  102. saved_labels = {label.text: label for label in self.project.labels.all()}
  103. for data in self.data:
  104. docs = self.save_doc(data)
  105. labels = self.extract_label(data)
  106. unique_labels = self.extract_unique_labels(labels)
  107. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  108. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  109. new_labels = self.save_label(unique_labels)
  110. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  111. annotations = self.make_annotations(docs, labels, saved_labels)
  112. self.save_annotation(annotations, user)
  113. @classmethod
  114. def extract_unique_labels(cls, labels):
  115. return set(itertools.chain(*labels))
  116. @classmethod
  117. def make_annotations(cls, docs, labels, saved_labels):
  118. annotations = []
  119. for doc, label in zip(docs, labels):
  120. for name in label:
  121. label = saved_labels[name]
  122. annotations.append({'document': doc.id, 'label': label.id})
  123. return annotations
  124. class SequenceLabelingStorage(BaseStorage):
  125. """Upload jsonl for sequence labeling.
  126. The format is as follows:
  127. {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
  128. ...
  129. """
  130. @transaction.atomic
  131. def save(self, user):
  132. saved_labels = {label.text: label for label in self.project.labels.all()}
  133. for data in self.data:
  134. docs = self.save_doc(data)
  135. labels = self.extract_label(data)
  136. unique_labels = self.extract_unique_labels(labels)
  137. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  138. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  139. new_labels = self.save_label(unique_labels)
  140. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  141. annotations = self.make_annotations(docs, labels, saved_labels)
  142. self.save_annotation(annotations, user)
  143. @classmethod
  144. def extract_unique_labels(cls, labels):
  145. return set([label for _, _, label in itertools.chain(*labels)])
  146. @classmethod
  147. def make_annotations(cls, docs, labels, saved_labels):
  148. annotations = []
  149. for doc, spans in zip(docs, labels):
  150. for span in spans:
  151. start_offset, end_offset, name = span
  152. label = saved_labels[name]
  153. annotations.append({'document': doc.id,
  154. 'label': label.id,
  155. 'start_offset': start_offset,
  156. 'end_offset': end_offset})
  157. return annotations
  158. class Seq2seqStorage(BaseStorage):
  159. """Store json for seq2seq.
  160. The format is as follows:
  161. {"text": "Hello, World!", "labels": ["こんにちは、世界!"]}
  162. ...
  163. """
  164. @transaction.atomic
  165. def save(self, user):
  166. for data in self.data:
  167. doc = self.save_doc(data)
  168. labels = self.extract_label(data)
  169. annotations = self.make_annotations(doc, labels)
  170. self.save_annotation(annotations, user)
  171. @classmethod
  172. def make_annotations(cls, docs, labels):
  173. annotations = []
  174. for doc, texts in zip(docs, labels):
  175. for text in texts:
  176. annotations.append({'document': doc.id, 'text': text})
  177. return annotations
  178. class Speech2textStorage(BaseStorage):
  179. """Store json for speech2text.
  180. The format is as follows:
  181. {"audio": "data:audio/mpeg;base64,...", "transcription": "こんにちは、世界!"}
  182. ...
  183. """
  184. @transaction.atomic
  185. def save(self, user):
  186. for data in self.data:
  187. for audio in data:
  188. audio['text'] = audio.pop('audio')
  189. doc = self.save_doc(data)
  190. annotations = self.make_annotations(doc, data)
  191. self.save_annotation(annotations, user)
  192. @classmethod
  193. def make_annotations(cls, docs, data):
  194. annotations = []
  195. for doc, datum in zip(docs, data):
  196. try:
  197. annotations.append({'document': doc.id, 'text': datum['transcription']})
  198. except KeyError:
  199. continue
  200. return annotations
  201. class FileParser(object):
  202. def parse(self, file):
  203. raise NotImplementedError()
  204. @staticmethod
  205. def encode_metadata(data):
  206. return json.dumps(data, ensure_ascii=False)
  207. class CoNLLParser(FileParser):
  208. """Uploads CoNLL format file.
  209. The file format is tab-separated values.
  210. A blank line is required at the end of a sentence.
  211. For example:
  212. ```
  213. EU B-ORG
  214. rejects O
  215. German B-MISC
  216. call O
  217. to O
  218. boycott O
  219. British B-MISC
  220. lamb O
  221. . O
  222. Peter B-PER
  223. Blackburn I-PER
  224. ...
  225. ```
  226. """
  227. def parse(self, file):
  228. data = []
  229. file = EncodedIO(file)
  230. file = io.TextIOWrapper(file, encoding=file.encoding)
  231. # Add check exception
  232. field_parsers = {
  233. "ne": lambda line, i: conllu.parser.parse_nullable_value(line[i]),
  234. }
  235. gen_parser = conllu.parse_incr(
  236. file,
  237. fields=("form", "ne"),
  238. field_parsers=field_parsers
  239. )
  240. try:
  241. for sentence in gen_parser:
  242. if not sentence:
  243. continue
  244. if len(data) >= settings.IMPORT_BATCH_SIZE:
  245. yield data
  246. data = []
  247. words, labels = [], []
  248. for item in sentence:
  249. word = item.get("form")
  250. tag = item.get("ne")
  251. if tag is not None:
  252. char_left = sum(map(len, words)) + len(words)
  253. char_right = char_left + len(word)
  254. span = [char_left, char_right, tag]
  255. labels.append(span)
  256. words.append(word)
  257. # Create and add JSONL
  258. data.append({'text': ' '.join(words), 'labels': labels})
  259. except conllu.parser.ParseException as e:
  260. raise FileParseException(line_num=-1, line=str(e))
  261. if data:
  262. yield data
  263. class PlainTextParser(FileParser):
  264. """Uploads plain text.
  265. The file format is as follows:
  266. ```
  267. EU rejects German call to boycott British lamb.
  268. President Obama is speaking at the White House.
  269. ...
  270. ```
  271. """
  272. def parse(self, file):
  273. file = EncodedIO(file)
  274. file = io.TextIOWrapper(file, encoding=file.encoding)
  275. while True:
  276. batch = list(itertools.islice(file, settings.IMPORT_BATCH_SIZE))
  277. if not batch:
  278. break
  279. yield [{'text': line.strip()} for line in batch]
  280. class CSVParser(FileParser):
  281. """Uploads csv file.
  282. The file format is comma separated values.
  283. Column names are required at the top of a file.
  284. For example:
  285. ```
  286. text, label
  287. "EU rejects German call to boycott British lamb.",Politics
  288. "President Obama is speaking at the White House.",Politics
  289. "He lives in Newark, Ohio.",Other
  290. ...
  291. ```
  292. """
  293. def parse(self, file):
  294. file = EncodedIO(file)
  295. file = io.TextIOWrapper(file, encoding=file.encoding)
  296. reader = csv.reader(file)
  297. yield from ExcelParser.parse_excel_csv_reader(reader)
  298. class ExcelParser(FileParser):
  299. def parse(self, file):
  300. excel_book = pyexcel.iget_book(file_type="xlsx", file_content=file.read())
  301. # Handle multiple sheets
  302. for sheet_name in excel_book.sheet_names():
  303. reader = excel_book[sheet_name].to_array()
  304. yield from self.parse_excel_csv_reader(reader)
  305. @staticmethod
  306. def parse_excel_csv_reader(reader):
  307. columns = next(reader)
  308. data = []
  309. if len(columns) == 1 and columns[0] != 'text':
  310. data.append({'text': columns[0]})
  311. for i, row in enumerate(reader, start=2):
  312. if len(data) >= settings.IMPORT_BATCH_SIZE:
  313. yield data
  314. data = []
  315. # Only text column
  316. if len(row) <= len(columns) and len(row) == 1:
  317. data.append({'text': row[0]})
  318. # Text, labels and metadata columns
  319. elif 2 <= len(row) <= len(columns):
  320. datum = dict(zip(columns, row))
  321. text, label = datum.pop('text'), datum.pop('label')
  322. meta = FileParser.encode_metadata(datum)
  323. if label != '':
  324. j = {'text': text, 'labels': [label], 'meta': meta}
  325. else:
  326. j = {'text': text, 'meta': meta}
  327. data.append(j)
  328. else:
  329. raise FileParseException(line_num=i, line=row)
  330. if data:
  331. yield data
  332. class JSONParser(FileParser):
  333. def parse(self, file):
  334. file = EncodedIO(file)
  335. file = io.TextIOWrapper(file, encoding=file.encoding)
  336. data = []
  337. for i, line in enumerate(file, start=1):
  338. if len(data) >= settings.IMPORT_BATCH_SIZE:
  339. yield data
  340. data = []
  341. try:
  342. j = json.loads(line)
  343. j['meta'] = FileParser.encode_metadata(j.get('meta', {}))
  344. data.append(j)
  345. except json.decoder.JSONDecodeError:
  346. raise FileParseException(line_num=i, line=line)
  347. if data:
  348. yield data
  349. class AudioParser(FileParser):
  350. def parse(self, file):
  351. file_type, _ = mimetypes.guess_type(file.name, strict=False)
  352. if not file_type:
  353. raise FileParseException(line_num=1, line='Unable to guess file type')
  354. audio = base64.b64encode(file.read())
  355. yield [{
  356. 'audio': f'data:{file_type};base64,{audio.decode("ascii")}',
  357. 'meta': json.dumps({'filename': file.name}),
  358. }]
  359. class JSONLRenderer(JSONRenderer):
  360. def render(self, data, accepted_media_type=None, renderer_context=None):
  361. """
  362. Render `data` into JSON, returning a bytestring.
  363. """
  364. if data is None:
  365. return bytes()
  366. if not isinstance(data, list):
  367. data = [data]
  368. for d in data:
  369. yield json.dumps(d,
  370. cls=self.encoder_class,
  371. ensure_ascii=self.ensure_ascii,
  372. allow_nan=not self.strict) + '\n'
  373. class JSONPainter(object):
  374. def paint(self, documents):
  375. serializer = DocumentSerializer(documents, many=True)
  376. data = []
  377. for d in serializer.data:
  378. d['meta'] = json.loads(d['meta'])
  379. for a in d['annotations']:
  380. a.pop('id')
  381. a.pop('prob')
  382. a.pop('document')
  383. data.append(d)
  384. return data
  385. @staticmethod
  386. def paint_labels(documents, labels):
  387. serializer_labels = LabelSerializer(labels, many=True)
  388. serializer = DocumentSerializer(documents, many=True)
  389. data = []
  390. for d in serializer.data:
  391. labels = []
  392. for a in d['annotations']:
  393. label_obj = [x for x in serializer_labels.data if x['id'] == a['label']][0]
  394. label_text = label_obj['text']
  395. label_start = a['start_offset']
  396. label_end = a['end_offset']
  397. labels.append([label_start, label_end, label_text])
  398. d.pop('annotations')
  399. d['labels'] = labels
  400. d['meta'] = json.loads(d['meta'])
  401. data.append(d)
  402. return data
  403. class CSVPainter(JSONPainter):
  404. def paint(self, documents):
  405. data = super().paint(documents)
  406. res = []
  407. for d in data:
  408. annotations = d.pop('annotations')
  409. for a in annotations:
  410. res.append({**d, **a})
  411. return res
  412. def iterable_to_io(iterable, buffer_size=io.DEFAULT_BUFFER_SIZE):
  413. """See https://stackoverflow.com/a/20260030/3817588."""
  414. class IterStream(io.RawIOBase):
  415. def __init__(self):
  416. self.leftover = None
  417. def readable(self):
  418. return True
  419. def readinto(self, b):
  420. try:
  421. l = len(b) # We're supposed to return at most this much
  422. chunk = self.leftover or next(iterable)
  423. output, self.leftover = chunk[:l], chunk[l:]
  424. b[:len(output)] = output
  425. return len(output)
  426. except StopIteration:
  427. return 0 # indicate EOF
  428. return io.BufferedReader(IterStream(), buffer_size=buffer_size)
  429. class EncodedIO(io.RawIOBase):
  430. def __init__(self, fobj, buffer_size=io.DEFAULT_BUFFER_SIZE, default_encoding='utf-8'):
  431. buffer = b''
  432. detector = UniversalDetector()
  433. while True:
  434. read = fobj.read(buffer_size)
  435. detector.feed(read)
  436. buffer += read
  437. if detector.done or len(read) < buffer_size:
  438. break
  439. if detector.done:
  440. self.encoding = detector.result['encoding']
  441. else:
  442. self.encoding = default_encoding
  443. self._fobj = fobj
  444. self._buffer = buffer
  445. def readable(self):
  446. return self._fobj.readable()
  447. def readinto(self, b):
  448. l = len(b)
  449. chunk = self._buffer or self._fobj.read(l)
  450. output, self._buffer = chunk[:l], chunk[l:]
  451. b[:len(output)] = output
  452. return len(output)