From 0e79c93ad1744efa32b31d3fd8776795249e181a Mon Sep 17 00:00:00 2001 From: Max Lv Date: Wed, 23 Sep 2015 22:15:59 +0800 Subject: [PATCH] add auth to udprelay --- src/encrypt.c | 58 +++++++++++++++++++++++++++++++++----------------- src/encrypt.h | 8 +++---- src/local.c | 2 +- src/redir.c | 2 +- src/server.c | 2 +- src/tunnel.c | 2 +- src/udprelay.c | 8 +++---- 7 files changed, 50 insertions(+), 32 deletions(-) diff --git a/src/encrypt.c b/src/encrypt.c index 81e7c7ef..a8c507d3 100644 --- a/src/encrypt.c +++ b/src/encrypt.c @@ -790,6 +790,8 @@ void cipher_context_init(cipher_ctx_t *ctx, int method, int enc) return; } + enc_iv_len = 0; + if (method >= SALSA20) { enc_iv_len = supported_ciphers_iv_size[method]; return; @@ -1029,27 +1031,27 @@ static int cipher_context_update(cipher_ctx_t *ctx, uint8_t *output, size_t *ole #endif } -int ss_onetimeauth(char *auth, char *msg, int msg_len, struct enc_ctx *ctx) +int ss_onetimeauth(char *auth, char *msg, int msg_len, uint8_t *iv) { uint8_t auth_key[MAX_IV_LENGTH + MAX_KEY_LENGTH]; - memcpy(auth_key, ctx->evp.iv, enc_iv_len); + memcpy(auth_key, iv, enc_iv_len); memcpy(auth_key + enc_iv_len, enc_key, enc_key_len); sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, (uint8_t *)auth); return 0; } -int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, struct enc_ctx *ctx) +int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, uint8_t *iv) { uint8_t hash[ONETIMEAUTH_BYTES]; uint8_t auth_key[MAX_IV_LENGTH + MAX_KEY_LENGTH]; - memcpy(auth_key, ctx->evp.iv, enc_iv_len); + memcpy(auth_key, iv, enc_iv_len); memcpy(auth_key + enc_iv_len, enc_key, enc_key_len); sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, hash); return memcmp(auth, hash, ONETIMEAUTH_BYTES); } -char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method) +char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method, int auth) { if (method > TABLE) { cipher_ctx_t evp; @@ -1069,10 +1071,21 @@ char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method) char *ciphertext = tmp_buf; uint8_t iv[MAX_IV_LENGTH]; + rand_bytes(iv, iv_len); cipher_context_set_iv(&evp, iv, iv_len, 1); memcpy(ciphertext, iv, iv_len); + if (auth) { + char hash[ONETIMEAUTH_BYTES]; + ss_onetimeauth(hash, plaintext, p_len, iv); + if (buf_size < ONETIMEAUTH_BYTES + p_len) { + plaintext = realloc(plaintext, ONETIMEAUTH_BYTES + p_len); + memcpy(plaintext + p_len, hash, ONETIMEAUTH_BYTES); + p_len = c_len = p_len + ONETIMEAUTH_BYTES; + } + } + if (method >= SALSA20) { crypto_stream_xor_ic((uint8_t *)(ciphertext + iv_len), (const uint8_t *)plaintext, (uint64_t)(p_len), @@ -1097,8 +1110,8 @@ char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method) cipher_context_release(&evp); - if (*len < iv_len + c_len) { - plaintext = realloc(plaintext, max(iv_len + c_len, buf_size)); + if (buf_size < iv_len + c_len) { + plaintext = realloc(plaintext, iv_len + c_len); } *len = iv_len + c_len; memcpy(plaintext, ciphertext, *len); @@ -1183,8 +1196,8 @@ char * ss_encrypt(int buf_size, char *plaintext, ssize_t *len, dump("CIPHER", ciphertext + iv_len, c_len); #endif - if (*len < iv_len + c_len) { - plaintext = realloc(plaintext, max(iv_len + c_len, buf_size)); + if (buf_size < iv_len + c_len) { + plaintext = realloc(plaintext, iv_len + c_len); } *len = iv_len + c_len; memcpy(plaintext, ciphertext, *len); @@ -1200,14 +1213,14 @@ char * ss_encrypt(int buf_size, char *plaintext, ssize_t *len, } } -char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method) +char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method, int auth) { if (method > TABLE) { cipher_ctx_t evp; cipher_context_init(&evp, method, 0); size_t iv_len = enc_iv_len; size_t c_len = *len, p_len = *len - iv_len; - int err = 1; + int ret = 1; static int tmp_len = 0; static char *tmp_buf = NULL; @@ -1228,12 +1241,18 @@ char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method) (uint64_t)(c_len - iv_len), (const uint8_t *)iv, 0, enc_key, method); } else { - err = cipher_context_update(&evp, (uint8_t *)plaintext, &p_len, + ret = cipher_context_update(&evp, (uint8_t *)plaintext, &p_len, (const uint8_t *)(ciphertext + iv_len), c_len - iv_len); } - if (!err) { + if (auth) { + char hash[ONETIMEAUTH_BYTES]; + ret = !ss_onetimeauth_verify(hash, plaintext, p_len - ONETIMEAUTH_BYTES, iv); + if (ret) p_len -= ONETIMEAUTH_BYTES; + } + + if (!ret) { free(ciphertext); cipher_context_release(&evp); return NULL; @@ -1246,8 +1265,8 @@ char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method) cipher_context_release(&evp); - if (*len < p_len) { - ciphertext = realloc(ciphertext, max(p_len, buf_size)); + if (buf_size < p_len) { + ciphertext = realloc(ciphertext, p_len); } *len = p_len; memcpy(ciphertext, plaintext, *len); @@ -1308,8 +1327,7 @@ char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len, struct enc_ctx * tmp_buf = plaintext; } if (padding) { - ciphertext = - realloc(ciphertext, max(c_len + padding, buf_size)); + ciphertext = realloc(ciphertext, max(c_len + padding, buf_size)); memmove(ciphertext + iv_len + padding, ciphertext + iv_len, c_len - iv_len); memset(ciphertext + iv_len, 0, padding); @@ -1340,8 +1358,8 @@ char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len, struct enc_ctx * dump("CIPHER", ciphertext + iv_len, c_len - iv_len); #endif - if (*len < p_len) { - ciphertext = realloc(ciphertext, max(p_len, buf_size)); + if (buf_size < p_len) { + ciphertext = realloc(ciphertext, p_len); } *len = p_len; memcpy(ciphertext, plaintext, *len); @@ -1501,7 +1519,7 @@ int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct if (cidx == CLEN_BYTES) { uint16_t clen = ntohs(*((uint16_t *)chunk->buf)); - if (chunk->len < clen) { + if (buf_size < clen + AUTH_BYTES) { chunk->buf = realloc(chunk->buf, clen + AUTH_BYTES); } chunk->len = clen; diff --git a/src/encrypt.h b/src/encrypt.h index 6f697850..b3f7ed30 100644 --- a/src/encrypt.h +++ b/src/encrypt.h @@ -164,8 +164,8 @@ struct enc_ctx { cipher_ctx_t evp; }; -char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method); -char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method); +char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method, int auth); +char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method, int auth); char * ss_encrypt(int buf_size, char *plaintext, ssize_t *len, struct enc_ctx *ctx); char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len, @@ -176,8 +176,8 @@ int enc_get_iv_len(void); void cipher_context_release(cipher_ctx_t *evp); unsigned char *enc_md5(const unsigned char *d, size_t n, unsigned char *md); -int ss_onetimeauth(char *auth, char *msg, int msg_len, struct enc_ctx *ctx); -int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, struct enc_ctx *ctx); +int ss_onetimeauth(char *auth, char *msg, int msg_len, uint8_t *iv); +int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, uint8_t *iv); int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct enc_ctx *ctx, int buf_size); char *ss_gen_hash(char *buf, ssize_t *buf_len, uint32_t *counter, struct enc_ctx *ctx, int buf_size); diff --git a/src/local.c b/src/local.c index f84c2cc4..9eff43e8 100644 --- a/src/local.c +++ b/src/local.c @@ -477,7 +477,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) if (!remote->direct) { if (auth) { ss_addr_to_send[0] |= ONETIMEAUTH_FLAG; - ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx); + ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx->evp.iv); addr_len += ONETIMEAUTH_BYTES; } diff --git a/src/redir.c b/src/redir.c index 7ba3397f..37246f49 100644 --- a/src/redir.c +++ b/src/redir.c @@ -377,7 +377,7 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) if (auth) { ss_addr_to_send[0] |= ONETIMEAUTH_FLAG; - ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx); + ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx->evp.iv); addr_len += ONETIMEAUTH_BYTES; } diff --git a/src/server.c b/src/server.c index 9a681325..395156c3 100644 --- a/src/server.c +++ b/src/server.c @@ -664,7 +664,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) offset += 2; if (auth || (atyp & ONETIMEAUTH_FLAG)) { - if (ss_onetimeauth_verify(server->buf + offset, server->buf, offset, server->d_ctx)) { + if (ss_onetimeauth_verify(server->buf + offset, server->buf, offset, server->d_ctx->evp.iv)) { LOGE("authentication error %d", atyp); report_addr(server->fd); close_and_free_server(EV_A_ server); diff --git a/src/tunnel.c b/src/tunnel.c index 9e848852..f0ebbebf 100644 --- a/src/tunnel.c +++ b/src/tunnel.c @@ -416,7 +416,7 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) if (auth) { ss_addr_to_send[0] |= ONETIMEAUTH_FLAG; - ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx); + ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx->evp.iv); addr_len += ONETIMEAUTH_BYTES; } diff --git a/src/udprelay.c b/src/udprelay.c index d373e691..29e02ca7 100644 --- a/src/udprelay.c +++ b/src/udprelay.c @@ -648,7 +648,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) } #ifdef UDPRELAY_LOCAL - buf = ss_decrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method); + buf = ss_decrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, 0); if (buf == NULL) { ERROR("[udp] server_ss_decrypt_all"); goto CLEAN_UP; @@ -708,7 +708,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) memcpy(buf, addr_header, addr_header_len); buf_len += addr_header_len; - buf = ss_encrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method); + buf = ss_encrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, 0); #endif size_t remote_src_addr_len = get_sockaddr_len((struct sockaddr *)&remote_ctx->src_addr); @@ -826,7 +826,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) tx += buf_len; - buf = ss_decrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method); + buf = ss_decrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, 1); if (buf == NULL) { ERROR("[udp] server_ss_decrypt_all"); goto CLEAN_UP; @@ -1081,7 +1081,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) memmove(buf, buf + offset, buf_len); } - buf = ss_encrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method); + buf = ss_encrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, 1); int s = sendto(remote_ctx->fd, buf, buf_len, 0, remote_addr, remote_addr_len);