diff --git a/src/acl.c b/src/acl.c index ff81736b..34622977 100644 --- a/src/acl.c +++ b/src/acl.c @@ -25,6 +25,7 @@ #include "rule.h" #include "utils.h" +#include "cache.h" #include "acl.h" static struct ip_set white_list_ipv4; @@ -38,6 +39,34 @@ static struct cork_dllist white_list_rules; static int acl_mode = BLACK_LIST; +static struct cache *block_list; + +void init_block_list() +{ + // Initialize cache + cache_create(&block_list, 256, NULL); +} + +int check_block_list(char* addr, int max_tries) +{ + size_t addr_len = strlen(addr); + + if (cache_key_exist(block_list, addr, addr_len)) { + int *count = NULL; + cache_lookup(block_list, addr, addr_len, &count); + if (count != NULL) { + if (*count > max_tries) return 1; + (*count)++; + } + } else { + int *count = (int*)ss_malloc(sizeof(int)); + *count = 1; + cache_insert(block_list, addr, addr_len, count); + } + + return 0; +} + static void parse_addr_cidr(const char *str, char *host, int *cidr) { diff --git a/src/acl.h b/src/acl.h index a37f6eb8..9be72e65 100644 --- a/src/acl.h +++ b/src/acl.h @@ -35,4 +35,7 @@ int acl_remove_ip(const char *ip); int get_acl_mode(void); +void init_block_list(); +int check_block_list(char* addr, int max_tries); + #endif // _ACL_H diff --git a/src/cache.c b/src/cache.c index fd4152e0..5387d021 100644 --- a/src/cache.c +++ b/src/cache.c @@ -93,6 +93,8 @@ cache_delete(struct cache *cache, int keep_data) if (entry->data != NULL) { if (cache->free_cb) { cache->free_cb(entry->data); + } else { + ss_free(tmp->data); } } ss_free(entry->key); diff --git a/src/server.c b/src/server.c index 26708b47..437fda57 100644 --- a/src/server.c +++ b/src/server.c @@ -306,6 +306,10 @@ report_addr(int fd) if (peer_name != NULL) { LOGE("failed to handshake with %s", peer_name); } + // Block all requests from this IP, if the err# exceeds 128. + if (check_block_list(peer_name, 128)) { + LOGE("block all requests from %s", peer_name); + } } int @@ -735,17 +739,7 @@ server_recv_cb(EV_P_ ev_io *w, int revents) server->buf->len = offset + header_len + ONETIMEAUTH_BYTES; if (ss_onetimeauth_verify(server->buf, server->d_ctx->evp.iv)) { - char *peer_name = get_peer_name(server->fd); - if (peer_name) { - LOGE("authentication error from %s", peer_name); - if (acl) { - if (get_acl_mode() == BLACK_LIST) { - // Auto ban enabled only in black list mode - acl_add_ip(peer_name); - LOGE("add %s to the black list", peer_name); - } - } - } + report_addr(server->fd); close_and_free_server(EV_A_ server); return; } @@ -1715,6 +1709,9 @@ main(int argc, char **argv) LOGI("initializing ciphers... %s", method); int m = enc_init(password, method); + // init block list + init_block_list(); + // initialize ev loop struct ev_loop *loop = EV_DEFAULT;