diff --git a/src/cache.c b/src/cache.c index c005fbfb..55c7017d 100644 --- a/src/cache.c +++ b/src/cache.c @@ -90,6 +90,7 @@ int cache_delete(struct cache *cache, int keep_data) if (cache->free_cb) { cache->free_cb(entry->data); } + free(entry->key); free(entry); } } @@ -107,9 +108,12 @@ int cache_delete(struct cache *cache, int keep_data) @param key The key of the entry to remove + @param key_len + The length of key + @return EINVAL if cache is NULL, 0 otherwise */ -int cache_remove(struct cache *cache, char *key) +int cache_remove(struct cache *cache, char *key, size_t key_len) { struct cache_entry *tmp; @@ -117,13 +121,14 @@ int cache_remove(struct cache *cache, char *key) return EINVAL; } - HASH_FIND_STR(cache->entries, key, tmp); + HASH_FIND(hh, cache->entries, key, key_len, tmp); if (tmp) { HASH_DEL(cache->entries, tmp); if (cache->free_cb) { cache->free_cb(tmp->data); } + free(tmp->key); free(tmp); } @@ -138,6 +143,9 @@ int cache_remove(struct cache *cache, char *key) @param key The key to look-up + @param key_len + The length of key + @param result Where to store the result if key is found. @@ -147,7 +155,7 @@ int cache_remove(struct cache *cache, char *key) @return EINVAL if cache is NULL, 0 otherwise */ -int cache_lookup(struct cache *cache, char *key, void *result) +int cache_lookup(struct cache *cache, char *key, size_t key_len, void *result) { struct cache_entry *tmp = NULL; char **dirty_hack = result; @@ -156,9 +164,8 @@ int cache_lookup(struct cache *cache, char *key, void *result) return EINVAL; } - HASH_FIND_STR(cache->entries, key, tmp); + HASH_FIND(hh, cache->entries, key, key_len, tmp); if (tmp) { - size_t key_len = strnlen(tmp->key, KEY_MAX_LENGTH); HASH_DELETE(hh, cache->entries, tmp); HASH_ADD_KEYPTR(hh, cache->entries, tmp->key, key_len, tmp); *dirty_hack = tmp->data; @@ -169,6 +176,26 @@ int cache_lookup(struct cache *cache, char *key, void *result) return 0; } +int cache_key_exist(struct cache *cache, char *key, size_t key_len) +{ + struct cache_entry *tmp = NULL; + + if (!cache || !key) { + return 0; + } + + HASH_FIND(hh, cache->entries, key, key_len, tmp); + if (tmp) { + HASH_DELETE(hh, cache->entries, tmp); + HASH_ADD_KEYPTR(hh, cache->entries, tmp->key, key_len, tmp); + return 1; + } else { + return 0; + } + + return 0; +} + /** Inserts a given pair into the cache @param cache @@ -177,16 +204,18 @@ int cache_lookup(struct cache *cache, char *key, void *result) @param key The key that identifies + @param key_len + The length of key + @param data Data associated with @return EINVAL if cache is NULL, ENOMEM if malloc fails, 0 otherwise */ -int cache_insert(struct cache *cache, char *key, void *data) +int cache_insert(struct cache *cache, char *key, size_t key_len, void *data) { struct cache_entry *entry = NULL; struct cache_entry *tmp_entry = NULL; - size_t key_len = 0; if (!cache || !data) { return EINVAL; @@ -196,9 +225,9 @@ int cache_insert(struct cache *cache, char *key, void *data) return ENOMEM; } - entry->key = key; + entry->key = malloc(key_len); + mempcpy(entry->key, key, key_len); entry->data = data; - key_len = strnlen(entry->key, KEY_MAX_LENGTH); HASH_ADD_KEYPTR(hh, cache->entries, entry->key, key_len, entry); if (HASH_COUNT(cache->entries) >= cache->max_entries) { @@ -209,7 +238,7 @@ int cache_insert(struct cache *cache, char *key, void *data) } else { free(entry->data); } - /* free(key->key) if data has been copied */ + free(entry->key); free(entry); break; } diff --git a/src/cache.h b/src/cache.h index ad415b02..5f24b993 100644 --- a/src/cache.h +++ b/src/cache.h @@ -30,8 +30,6 @@ #include "uthash.h" -#define KEY_MAX_LENGTH 32 - /** * A cache entry */ @@ -54,8 +52,9 @@ struct cache { extern int cache_create(struct cache **dst, const size_t capacity, void (*free_cb)(void *element)); extern int cache_delete(struct cache *cache, int keep_data); -extern int cache_lookup(struct cache *cache, char *key, void *result); -extern int cache_insert(struct cache *cache, char *key, void *data); -extern int cache_remove(struct cache *cache, char *key); +extern int cache_lookup(struct cache *cache, char *key, size_t key_len, void *result); +extern int cache_insert(struct cache *cache, char *key, size_t key_len, void *data); +extern int cache_remove(struct cache *cache, char *key, size_t key_len); +extern int cache_key_exist(struct cache *cache, char *key, size_t key_len); #endif diff --git a/src/encrypt.c b/src/encrypt.c index 10cb3a4a..c870f59f 100644 --- a/src/encrypt.c +++ b/src/encrypt.c @@ -67,6 +67,7 @@ #include +#include "cache.h" #include "encrypt.h" #include "utils.h" @@ -79,6 +80,8 @@ static int enc_key_len; static int enc_iv_len; static int enc_method; +static struct cache *iv_cache; + #ifdef DEBUG static void dump(char *tag, char *text, int len) { @@ -1263,6 +1266,13 @@ char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len, cipher_context_set_iv(&ctx->evp, iv, iv_len, 0); ctx->counter = 0; ctx->init = 1; + + if (cache_key_exist(iv_cache, (char *)iv, MAX_IV_LENGTH)) { + free(ciphertext); + return NULL; + } else { + cache_insert(iv_cache, (char *)iv, MAX_IV_LENGTH, NULL); + } } if (enc_method >= SALSA20) { @@ -1336,6 +1346,9 @@ void enc_key_init(int method, const char *pass) return; } + // Inilitialize cache + cache_create(&iv_cache, 256, NULL); + #if defined(USE_CRYPTO_OPENSSL) OpenSSL_add_all_algorithms(); #endif diff --git a/src/server.c b/src/server.c index c80d6a36..3c7e29a2 100644 --- a/src/server.c +++ b/src/server.c @@ -563,7 +563,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } else if (atyp == 3) { // Domain name uint8_t name_len = *(uint8_t *)(server->buf + offset); - if (name_len < r && name_len < 255 && name_len > 0) { + if (name_len < r) { memcpy(host, server->buf + offset + 1, name_len); offset += name_len + 1; } else { diff --git a/src/udprelay.c b/src/udprelay.c index 4b33e7f3..e103c6e4 100644 --- a/src/udprelay.c +++ b/src/udprelay.c @@ -528,7 +528,7 @@ static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) } char *key = hash_key(remote_ctx->af, &remote_ctx->src_addr); - cache_remove(remote_ctx->server_ctx->conn_cache, key); + cache_remove(remote_ctx->server_ctx->conn_cache, key, 32); } #ifdef UDPRELAY_REMOTE @@ -552,7 +552,7 @@ static void query_resolve_cb(struct sockaddr *addr, void *data) // Lookup in the conn cache if (remote_ctx == NULL) { char *key = hash_key(0, &query_ctx->src_addr); - cache_lookup(query_ctx->server_ctx->conn_cache, key, (void *)&remote_ctx); + cache_lookup(query_ctx->server_ctx->conn_cache, key, 32, (void *)&remote_ctx); } if (remote_ctx == NULL) { @@ -597,8 +597,7 @@ static void query_resolve_cb(struct sockaddr *addr, void *data) if (!cache_hit) { // Add to conn cache char *key = hash_key(0, &remote_ctx->src_addr); - cache_insert(query_ctx->server_ctx->conn_cache, key, - (void *)remote_ctx); + cache_insert(query_ctx->server_ctx->conn_cache, key, 32, (void *)remote_ctx); ev_io_start(EV_A_ & remote_ctx->io); ev_timer_start(EV_A_ & remote_ctx->watcher); } @@ -980,7 +979,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) struct cache *conn_cache = server_ctx->conn_cache; struct remote_ctx *remote_ctx = NULL; - cache_lookup(conn_cache, key, (void *)&remote_ctx); + cache_lookup(conn_cache, key, 32, (void *)&remote_ctx); if (remote_ctx != NULL) { if (memcmp(&src_addr, &remote_ctx->src_addr, sizeof(src_addr))) { @@ -1069,7 +1068,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) memcpy(remote_ctx->addr_header, addr_header, addr_header_len); // Add to conn cache - cache_insert(conn_cache, key, (void *)remote_ctx); + cache_insert(conn_cache, key, 32, (void *)remote_ctx); // Start remote io ev_io_start(EV_A_ & remote_ctx->io); @@ -1148,8 +1147,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) // Add to conn cache remote_ctx->af = dst_addr.ss_family; char *key = hash_key(remote_ctx->af, &remote_ctx->src_addr); - cache_insert(server_ctx->conn_cache, key, - (void *)remote_ctx); + cache_insert(server_ctx->conn_cache, key, 32, (void *)remote_ctx); ev_io_start(EV_A_ & remote_ctx->io); ev_timer_start(EV_A_ & remote_ctx->watcher);