You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

481 lines
15 KiB

  1. import csv
  2. import io
  3. import itertools
  4. import json
  5. import re
  6. from collections import defaultdict
  7. from random import Random
  8. from django.db import transaction
  9. from django.conf import settings
  10. from rest_framework.renderers import JSONRenderer
  11. from seqeval.metrics.sequence_labeling import get_entities
  12. from .exceptions import FileParseException
  13. from .models import Label
  14. from .serializers import DocumentSerializer, LabelSerializer
  15. def extract_label(tag):
  16. ptn = re.compile(r'(B|I|E|S)-(.+)')
  17. m = ptn.match(tag)
  18. if m:
  19. return m.groups()[1]
  20. else:
  21. return tag
  22. class BaseStorage(object):
  23. def __init__(self, data, project):
  24. self.data = data
  25. self.project = project
  26. @transaction.atomic
  27. def save(self, user):
  28. raise NotImplementedError()
  29. def save_doc(self, data):
  30. serializer = DocumentSerializer(data=data, many=True)
  31. serializer.is_valid(raise_exception=True)
  32. doc = serializer.save(project=self.project)
  33. return doc
  34. def save_label(self, data):
  35. serializer = LabelSerializer(data=data, many=True)
  36. serializer.is_valid(raise_exception=True)
  37. label = serializer.save(project=self.project)
  38. return label
  39. def save_annotation(self, data, user):
  40. annotation_serializer = self.project.get_annotation_serializer()
  41. serializer = annotation_serializer(data=data, many=True)
  42. serializer.is_valid(raise_exception=True)
  43. annotation = serializer.save(user=user)
  44. return annotation
  45. @classmethod
  46. def extract_label(cls, data):
  47. return [d.get('labels', []) for d in data]
  48. @classmethod
  49. def exclude_created_labels(cls, labels, created):
  50. return [label for label in labels if label not in created]
  51. @classmethod
  52. def to_serializer_format(cls, labels, created, random_seed=None):
  53. existing_shortkeys = {(label.suffix_key, label.prefix_key)
  54. for label in created.values()}
  55. serializer_labels = []
  56. for label in sorted(labels):
  57. serializer_label = {'text': label}
  58. shortkey = cls.get_shortkey(label, existing_shortkeys)
  59. if shortkey:
  60. serializer_label['suffix_key'] = shortkey[0]
  61. serializer_label['prefix_key'] = shortkey[1]
  62. existing_shortkeys.add(shortkey)
  63. color = Color.random(seed=random_seed)
  64. serializer_label['background_color'] = color.hex
  65. serializer_label['text_color'] = color.contrast_color.hex
  66. serializer_labels.append(serializer_label)
  67. return serializer_labels
  68. @classmethod
  69. def get_shortkey(cls, label, existing_shortkeys):
  70. model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS]
  71. prefix_keys = [None] + model_prefix_keys
  72. model_suffix_keys = {key for (key, _) in Label.SUFFIX_KEYS}
  73. suffix_keys = [key for key in label.lower() if key in model_suffix_keys]
  74. for shortkey in itertools.product(suffix_keys, prefix_keys):
  75. if shortkey not in existing_shortkeys:
  76. return shortkey
  77. return None
  78. @classmethod
  79. def update_saved_labels(cls, saved, new):
  80. for label in new:
  81. saved[label.text] = label
  82. return saved
  83. class PlainStorage(BaseStorage):
  84. @transaction.atomic
  85. def save(self, user):
  86. for text in self.data:
  87. self.save_doc(text)
  88. class ClassificationStorage(BaseStorage):
  89. """Store json for text classification.
  90. The format is as follows:
  91. {"text": "Python is awesome!", "labels": ["positive"]}
  92. ...
  93. """
  94. @transaction.atomic
  95. def save(self, user):
  96. saved_labels = {label.text: label for label in self.project.labels.all()}
  97. for data in self.data:
  98. docs = self.save_doc(data)
  99. labels = self.extract_label(data)
  100. unique_labels = self.extract_unique_labels(labels)
  101. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  102. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  103. new_labels = self.save_label(unique_labels)
  104. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  105. annotations = self.make_annotations(docs, labels, saved_labels)
  106. self.save_annotation(annotations, user)
  107. @classmethod
  108. def extract_unique_labels(cls, labels):
  109. return set(itertools.chain(*labels))
  110. @classmethod
  111. def make_annotations(cls, docs, labels, saved_labels):
  112. annotations = []
  113. for doc, label in zip(docs, labels):
  114. for name in label:
  115. label = saved_labels[name]
  116. annotations.append({'document': doc.id, 'label': label.id})
  117. return annotations
  118. class SequenceLabelingStorage(BaseStorage):
  119. """Upload jsonl for sequence labeling.
  120. The format is as follows:
  121. {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
  122. ...
  123. """
  124. @transaction.atomic
  125. def save(self, user):
  126. saved_labels = {label.text: label for label in self.project.labels.all()}
  127. for data in self.data:
  128. docs = self.save_doc(data)
  129. labels = self.extract_label(data)
  130. unique_labels = self.extract_unique_labels(labels)
  131. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  132. unique_labels = self.to_serializer_format(unique_labels, saved_labels)
  133. new_labels = self.save_label(unique_labels)
  134. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  135. annotations = self.make_annotations(docs, labels, saved_labels)
  136. self.save_annotation(annotations, user)
  137. @classmethod
  138. def extract_unique_labels(cls, labels):
  139. return set([label for _, _, label in itertools.chain(*labels)])
  140. @classmethod
  141. def make_annotations(cls, docs, labels, saved_labels):
  142. annotations = []
  143. for doc, spans in zip(docs, labels):
  144. for span in spans:
  145. start_offset, end_offset, name = span
  146. label = saved_labels[name]
  147. annotations.append({'document': doc.id,
  148. 'label': label.id,
  149. 'start_offset': start_offset,
  150. 'end_offset': end_offset})
  151. return annotations
  152. class Seq2seqStorage(BaseStorage):
  153. """Store json for seq2seq.
  154. The format is as follows:
  155. {"text": "Hello, World!", "labels": ["こんにちは、世界!"]}
  156. ...
  157. """
  158. @transaction.atomic
  159. def save(self, user):
  160. for data in self.data:
  161. doc = self.save_doc(data)
  162. labels = self.extract_label(data)
  163. annotations = self.make_annotations(doc, labels)
  164. self.save_annotation(annotations, user)
  165. @classmethod
  166. def make_annotations(cls, docs, labels):
  167. annotations = []
  168. for doc, texts in zip(docs, labels):
  169. for text in texts:
  170. annotations.append({'document': doc.id, 'text': text})
  171. return annotations
  172. class FileParser(object):
  173. def parse(self, file):
  174. raise NotImplementedError()
  175. class CoNLLParser(FileParser):
  176. """Uploads CoNLL format file.
  177. The file format is tab-separated values.
  178. A blank line is required at the end of a sentence.
  179. For example:
  180. ```
  181. EU B-ORG
  182. rejects O
  183. German B-MISC
  184. call O
  185. to O
  186. boycott O
  187. British B-MISC
  188. lamb O
  189. . O
  190. Peter B-PER
  191. Blackburn I-PER
  192. ...
  193. ```
  194. """
  195. def parse(self, file):
  196. words, tags = [], []
  197. data = []
  198. file = io.TextIOWrapper(file, encoding='utf-8')
  199. for i, line in enumerate(file, start=1):
  200. if len(data) >= settings.IMPORT_BATCH_SIZE:
  201. yield data
  202. data = []
  203. line = line.strip()
  204. if line:
  205. try:
  206. word, tag = line.split('\t')
  207. except ValueError:
  208. raise FileParseException(line_num=i, line=line)
  209. words.append(word)
  210. tags.append(tag)
  211. elif words and tags:
  212. j = self.calc_char_offset(words, tags)
  213. data.append(j)
  214. words, tags = [], []
  215. if len(words) > 0:
  216. j = self.calc_char_offset(words, tags)
  217. data.append(j)
  218. if data:
  219. yield data
  220. @classmethod
  221. def calc_char_offset(cls, words, tags):
  222. doc = ' '.join(words)
  223. j = {'text': ' '.join(words), 'labels': []}
  224. pos = defaultdict(int)
  225. for label, start_offset, end_offset in get_entities(tags):
  226. entity = ' '.join(words[start_offset: end_offset + 1])
  227. char_left = doc.index(entity, pos[entity])
  228. char_right = char_left + len(entity)
  229. span = [char_left, char_right, label]
  230. j['labels'].append(span)
  231. pos[entity] = char_right
  232. return j
  233. class PlainTextParser(FileParser):
  234. """Uploads plain text.
  235. The file format is as follows:
  236. ```
  237. EU rejects German call to boycott British lamb.
  238. President Obama is speaking at the White House.
  239. ...
  240. ```
  241. """
  242. def parse(self, file):
  243. file = io.TextIOWrapper(file, encoding='utf-8')
  244. while True:
  245. batch = list(itertools.islice(file, settings.IMPORT_BATCH_SIZE))
  246. if not batch:
  247. break
  248. yield [{'text': line.strip()} for line in batch]
  249. class CSVParser(FileParser):
  250. """Uploads csv file.
  251. The file format is comma separated values.
  252. Column names are required at the top of a file.
  253. For example:
  254. ```
  255. text, label
  256. "EU rejects German call to boycott British lamb.",Politics
  257. "President Obama is speaking at the White House.",Politics
  258. "He lives in Newark, Ohio.",Other
  259. ...
  260. ```
  261. """
  262. def parse(self, file):
  263. file = io.TextIOWrapper(file, encoding='utf-8')
  264. reader = csv.reader(file)
  265. columns = next(reader)
  266. data = []
  267. for i, row in enumerate(reader, start=2):
  268. if len(data) >= settings.IMPORT_BATCH_SIZE:
  269. yield data
  270. data = []
  271. if len(row) == len(columns) and len(row) >= 2:
  272. text, label = row[:2]
  273. meta = json.dumps(dict(zip(columns[2:], row[2:])))
  274. j = {'text': text, 'labels': [label], 'meta': meta}
  275. data.append(j)
  276. else:
  277. raise FileParseException(line_num=i, line=row)
  278. if data:
  279. yield data
  280. class JSONParser(FileParser):
  281. def parse(self, file):
  282. file = io.TextIOWrapper(file, encoding='utf-8')
  283. data = []
  284. for i, line in enumerate(file, start=1):
  285. if len(data) >= settings.IMPORT_BATCH_SIZE:
  286. yield data
  287. data = []
  288. try:
  289. j = json.loads(line)
  290. #j = json.loads(line.decode('utf-8'))
  291. j['meta'] = json.dumps(j.get('meta', {}))
  292. data.append(j)
  293. except json.decoder.JSONDecodeError:
  294. raise FileParseException(line_num=i, line=line)
  295. if data:
  296. yield data
  297. class JSONLRenderer(JSONRenderer):
  298. def render(self, data, accepted_media_type=None, renderer_context=None):
  299. """
  300. Render `data` into JSON, returning a bytestring.
  301. """
  302. if data is None:
  303. return bytes()
  304. if not isinstance(data, list):
  305. data = [data]
  306. for d in data:
  307. yield json.dumps(d,
  308. cls=self.encoder_class,
  309. ensure_ascii=self.ensure_ascii,
  310. allow_nan=not self.strict) + '\n'
  311. class JSONPainter(object):
  312. def paint(self, documents):
  313. serializer = DocumentSerializer(documents, many=True)
  314. data = []
  315. for d in serializer.data:
  316. d['meta'] = json.loads(d['meta'])
  317. for a in d['annotations']:
  318. a.pop('id')
  319. a.pop('prob')
  320. a.pop('document')
  321. data.append(d)
  322. return data
  323. @staticmethod
  324. def paint_labels(documents, labels):
  325. serializer_labels = LabelSerializer(labels, many=True)
  326. serializer = DocumentSerializer(documents, many=True)
  327. data = []
  328. for d in serializer.data:
  329. labels = []
  330. for a in d['annotations']:
  331. label_obj = [x for x in serializer_labels.data if x['id'] == a['label']][0]
  332. label_text = label_obj['text']
  333. label_start = a['start_offset']
  334. label_end = a['end_offset']
  335. labels.append([label_start, label_end, label_text])
  336. d.pop('annotations')
  337. d['labels'] = labels
  338. d['meta'] = json.loads(d['meta'])
  339. data.append(d)
  340. return data
  341. class CSVPainter(JSONPainter):
  342. def paint(self, documents):
  343. data = super().paint(documents)
  344. res = []
  345. for d in data:
  346. annotations = d.pop('annotations')
  347. for a in annotations:
  348. res.append({**d, **a})
  349. return res
  350. class Color:
  351. def __init__(self, red, green, blue):
  352. self.red = red
  353. self.green = green
  354. self.blue = blue
  355. @property
  356. def contrast_color(self):
  357. """Generate black or white color.
  358. Ensure that text and background color combinations provide
  359. sufficient contrast when viewed by someone having color deficits or
  360. when viewed on a black and white screen.
  361. Algorithm from w3c:
  362. * https://www.w3.org/TR/AERT/#color-contrast
  363. """
  364. return Color.white() if self.brightness < 128 else Color.black()
  365. @property
  366. def brightness(self):
  367. return ((self.red * 299) + (self.green * 587) + (self.blue * 114)) / 1000
  368. @property
  369. def hex(self):
  370. return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
  371. @classmethod
  372. def white(cls):
  373. return cls(red=255, green=255, blue=255)
  374. @classmethod
  375. def black(cls):
  376. return cls(red=0, green=0, blue=0)
  377. @classmethod
  378. def random(cls, seed=None):
  379. rgb = Random(seed).choices(range(256), k=3)
  380. return cls(*rgb)
  381. def iterable_to_io(iterable, buffer_size=io.DEFAULT_BUFFER_SIZE):
  382. """See https://stackoverflow.com/a/20260030/3817588."""
  383. class IterStream(io.RawIOBase):
  384. def __init__(self):
  385. self.leftover = None
  386. def readable(self):
  387. return True
  388. def readinto(self, b):
  389. try:
  390. l = len(b) # We're supposed to return at most this much
  391. chunk = self.leftover or next(iterable)
  392. output, self.leftover = chunk[:l], chunk[l:]
  393. b[:len(output)] = output
  394. return len(output)
  395. except StopIteration:
  396. return 0 # indicate EOF
  397. return io.BufferedReader(IterStream(), buffer_size=buffer_size)