diff --git a/src/udprelay.c b/src/udprelay.c index aaad030b..6d3a0475 100644 --- a/src/udprelay.c +++ b/src/udprelay.c @@ -86,8 +86,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents); static void remote_recv_cb(EV_P_ ev_io *w, int revents); static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents); -static char *hash_key(const char *header, const int header_len, - const struct sockaddr_storage *addr); +static char *hash_key(const int af, const struct sockaddr_storage *addr); #ifdef UDPRELAY_REMOTE static void query_resolve_cb(struct sockaddr *addr, void *data); #endif @@ -169,19 +168,17 @@ static int get_dstaddr(struct msghdr *msg, struct sockaddr_storage *dstaddr) } #endif -static char *hash_key(const char *header, const int header_len, - const struct sockaddr_storage *addr) +static char *hash_key(const int af, const struct sockaddr_storage *addr) { - char key[384]; + int addr_len = sizeof(struct sockaddr_storage); + int key_len = addr_len + sizeof(int); + char key[key_len]; - // calculate hash key - // assert header_len < 256 - memset(key, 0, 384); - memcpy(key, addr, sizeof(struct sockaddr_storage)); - memcpy(key + sizeof(struct sockaddr_storage), header, header_len); + memset(key, 0, key_len); + memcpy(key, &af, sizeof(int)); + memcpy(key + sizeof(int), (const uint8_t *)addr, addr_len); - return (char *)enc_md5((const uint8_t *)key, - sizeof(struct sockaddr_storage) + header_len, NULL); + return (char *)enc_md5((const uint8_t *)key, key_len, NULL); } static int parse_udprealy_header(const char * buf, const int buf_len, @@ -432,6 +429,7 @@ struct remote_ctx *new_remote(int fd, struct server_ctx *server_ctx) { struct remote_ctx *ctx = malloc(sizeof(struct remote_ctx)); memset(ctx, 0, sizeof(struct remote_ctx)); + ctx->fd = fd; ctx->server_ctx = server_ctx; ev_io_init(&ctx->io, remote_recv_cb, fd, EV_READ); @@ -498,8 +496,7 @@ static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) LOGI("[udp] connection timeout"); } - char *key = hash_key(remote_ctx->addr_header, - remote_ctx->addr_header_len, &remote_ctx->src_addr); + char *key = hash_key(remote_ctx->af, &remote_ctx->src_addr); cache_remove(remote_ctx->server_ctx->conn_cache, key); } @@ -518,59 +515,69 @@ static void query_resolve_cb(struct sockaddr *addr, void *data) if (addr == NULL) { LOGE("[udp] udns returned an error"); } else { - int remotefd = create_remote_socket(addr->sa_family == AF_INET6); - if (remotefd != -1) { - setnonblocking(remotefd); + struct remote_ctx *remote_ctx = query_ctx->remote_ctx; + int cache_hit = 0; + + if (remote_ctx == NULL) { + int remotefd = create_remote_socket(addr->sa_family == AF_INET6); + if (remotefd != -1) { + setnonblocking(remotefd); #ifdef SO_BROADCAST - set_broadcast(remotefd); + set_broadcast(remotefd); #endif #ifdef SO_NOSIGPIPE - set_nosigpipe(remotefd); + set_nosigpipe(remotefd); #endif #ifdef SET_INTERFACE - if (query_ctx->server_ctx->iface) { - setinterface(remotefd, query_ctx->server_ctx->iface); - } + if (query_ctx->server_ctx->iface) { + setinterface(remotefd, query_ctx->server_ctx->iface); + } #endif - - struct remote_ctx *remote_ctx = new_remote(remotefd, - query_ctx->server_ctx); - remote_ctx->src_addr = query_ctx->src_addr; + remote_ctx = new_remote(remotefd, query_ctx->server_ctx); + remote_ctx->src_addr = query_ctx->src_addr; + if (addr->sa_family == AF_INET) { + memcpy(&(remote_ctx->dst_addr), addr, sizeof(struct sockaddr_in)); + } else if (addr->sa_family == AF_INET6) { + memcpy(&(remote_ctx->dst_addr), addr, sizeof(struct sockaddr_in6)); + } + remote_ctx->server_ctx = query_ctx->server_ctx; + remote_ctx->addr_header_len = query_ctx->addr_header_len; + memcpy(remote_ctx->addr_header, query_ctx->addr_header, + query_ctx->addr_header_len); + } else { + ERROR("[udp] bind() error"); + } + } else { if (addr->sa_family == AF_INET) { - memcpy(&(remote_ctx->dst_addr), addr, - sizeof(struct sockaddr_in)); + memcpy(&(remote_ctx->dst_addr), addr, sizeof(struct sockaddr_in)); } else if (addr->sa_family == AF_INET6) { - memcpy(&(remote_ctx->dst_addr), addr, - sizeof(struct sockaddr_in6)); + memcpy(&(remote_ctx->dst_addr), addr, sizeof(struct sockaddr_in6)); } - remote_ctx->server_ctx = query_ctx->server_ctx; - remote_ctx->addr_header_len = query_ctx->addr_header_len; - memcpy(remote_ctx->addr_header, query_ctx->addr_header, - query_ctx->addr_header_len); + cache_hit = 1; + } + if (remote_ctx != NULL) { size_t addr_len = get_sockaddr_len((struct sockaddr *)&remote_ctx->dst_addr); int s = sendto(remote_ctx->fd, query_ctx->buf, query_ctx->buf_len, - 0, (struct sockaddr *)&remote_ctx->dst_addr, - addr_len); + 0, (struct sockaddr *)&remote_ctx->dst_addr, addr_len); if (s == -1) { ERROR("[udp] sendto_remote"); - close_and_free_remote(EV_A_ remote_ctx); + if (!cache_hit) { + close_and_free_remote(EV_A_ remote_ctx); + } } else { // Add to conn cache - char *key = hash_key(remote_ctx->addr_header, - remote_ctx->addr_header_len, - &remote_ctx->src_addr); - cache_insert(query_ctx->server_ctx->conn_cache, key, - (void *)remote_ctx); - - ev_io_start(EV_A_ & remote_ctx->io); - ev_timer_start(EV_A_ & remote_ctx->watcher); + if (!cache_hit) { + char *key = hash_key(0, &remote_ctx->src_addr); + cache_insert(query_ctx->server_ctx->conn_cache, key, + (void *)remote_ctx); + ev_io_start(EV_A_ & remote_ctx->io); + ev_timer_start(EV_A_ & remote_ctx->watcher); + } } - - } else { - ERROR("[udp] bind() error"); } + } // clean up @@ -599,8 +606,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) char *buf = malloc(BUF_SIZE); // recv - ssize_t buf_len = recvfrom(remote_ctx->fd, buf, BUF_SIZE, 0, &src_addr, - &src_addr_len); + ssize_t buf_len = recvfrom(remote_ctx->fd, buf, BUF_SIZE, 0, &src_addr, &src_addr_len); if (buf_len == -1) { // error on recv @@ -851,6 +857,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) memcpy(buf, addr_header, addr_header_len); buf_len += addr_header_len; + char *key = hash_key(dst_addr.ss_family, &src_addr); + #elif UDPRELAY_TUNNEL char addr_header[256] = { 0 }; @@ -905,6 +913,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) memcpy(buf, addr_header, addr_header_len); buf_len += addr_header_len; + char *key = hash_key(ip.version == 4 ? AF_INET : AF_INET6, &src_addr); + #else char host[256] = { 0 }; @@ -920,18 +930,17 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) goto CLEAN_UP; } char *addr_header = buf + offset; + + char *key = hash_key(storage.ss_family, &src_addr); #endif - char *key = hash_key(addr_header, addr_header_len, &src_addr); struct cache *conn_cache = server_ctx->conn_cache; struct remote_ctx *remote_ctx = NULL; cache_lookup(conn_cache, key, (void *)&remote_ctx); if (remote_ctx != NULL) { - if (memcmp(&src_addr, &remote_ctx->src_addr, sizeof(src_addr)) - || remote_ctx->addr_header_len != addr_header_len - || memcmp(addr_header, remote_ctx->addr_header, addr_header_len) != 0) { + if (memcmp(&src_addr, &remote_ctx->src_addr, sizeof(src_addr))) { remote_ctx = NULL; } } @@ -1000,6 +1009,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) #ifdef UDPRELAY_REDIR memcpy(&(remote_ctx->dst_addr), &dst_addr, get_sockaddr_len((struct sockaddr *)&dst_addr)); #endif + remote_ctx->af = remote_addr->sa_family; remote_ctx->addr_header_len = addr_header_len; memcpy(remote_ctx->addr_header, addr_header, addr_header_len); @@ -1026,7 +1036,21 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) #else - if (remote_ctx == NULL) { + int cache_hit = 0; + int need_query = 0; + + if (remote_ctx != NULL) { + cache_hit = 1; + // detect destination mismatch + if (remote_ctx->addr_header_len != addr_header_len + || memcmp(addr_header, remote_ctx->addr_header, addr_header_len) != 0) { + if (storage.ss_family == AF_INET || storage.ss_family == AF_INET6) { + remote_ctx->dst_addr = storage; + } else { + need_query = 1; + } + } + } else { if (storage.ss_family == AF_INET || storage.ss_family == AF_INET6) { int remotefd = create_remote_socket(storage.ss_family == AF_INET6); if (remotefd != -1) { @@ -1042,70 +1066,69 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) setinterface(remotefd, server_ctx->iface); } #endif - struct remote_ctx *remote_ctx = new_remote(remotefd, server_ctx); + remote_ctx = new_remote(remotefd, server_ctx); remote_ctx->src_addr = src_addr; remote_ctx->dst_addr = storage; remote_ctx->server_ctx = server_ctx; remote_ctx->addr_header_len = addr_header_len; memcpy(remote_ctx->addr_header, addr_header, addr_header_len); - - size_t addr_len = get_sockaddr_len((struct sockaddr *)&remote_ctx->dst_addr); - int s = sendto(remote_ctx->fd, buf + addr_header_len, - buf_len - addr_header_len, 0, - (struct sockaddr *)&remote_ctx->dst_addr, - addr_len); - - if (s == -1) { - ERROR("[udp] sendto_remote"); - close_and_free_remote(EV_A_ remote_ctx); - } else { - // Add to conn cache - char *key = hash_key(remote_ctx->addr_header, - remote_ctx->addr_header_len, - &remote_ctx->src_addr); - cache_insert(server_ctx->conn_cache, key, - (void *)remote_ctx); - - ev_io_start(EV_A_ & remote_ctx->io); - ev_timer_start(EV_A_ & remote_ctx->watcher); - } } else { ERROR("[udp] bind() error"); - } - } else { - struct addrinfo hints; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_DGRAM; - hints.ai_protocol = IPPROTO_UDP; - - struct query_ctx *query_ctx = new_query_ctx(buf + addr_header_len, - buf_len - - addr_header_len); - query_ctx->server_ctx = server_ctx; - query_ctx->addr_header_len = addr_header_len; - query_ctx->src_addr = src_addr; - memcpy(query_ctx->addr_header, addr_header, addr_header_len); - - struct ResolvQuery *query = resolv_query(host, query_resolve_cb, - NULL, query_ctx, - htons(atoi(port))); - if (query == NULL) { - ERROR("[udp] unable to create DNS query"); - close_and_free_query(EV_A_ query_ctx); goto CLEAN_UP; } - query_ctx->query = query; } - } else { + } + + if (remote_ctx != NULL && !need_query) { size_t addr_len = get_sockaddr_len((struct sockaddr *)&remote_ctx->dst_addr); int s = sendto(remote_ctx->fd, buf + addr_header_len, - buf_len - addr_header_len, 0, - (struct sockaddr *)&remote_ctx->dst_addr, addr_len); + buf_len - addr_header_len, 0, + (struct sockaddr *)&remote_ctx->dst_addr, addr_len); if (s == -1) { ERROR("[udp] sendto_remote"); + if (!cache_hit) { + close_and_free_remote(EV_A_ remote_ctx); + } + } else { + if (!cache_hit) { + // Add to conn cache + remote_ctx->af = remote_ctx->dst_addr.ss_family; + char *key = hash_key(remote_ctx->af, &remote_ctx->src_addr); + cache_insert(server_ctx->conn_cache, key, + (void *)remote_ctx); + + ev_io_start(EV_A_ & remote_ctx->io); + ev_timer_start(EV_A_ & remote_ctx->watcher); + } + } + } else { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_DGRAM; + hints.ai_protocol = IPPROTO_UDP; + + struct query_ctx *query_ctx = new_query_ctx(buf + addr_header_len, + buf_len - + addr_header_len); + query_ctx->server_ctx = server_ctx; + query_ctx->addr_header_len = addr_header_len; + query_ctx->src_addr = src_addr; + memcpy(query_ctx->addr_header, addr_header, addr_header_len); + + if (need_query) { + query_ctx->remote_ctx = remote_ctx; + } + + struct ResolvQuery *query = resolv_query(host, query_resolve_cb, + NULL, query_ctx, htons(atoi(port))); + if (query == NULL) { + ERROR("[udp] unable to create DNS query"); + close_and_free_query(EV_A_ query_ctx); + goto CLEAN_UP; } + query_ctx->query = query; } #endif diff --git a/src/udprelay.h b/src/udprelay.h index a876c18f..1affdaa9 100644 --- a/src/udprelay.h +++ b/src/udprelay.h @@ -67,12 +67,14 @@ struct query_ctx { int addr_header_len; char addr_header[384]; struct server_ctx *server_ctx; + struct remote_ctx *remote_ctx; }; #endif struct remote_ctx { ev_io io; ev_timer watcher; + int af; int fd; int src_fd; int addr_header_len;