diff --git a/src/stream.c b/src/stream.c index 560802b3..c530358c 100644 --- a/src/stream.c +++ b/src/stream.c @@ -243,6 +243,11 @@ stream_cipher_ctx_init(cipher_ctx_t *ctx, int method, int enc) void stream_cipher_ctx_release(cipher_ctx_t *cipher_ctx) { + if (cipher_ctx->chunk != NULL) { + bfree(cipher_ctx->chunk); + ss_free(cipher_ctx->chunk); + cipher_ctx->chunk = NULL; + } mbedtls_cipher_free(cipher_ctx->evp); ss_free(cipher_ctx->evp); } @@ -489,22 +494,38 @@ stream_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity) static buffer_t tmp = { 0, 0, 0, NULL }; - size_t nonce_len = 0; - int err = CRYPTO_OK; + int err = CRYPTO_OK; brealloc(&tmp, ciphertext->len, capacity); buffer_t *plaintext = &tmp; plaintext->len = ciphertext->len; if (!cipher_ctx->init) { - if (ciphertext->len <= cipher->nonce_len) - return CRYPTO_ERROR; + if (cipher_ctx->chunk == NULL) { + cipher_ctx->chunk = (buffer_t *)ss_malloc(sizeof(buffer_t)); + memset(cipher_ctx->chunk, 0, sizeof(buffer_t)); + balloc(cipher_ctx->chunk, cipher->nonce_len); + } + + size_t left_len = min(cipher->nonce_len - cipher_ctx->chunk->len, + ciphertext->len); + + if (left_len > 0) { + memcpy(cipher_ctx->chunk->data, ciphertext->data, left_len); + memmove(ciphertext->data, ciphertext->data + left_len, + ciphertext->len - left_len); + cipher_ctx->chunk->len += left_len; + ciphertext->len -= left_len; + } + + if (cipher_ctx->chunk->len < cipher->nonce_len) + return CRYPTO_NEED_MORE; - uint8_t *nonce = cipher_ctx->nonce; - nonce_len = cipher->nonce_len; + uint8_t *nonce = cipher_ctx->nonce; + size_t nonce_len = cipher->nonce_len; plaintext->len -= nonce_len; - memcpy(nonce, ciphertext->data, nonce_len); + memcpy(nonce, cipher_ctx->chunk->data, nonce_len); cipher_ctx_set_nonce(cipher_ctx, nonce, nonce_len, 0); cipher_ctx->counter = 0; cipher_ctx->init = 1; @@ -519,30 +540,33 @@ stream_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity) } } + if (ciphertext->len <= 0) + return CRYPTO_NEED_MORE; + if (cipher->method >= SALSA20) { int padding = cipher_ctx->counter % SODIUM_BLOCK_SIZE; brealloc(plaintext, (plaintext->len + padding) * 2, capacity); if (padding) { brealloc(ciphertext, ciphertext->len + padding, capacity); - memmove(ciphertext->data + nonce_len + padding, ciphertext->data + nonce_len, - ciphertext->len - nonce_len); - sodium_memzero(ciphertext->data + nonce_len, padding); + memmove(ciphertext->data + padding, ciphertext->data, + ciphertext->len); + sodium_memzero(ciphertext->data, padding); } crypto_stream_xor_ic((uint8_t *)plaintext->data, - (const uint8_t *)(ciphertext->data + nonce_len), - (uint64_t)(ciphertext->len - nonce_len + padding), + (const uint8_t *)(ciphertext->data), + (uint64_t)(ciphertext->len + padding), (const uint8_t *)cipher_ctx->nonce, cipher_ctx->counter / SODIUM_BLOCK_SIZE, cipher->key, cipher->method); - cipher_ctx->counter += ciphertext->len - nonce_len; + cipher_ctx->counter += ciphertext->len; if (padding) { memmove(plaintext->data, plaintext->data + padding, plaintext->len); } } else { err = cipher_ctx_update(cipher_ctx, (uint8_t *)plaintext->data, &plaintext->len, - (const uint8_t *)(ciphertext->data + nonce_len), - ciphertext->len - nonce_len); + (const uint8_t *)(ciphertext->data), + ciphertext->len); } if (err) { @@ -552,7 +576,7 @@ stream_decrypt(buffer_t *ciphertext, cipher_ctx_t *cipher_ctx, size_t capacity) #ifdef DEBUG dump("PLAIN", plaintext->data, plaintext->len); - dump("CIPHER", ciphertext->data + nonce_len, ciphertext->len - nonce_len); + dump("CIPHER", ciphertext->data, ciphertext->len); #endif brealloc(ciphertext, plaintext->len, capacity);