You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

45 lines
933 B

6 years ago
  1. """
  2. Utilities.
  3. """
  4. import json
  5. def train_test_split(data):
  6. x_train, x_test, y_train, ids = [], [], [], []
  7. for d in data:
  8. text = d['text']
  9. label = d['label']
  10. if d['manual']:
  11. x_train.append(text)
  12. y_train.append(label)
  13. else:
  14. x_test.append(text)
  15. ids.append(d['id'])
  16. return x_train, x_test, y_train, ids
  17. def load_dataset(filename):
  18. with open(filename) as f:
  19. data = [json.loads(line) for line in f]
  20. return data
  21. def save_dataset(obj, filename):
  22. with open(filename, 'w') as f:
  23. for line in obj:
  24. f.write('{}\n'.format(json.dumps(line)))
  25. def make_output(data, ids, y_pred, y_prob):
  26. i = 0
  27. for d in data:
  28. if i == len(ids):
  29. break
  30. if d['id'] == ids[i]:
  31. d['label'] = str(y_pred[i])
  32. d['prob'] = float(y_prob[i])
  33. i += 1
  34. return data