Browse Source

Add session key following SIP007

pull/1243/head
Max Lv 7 years ago
parent
commit
6c62647a66
3 changed files with 224 additions and 113 deletions
  1. 217
      src/aead.c
  2. 104
      src/crypto.c
  3. 16
      src/crypto.h

217
src/aead.c

@ -185,7 +185,7 @@ static const int supported_aead_ciphers_tag_size[AEAD_CIPHER_NUM] = {
};
static int
cipher_aead_encrypt(cipher_ctx_t *cipher_ctx,
aead_cipher_encrypt(cipher_ctx_t *cipher_ctx,
uint8_t *c,
size_t *clen,
uint8_t *m,
@ -193,13 +193,14 @@ cipher_aead_encrypt(cipher_ctx_t *cipher_ctx,
uint8_t *ad,
size_t adlen,
uint8_t *n,
uint8_t *k,
size_t nlen,
size_t tlen)
uint8_t *k)
{
int err = CRYPTO_OK;
unsigned long long long_clen = 0;
size_t nlen = cipher_ctx->cipher->nonce_len;
size_t tlen = cipher_ctx->cipher->tag_len;
switch (cipher_ctx->cipher->method) {
case AES128GCM:
case AES192GCM:
@ -228,21 +229,18 @@ cipher_aead_encrypt(cipher_ctx_t *cipher_ctx,
}
static int
cipher_aead_decrypt(cipher_ctx_t *cipher_ctx,
uint8_t *p,
size_t *plen,
uint8_t *m,
size_t mlen,
uint8_t *ad,
size_t adlen,
uint8_t *n,
uint8_t *k,
size_t nlen,
size_t tlen)
aead_cipher_decrypt(cipher_ctx_t *cipher_ctx,
uint8_t *p, size_t *plen,
uint8_t *m, size_t mlen,
uint8_t *ad, size_t adlen,
uint8_t *n, uint8_t *k)
{
int err = CRYPTO_ERROR;
unsigned long long long_plen = 0;
size_t nlen = cipher_ctx->cipher->nonce_len;
size_t tlen = cipher_ctx->cipher->tag_len;
switch (cipher_ctx->cipher->method) {
case AES128GCM:
case AES192GCM:
@ -296,6 +294,34 @@ aead_get_cipher_type(int method)
return mbedtls_cipher_info_from_string(mbedtlsname);
}
static void
aead_cipher_ctx_set_key(cipher_ctx_t *cipher_ctx, int enc)
{
const digest_type_t *md = mbedtls_md_info_from_string("SHA1");
if (md == NULL) {
FATAL("SHA1 Digest not found in crypto library");
}
int err = crypto_hkdf(md,
cipher_ctx->salt, cipher_ctx->cipher->key_len,
cipher_ctx->cipher->key, cipher_ctx->cipher->key_len,
(uint8_t *)SUBKEY_INFO, strlen(SUBKEY_INFO),
cipher_ctx->skey, cipher_ctx->cipher->key_len);
if (err) {
FATAL("Unable to generate subkey");
}
memset(cipher_ctx->nonce, 0, cipher_ctx->cipher->nonce_len);
if (mbedtls_cipher_setkey(cipher_ctx->evp, cipher_ctx->skey,
cipher_ctx->cipher->key_len * 8, enc) != 0) {
FATAL("Cannot set mbed TLS cipher key");
}
if (mbedtls_cipher_reset(cipher_ctx->evp) != 0) {
FATAL("Cannot finish preparation of mbed TLS cipher context");
}
}
static void
aead_cipher_ctx_init(cipher_ctx_t *cipher_ctx, int method, int enc)
{
@ -324,15 +350,6 @@ aead_cipher_ctx_init(cipher_ctx_t *cipher_ctx, int method, int enc)
if (mbedtls_cipher_setup(evp, cipher) != 0) {
FATAL("Cannot initialize mbed TLS cipher context");
}
if (mbedtls_cipher_setkey(evp, cipher_ctx->cipher->key,
cipher_ctx->cipher->key_len * 8, enc) != 0) {
mbedtls_cipher_free(evp);
FATAL("Cannot set mbed TLS cipher key");
}
if (mbedtls_cipher_reset(evp) != 0) {
mbedtls_cipher_free(evp);
FATAL("Cannot finish preparation of mbed TLS cipher context");
}
#ifdef DEBUG
dump("KEY", (char *)cipher_ctx->cipher->key, cipher_ctx->cipher->key_len);
@ -348,7 +365,7 @@ aead_ctx_init(cipher_t *cipher, cipher_ctx_t *cipher_ctx, int enc)
aead_cipher_ctx_init(cipher_ctx, cipher->method, enc);
if (enc) {
rand_bytes(cipher_ctx->nonce, cipher->nonce_len);
rand_bytes(cipher_ctx->salt, cipher->key_len);
}
}
@ -375,32 +392,25 @@ aead_encrypt_all(buffer_t *plaintext, cipher_t *cipher, size_t capacity)
cipher_ctx_t cipher_ctx;
aead_ctx_init(cipher, &cipher_ctx, 1);
size_t nonce_len = cipher->nonce_len;
size_t salt_len = cipher->key_len;
size_t tag_len = cipher->tag_len;
int err = CRYPTO_OK;
static buffer_t tmp = { 0, 0, 0, NULL };
brealloc(&tmp, nonce_len + tag_len + plaintext->len, capacity);
brealloc(&tmp, salt_len + tag_len + plaintext->len, capacity);
buffer_t *ciphertext = &tmp;
ciphertext->len = tag_len + plaintext->len;
// generate nonce
uint8_t *nonce = cipher_ctx.nonce;
/* copy nonce to first pos */
memcpy(ciphertext->data, nonce, nonce_len);
/* copy salt to first pos */
memcpy(ciphertext->data, cipher_ctx.salt, salt_len);
aead_cipher_ctx_set_key(&cipher_ctx, 1);
size_t clen = ciphertext->len;
err = cipher_aead_encrypt(&cipher_ctx,
(uint8_t *)ciphertext->data + nonce_len,
&clen,
(uint8_t *)plaintext->data,
plaintext->len,
NULL,
0,
nonce,
cipher->key,
nonce_len,
tag_len);
err = aead_cipher_encrypt(&cipher_ctx,
(uint8_t *)ciphertext->data + salt_len, &clen,
(uint8_t *)plaintext->data, plaintext->len,
NULL, 0, cipher_ctx.nonce, cipher_ctx.skey);
if (err) {
bfree(plaintext);
@ -408,18 +418,12 @@ aead_encrypt_all(buffer_t *plaintext, cipher_t *cipher, size_t capacity)
return CRYPTO_ERROR;
}
#ifdef DEBUG
dump("PLAIN", plaintext->data, plaintext->len);
dump("CIPHER", ciphertext->data + nonce_len, ciphertext->len);
#endif
aead_ctx_release(&cipher_ctx);
assert(ciphertext->len == clen);
brealloc(plaintext, nonce_len + ciphertext->len, capacity);
memcpy(plaintext->data, ciphertext->data, nonce_len + ciphertext->len);
plaintext->len = nonce_len + ciphertext->len;
brealloc(plaintext, salt_len + ciphertext->len, capacity);
memcpy(plaintext->data, ciphertext->data, salt_len + ciphertext->len);
plaintext->len = salt_len + ciphertext->len;
return CRYPTO_OK;
}
@ -427,11 +431,11 @@ aead_encrypt_all(buffer_t *plaintext, cipher_t *cipher, size_t capacity)
int
aead_decrypt_all(buffer_t *ciphertext, cipher_t *cipher, size_t capacity)
{
size_t nonce_len = cipher->nonce_len;
size_t salt_len = cipher->key_len;
size_t tag_len = cipher->tag_len;
int err = CRYPTO_OK;
if (ciphertext->len <= nonce_len + tag_len) {
if (ciphertext->len <= salt_len + tag_len) {
return CRYPTO_ERROR;
}
@ -441,24 +445,20 @@ aead_decrypt_all(buffer_t *ciphertext, cipher_t *cipher, size_t capacity)
static buffer_t tmp = { 0, 0, 0, NULL };
brealloc(&tmp, ciphertext->len, capacity);
buffer_t *plaintext = &tmp;
plaintext->len = ciphertext->len - nonce_len - tag_len;
plaintext->len = ciphertext->len - salt_len - tag_len;
/* get nonce */
uint8_t *nonce = cipher_ctx.nonce;
memcpy(nonce, ciphertext->data, nonce_len);
/* get salt */
uint8_t *salt = cipher_ctx.salt;
memcpy(salt, ciphertext->data, salt_len);
aead_cipher_ctx_set_key(&cipher_ctx, 0);
size_t plen = plaintext->len;
err = cipher_aead_decrypt(&cipher_ctx,
(uint8_t *)plaintext->data,
&plen,
(uint8_t *)ciphertext->data + nonce_len,
ciphertext->len - nonce_len,
NULL,
0,
nonce,
cipher->key,
nonce_len,
tag_len);
err = aead_cipher_decrypt(&cipher_ctx,
(uint8_t *)plaintext->data, &plen,
(uint8_t *)ciphertext->data + salt_len,
ciphertext->len - salt_len, NULL, 0,
cipher_ctx.nonce, cipher_ctx.skey);
if (err) {
bfree(ciphertext);
@ -466,11 +466,6 @@ aead_decrypt_all(buffer_t *ciphertext, cipher_t *cipher, size_t capacity)
return CRYPTO_ERROR;
}
#ifdef DEBUG
dump("PLAIN", plaintext->data, plaintext->len);
dump("CIPHER", ciphertext->data + nonce_len, ciphertext->len - nonce_len);
#endif
aead_ctx_release(&cipher_ctx);
brealloc(ciphertext, plaintext->len, capacity);
@ -481,9 +476,12 @@ aead_decrypt_all(buffer_t *ciphertext, cipher_t *cipher, size_t capacity)
}
static int
aead_chunk_encrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n,
uint16_t plen, size_t nlen, size_t tlen)
aead_chunk_encrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c,
uint8_t *n, uint16_t plen)
{
size_t nlen = ctx->cipher->nonce_len;
size_t tlen = ctx->cipher->tag_len;
assert(plen + tlen < CHUNK_SIZE_MASK);
int err;
@ -493,8 +491,8 @@ aead_chunk_encrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n,
memcpy(len_buf, &t, CHUNK_SIZE_LEN);
clen = CHUNK_SIZE_LEN + tlen;
err = cipher_aead_encrypt(ctx, c, &clen, len_buf, CHUNK_SIZE_LEN,
NULL, 0, n, ctx->cipher->key, nlen, tlen);
err = aead_cipher_encrypt(ctx, c, &clen, len_buf, CHUNK_SIZE_LEN,
NULL, 0, n, ctx->skey);
if (err)
return CRYPTO_ERROR;
assert(clen == CHUNK_SIZE_LEN + tlen);
@ -502,8 +500,8 @@ aead_chunk_encrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n,
sodium_increment(n, nlen);
clen = plen + tlen;
err = cipher_aead_encrypt(ctx, c + CHUNK_SIZE_LEN + tlen, &clen, p, plen,
NULL, 0, n, ctx->cipher->key, nlen, tlen);
err = aead_cipher_encrypt(ctx, c + CHUNK_SIZE_LEN + tlen, &clen,p, plen,
NULL, 0, n, ctx->skey);
if (err)
return CRYPTO_ERROR;
assert(clen == plen + tlen);
@ -529,36 +527,32 @@ aead_encrypt(buffer_t *plaintext, cipher_ctx_t *cipher_ctx, size_t capacity)
cipher_t *cipher = cipher_ctx->cipher;
int err = CRYPTO_ERROR;
size_t nonce_ofst = 0;
size_t nonce_len = cipher->nonce_len;
size_t salt_ofst = 0;
size_t salt_len = cipher->key_len;
size_t tag_len = cipher->tag_len;
if (!cipher_ctx->init) {
nonce_ofst = nonce_len;
salt_ofst = salt_len;
}
size_t out_len = nonce_ofst + 2 * tag_len + plaintext->len + CHUNK_SIZE_LEN;
size_t out_len = salt_ofst + 2 * tag_len + plaintext->len + CHUNK_SIZE_LEN;
brealloc(&tmp, out_len, capacity);
ciphertext = &tmp;
ciphertext->len = out_len;
if (!cipher_ctx->init) {
memcpy(ciphertext->data, cipher_ctx->nonce, nonce_len);
memcpy(ciphertext->data, cipher_ctx->salt, salt_len);
aead_cipher_ctx_set_key(cipher_ctx, 1);
cipher_ctx->init = 1;
}
err = aead_chunk_encrypt(cipher_ctx,
(uint8_t *)plaintext->data,
(uint8_t *)ciphertext->data + nonce_ofst,
cipher_ctx->nonce, plaintext->len, nonce_len, tag_len);
(uint8_t *)ciphertext->data + salt_ofst,
cipher_ctx->nonce, plaintext->len);
if (err)
return err;
#ifdef DEBUG
dump("PLAIN", plaintext->data, plaintext->len);
dump("CIPHER", ciphertext->data + nonce_ofst, ciphertext->len);
#endif
brealloc(plaintext, ciphertext->len, capacity);
memcpy(plaintext->data, ciphertext->data, ciphertext->len);
plaintext->len = ciphertext->len;
@ -568,17 +562,19 @@ aead_encrypt(buffer_t *plaintext, cipher_ctx_t *cipher_ctx, size_t capacity)
static int
aead_chunk_decrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n,
size_t *plen, size_t *clen, size_t nlen, size_t tlen)
size_t *plen, size_t *clen)
{
int err;
size_t mlen;
size_t nlen = ctx->cipher->nonce_len;
size_t tlen = ctx->cipher->tag_len;
if (*clen <= 2 * tlen + CHUNK_SIZE_LEN)
return CRYPTO_NEED_MORE;
uint8_t len_buf[2];
err = cipher_aead_decrypt(ctx, len_buf, plen, c, CHUNK_SIZE_LEN + tlen,
NULL, 0, n, ctx->cipher->key, nlen, tlen);
err = aead_cipher_decrypt(ctx, len_buf, plen, c, CHUNK_SIZE_LEN + tlen,
NULL, 0, n, ctx->skey);
if (err)
return CRYPTO_ERROR;
assert(*plen == CHUNK_SIZE_LEN);
@ -596,8 +592,8 @@ aead_chunk_decrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n,
sodium_increment(n, nlen);
err = cipher_aead_decrypt(ctx, p, plen, c + CHUNK_SIZE_LEN + tlen, mlen + tlen,
NULL, 0, n, ctx->cipher->key, nlen, tlen);
err = aead_cipher_decrypt(ctx, p, plen, c + CHUNK_SIZE_LEN + tlen, mlen + tlen,
NULL, 0, n, ctx->skey);
if (err)
return CRYPTO_ERROR;
assert(*plen == mlen);
@ -620,8 +616,7 @@ aead_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity)
cipher_t *cipher = cipher_ctx->cipher;
size_t nonce_len = cipher->nonce_len;
size_t tag_len = cipher->tag_len;
size_t salt_len = cipher->key_len;
if (cipher_ctx->chunk == NULL) {
cipher_ctx->chunk = (buffer_t *)ss_malloc(sizeof(buffer_t));
@ -639,20 +634,23 @@ aead_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity)
buffer_t *plaintext = &tmp;
if (!cipher_ctx->init) {
if (cipher_ctx->chunk->len <= nonce_len)
if (cipher_ctx->chunk->len <= salt_len)
return CRYPTO_NEED_MORE;
memcpy(cipher_ctx->nonce, cipher_ctx->chunk->data, nonce_len);
if (cache_key_exist(nonce_cache, (char *)cipher_ctx->nonce, nonce_len)) {
memcpy(cipher_ctx->salt, cipher_ctx->chunk->data, salt_len);
aead_cipher_ctx_set_key(cipher_ctx, 0);
if (cache_key_exist(nonce_cache, (char *)cipher_ctx->salt, salt_len)) {
bfree(ciphertext);
return CRYPTO_ERROR;
} else {
cache_insert(nonce_cache, (char *)cipher_ctx->nonce, nonce_len, NULL);
cache_insert(nonce_cache, (char *)cipher_ctx->salt, salt_len, NULL);
}
memmove(cipher_ctx->chunk->data, cipher_ctx->chunk->data + nonce_len,
cipher_ctx->chunk->len - nonce_len);
cipher_ctx->chunk->len -= nonce_len;
memmove(cipher_ctx->chunk->data, cipher_ctx->chunk->data + salt_len,
cipher_ctx->chunk->len - salt_len);
cipher_ctx->chunk->len -= salt_len;
cipher_ctx->init = 1;
}
@ -664,9 +662,7 @@ aead_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity)
err = aead_chunk_decrypt(cipher_ctx,
(uint8_t *)plaintext->data + plen,
(uint8_t *)cipher_ctx->chunk->data,
cipher_ctx->nonce,
&chunk_plen, &chunk_clen,
nonce_len, tag_len);
cipher_ctx->nonce, &chunk_plen, &chunk_clen);
if (err == CRYPTO_ERROR) {
return err;
} else if (err == CRYPTO_NEED_MORE) {
@ -680,11 +676,6 @@ aead_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity)
}
plaintext->len = plen;
#ifdef DEBUG
dump("PLAIN", plaintext->data, plaintext->len);
dump("CIPHER", ciphertext->data + nonce_len, ciphertext->len - nonce_len);
#endif
brealloc(ciphertext, plaintext->len, capacity);
memcpy(ciphertext->data, plaintext->data, plaintext->len);
ciphertext->len = plaintext->len;

