diff --git a/src/acl.c b/src/acl.c index 28c10097..27c80524 100644 --- a/src/acl.c +++ b/src/acl.c @@ -1,7 +1,10 @@ #include +#include + #include "utils.h" -static struct ip_set set; +static struct ip_set acl_ip_set; +static struct cork_string_array acl_domain_array; static void parse_addr_cidr(const char *str, char **host, int *cidr) { @@ -35,8 +38,12 @@ static void parse_addr_cidr(const char *str, char **host, int *cidr) int init_acl(const char *path) { + // initialize ipset ipset_init_library(); - ipset_init(&set); + ipset_init(&acl_ip_set); + + // initialize array + cork_string_array_init(&acl_domain_array); FILE *f = fopen(path, "r"); if (f == NULL) FATAL("Invalid acl path."); @@ -49,13 +56,23 @@ int init_acl(const char *path) char *host = NULL; int cidr; parse_addr_cidr(line, &host, &cidr); - struct cork_ipv4 addr; - int err = cork_ipv4_init(&addr, host); - if (err) continue; - if (cidr >= 0) - ipset_ipv4_add_network(&set, &addr, cidr); + + if (cidr == -1) + { + cork_string_array_append(&acl_domain_array, host); + } else - ipset_ipv4_add(&set, &addr); + { + struct cork_ipv4 addr; + int err = cork_ipv4_init(&addr, host); + if (!err) + { + if (cidr >= 0) + ipset_ipv4_add_network(&acl_ip_set, &addr, cidr); + else + ipset_ipv4_add(&acl_ip_set, &addr); + } + } if (host != NULL) free(host); } @@ -68,15 +85,42 @@ int init_acl(const char *path) void free_acl(void) { - ipset_done(&set); + ipset_done(&acl_ip_set); +} + +int acl_contains_domain(const char* domain) +{ + const char **list = acl_domain_array.items; + const int size = acl_domain_array.size; + const int domain_len = strlen(domain); + + for (int i = 0; i < size; i++) + { + const char *acl_domain = list[i]; + const int acl_domain_len = strlen(acl_domain); + if (acl_domain_len > domain_len) continue; + int match = true; + for (int offset = 1; offset <= acl_domain_len; offset++) + { + if (domain[domain_len - offset] != acl_domain[acl_domain_len - offset]) + { + match = false; + break; + } + } + if (match) return 1; + } + + return 0; } -int acl_is_bypass(const char* host) +int acl_contains_ip(const char* host) { struct cork_ipv4 addr; int err = cork_ipv4_init(&addr, host); if (err) return 0; + struct cork_ip ip; cork_ip_from_ipv4(&ip, &addr); - return ipset_contains_ip(&set, &ip); + return ipset_contains_ip(&acl_ip_set, &ip); } diff --git a/src/acl.h b/src/acl.h index 2acb6678..2b392a3f 100644 --- a/src/acl.h +++ b/src/acl.h @@ -3,6 +3,8 @@ int init_acl(const char *path); void free_acl(void); -int acl_is_bypass(const char* host); + +int acl_contains_ip(const char* ip); +int acl_contains_domain(const char* domain); #endif // _ACL_H diff --git a/src/local.c b/src/local.c index 0abe9a77..c6f98df4 100644 --- a/src/local.c +++ b/src/local.c @@ -361,7 +361,6 @@ static void server_recv_cb (EV_P_ ev_io *w, int revents) host, INET_ADDRSTRLEN); sprintf(port, "%d", p); } - } else if (request->atyp == 3) { @@ -378,7 +377,6 @@ static void server_recv_cb (EV_P_ ev_io *w, int revents) host[name_len] = '\0'; sprintf(port, "%d", p); } - } else if (request->atyp == 4) { @@ -394,7 +392,6 @@ static void server_recv_cb (EV_P_ ev_io *w, int revents) host, INET6_ADDRSTRLEN); sprintf(port, "%d", p); } - } else { @@ -413,7 +410,8 @@ static void server_recv_cb (EV_P_ ev_io *w, int revents) LOGD("connect to %s:%s", host, port); } - if (acl_is_bypass(host)) + if ((request->atyp == 1 && acl_contains_ip(host)) + || (request->atyp = 3 && acl_contains_domain(host))) { remote = connect_to_remote(server->listener, host, port); remote->direct = 1;