You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

446 lines
14 KiB

  1. import csv
  2. import io
  3. import itertools
  4. import json
  5. import re
  6. from collections import defaultdict
  7. from django.db import transaction
  8. from rest_framework.renderers import JSONRenderer
  9. from seqeval.metrics.sequence_labeling import get_entities
  10. from app.settings import IMPORT_BATCH_SIZE
  11. from .exceptions import FileParseException
  12. from .models import Label
  13. from .serializers import DocumentSerializer, LabelSerializer
  14. def extract_label(tag):
  15. ptn = re.compile(r'(B|I|E|S)-(.+)')
  16. m = ptn.match(tag)
  17. if m:
  18. return m.groups()[1]
  19. else:
  20. return tag
  21. class BaseStorage(object):
  22. def __init__(self, data, project):
  23. self.data = data
  24. self.project = project
  25. @transaction.atomic
  26. def save(self, user):
  27. raise NotImplementedError()
  28. def save_doc(self, data):
  29. serializer = DocumentSerializer(data=data, many=True)
  30. serializer.is_valid(raise_exception=True)
  31. doc = serializer.save(project=self.project)
  32. return doc
  33. def save_label(self, data):
  34. serializer = LabelSerializer(data=data, many=True)
  35. serializer.is_valid(raise_exception=True)
  36. label = serializer.save(project=self.project)
  37. return label
  38. def save_annotation(self, data, user):
  39. annotation_serializer = self.project.get_annotation_serializer()
  40. serializer = annotation_serializer(data=data, many=True)
  41. serializer.is_valid(raise_exception=True)
  42. annotation = serializer.save(user=user)
  43. return annotation
  44. def extract_label(self, data):
  45. """Extract labels from parsed data.
  46. Example:
  47. >>> data = [{"labels": ["positive"]}, {"labels": ["negative"]}]
  48. >>> self.extract_label(data)
  49. [["positive"], ["negative"]]
  50. """
  51. return [d.get('labels', []) for d in data]
  52. def exclude_created_labels(self, labels, created):
  53. """Exclude created labels.
  54. Example:
  55. >>> labels = ["positive", "negative"]
  56. >>> created = {"positive": ...}
  57. >>> self.exclude_created_labels(labels, created)
  58. ["negative"]
  59. """
  60. return [label for label in labels if label not in created]
  61. def to_serializer_format(self, labels):
  62. """Exclude created labels.
  63. Example:
  64. >>> labels = ["positive"]
  65. >>> self.to_serializer_format(labels)
  66. [{"text": "negative"}]
  67. ```
  68. """
  69. return [{'text': label} for label in labels]
  70. def update_saved_labels(self, saved, new):
  71. """Update saved labels.
  72. Example:
  73. >>> saved = {'positive': ...}
  74. >>> new = [<Label: positive>]
  75. """
  76. for label in new:
  77. saved[label.text] = label
  78. return saved
  79. class PlainStorage(BaseStorage):
  80. @transaction.atomic
  81. def save(self, user):
  82. for text in self.data:
  83. self.save_doc(text)
  84. class ClassificationStorage(BaseStorage):
  85. """Store json for text classification.
  86. The format is as follows:
  87. {"text": "Python is awesome!", "labels": ["positive"]}
  88. ...
  89. """
  90. @transaction.atomic
  91. def save(self, user):
  92. saved_labels = {label.text: label for label in self.project.labels.all()}
  93. for data in self.data:
  94. docs = self.save_doc(data)
  95. labels = self.extract_label(data)
  96. unique_labels = self.extract_unique_labels(labels)
  97. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  98. unique_labels = self.to_serializer_format(unique_labels)
  99. new_labels = self.save_label(unique_labels)
  100. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  101. annotations = self.make_annotations(docs, labels, saved_labels)
  102. self.save_annotation(annotations, user)
  103. def extract_unique_labels(self, labels):
  104. """Extract unique labels
  105. Example:
  106. >>> labels = [["positive"], ["positive", "negative"], ["negative"]]
  107. >>> self.extract_unique_labels(labels)
  108. ["positive", "negative"]
  109. """
  110. return set(itertools.chain(*labels))
  111. def make_annotations(self, docs, labels, saved_labels):
  112. """Make list of annotation obj for serializer.
  113. Example:
  114. >>> docs = ["<Document: a>", "<Document: b>", "<Document: c>"]
  115. >>> labels = [["positive"], ["positive", "negative"], ["negative"]]
  116. >>> saved_labels = {"positive": "<Label: positive>", 'negative': "<Label: negative>"}
  117. >>> self.make_annotations(docs, labels, saved_labels)
  118. [{"document": 1, "label": 1}, {"document": 2, "label": 1}
  119. {"document": 2, "label": 2}, {"document": 3, "label": 2}]
  120. """
  121. annotations = []
  122. for doc, label in zip(docs, labels):
  123. for name in label:
  124. label = saved_labels[name]
  125. annotations.append({'document': doc.id, 'label': label.id})
  126. return annotations
  127. class SequenceLabelingStorage(BaseStorage):
  128. """Upload jsonl for sequence labeling.
  129. The format is as follows:
  130. {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
  131. ...
  132. """
  133. @transaction.atomic
  134. def save(self, user):
  135. saved_labels = {label.text: label for label in self.project.labels.all()}
  136. for data in self.data:
  137. docs = self.save_doc(data)
  138. labels = self.extract_label(data)
  139. unique_labels = self.extract_unique_labels(labels)
  140. unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
  141. unique_labels = self.to_serializer_format(unique_labels)
  142. new_labels = self.save_label(unique_labels)
  143. saved_labels = self.update_saved_labels(saved_labels, new_labels)
  144. annotations = self.make_annotations(docs, labels, saved_labels)
  145. self.save_annotation(annotations, user)
  146. def extract_unique_labels(self, labels):
  147. """Extract unique labels
  148. Example:
  149. >>> labels = [[[0, 1, "LOC"]], [[3, 4, "ORG"]]]
  150. >>> self.extract_unique_labels(labels)
  151. ["LOC", "ORG"]
  152. """
  153. return set([label for _, _, label in itertools.chain(*labels)])
  154. def make_annotations(self, docs, labels, saved_labels):
  155. """Make list of annotation obj for serializer.
  156. Example:
  157. >>> docs = ["<Document: a>", "<Document: b>"]
  158. >>> labels = labels = [[[0, 1, "LOC"]], [[3, 4, "ORG"]]]
  159. >>> saved_labels = {"LOC": "<Label: LOC>", 'ORG': "<Label: ORG>"}
  160. >>> self.make_annotations(docs, labels, saved_labels)
  161. [
  162. {"document": 1, "label": 1, "start_offset": 0, "end_offset": 1}
  163. {"document": 2, "label": 2, "start_offset": 3, "end_offset": 4}
  164. ]
  165. """
  166. annotations = []
  167. for doc, spans in zip(docs, labels):
  168. for span in spans:
  169. start_offset, end_offset, name = span
  170. label = saved_labels[name]
  171. annotations.append({'document': doc.id,
  172. 'label': label.id,
  173. 'start_offset': start_offset,
  174. 'end_offset': end_offset})
  175. return annotations
  176. class Seq2seqStorage(BaseStorage):
  177. """Store json for seq2seq.
  178. The format is as follows:
  179. {"text": "Hello, World!", "labels": ["こんにちは、世界!"]}
  180. ...
  181. """
  182. @transaction.atomic
  183. def save(self, user):
  184. for data in self.data:
  185. doc = self.save_doc(data)
  186. labels = self.extract_label(data)
  187. annotations = self.make_annotations(doc, labels)
  188. self.save_annotation(annotations, user)
  189. def make_annotations(self, docs, labels):
  190. """Make list of annotation obj for serializer.
  191. Example:
  192. >>> docs = ["<Document: a>", "<Document: b>"]
  193. >>> labels = [["Hello!"], ["How are you?", "What's up?"]]
  194. >>> self.make_annotations(docs, labels)
  195. [{"document": 1, "text": "Hello"}, {"document": 2, "text": "How are you?"}
  196. {"document": 2, "text": "What's up?"}]
  197. """
  198. annotations = []
  199. for doc, texts in zip(docs, labels):
  200. for text in texts:
  201. annotations.append({'document': doc.id, 'text': text})
  202. return annotations
  203. class FileParser(object):
  204. def parse(self, file):
  205. raise NotImplementedError()
  206. class CoNLLParser(FileParser):
  207. """Uploads CoNLL format file.
  208. The file format is tab-separated values.
  209. A blank line is required at the end of a sentence.
  210. For example:
  211. ```
  212. EU B-ORG
  213. rejects O
  214. German B-MISC
  215. call O
  216. to O
  217. boycott O
  218. British B-MISC
  219. lamb O
  220. . O
  221. Peter B-PER
  222. Blackburn I-PER
  223. ...
  224. ```
  225. """
  226. def parse(self, file):
  227. """Store json for seq2seq.
  228. Return format:
  229. {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
  230. ...
  231. """
  232. words, tags = [], []
  233. data = []
  234. for i, line in enumerate(file, start=1):
  235. if len(data) >= IMPORT_BATCH_SIZE:
  236. yield data
  237. data = []
  238. line = line.decode('utf-8')
  239. line = line.strip()
  240. if line:
  241. try:
  242. word, tag = line.split('\t')
  243. except ValueError:
  244. raise FileParseException(line_num=i, line=line)
  245. words.append(word)
  246. tags.append(tag)
  247. else:
  248. j = self.calc_char_offset(words, tags)
  249. data.append(j)
  250. words, tags = [], []
  251. if len(words) > 0:
  252. j = self.calc_char_offset(words, tags)
  253. data.append(j)
  254. yield data
  255. def calc_char_offset(self, words, tags):
  256. """
  257. Examples:
  258. >>> words = ['EU', 'rejects', 'German', 'call']
  259. >>> tags = ['B-ORG', 'O', 'B-MISC', 'O']
  260. >>> entities = get_entities(tags)
  261. >>> entities
  262. [['ORG', 0, 0], ['MISC', 2, 2]]
  263. >>> self.calc_char_offset(words, tags)
  264. {
  265. 'text': 'EU rejects German call',
  266. 'labels': [[0, 2, 'ORG'], [11, 17, 'MISC']]
  267. }
  268. """
  269. doc = ' '.join(words)
  270. j = {'text': ' '.join(words), 'labels': []}
  271. pos = defaultdict(int)
  272. for label, start_offset, end_offset in get_entities(tags):
  273. entity = ' '.join(words[start_offset: end_offset + 1])
  274. char_left = doc.index(entity, pos[entity])
  275. char_right = char_left + len(entity)
  276. span = [char_left, char_right, label]
  277. j['labels'].append(span)
  278. pos[entity] = char_right
  279. return j
  280. class PlainTextParser(FileParser):
  281. """Uploads plain text.
  282. The file format is as follows:
  283. ```
  284. EU rejects German call to boycott British lamb.
  285. President Obama is speaking at the White House.
  286. ...
  287. ```
  288. """
  289. def parse(self, file):
  290. file = io.TextIOWrapper(file, encoding='utf-8')
  291. while True:
  292. batch = list(itertools.islice(file, IMPORT_BATCH_SIZE))
  293. if not batch:
  294. raise StopIteration
  295. yield [{'text': line.strip()} for line in batch]
  296. class CSVParser(FileParser):
  297. """Uploads csv file.
  298. The file format is comma separated values.
  299. Column names are required at the top of a file.
  300. For example:
  301. ```
  302. text, label
  303. "EU rejects German call to boycott British lamb.",Politics
  304. "President Obama is speaking at the White House.",Politics
  305. "He lives in Newark, Ohio.",Other
  306. ...
  307. ```
  308. """
  309. def parse(self, file):
  310. file = io.TextIOWrapper(file, encoding='utf-8')
  311. reader = csv.reader(file)
  312. columns = next(reader)
  313. data = []
  314. for i, row in enumerate(reader, start=2):
  315. if len(data) >= IMPORT_BATCH_SIZE:
  316. yield data
  317. data = []
  318. if len(row) == len(columns) and len(row) >= 2:
  319. text, label = row[:2]
  320. meta = json.dumps(dict(zip(columns[2:], row[2:])))
  321. j = {'text': text, 'labels': [label], 'meta': meta}
  322. data.append(j)
  323. else:
  324. raise FileParseException(line_num=i, line=row)
  325. if data:
  326. yield data
  327. class JSONParser(FileParser):
  328. def parse(self, file):
  329. data = []
  330. for i, line in enumerate(file, start=1):
  331. if len(data) >= IMPORT_BATCH_SIZE:
  332. yield data
  333. data = []
  334. try:
  335. j = json.loads(line)
  336. j['meta'] = json.dumps(j.get('meta', {}))
  337. data.append(j)
  338. except json.decoder.JSONDecodeError:
  339. raise FileParseException(line_num=i, line=line)
  340. if data:
  341. yield data
  342. class JSONLRenderer(JSONRenderer):
  343. def render(self, data, accepted_media_type=None, renderer_context=None):
  344. """
  345. Render `data` into JSON, returning a bytestring.
  346. """
  347. if data is None:
  348. return bytes()
  349. if not isinstance(data, list):
  350. data = [data]
  351. for d in data:
  352. yield json.dumps(d,
  353. cls=self.encoder_class,
  354. ensure_ascii=self.ensure_ascii,
  355. allow_nan=not self.strict) + '\n'
  356. class JSONPainter(object):
  357. def paint(self, documents):
  358. serializer = DocumentSerializer(documents, many=True)
  359. data = []
  360. for d in serializer.data:
  361. d['meta'] = json.loads(d['meta'])
  362. for a in d['annotations']:
  363. a.pop('id')
  364. a.pop('prob')
  365. a.pop('document')
  366. data.append(d)
  367. return data
  368. class CSVPainter(JSONPainter):
  369. def paint(self, documents):
  370. data = super().paint(documents)
  371. res = []
  372. for d in data:
  373. annotations = d.pop('annotations')
  374. for a in annotations:
  375. res.append({**d, **a})
  376. return res