You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

508 lines
16 KiB

5 years ago
5 years ago
5 years ago
  1. import csv
  2. import io
  3. import itertools
  4. import json
  5. import re
  6. from collections import defaultdict
  7. from random import Random
  8. import conllu
  9. from django.db import transaction
  10. from django.conf import settings
  11. import pyexcel
  12. from rest_framework.renderers import JSONRenderer
  13. from seqeval.metrics.sequence_labeling import get_entities
  14. from .exceptions import FileParseException
  15. from .models import Label
  16. from .serializers import DocumentSerializer, LabelSerializer
  17. def extract_label(tag):
  18. ptn = re.compile(r'(B|I|E|S)-(.+)')
  19. m = ptn.match(tag)
  20. if m:
  21. return m.groups()[1]
  22. else:
  23. return tag
  24. class BaseStorage(object):
  25. def __init__(self, data, project):
  26. self.data = data
  27. self.project = project
  28. @transaction.atomic
  29. def save(self, user):
  30. raise NotImplementedError()
  31. def save_doc(self, data):
  32. serializer = DocumentSerializer(data=data, many=True)
  33. serializer.is_valid(raise_exception=True)
  34. doc = serializer.save(project=self.project)
  35. return doc
  36. def save_label(self, data):
  37. serializer = LabelSerializer(data=data, many=True)
  38. serializer.is_valid(raise_exception=True)
  39. label = serializer.save(project=self.project)
  40. return label
  41. def save_annotation(self, data, user):
  42. annotation_serializer = self.project.get_annotation_serializer()
  43. serializer = annotation_serializer(data=data, many=True)
  44. serializer.is_valid(raise_exception=True)
  45. annotation = serializer.save(user=user)
  46. return annotation
  47. @classmethod
  48. def extract_label(cls, data):
  49. return [d.get('labels', []) for d in data]
  50. @classmethod
  51. def exclude_created_labels(cls, labels, created):
  52. return [label for label in labels if label not in created]
  53. @classmethod
  54. def to_serializer_format(cls, labels, created, random_seed=None):
  55. existing_shortkeys = {(label.suffix_key, label.prefix_key)
  56. for label in created.values()}
  57. serializer_labels = []
  58. for label in sorted(labels):
  59. serializer_label = {'text': label}
  60. shortkey = cls.get_shortkey(label, existing_shortkeys)
  61. if shortkey:
  62. serializer_label['suffix_key'] = shortkey[0]
  63. serializer_label['prefix_key'] = shortkey[1]
  64. existing_shortkeys.add(shortkey)
  65. color = Color.random(seed=random_seed)
  66. serializer_label['background_color'] = color.hex
  67. serializer_label['text_color'] = color.contrast_color.hex
  68. serializer_labels.append(serializer_label)
  69. return serializer_labels
  70. @classmethod
  71. def get_shortkey(cls, label, existing_shortkeys):
  72. model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS]
  73. prefix_keys = [None] + model_prefix_keys
  74. model_suffix_keys = {key for (key, _) in Label.SUFFIX_KEYS}
  75. suffix_keys = [key for key in label.lower() if key in model_suffix_keys]
  76. for shortkey in itertools.product(suffix_keys, prefix_keys):
  77. if shortkey not in existing_shortkeys:
  78. return shortkey
  79. return None
  80. @classmethod
  81. def update_saved_labels(cls, saved, new):
  82. for label in new:
  83. saved[label.text] = label
  84. return saved
  85. class PlainStorage(BaseStorage):
  86. @transaction.atomic
  87. def save(self, user):
  88. for text in self.data:
  89. self.save_doc(text)
  90. class ClassificationStorage(BaseStorage):
  91. """Store json for text classification.
  92. The format is as follows:
  93. {"text": "Python is awesome!", "labels": ["positive"]}
  94. ...
  95. """
  96. @transaction.atomic
  97. def save(self, user):
  98. saved_labels = {label.text: label for label in self.project.labels.all()}
  99. for data in self.data:
  100. docs = self.save_doc(data)
  101. labels = self.extract_label(data)
  102. unique_labels = self.extract_unique_labels(labels)
  103. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  104. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  105. new_labels = self.save_label(unique_labels)
  106. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  107. annotations = self.make_annotations(docs, labels, saved_labels)
  108. self.save_annotation(annotations, user)
  109. @classmethod
  110. def extract_unique_labels(cls, labels):
  111. return set(itertools.chain(*labels))
  112. @classmethod
  113. def make_annotations(cls, docs, labels, saved_labels):
  114. annotations = []
  115. for doc, label in zip(docs, labels):
  116. for name in label:
  117. label = saved_labels[name]
  118. annotations.append({'document': doc.id, 'label': label.id})
  119. return annotations
  120. class SequenceLabelingStorage(BaseStorage):
  121. """Upload jsonl for sequence labeling.
  122. The format is as follows:
  123. {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
  124. ...
  125. """
  126. @transaction.atomic
  127. def save(self, user):
  128. saved_labels = {label.text: label for label in self.project.labels.all()}
  129. for data in self.data:
  130. docs = self.save_doc(data)
  131. labels = self.extract_label(data)
  132. unique_labels = self.extract_unique_labels(labels)
  133. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  134. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  135. new_labels = self.save_label(unique_labels)
  136. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  137. annotations = self.make_annotations(docs, labels, saved_labels)
  138. self.save_annotation(annotations, user)
  139. @classmethod
  140. def extract_unique_labels(cls, labels):
  141. return set([label for _, _, label in itertools.chain(*labels)])
  142. @classmethod
  143. def make_annotations(cls, docs, labels, saved_labels):
  144. annotations = []
  145. for doc, spans in zip(docs, labels):
  146. for span in spans:
  147. start_offset, end_offset, name = span
  148. label = saved_labels[name]
  149. annotations.append({'document': doc.id,
  150. 'label': label.id,
  151. 'start_offset': start_offset,
  152. 'end_offset': end_offset})
  153. return annotations
  154. class Seq2seqStorage(BaseStorage):
  155. """Store json for seq2seq.
  156. The format is as follows:
  157. {"text": "Hello, World!", "labels": ["こんにちは、世界!"]}
  158. ...
  159. """
  160. @transaction.atomic
  161. def save(self, user):
  162. for data in self.data:
  163. doc = self.save_doc(data)
  164. labels = self.extract_label(data)
  165. annotations = self.make_annotations(doc, labels)
  166. self.save_annotation(annotations, user)
  167. @classmethod
  168. def make_annotations(cls, docs, labels):
  169. annotations = []
  170. for doc, texts in zip(docs, labels):
  171. for text in texts:
  172. annotations.append({'document': doc.id, 'text': text})
  173. return annotations
  174. class FileParser(object):
  175. def parse(self, file):
  176. raise NotImplementedError()
  177. class CoNLLParser(FileParser):
  178. """Uploads CoNLL format file.
  179. The file format is tab-separated values.
  180. A blank line is required at the end of a sentence.
  181. For example:
  182. ```
  183. EU B-ORG
  184. rejects O
  185. German B-MISC
  186. call O
  187. to O
  188. boycott O
  189. British B-MISC
  190. lamb O
  191. . O
  192. Peter B-PER
  193. Blackburn I-PER
  194. ...
  195. ```
  196. """
  197. def parse(self, file):
  198. data = []
  199. file = io.TextIOWrapper(file, encoding='utf-8')
  200. # Add check exception
  201. field_parsers = {
  202. "ne": lambda line, i: conllu.parser.parse_nullable_value(line[i]),
  203. }
  204. gen_parser = conllu.parse_incr(
  205. file,
  206. fields=("form", "ne"),
  207. field_parsers=field_parsers
  208. )
  209. try:
  210. for sentence in gen_parser:
  211. if not sentence:
  212. continue
  213. if len(data) >= settings.IMPORT_BATCH_SIZE:
  214. yield data
  215. data = []
  216. words, labels = [], []
  217. for item in sentence:
  218. word = item.get("form")
  219. tag = item.get("ne")
  220. if tag is not None:
  221. char_left = sum(map(len, words)) + len(words)
  222. char_right = char_left + len(word)
  223. span = [char_left, char_right, tag]
  224. labels.append(span)
  225. words.append(word)
  226. # Create and add JSONL
  227. data.append({'text': ' '.join(words), 'labels': labels})
  228. except conllu.parser.ParseException as e:
  229. raise FileParseException(line_num=-1, line=str(e))
  230. if data:
  231. yield data
  232. class PlainTextParser(FileParser):
  233. """Uploads plain text.
  234. The file format is as follows:
  235. ```
  236. EU rejects German call to boycott British lamb.
  237. President Obama is speaking at the White House.
  238. ...
  239. ```
  240. """
  241. def parse(self, file):
  242. file = io.TextIOWrapper(file, encoding='utf-8')
  243. while True:
  244. batch = list(itertools.islice(file, settings.IMPORT_BATCH_SIZE))
  245. if not batch:
  246. break
  247. yield [{'text': line.strip()} for line in batch]
  248. class CSVParser(FileParser):
  249. """Uploads csv file.
  250. The file format is comma separated values.
  251. Column names are required at the top of a file.
  252. For example:
  253. ```
  254. text, label
  255. "EU rejects German call to boycott British lamb.",Politics
  256. "President Obama is speaking at the White House.",Politics
  257. "He lives in Newark, Ohio.",Other
  258. ...
  259. ```
  260. """
  261. def parse(self, file):
  262. file = io.TextIOWrapper(file, encoding='utf-8')
  263. reader = csv.reader(file)
  264. yield from ExcelParser.parse_excel_csv_reader(reader)
  265. class ExcelParser(FileParser):
  266. def parse(self, file):
  267. excel_book = pyexcel.iget_book(file_type="xlsx", file_content=file.read())
  268. # Handle multiple sheets
  269. for sheet_name in excel_book.sheet_names():
  270. reader = excel_book[sheet_name].to_array()
  271. yield from self.parse_excel_csv_reader(reader)
  272. @staticmethod
  273. def parse_excel_csv_reader(reader):
  274. columns = next(reader)
  275. data = []
  276. if len(columns) == 1 and columns[0] != 'text':
  277. data.append({'text': columns[0]})
  278. for i, row in enumerate(reader, start=2):
  279. if len(data) >= settings.IMPORT_BATCH_SIZE:
  280. yield data
  281. data = []
  282. # Only text column
  283. if len(row) == len(columns) and len(row) == 1:
  284. data.append({'text': row[0]})
  285. # Text, labels and metadata columns
  286. elif len(row) == len(columns) and len(row) >= 2:
  287. text, label = row[:2]
  288. meta = json.dumps(dict(zip(columns[2:], row[2:])))
  289. j = {'text': text, 'labels': [label], 'meta': meta}
  290. data.append(j)
  291. else:
  292. raise FileParseException(line_num=i, line=row)
  293. if data:
  294. yield data
  295. class JSONParser(FileParser):
  296. def parse(self, file):
  297. file = io.TextIOWrapper(file, encoding='utf-8')
  298. data = []
  299. for i, line in enumerate(file, start=1):
  300. if len(data) >= settings.IMPORT_BATCH_SIZE:
  301. yield data
  302. data = []
  303. try:
  304. j = json.loads(line)
  305. j['meta'] = json.dumps(j.get('meta', {}))
  306. data.append(j)
  307. except json.decoder.JSONDecodeError:
  308. raise FileParseException(line_num=i, line=line)
  309. if data:
  310. yield data
  311. class JSONLRenderer(JSONRenderer):
  312. def render(self, data, accepted_media_type=None, renderer_context=None):
  313. """
  314. Render `data` into JSON, returning a bytestring.
  315. """
  316. if data is None:
  317. return bytes()
  318. if not isinstance(data, list):
  319. data = [data]
  320. for d in data:
  321. yield json.dumps(d,
  322. cls=self.encoder_class,
  323. ensure_ascii=self.ensure_ascii,
  324. allow_nan=not self.strict) + '\n'
  325. class JSONPainter(object):
  326. def paint(self, documents):
  327. serializer = DocumentSerializer(documents, many=True)
  328. data = []
  329. for d in serializer.data:
  330. d['meta'] = json.loads(d['meta'])
  331. for a in d['annotations']:
  332. a.pop('id')
  333. a.pop('prob')
  334. a.pop('document')
  335. data.append(d)
  336. return data
  337. @staticmethod
  338. def paint_labels(documents, labels):
  339. serializer_labels = LabelSerializer(labels, many=True)
  340. serializer = DocumentSerializer(documents, many=True)
  341. data = []
  342. for d in serializer.data:
  343. labels = []
  344. for a in d['annotations']:
  345. label_obj = [x for x in serializer_labels.data if x['id'] == a['label']][0]
  346. label_text = label_obj['text']
  347. label_start = a['start_offset']
  348. label_end = a['end_offset']
  349. labels.append([label_start, label_end, label_text])
  350. d.pop('annotations')
  351. d['labels'] = labels
  352. d['meta'] = json.loads(d['meta'])
  353. data.append(d)
  354. return data
  355. class CSVPainter(JSONPainter):
  356. def paint(self, documents):
  357. data = super().paint(documents)
  358. res = []
  359. for d in data:
  360. annotations = d.pop('annotations')
  361. for a in annotations:
  362. res.append({**d, **a})
  363. return res
  364. class Color:
  365. def __init__(self, red, green, blue):
  366. self.red = red
  367. self.green = green
  368. self.blue = blue
  369. @property
  370. def contrast_color(self):
  371. """Generate black or white color.
  372. Ensure that text and background color combinations provide
  373. sufficient contrast when viewed by someone having color deficits or
  374. when viewed on a black and white screen.
  375. Algorithm from w3c:
  376. * https://www.w3.org/TR/AERT/#color-contrast
  377. """
  378. return Color.white() if self.brightness < 128 else Color.black()
  379. @property
  380. def brightness(self):
  381. return ((self.red * 299) + (self.green * 587) + (self.blue * 114)) / 1000
  382. @property
  383. def hex(self):
  384. return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
  385. @classmethod
  386. def white(cls):
  387. return cls(red=255, green=255, blue=255)
  388. @classmethod
  389. def black(cls):
  390. return cls(red=0, green=0, blue=0)
  391. @classmethod
  392. def random(cls, seed=None):
  393. rgb = Random(seed).choices(range(256), k=3)
  394. return cls(*rgb)
  395. def iterable_to_io(iterable, buffer_size=io.DEFAULT_BUFFER_SIZE):
  396. """See https://stackoverflow.com/a/20260030/3817588."""
  397. class IterStream(io.RawIOBase):
  398. def __init__(self):
  399. self.leftover = None
  400. def readable(self):
  401. return True
  402. def readinto(self, b):
  403. try:
  404. l = len(b) # We're supposed to return at most this much
  405. chunk = self.leftover or next(iterable)
  406. output, self.leftover = chunk[:l], chunk[l:]
  407. b[:len(output)] = output
  408. return len(output)
  409. except StopIteration:
  410. return 0 # indicate EOF
  411. return io.BufferedReader(IterStream(), buffer_size=buffer_size)