diff --git a/src/aead.c b/src/aead.c index baf2ce3e..938760c7 100644 --- a/src/aead.c +++ b/src/aead.c @@ -185,7 +185,7 @@ static const int supported_aead_ciphers_tag_size[AEAD_CIPHER_NUM] = { }; static int -cipher_aead_encrypt(cipher_ctx_t *cipher_ctx, +aead_cipher_encrypt(cipher_ctx_t *cipher_ctx, uint8_t *c, size_t *clen, uint8_t *m, @@ -193,13 +193,14 @@ cipher_aead_encrypt(cipher_ctx_t *cipher_ctx, uint8_t *ad, size_t adlen, uint8_t *n, - uint8_t *k, - size_t nlen, - size_t tlen) + uint8_t *k) { int err = CRYPTO_OK; unsigned long long long_clen = 0; + size_t nlen = cipher_ctx->cipher->nonce_len; + size_t tlen = cipher_ctx->cipher->tag_len; + switch (cipher_ctx->cipher->method) { case AES128GCM: case AES192GCM: @@ -228,21 +229,18 @@ cipher_aead_encrypt(cipher_ctx_t *cipher_ctx, } static int -cipher_aead_decrypt(cipher_ctx_t *cipher_ctx, - uint8_t *p, - size_t *plen, - uint8_t *m, - size_t mlen, - uint8_t *ad, - size_t adlen, - uint8_t *n, - uint8_t *k, - size_t nlen, - size_t tlen) +aead_cipher_decrypt(cipher_ctx_t *cipher_ctx, + uint8_t *p, size_t *plen, + uint8_t *m, size_t mlen, + uint8_t *ad, size_t adlen, + uint8_t *n, uint8_t *k) { int err = CRYPTO_ERROR; unsigned long long long_plen = 0; + size_t nlen = cipher_ctx->cipher->nonce_len; + size_t tlen = cipher_ctx->cipher->tag_len; + switch (cipher_ctx->cipher->method) { case AES128GCM: case AES192GCM: @@ -296,6 +294,34 @@ aead_get_cipher_type(int method) return mbedtls_cipher_info_from_string(mbedtlsname); } +static void +aead_cipher_ctx_set_key(cipher_ctx_t *cipher_ctx, int enc) +{ + const digest_type_t *md = mbedtls_md_info_from_string("SHA1"); + if (md == NULL) { + FATAL("SHA1 Digest not found in crypto library"); + } + + int err = crypto_hkdf(md, + cipher_ctx->salt, cipher_ctx->cipher->key_len, + cipher_ctx->cipher->key, cipher_ctx->cipher->key_len, + (uint8_t *)SUBKEY_INFO, strlen(SUBKEY_INFO), + cipher_ctx->skey, cipher_ctx->cipher->key_len); + if (err) { + FATAL("Unable to generate subkey"); + } + + memset(cipher_ctx->nonce, 0, cipher_ctx->cipher->nonce_len); + + if (mbedtls_cipher_setkey(cipher_ctx->evp, cipher_ctx->skey, + cipher_ctx->cipher->key_len * 8, enc) != 0) { + FATAL("Cannot set mbed TLS cipher key"); + } + if (mbedtls_cipher_reset(cipher_ctx->evp) != 0) { + FATAL("Cannot finish preparation of mbed TLS cipher context"); + } +} + static void aead_cipher_ctx_init(cipher_ctx_t *cipher_ctx, int method, int enc) { @@ -324,15 +350,6 @@ aead_cipher_ctx_init(cipher_ctx_t *cipher_ctx, int method, int enc) if (mbedtls_cipher_setup(evp, cipher) != 0) { FATAL("Cannot initialize mbed TLS cipher context"); } - if (mbedtls_cipher_setkey(evp, cipher_ctx->cipher->key, - cipher_ctx->cipher->key_len * 8, enc) != 0) { - mbedtls_cipher_free(evp); - FATAL("Cannot set mbed TLS cipher key"); - } - if (mbedtls_cipher_reset(evp) != 0) { - mbedtls_cipher_free(evp); - FATAL("Cannot finish preparation of mbed TLS cipher context"); - } #ifdef DEBUG dump("KEY", (char *)cipher_ctx->cipher->key, cipher_ctx->cipher->key_len); @@ -348,7 +365,7 @@ aead_ctx_init(cipher_t *cipher, cipher_ctx_t *cipher_ctx, int enc) aead_cipher_ctx_init(cipher_ctx, cipher->method, enc); if (enc) { - rand_bytes(cipher_ctx->nonce, cipher->nonce_len); + rand_bytes(cipher_ctx->salt, cipher->key_len); } } @@ -375,32 +392,25 @@ aead_encrypt_all(buffer_t *plaintext, cipher_t *cipher, size_t capacity) cipher_ctx_t cipher_ctx; aead_ctx_init(cipher, &cipher_ctx, 1); - size_t nonce_len = cipher->nonce_len; + size_t salt_len = cipher->key_len; size_t tag_len = cipher->tag_len; int err = CRYPTO_OK; static buffer_t tmp = { 0, 0, 0, NULL }; - brealloc(&tmp, nonce_len + tag_len + plaintext->len, capacity); + brealloc(&tmp, salt_len + tag_len + plaintext->len, capacity); buffer_t *ciphertext = &tmp; ciphertext->len = tag_len + plaintext->len; - // generate nonce - uint8_t *nonce = cipher_ctx.nonce; - /* copy nonce to first pos */ - memcpy(ciphertext->data, nonce, nonce_len); + /* copy salt to first pos */ + memcpy(ciphertext->data, cipher_ctx.salt, salt_len); + + aead_cipher_ctx_set_key(&cipher_ctx, 1); size_t clen = ciphertext->len; - err = cipher_aead_encrypt(&cipher_ctx, - (uint8_t *)ciphertext->data + nonce_len, - &clen, - (uint8_t *)plaintext->data, - plaintext->len, - NULL, - 0, - nonce, - cipher->key, - nonce_len, - tag_len); + err = aead_cipher_encrypt(&cipher_ctx, + (uint8_t *)ciphertext->data + salt_len, &clen, + (uint8_t *)plaintext->data, plaintext->len, + NULL, 0, cipher_ctx.nonce, cipher_ctx.skey); if (err) { bfree(plaintext); @@ -408,18 +418,12 @@ aead_encrypt_all(buffer_t *plaintext, cipher_t *cipher, size_t capacity) return CRYPTO_ERROR; } -#ifdef DEBUG - dump("PLAIN", plaintext->data, plaintext->len); - dump("CIPHER", ciphertext->data + nonce_len, ciphertext->len); -#endif - aead_ctx_release(&cipher_ctx); - assert(ciphertext->len == clen); - brealloc(plaintext, nonce_len + ciphertext->len, capacity); - memcpy(plaintext->data, ciphertext->data, nonce_len + ciphertext->len); - plaintext->len = nonce_len + ciphertext->len; + brealloc(plaintext, salt_len + ciphertext->len, capacity); + memcpy(plaintext->data, ciphertext->data, salt_len + ciphertext->len); + plaintext->len = salt_len + ciphertext->len; return CRYPTO_OK; } @@ -427,11 +431,11 @@ aead_encrypt_all(buffer_t *plaintext, cipher_t *cipher, size_t capacity) int aead_decrypt_all(buffer_t *ciphertext, cipher_t *cipher, size_t capacity) { - size_t nonce_len = cipher->nonce_len; + size_t salt_len = cipher->key_len; size_t tag_len = cipher->tag_len; int err = CRYPTO_OK; - if (ciphertext->len <= nonce_len + tag_len) { + if (ciphertext->len <= salt_len + tag_len) { return CRYPTO_ERROR; } @@ -441,24 +445,20 @@ aead_decrypt_all(buffer_t *ciphertext, cipher_t *cipher, size_t capacity) static buffer_t tmp = { 0, 0, 0, NULL }; brealloc(&tmp, ciphertext->len, capacity); buffer_t *plaintext = &tmp; - plaintext->len = ciphertext->len - nonce_len - tag_len; + plaintext->len = ciphertext->len - salt_len - tag_len; - /* get nonce */ - uint8_t *nonce = cipher_ctx.nonce; - memcpy(nonce, ciphertext->data, nonce_len); + /* get salt */ + uint8_t *salt = cipher_ctx.salt; + memcpy(salt, ciphertext->data, salt_len); + + aead_cipher_ctx_set_key(&cipher_ctx, 0); size_t plen = plaintext->len; - err = cipher_aead_decrypt(&cipher_ctx, - (uint8_t *)plaintext->data, - &plen, - (uint8_t *)ciphertext->data + nonce_len, - ciphertext->len - nonce_len, - NULL, - 0, - nonce, - cipher->key, - nonce_len, - tag_len); + err = aead_cipher_decrypt(&cipher_ctx, + (uint8_t *)plaintext->data, &plen, + (uint8_t *)ciphertext->data + salt_len, + ciphertext->len - salt_len, NULL, 0, + cipher_ctx.nonce, cipher_ctx.skey); if (err) { bfree(ciphertext); @@ -466,11 +466,6 @@ aead_decrypt_all(buffer_t *ciphertext, cipher_t *cipher, size_t capacity) return CRYPTO_ERROR; } -#ifdef DEBUG - dump("PLAIN", plaintext->data, plaintext->len); - dump("CIPHER", ciphertext->data + nonce_len, ciphertext->len - nonce_len); -#endif - aead_ctx_release(&cipher_ctx); brealloc(ciphertext, plaintext->len, capacity); @@ -481,9 +476,12 @@ aead_decrypt_all(buffer_t *ciphertext, cipher_t *cipher, size_t capacity) } static int -aead_chunk_encrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n, - uint16_t plen, size_t nlen, size_t tlen) +aead_chunk_encrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, + uint8_t *n, uint16_t plen) { + size_t nlen = ctx->cipher->nonce_len; + size_t tlen = ctx->cipher->tag_len; + assert(plen + tlen < CHUNK_SIZE_MASK); int err; @@ -493,8 +491,8 @@ aead_chunk_encrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n, memcpy(len_buf, &t, CHUNK_SIZE_LEN); clen = CHUNK_SIZE_LEN + tlen; - err = cipher_aead_encrypt(ctx, c, &clen, len_buf, CHUNK_SIZE_LEN, - NULL, 0, n, ctx->cipher->key, nlen, tlen); + err = aead_cipher_encrypt(ctx, c, &clen, len_buf, CHUNK_SIZE_LEN, + NULL, 0, n, ctx->skey); if (err) return CRYPTO_ERROR; assert(clen == CHUNK_SIZE_LEN + tlen); @@ -502,8 +500,8 @@ aead_chunk_encrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n, sodium_increment(n, nlen); clen = plen + tlen; - err = cipher_aead_encrypt(ctx, c + CHUNK_SIZE_LEN + tlen, &clen, p, plen, - NULL, 0, n, ctx->cipher->key, nlen, tlen); + err = aead_cipher_encrypt(ctx, c + CHUNK_SIZE_LEN + tlen, &clen,p, plen, + NULL, 0, n, ctx->skey); if (err) return CRYPTO_ERROR; assert(clen == plen + tlen); @@ -529,36 +527,32 @@ aead_encrypt(buffer_t *plaintext, cipher_ctx_t *cipher_ctx, size_t capacity) cipher_t *cipher = cipher_ctx->cipher; int err = CRYPTO_ERROR; - size_t nonce_ofst = 0; - size_t nonce_len = cipher->nonce_len; + size_t salt_ofst = 0; + size_t salt_len = cipher->key_len; size_t tag_len = cipher->tag_len; if (!cipher_ctx->init) { - nonce_ofst = nonce_len; + salt_ofst = salt_len; } - size_t out_len = nonce_ofst + 2 * tag_len + plaintext->len + CHUNK_SIZE_LEN; + size_t out_len = salt_ofst + 2 * tag_len + plaintext->len + CHUNK_SIZE_LEN; brealloc(&tmp, out_len, capacity); ciphertext = &tmp; ciphertext->len = out_len; if (!cipher_ctx->init) { - memcpy(ciphertext->data, cipher_ctx->nonce, nonce_len); + memcpy(ciphertext->data, cipher_ctx->salt, salt_len); + aead_cipher_ctx_set_key(cipher_ctx, 1); cipher_ctx->init = 1; } err = aead_chunk_encrypt(cipher_ctx, (uint8_t *)plaintext->data, - (uint8_t *)ciphertext->data + nonce_ofst, - cipher_ctx->nonce, plaintext->len, nonce_len, tag_len); + (uint8_t *)ciphertext->data + salt_ofst, + cipher_ctx->nonce, plaintext->len); if (err) return err; -#ifdef DEBUG - dump("PLAIN", plaintext->data, plaintext->len); - dump("CIPHER", ciphertext->data + nonce_ofst, ciphertext->len); -#endif - brealloc(plaintext, ciphertext->len, capacity); memcpy(plaintext->data, ciphertext->data, ciphertext->len); plaintext->len = ciphertext->len; @@ -568,17 +562,19 @@ aead_encrypt(buffer_t *plaintext, cipher_ctx_t *cipher_ctx, size_t capacity) static int aead_chunk_decrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n, - size_t *plen, size_t *clen, size_t nlen, size_t tlen) + size_t *plen, size_t *clen) { int err; size_t mlen; + size_t nlen = ctx->cipher->nonce_len; + size_t tlen = ctx->cipher->tag_len; if (*clen <= 2 * tlen + CHUNK_SIZE_LEN) return CRYPTO_NEED_MORE; uint8_t len_buf[2]; - err = cipher_aead_decrypt(ctx, len_buf, plen, c, CHUNK_SIZE_LEN + tlen, - NULL, 0, n, ctx->cipher->key, nlen, tlen); + err = aead_cipher_decrypt(ctx, len_buf, plen, c, CHUNK_SIZE_LEN + tlen, + NULL, 0, n, ctx->skey); if (err) return CRYPTO_ERROR; assert(*plen == CHUNK_SIZE_LEN); @@ -596,8 +592,8 @@ aead_chunk_decrypt(cipher_ctx_t *ctx, uint8_t *p, uint8_t *c, uint8_t *n, sodium_increment(n, nlen); - err = cipher_aead_decrypt(ctx, p, plen, c + CHUNK_SIZE_LEN + tlen, mlen + tlen, - NULL, 0, n, ctx->cipher->key, nlen, tlen); + err = aead_cipher_decrypt(ctx, p, plen, c + CHUNK_SIZE_LEN + tlen, mlen + tlen, + NULL, 0, n, ctx->skey); if (err) return CRYPTO_ERROR; assert(*plen == mlen); @@ -620,8 +616,7 @@ aead_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity) cipher_t *cipher = cipher_ctx->cipher; - size_t nonce_len = cipher->nonce_len; - size_t tag_len = cipher->tag_len; + size_t salt_len = cipher->key_len; if (cipher_ctx->chunk == NULL) { cipher_ctx->chunk = (buffer_t *)ss_malloc(sizeof(buffer_t)); @@ -639,20 +634,23 @@ aead_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity) buffer_t *plaintext = &tmp; if (!cipher_ctx->init) { - if (cipher_ctx->chunk->len <= nonce_len) + if (cipher_ctx->chunk->len <= salt_len) return CRYPTO_NEED_MORE; - memcpy(cipher_ctx->nonce, cipher_ctx->chunk->data, nonce_len); - if (cache_key_exist(nonce_cache, (char *)cipher_ctx->nonce, nonce_len)) { + memcpy(cipher_ctx->salt, cipher_ctx->chunk->data, salt_len); + + aead_cipher_ctx_set_key(cipher_ctx, 0); + + if (cache_key_exist(nonce_cache, (char *)cipher_ctx->salt, salt_len)) { bfree(ciphertext); return CRYPTO_ERROR; } else { - cache_insert(nonce_cache, (char *)cipher_ctx->nonce, nonce_len, NULL); + cache_insert(nonce_cache, (char *)cipher_ctx->salt, salt_len, NULL); } - memmove(cipher_ctx->chunk->data, cipher_ctx->chunk->data + nonce_len, - cipher_ctx->chunk->len - nonce_len); - cipher_ctx->chunk->len -= nonce_len; + memmove(cipher_ctx->chunk->data, cipher_ctx->chunk->data + salt_len, + cipher_ctx->chunk->len - salt_len); + cipher_ctx->chunk->len -= salt_len; cipher_ctx->init = 1; } @@ -664,9 +662,7 @@ aead_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity) err = aead_chunk_decrypt(cipher_ctx, (uint8_t *)plaintext->data + plen, (uint8_t *)cipher_ctx->chunk->data, - cipher_ctx->nonce, - &chunk_plen, &chunk_clen, - nonce_len, tag_len); + cipher_ctx->nonce, &chunk_plen, &chunk_clen); if (err == CRYPTO_ERROR) { return err; } else if (err == CRYPTO_NEED_MORE) { @@ -680,11 +676,6 @@ aead_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity) } plaintext->len = plen; -#ifdef DEBUG - dump("PLAIN", plaintext->data, plaintext->len); - dump("CIPHER", ciphertext->data + nonce_len, ciphertext->len - nonce_len); -#endif - brealloc(ciphertext, plaintext->len, capacity); memcpy(ciphertext->data, plaintext->data, plaintext->len); ciphertext->len = plaintext->len; diff --git a/src/crypto.c b/src/crypto.c index ec6f2465..ab3ea2ca 100644 --- a/src/crypto.c +++ b/src/crypto.c @@ -209,6 +209,110 @@ crypto_derive_key(const char *pass, uint8_t *key, size_t key_len) return key_len; } +/* HKDF-Extract + HKDF-Expand */ +int crypto_hkdf(const mbedtls_md_info_t *md, const unsigned char *salt, + int salt_len, const unsigned char *ikm, int ikm_len, + const unsigned char *info, int info_len, unsigned char *okm, + int okm_len) +{ + unsigned char prk[MBEDTLS_MD_MAX_SIZE]; + + return crypto_hkdf_extract(md, salt, salt_len, ikm, ikm_len, prk) || + crypto_hkdf_expand(md, prk, mbedtls_md_get_size(md), info, info_len, + okm, okm_len); +} + +/* HKDF-Extract(salt, IKM) -> PRK */ +int crypto_hkdf_extract(const mbedtls_md_info_t *md, const unsigned char *salt, + int salt_len, const unsigned char *ikm, int ikm_len, + unsigned char *prk) +{ + int hash_len; + unsigned char null_salt[MBEDTLS_MD_MAX_SIZE] = { '\0' }; + + if (salt_len < 0) { + return CRYPTO_ERROR; + } + + hash_len = mbedtls_md_get_size(md); + + if (salt == NULL) { + salt = null_salt; + salt_len = hash_len; + } + + return mbedtls_md_hmac(md, salt, salt_len, ikm, ikm_len, prk); +} + +/* HKDF-Expand(PRK, info, L) -> OKM */ +int crypto_hkdf_expand(const mbedtls_md_info_t *md, const unsigned char *prk, + int prk_len, const unsigned char *info, int info_len, + unsigned char *okm, int okm_len) +{ + int hash_len; + int N; + int T_len = 0, where = 0, i, ret; + mbedtls_md_context_t ctx; + unsigned char T[MBEDTLS_MD_MAX_SIZE]; + + if (info_len < 0 || okm_len < 0 || okm == NULL) { + return CRYPTO_ERROR; + } + + hash_len = mbedtls_md_get_size(md); + + if (prk_len < hash_len) { + return CRYPTO_ERROR; + } + + if (info == NULL) { + info = (const unsigned char *)""; + } + + N = okm_len / hash_len; + + if ((okm_len % hash_len) != 0) { + N++; + } + + if (N > 255) { + return CRYPTO_ERROR; + } + + mbedtls_md_init(&ctx); + + if ((ret = mbedtls_md_setup(&ctx, md, 1)) != 0) { + mbedtls_md_free(&ctx); + return ret; + } + + /* Section 2.3. */ + for (i = 1; i <= N; i++) { + unsigned char c = i; + + ret = mbedtls_md_hmac_starts(&ctx, prk, prk_len) || + mbedtls_md_hmac_update(&ctx, T, T_len) || + mbedtls_md_hmac_update(&ctx, info, info_len) || + /* The constant concatenated to the end of each T(n) is a single + octet. */ + mbedtls_md_hmac_update(&ctx, &c, 1) || + mbedtls_md_hmac_finish(&ctx, T); + + if (ret != 0) { + mbedtls_md_free(&ctx); + return ret; + } + + memcpy(okm + where, T, (i != N) ? hash_len : (okm_len - where)); + where += hash_len; + T_len = hash_len; + } + + mbedtls_md_free(&ctx); + + return 0; +} + int crypto_parse_key(const char *base64, uint8_t *key, size_t key_len) { diff --git a/src/crypto.h b/src/crypto.h index 424420dc..ee25e986 100644 --- a/src/crypto.h +++ b/src/crypto.h @@ -63,6 +63,9 @@ typedef mbedtls_md_info_t digest_type_t; #define min(a, b) (((a) < (b)) ? (a) : (b)) #define max(a, b) (((a) > (b)) ? (a) : (b)) +#define SUBKEY_INFO "ss-subkey" +#define IV_INFO "ss-iv" + typedef struct buffer { size_t idx; size_t len; @@ -72,6 +75,7 @@ typedef struct buffer { typedef struct { int method; + int skey; cipher_kt_t *info; size_t nonce_len; size_t key_len; @@ -85,6 +89,8 @@ typedef struct { cipher_evp_t *evp; cipher_t *cipher; buffer_t *chunk; + uint8_t salt[MAX_KEY_LENGTH]; + uint8_t skey[MAX_KEY_LENGTH]; uint8_t nonce[MAX_NONCE_LENGTH]; } cipher_ctx_t; @@ -111,6 +117,16 @@ unsigned char *crypto_md5(const unsigned char *, size_t, unsigned char *); int crypto_derive_key(const char *, uint8_t *, size_t); int crypto_parse_key(const char *, uint8_t *, size_t); +int crypto_hkdf(const mbedtls_md_info_t *md, const unsigned char *salt, + int salt_len, const unsigned char *ikm, int ikm_len, + const unsigned char *info, int info_len, unsigned char *okm, + int okm_len); +int crypto_hkdf_extract(const mbedtls_md_info_t *md, const unsigned char *salt, + int salt_len, const unsigned char *ikm, int ikm_len, + unsigned char *prk); +int crypto_hkdf_expand(const mbedtls_md_info_t *md, const unsigned char *prk, + int prk_len, const unsigned char *info, int info_len, + unsigned char *okm, int okm_len); extern struct cache *nonce_cache; extern const char *supported_stream_ciphers[];