104
src/crypto.c

@ -209,6 +209,110 @@ crypto_derive_key(const char *pass, uint8_t *key, size_t key_len)
return key_len;
}
/* HKDF-Extract + HKDF-Expand */
int crypto_hkdf(const mbedtls_md_info_t *md, const unsigned char *salt,
int salt_len, const unsigned char *ikm, int ikm_len,
const unsigned char *info, int info_len, unsigned char *okm,
int okm_len)
{
unsigned char prk[MBEDTLS_MD_MAX_SIZE];
return crypto_hkdf_extract(md, salt, salt_len, ikm, ikm_len, prk) ||
crypto_hkdf_expand(md, prk, mbedtls_md_get_size(md), info, info_len,
okm, okm_len);
}
/* HKDF-Extract(salt, IKM) -> PRK */
int crypto_hkdf_extract(const mbedtls_md_info_t *md, const unsigned char *salt,
int salt_len, const unsigned char *ikm, int ikm_len,
unsigned char *prk)
{
int hash_len;
unsigned char null_salt[MBEDTLS_MD_MAX_SIZE] = { '\0' };
if (salt_len < 0) {
return CRYPTO_ERROR;
}
hash_len = mbedtls_md_get_size(md);
if (salt == NULL) {
salt = null_salt;
salt_len = hash_len;
}
return mbedtls_md_hmac(md, salt, salt_len, ikm, ikm_len, prk);
}
/* HKDF-Expand(PRK, info, L) -> OKM */
int crypto_hkdf_expand(const mbedtls_md_info_t *md, const unsigned char *prk,
int prk_len, const unsigned char *info, int info_len,
unsigned char *okm, int okm_len)
{
int hash_len;
int N;
int T_len = 0, where = 0, i, ret;
mbedtls_md_context_t ctx;
unsigned char T[MBEDTLS_MD_MAX_SIZE];
if (info_len < 0 || okm_len < 0 || okm == NULL) {
return CRYPTO_ERROR;
}
hash_len = mbedtls_md_get_size(md);
if (prk_len < hash_len) {
return CRYPTO_ERROR;
}
if (info == NULL) {
info = (const unsigned char *)"";
}
N = okm_len / hash_len;
if ((okm_len % hash_len) != 0) {
N++;
}
if (N > 255) {
return CRYPTO_ERROR;
}
mbedtls_md_init(&ctx);
if ((ret = mbedtls_md_setup(&ctx, md, 1)) != 0) {
mbedtls_md_free(&ctx);
return ret;
}
/* Section 2.3. */
for (i = 1; i <= N; i++) {
unsigned char c = i;
ret = mbedtls_md_hmac_starts(&ctx, prk, prk_len) ||
mbedtls_md_hmac_update(&ctx, T, T_len) ||
mbedtls_md_hmac_update(&ctx, info, info_len) ||
/* The constant concatenated to the end of each T(n) is a single
octet. */
mbedtls_md_hmac_update(&ctx, &c, 1) ||
mbedtls_md_hmac_finish(&ctx, T);
if (ret != 0) {
mbedtls_md_free(&ctx);
return ret;
}
memcpy(okm + where, T, (i != N) ? hash_len : (okm_len - where));
where += hash_len;
T_len = hash_len;
}
mbedtls_md_free(&ctx);
return 0;
}
int
crypto_parse_key(const char *base64, uint8_t *key, size_t key_len)
{

16
src/crypto.h

@ -63,6 +63,9 @@ typedef mbedtls_md_info_t digest_type_t;
#define min(a, b) (((a) < (b)) ? (a) : (b))
#define max(a, b) (((a) > (b)) ? (a) : (b))
#define SUBKEY_INFO "ss-subkey"
#define IV_INFO "ss-iv"
typedef struct buffer {
size_t idx;
size_t len;
@ -72,6 +75,7 @@ typedef struct buffer {
typedef struct {
int method;
int skey;
cipher_kt_t *info;
size_t nonce_len;
size_t key_len;
@ -85,6 +89,8 @@ typedef struct {
cipher_evp_t *evp;
cipher_t *cipher;
buffer_t *chunk;
uint8_t salt[MAX_KEY_LENGTH];
uint8_t skey[MAX_KEY_LENGTH];
uint8_t nonce[MAX_NONCE_LENGTH];
} cipher_ctx_t;
@ -111,6 +117,16 @@ unsigned char *crypto_md5(const unsigned char *, size_t, unsigned char *);
int crypto_derive_key(const char *, uint8_t *, size_t);
int crypto_parse_key(const char *, uint8_t *, size_t);
int crypto_hkdf(const mbedtls_md_info_t *md, const unsigned char *salt,
int salt_len, const unsigned char *ikm, int ikm_len,
const unsigned char *info, int info_len, unsigned char *okm,
int okm_len);
int crypto_hkdf_extract(const mbedtls_md_info_t *md, const unsigned char *salt,
int salt_len, const unsigned char *ikm, int ikm_len,
unsigned char *prk);
int crypto_hkdf_expand(const mbedtls_md_info_t *md, const unsigned char *prk,
int prk_len, const unsigned char *info, int info_len,
unsigned char *okm, int okm_len);
extern struct cache *nonce_cache;
extern const char *supported_stream_ciphers[];

Loading…
Cancel
Save