diff --git a/src/encrypt.c b/src/encrypt.c index 025a5053..b040ca1e 100644 --- a/src/encrypt.c +++ b/src/encrypt.c @@ -201,6 +201,35 @@ static const int supported_ciphers_key_size[CIPHER_NUM] = 0, 16, 16, 16, 24, 32, 16, 16, 24, 32, 16, 8, 16, 16, 16, 32, 32 }; +int balloc(buffer_t *ptr, size_t capacity) +{ + memset(ptr, 0 , sizeof(buffer_t)); + ptr->array = malloc(capacity); + ptr->capacity = capacity; + return capacity; +} + +int brealloc(buffer_t *ptr, size_t len, size_t capacity) +{ + int real_capacity = max(len, capacity); + if (ptr->capacity < real_capacity) { + ptr->array = realloc(ptr->array, real_capacity); + } + ptr->capacity = real_capacity; + return real_capacity; +} + +void bfree(buffer_t *ptr) +{ + ptr->idx = 0; + ptr->len = 0; + ptr->capacity = 0; + if (ptr->array != NULL) { + free(ptr->array); + ptr->array = NULL; + } +} + static int crypto_stream_xor_ic(uint8_t *c, const uint8_t *m, uint64_t mlen, const uint8_t *n, uint64_t ic, const uint8_t *k, int method) @@ -1030,25 +1059,28 @@ static int cipher_context_update(cipher_ctx_t *ctx, uint8_t *output, size_t *ole #endif } -int ss_onetimeauth(char *auth, char *msg, int msg_len, uint8_t *iv) +int ss_onetimeauth(buffer_t *buf, uint8_t *iv) { uint8_t hash[ONETIMEAUTH_BYTES * 2]; uint8_t auth_key[MAX_IV_LENGTH + MAX_KEY_LENGTH]; memcpy(auth_key, iv, enc_iv_len); memcpy(auth_key + enc_iv_len, enc_key, enc_key_len); + brealloc(buf, ONETIMEAUTH_BYTES + buf->len, buf->capacity); + #if defined(USE_CRYPTO_OPENSSL) - HMAC(EVP_sha1(), auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, (uint8_t *)hash, NULL); + HMAC(EVP_sha1(), auth_key, enc_iv_len + enc_key_len, (uint8_t *)buf->array, buf->len, (uint8_t *)hash, NULL); #else - ss_sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, (uint8_t *)hash); + ss_sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)buf->array, buf->len, (uint8_t *)hash); #endif - memcpy(auth, hash, ONETIMEAUTH_BYTES); + memcpy(buf->array + buf->len, hash, ONETIMEAUTH_BYTES); + buf->len += ONETIMEAUTH_BYTES; return 0; } -int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, uint8_t *iv) +int ss_onetimeauth_verify(buffer_t *buf, uint8_t *iv) { uint8_t hash[ONETIMEAUTH_BYTES * 2]; uint8_t auth_key[MAX_IV_LENGTH + MAX_KEY_LENGTH]; @@ -1056,219 +1088,190 @@ int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, uint8_t *iv) memcpy(auth_key + enc_iv_len, enc_key, enc_key_len); #if defined(USE_CRYPTO_OPENSSL) - HMAC(EVP_sha1(), auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, hash, NULL); + HMAC(EVP_sha1(), auth_key, enc_iv_len + enc_key_len, (uint8_t *)buf->array, buf->len, hash, NULL); #else - ss_sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, hash); + ss_sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)buf->array, buf->len, hash); #endif - return memcmp(auth, hash, ONETIMEAUTH_BYTES); + return memcmp(buf->array + buf->len - ONETIMEAUTH_BYTES, hash, ONETIMEAUTH_BYTES); } -char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method, int auth) +int ss_encrypt_all(buffer_t *plain, int method, int auth) { if (method > TABLE) { cipher_ctx_t evp; cipher_context_init(&evp, method, 1); - size_t p_len = *len, c_len = *len; + size_t p_len = plain->len, c_len = plain->len; size_t iv_len = enc_iv_len; int err = 1; - static int tmp_len = 0; - static char *tmp_buf = NULL; - int buf_len = max(iv_len + c_len, buf_size); - if (tmp_len < buf_len) { - tmp_len = buf_len; - tmp_buf = realloc(tmp_buf, buf_len); - } - char *ciphertext = tmp_buf; + static buffer_t tmp = {0}; + brealloc(&tmp, iv_len + c_len, plain->capacity); + buffer_t *cipher = &tmp; uint8_t iv[MAX_IV_LENGTH]; rand_bytes(iv, iv_len); cipher_context_set_iv(&evp, iv, iv_len, 1); - memcpy(ciphertext, iv, iv_len); + memcpy(cipher, iv, iv_len); if (auth) { - char hash[ONETIMEAUTH_BYTES * 2]; - ss_onetimeauth(hash, plaintext, p_len, iv); - if (buf_size < ONETIMEAUTH_BYTES + p_len) { - plaintext = realloc(plaintext, ONETIMEAUTH_BYTES + p_len); - } - memcpy(plaintext + p_len, hash, ONETIMEAUTH_BYTES); + ss_onetimeauth(plain, iv); p_len = c_len = p_len + ONETIMEAUTH_BYTES; } if (method >= SALSA20) { - crypto_stream_xor_ic((uint8_t *)(ciphertext + iv_len), - (const uint8_t *)plaintext, (uint64_t)(p_len), + crypto_stream_xor_ic((uint8_t *)(cipher + iv_len), + (const uint8_t *)plain->array, (uint64_t)(p_len), (const uint8_t *)iv, 0, enc_key, method); } else { - err = cipher_context_update(&evp, (uint8_t *)(ciphertext + iv_len), - &c_len, (const uint8_t *)plaintext, + err = cipher_context_update(&evp, (uint8_t *)(cipher + iv_len), + &c_len, (const uint8_t *)plain->array, p_len); } if (!err) { - free(plaintext); + bfree(plain); cipher_context_release(&evp); - return NULL; + return -1; } #ifdef DEBUG - dump("PLAIN", plaintext, *len); - dump("CIPHER", ciphertext + iv_len, c_len); + dump("PLAIN", plain->array, *len); + dump("CIPHER", cipher->array + iv_len, c_len); #endif cipher_context_release(&evp); - if (buf_size < iv_len + c_len) { - plaintext = realloc(plaintext, iv_len + c_len); - } - *len = iv_len + c_len; - memcpy(plaintext, ciphertext, *len); + brealloc(plain, iv_len + c_len, plain->capacity); + memcpy(plain, cipher, iv_len + c_len); + plain->len = iv_len + c_len; - return plaintext; + return 0; } else { - char *begin = plaintext; - while (plaintext < begin + *len) { - *plaintext = (char)enc_table[(uint8_t)*plaintext]; - plaintext++; + char *begin = plain->array; + char *ptr = plain->array; + while (ptr < begin + plain->len) { + *ptr = (char)enc_table[(uint8_t)*ptr]; + ptr++; } - return begin; + return 0; } } -char * ss_encrypt(int buf_size, char *plaintext, ssize_t *len, - struct enc_ctx *ctx) +int ss_encrypt(buffer_t *plain, enc_ctx_t *ctx) { if (ctx != NULL) { - static int tmp_len = 0; - static char *tmp_buf = NULL; + static buffer_t tmp = {0}; int err = 1; size_t iv_len = 0; - size_t p_len = *len, c_len = *len; + size_t p_len = plain->len, c_len = plain->len; if (!ctx->init) { iv_len = enc_iv_len; } - int buf_len = max(iv_len + c_len, buf_size); - if (tmp_len < buf_len) { - tmp_len = buf_len; - tmp_buf = realloc(tmp_buf, buf_len); - } - char *ciphertext = tmp_buf; + brealloc(&tmp, iv_len + c_len, plain->capacity); + + buffer_t *cipher = &tmp; if (!ctx->init) { cipher_context_set_iv(&ctx->evp, ctx->evp.iv, iv_len, 1); - memcpy(ciphertext, ctx->evp.iv, iv_len); + memcpy(cipher->array, ctx->evp.iv, iv_len); ctx->counter = 0; ctx->init = 1; } if (enc_method >= SALSA20) { int padding = ctx->counter % SODIUM_BLOCK_SIZE; - if (buf_len < iv_len + padding + c_len) { - buf_len = max(iv_len + (padding + c_len) * 2, buf_size); - ciphertext = realloc(ciphertext, buf_len); - tmp_len = buf_len; - tmp_buf = ciphertext; - } + brealloc(cipher, iv_len + (padding + c_len) * 2, cipher->capacity); if (padding) { - plaintext = realloc(plaintext, max(p_len + padding, buf_size)); - memmove(plaintext + padding, plaintext, p_len); - memset(plaintext, 0, padding); + brealloc(plain, p_len + padding, plain->capacity); + memmove(plain->array + padding, plain->array, p_len); + memset(plain->array, 0, padding); } - crypto_stream_xor_ic((uint8_t *)(ciphertext + iv_len), - (const uint8_t *)plaintext, + crypto_stream_xor_ic((uint8_t *)(cipher->array + iv_len), + (const uint8_t *)plain->array, (uint64_t)(p_len + padding), (const uint8_t *)ctx->evp.iv, ctx->counter / SODIUM_BLOCK_SIZE, enc_key, enc_method); ctx->counter += p_len; if (padding) { - memmove(ciphertext + iv_len, ciphertext + iv_len + padding, - c_len); + memmove(cipher->array + iv_len, + cipher->array + iv_len + padding, c_len); } } else { err = cipher_context_update(&ctx->evp, - (uint8_t *)(ciphertext + iv_len), - &c_len, (const uint8_t *)plaintext, + (uint8_t *)(cipher->array + iv_len), + &c_len, (const uint8_t *)plain->array, p_len); if (!err) { - free(plaintext); - return NULL; + return -1; } } #ifdef DEBUG - dump("PLAIN", plaintext, p_len); - dump("CIPHER", ciphertext + iv_len, c_len); + dump("PLAIN", plain->array, p_len); + dump("CIPHER", cipher->array + iv_len, c_len); #endif - if (buf_size < iv_len + c_len) { - plaintext = realloc(plaintext, iv_len + c_len); - } - *len = iv_len + c_len; - memcpy(plaintext, ciphertext, *len); + brealloc(plain, iv_len + c_len, plain->capacity); + memcpy(plain->array, cipher->array, iv_len + c_len); + plain->len = iv_len + c_len; - return plaintext; + return 0; } else { - char *begin = plaintext; - while (plaintext < begin + *len) { - *plaintext = (char)enc_table[(uint8_t)*plaintext]; - plaintext++; + char *begin = plain->array; + char *ptr = plain->array; + while (ptr < begin + plain->len) { + *ptr = (char)enc_table[(uint8_t)*ptr]; + ptr++; } - return begin; + return 0; } } -char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method, int auth) +int ss_decrypt_all(buffer_t *cipher, int method, int auth) { if (method > TABLE) { size_t iv_len = enc_iv_len; - size_t c_len = *len, p_len = *len - iv_len; + size_t c_len = cipher->len; + size_t p_len = cipher->len - iv_len; int ret = 1; - if (*len <= iv_len) { - return NULL; + if (c_len <= iv_len) { + return -1; } cipher_ctx_t evp; cipher_context_init(&evp, method, 0); - static int tmp_len = 0; - static char *tmp_buf = NULL; - int buf_len = max(p_len, buf_size); - if (tmp_len < buf_len) { - tmp_len = buf_len; - tmp_buf = realloc(tmp_buf, buf_len); - } - char *plaintext = tmp_buf; + static buffer_t tmp = {0}; + brealloc(&tmp, cipher->len, cipher->capacity); + buffer_t *plain = &tmp; uint8_t iv[MAX_IV_LENGTH]; - memcpy(iv, ciphertext, iv_len); + memcpy(iv, cipher, iv_len); cipher_context_set_iv(&evp, iv, iv_len, 0); if (method >= SALSA20) { - crypto_stream_xor_ic((uint8_t *)plaintext, - (const uint8_t *)(ciphertext + iv_len), + crypto_stream_xor_ic((uint8_t *)plain->array, + (const uint8_t *)(cipher->array + iv_len), (uint64_t)(c_len - iv_len), (const uint8_t *)iv, 0, enc_key, method); } else { - ret = cipher_context_update(&evp, (uint8_t *)plaintext, &p_len, - (const uint8_t *)(ciphertext + iv_len), + ret = cipher_context_update(&evp, (uint8_t *)plain->array, &p_len, + (const uint8_t *)(cipher->array + iv_len), c_len - iv_len); } - if (auth || (plaintext[0] & ONETIMEAUTH_FLAG)) { + if (auth || (plain->array[0] & ONETIMEAUTH_FLAG)) { if (p_len > ONETIMEAUTH_BYTES) { - char hash[ONETIMEAUTH_BYTES]; - memcpy(hash, plaintext + p_len - ONETIMEAUTH_BYTES, ONETIMEAUTH_BYTES); - ret = !ss_onetimeauth_verify(hash, plaintext, p_len - ONETIMEAUTH_BYTES, iv); + ret = !ss_onetimeauth_verify(plain, iv); if (ret) { p_len -= ONETIMEAUTH_BYTES; } @@ -1278,65 +1281,61 @@ char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method, } if (!ret) { - free(ciphertext); + bfree(cipher); cipher_context_release(&evp); - return NULL; + return -1; } #ifdef DEBUG - dump("PLAIN", plaintext, p_len); - dump("CIPHER", ciphertext + iv_len, c_len - iv_len); + dump("PLAIN", plain->array, p_len); + dump("CIPHER", cipher->array + iv_len, c_len - iv_len); #endif cipher_context_release(&evp); - if (buf_size < p_len) { - ciphertext = realloc(ciphertext, p_len); - } - *len = p_len; - memcpy(ciphertext, plaintext, *len); + brealloc(cipher, p_len, plain->capacity); + memcpy(cipher, plain, p_len); + cipher->len = p_len; - return ciphertext; + return 0; } else { - char *begin = ciphertext; - while (ciphertext < begin + *len) { - *ciphertext = (char)dec_table[(uint8_t)*ciphertext]; - ciphertext++; + char *begin = cipher->array; + char *ptr = cipher->array; + while (ptr < begin + cipher->len) { + *ptr = (char)dec_table[(uint8_t)*ptr]; + ptr++; } - return begin; + return 0; } } -char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len, struct enc_ctx *ctx) +int ss_decrypt(buffer_t *cipher, enc_ctx_t *ctx) { if (ctx != NULL) { - static int tmp_len = 0; - static char *tmp_buf = NULL; + static buffer_t tmp = {0}; - size_t c_len = *len, p_len = *len; + size_t c_len = cipher->len; + size_t p_len = cipher->len; size_t iv_len = 0; int err = 1; - int buf_len = max(p_len, buf_size); - if (tmp_len < buf_len) { - tmp_len = buf_len; - tmp_buf = realloc(tmp_buf, buf_len); - } - char *plaintext = tmp_buf; + brealloc(&tmp, p_len, cipher->capacity); + buffer_t *plain = &tmp; if (!ctx->init) { uint8_t iv[MAX_IV_LENGTH]; iv_len = enc_iv_len; p_len -= iv_len; - memcpy(iv, ciphertext, iv_len); + + memcpy(iv, cipher->array, iv_len); cipher_context_set_iv(&ctx->evp, iv, iv_len, 0); ctx->counter = 0; ctx->init = 1; if (enc_method >= RC4_MD5) { if (cache_key_exist(iv_cache, (char *)iv, iv_len)) { - free(ciphertext); - return NULL; + bfree(cipher); + return -1; } else { cache_insert(iv_cache, (char *)iv, iv_len, NULL); } @@ -1345,64 +1344,59 @@ char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len, struct enc_ctx * if (enc_method >= SALSA20) { int padding = ctx->counter % SODIUM_BLOCK_SIZE; - if (buf_len < (p_len + padding) * 2) { - buf_len = max((p_len + padding) * 2, buf_size); - plaintext = realloc(plaintext, buf_len); - tmp_len = buf_len; - tmp_buf = plaintext; - } + brealloc(plain, (p_len + padding) * 2, plain->capacity); + if (padding) { - ciphertext = realloc(ciphertext, max(c_len + padding, buf_size)); - memmove(ciphertext + iv_len + padding, ciphertext + iv_len, + brealloc(cipher, c_len + padding, cipher->capacity); + memmove(cipher->array + iv_len + padding, cipher->array + iv_len, c_len - iv_len); - memset(ciphertext + iv_len, 0, padding); + memset(cipher->array + iv_len, 0, padding); } - crypto_stream_xor_ic((uint8_t *)plaintext, - (const uint8_t *)(ciphertext + iv_len), + crypto_stream_xor_ic((uint8_t *)plain->array, + (const uint8_t *)(cipher->array + iv_len), (uint64_t)(c_len - iv_len + padding), (const uint8_t *)ctx->evp.iv, ctx->counter / SODIUM_BLOCK_SIZE, enc_key, enc_method); ctx->counter += c_len - iv_len; if (padding) { - memmove(plaintext, plaintext + padding, p_len); + memmove(plain->array, plain->array + padding, p_len); } } else { - err = cipher_context_update(&ctx->evp, (uint8_t *)plaintext, &p_len, - (const uint8_t *)(ciphertext + iv_len), + err = cipher_context_update(&ctx->evp, (uint8_t *)plain->array, &p_len, + (const uint8_t *)(cipher + iv_len), c_len - iv_len); } if (!err) { - free(ciphertext); - return NULL; + bfree(cipher); + return -1; } #ifdef DEBUG - dump("PLAIN", plaintext, p_len); - dump("CIPHER", ciphertext + iv_len, c_len - iv_len); + dump("PLAIN", plain->array, p_len); + dump("CIPHER", cipher->array + iv_len, c_len - iv_len); #endif - if (buf_size < p_len) { - ciphertext = realloc(ciphertext, p_len); - } - *len = p_len; - memcpy(ciphertext, plaintext, *len); + brealloc(cipher, p_len, cipher->capacity); + memcpy(cipher->array, plain->array, p_len); + cipher->len = p_len; - return ciphertext; + return 0; } else { - char *begin = ciphertext; - while (ciphertext < begin + *len) { - *ciphertext = (char)dec_table[(uint8_t)*ciphertext]; - ciphertext++; + char *begin = cipher->array; + char *ptr = cipher->array; + while (ptr < begin + cipher->len) { + *ptr = (char)dec_table[(uint8_t)*ptr]; + ptr++; } - return begin; + return 0; } } -void enc_ctx_init(int method, struct enc_ctx *ctx, int enc) +void enc_ctx_init(int method, enc_ctx_t *ctx, int enc) { - memset(ctx, 0, sizeof(struct enc_ctx)); + memset(ctx, 0, sizeof(enc_ctx_t)); cipher_context_init(&ctx->evp, method, enc); if (enc) { @@ -1520,33 +1514,22 @@ int enc_init(const char *pass, const char *method) return m; } -int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct enc_ctx *ctx, int buf_size) +int ss_check_hash(buffer_t *buf, chunk_t *chunk, enc_ctx_t *ctx) { int i, j, k; - char *buf = *buf_ptr; - ssize_t blen = *buf_len; + ssize_t blen = buf->len; uint32_t cidx = chunk->idx; - if (chunk->buf == NULL) { - chunk->buf = (char *)malloc(buf_size); - chunk->len = buf_size - AUTH_BYTES; - } - - int size = max(chunk->len + blen, buf_size); - if (buf_size < size) { - buf = realloc(buf, size); - } + brealloc(chunk->buf, blen, buf->capacity); + brealloc(buf, chunk->len + blen, buf->capacity); for (i = 0, j = 0, k = 0; i < blen; i++) { - chunk->buf[cidx++] = buf[k++]; + chunk->buf->array[cidx++] = buf->array[k++]; if (cidx == CLEN_BYTES) { - uint16_t clen = ntohs(*((uint16_t *)chunk->buf)); - - if (buf_size < clen + AUTH_BYTES) { - chunk->buf = realloc(chunk->buf, clen + AUTH_BYTES); - } + uint16_t clen = ntohs(*((uint16_t *)chunk->buf->array)); + brealloc(chunk->buf, clen + AUTH_BYTES, buf->capacity); chunk->len = clen; } @@ -1560,20 +1543,19 @@ int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct memcpy(key + enc_iv_len, &c, sizeof(uint32_t)); #if defined(USE_CRYPTO_OPENSSL) HMAC(EVP_sha1(), key, enc_iv_len + sizeof(uint32_t), - (uint8_t *)chunk->buf + AUTH_BYTES, chunk->len, hash, NULL); + (uint8_t *)chunk->buf->array + AUTH_BYTES, chunk->len, hash, NULL); #else ss_sha1_hmac(key, enc_iv_len + sizeof(uint32_t), - (uint8_t *)chunk->buf + AUTH_BYTES, chunk->len, hash); + (uint8_t *)chunk->buf->array + AUTH_BYTES, chunk->len, hash); #endif - if (memcmp(hash, chunk->buf + CLEN_BYTES, ONETIMEAUTH_BYTES) != 0) { - *buf_ptr = buf; + if (memcmp(hash, chunk->buf->array + CLEN_BYTES, ONETIMEAUTH_BYTES) != 0) { return 0; } // Copy chunk back to buffer - memmove(buf + j + chunk->len, buf + k, blen - i - 1); - memcpy(buf + j, chunk->buf + AUTH_BYTES, chunk->len); + memmove(buf->array + j + chunk->len, buf->array + k, blen - i - 1); + memcpy(buf->array + j, chunk->buf->array + AUTH_BYTES, chunk->len); // Reset the base offset j += chunk->len; @@ -1583,39 +1565,34 @@ int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct } } - *buf_ptr = buf; - *buf_len = j; + buf->len = j; chunk->idx = cidx; return 1; } -char *ss_gen_hash(char *buf, ssize_t *buf_len, uint32_t *counter, struct enc_ctx *ctx, int buf_size) +int ss_gen_hash(buffer_t *buf, uint32_t *counter, enc_ctx_t *ctx) { - ssize_t blen = *buf_len; - int size = max(AUTH_BYTES + blen, buf_size); - - if (buf_size < size) { - buf = realloc(buf, size); - } - + ssize_t blen = buf->len; uint16_t chunk_len = htons((uint16_t)blen); uint8_t hash[ONETIMEAUTH_BYTES * 2]; uint8_t key[MAX_IV_LENGTH + sizeof(uint32_t)]; - uint32_t c = htonl(*counter); + + brealloc(buf, AUTH_BYTES + blen, buf->capacity); memcpy(key, ctx->evp.iv, enc_iv_len); memcpy(key + enc_iv_len, &c, sizeof(uint32_t)); #if defined(USE_CRYPTO_OPENSSL) - HMAC(EVP_sha1(), key, enc_iv_len + sizeof(uint32_t), (uint8_t *)buf, blen, hash, NULL); + HMAC(EVP_sha1(), key, enc_iv_len + sizeof(uint32_t), (uint8_t *)buf->array, blen, hash, NULL); #else - ss_sha1_hmac(key, enc_iv_len + sizeof(uint32_t), (uint8_t *)buf, blen, hash); + ss_sha1_hmac(key, enc_iv_len + sizeof(uint32_t), (uint8_t *)buf->array, blen, hash); #endif - memmove(buf + AUTH_BYTES, buf, blen); - memcpy(buf + CLEN_BYTES, hash, ONETIMEAUTH_BYTES); - memcpy(buf, &chunk_len, CLEN_BYTES); + memmove(buf->array + AUTH_BYTES, buf->array, blen); + memcpy(buf->array + CLEN_BYTES, hash, ONETIMEAUTH_BYTES); + memcpy(buf->array, &chunk_len, CLEN_BYTES); *counter = *counter + 1; - *buf_len = blen + AUTH_BYTES; - return buf; + buf->len = blen + AUTH_BYTES; + + return 0; } diff --git a/src/encrypt.h b/src/encrypt.h index a5251bb4..456b1fc0 100644 --- a/src/encrypt.h +++ b/src/encrypt.h @@ -151,35 +151,45 @@ typedef struct { #define min(a, b) (((a) < (b)) ? (a) : (b)) #define max(a, b) (((a) > (b)) ? (a) : (b)) -struct chunk { +typedef struct buffer { + size_t idx; + size_t len; + size_t capacity; + char *array; +} buffer_t; + +typedef struct chunk { uint32_t idx; uint32_t len; uint32_t counter; - char *buf; -}; + buffer_t *buf; +} chunk_t; -struct enc_ctx { +typedef struct enc_ctx { uint8_t init; uint64_t counter; cipher_ctx_t evp; -}; - -char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method, int auth); -char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method, int auth); -char * ss_encrypt(int buf_size, char *plaintext, ssize_t *len, - struct enc_ctx *ctx); -char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len, - struct enc_ctx *ctx); -void enc_ctx_init(int method, struct enc_ctx *ctx, int enc); +} enc_ctx_t; + +int ss_encrypt_all(buffer_t *plaintext, int method, int auth); +int ss_decrypt_all(buffer_t *ciphertext, int method, int auth); +int ss_encrypt(buffer_t *plaintext, enc_ctx_t *ctx); +int ss_decrypt(buffer_t *ciphertext, enc_ctx_t *ctx); + +void enc_ctx_init(int method, enc_ctx_t *ctx, int enc); int enc_init(const char *pass, const char *method); int enc_get_iv_len(void); void cipher_context_release(cipher_ctx_t *evp); unsigned char *enc_md5(const unsigned char *d, size_t n, unsigned char *md); -int ss_onetimeauth(char *auth, char *msg, int msg_len, uint8_t *iv); -int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, uint8_t *iv); +int ss_onetimeauth(buffer_t *buf, uint8_t *iv); +int ss_onetimeauth_verify(buffer_t *buf, uint8_t *iv); + +int ss_check_hash(buffer_t *buf, chunk_t *chunk, enc_ctx_t *ctx); +int ss_gen_hash(buffer_t *buf, uint32_t *counter, enc_ctx_t *ctx); -int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct enc_ctx *ctx, int buf_size); -char *ss_gen_hash(char *buf, ssize_t *buf_len, uint32_t *counter, struct enc_ctx *ctx, int buf_size); +int balloc(buffer_t *ptr, size_t capacity); +int brealloc(buffer_t *ptr, size_t len, size_t capacity); +void bfree(buffer_t *ptr); #endif // _ENCRYPT_H diff --git a/src/local.c b/src/local.c index 1d0b4a39..eb61e8f7 100644 --- a/src/local.c +++ b/src/local.c @@ -107,14 +107,14 @@ static void accept_cb(EV_P_ ev_io *w, int revents); static void signal_cb(EV_P_ ev_signal *w, int revents); static int create_and_bind(const char *addr, const char *port); -static struct remote * create_remote(struct listen_ctx *listener, struct sockaddr *addr); -static void free_remote(struct remote *remote); -static void close_and_free_remote(EV_P_ struct remote *remote); -static void free_server(struct server *server); -static void close_and_free_server(EV_P_ struct server *server); +static remote_t * create_remote(listen_ctx_t *listener, struct sockaddr *addr); +static void free_remote(remote_t *remote); +static void close_and_free_remote(EV_P_ remote_t *remote); +static void free_server(server_t *server); +static void close_and_free_server(EV_P_ server_t *server); -static struct remote * new_remote(int fd, int timeout); -static struct server * new_server(int fd, int method); +static remote_t * new_remote(int fd, int timeout); +static server_t * new_server(int fd, int method); static struct cork_dllist connections; @@ -196,8 +196,8 @@ static void free_connections(struct ev_loop *loop) for (curr = cork_dllist_start(&connections); !cork_dllist_is_end(&connections, curr); curr = curr->next) { - struct server *server = cork_container_of(curr, struct server, entries); - struct remote *remote = server->remote; + server_t *server = cork_container_of(curr, server_t, entries); + remote_t *remote = server->remote; close_and_free_server(loop, server); close_and_free_remote(loop, remote); } @@ -205,10 +205,10 @@ static void free_connections(struct ev_loop *loop) static void server_recv_cb(EV_P_ ev_io *w, int revents) { - struct server_ctx *server_recv_ctx = (struct server_ctx *)w; - struct server *server = server_recv_ctx->server; - struct remote *remote = server->remote; - char *buf; + server_ctx_t *server_recv_ctx = (server_ctx_t *)w; + server_t *server = server_recv_ctx->server; + remote_t *remote = server->remote; + buffer_t *buf; if (remote == NULL) { buf = server->buf; @@ -218,7 +218,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) ssize_t r; - r = recv(server->fd, buf, BUF_SIZE, 0); + r = recv(server->fd, buf->array, BUF_SIZE, 0); if (r == 0) { // connection closed @@ -238,6 +238,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } } + buf->len = r; + while (1) { // local socks5 server if (server->stage == 5) { @@ -248,18 +250,17 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } if (!remote->direct && remote->send_ctx->connected && auth) { - remote->buf = ss_gen_hash(remote->buf, &r, &remote->counter, server->e_ctx, BUF_SIZE); + ss_gen_hash(remote->buf, &remote->counter, server->e_ctx); } // insert shadowsocks header if (!remote->direct) { #ifdef ANDROID - tx += r; + tx += remote->buf_len; #endif - remote->buf = ss_encrypt(BUF_SIZE, remote->buf, &r, - server->e_ctx); + int err = ss_encrypt(remote->buf, server->e_ctx); - if (remote->buf == NULL) { + if (err) { LOGE("invalid password or cipher"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); @@ -280,8 +281,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } #endif - remote->buf_idx = 0; - remote->buf_len = r; + remote->buf->idx = 0; if (!fast_open || remote->direct) { // connecting, wait until connected @@ -293,13 +293,12 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) ev_timer_start(EV_A_ & remote->send_ctx->watcher); } else { #ifdef TCP_FASTOPEN - int s = sendto(remote->fd, remote->buf, r, MSG_FASTOPEN, + int s = sendto(remote->fd, remote->buf->array, remote->buf->len, MSG_FASTOPEN, (struct sockaddr *)&(remote->addr), remote->addr_len); if (s == -1) { if (errno == EINPROGRESS) { // in progress, wait until connected - remote->buf_idx = 0; - remote->buf_len = r; + remote->buf->idx = 0; ev_io_stop(EV_A_ & server_recv_ctx->io); ev_io_start(EV_A_ & remote->send_ctx->io); return; @@ -315,9 +314,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); return; } - } else if (s < r) { - remote->buf_len = r - s; - remote->buf_idx = s; + } else if (s < remote->buf->len) { + remote->buf->len -= s; + remote->buf->idx = s; } // Just connected @@ -331,12 +330,11 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) #endif } } else { - int s = send(remote->fd, remote->buf, r, 0); + int s = send(remote->fd, remote->buf->array, remote->buf->len, 0); if (s == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // no data, wait for send - remote->buf_idx = 0; - remote->buf_len = r; + remote->buf->idx = 0; ev_io_stop(EV_A_ & server_recv_ctx->io); ev_io_start(EV_A_ & remote->send_ctx->io); return; @@ -346,9 +344,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); return; } - } else if (s < r) { - remote->buf_len = r - s; - remote->buf_idx = s; + } else if (s < remote->buf->len) { + remote->buf->len -= s; + remote->buf->idx = s; ev_io_stop(EV_A_ & server_recv_ctx->io); ev_io_start(EV_A_ & remote->send_ctx->io); return; @@ -365,16 +363,16 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) send(server->fd, send_buf, sizeof(response), 0); server->stage = 1; - int off = (buf[1] & 0xff) + 2; - if (buf[0] == 0x05 && off < r) { - memmove(buf, buf + off, r - off); - r -= off; + int off = (buf->array[1] & 0xff) + 2; + if (buf->array[0] == 0x05 && off < buf->len) { + memmove(buf->array, buf->array + off, buf->len - off); + buf->len -= off; continue; } return; } else if (server->stage == 1) { - struct socks5_request *request = (struct socks5_request *)buf; + struct socks5_request *request = (struct socks5_request *)buf->array; struct sockaddr_in sock_addr; memset(&sock_addr, 0, sizeof(sock_addr)); @@ -403,21 +401,22 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) return; } else { char host[256], port[16]; - char ss_addr_to_send[320]; - ssize_t addr_len = 0; - ss_addr_to_send[addr_len++] = request->atyp; + buffer_t ss_addr_to_send; + buffer_t *abuf = &ss_addr_to_send; + balloc(abuf, BUF_SIZE); + + abuf->array[abuf->len++] = request->atyp; // get remote addr and port if (request->atyp == 1) { // IP V4 size_t in_addr_len = sizeof(struct in_addr); - memcpy(ss_addr_to_send + addr_len, buf + 4, in_addr_len + 2); - addr_len += in_addr_len + 2; + memcpy(abuf->array + abuf->len, buf->array + 4, in_addr_len + 2); + abuf->len += in_addr_len + 2; if (acl || verbose) { - uint16_t p = - ntohs(*(uint16_t *)(buf + 4 + in_addr_len)); + uint16_t p = ntohs(*(uint16_t *)(buf + 4 + in_addr_len)); dns_ntop(AF_INET, (const void *)(buf + 4), host, INET_ADDRSTRLEN); sprintf(port, "%d", p); @@ -425,9 +424,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } else if (request->atyp == 3) { // Domain name uint8_t name_len = *(uint8_t *)(buf + 4); - ss_addr_to_send[addr_len++] = name_len; - memcpy(ss_addr_to_send + addr_len, buf + 4 + 1, name_len + 2); - addr_len += name_len + 2; + abuf->array[abuf->len++] = name_len; + memcpy(abuf->array + abuf->len, buf + 4 + 1, name_len + 2); + abuf->len += name_len + 2; if (acl || verbose) { uint16_t p = @@ -439,17 +438,17 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } else if (request->atyp == 4) { // IP V6 size_t in6_addr_len = sizeof(struct in6_addr); - memcpy(ss_addr_to_send + addr_len, buf + 4, in6_addr_len + 2); - addr_len += in6_addr_len + 2; + memcpy(abuf->array + abuf->len, buf + 4, in6_addr_len + 2); + abuf->len += in6_addr_len + 2; if (acl || verbose) { - uint16_t p = - ntohs(*(uint16_t *)(buf + 4 + in6_addr_len)); + uint16_t p = ntohs(*(uint16_t *)(buf + 4 + in6_addr_len)); dns_ntop(AF_INET6, (const void *)(buf + 4), host, INET6_ADDRSTRLEN); sprintf(port, "%d", p); } } else { + bfree(abuf); LOGE("unsupported addrtype: %d", request->atyp); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); @@ -458,8 +457,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) server->stage = 5; - r -= (3 + addr_len); - buf += (3 + addr_len); + buf->len -= (3 + abuf->len); + if (buf->len > 0) { + memmove(buf->array, buf->array + 3 + abuf->len, buf->len); + } if (verbose) { LOGI("connect to %s:%s", host, port); @@ -480,6 +481,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } if (remote == NULL) { + bfree(abuf); LOGE("invalid remote addr"); close_and_free_server(EV_A_ server); return; @@ -487,28 +489,31 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) if (!remote->direct) { if (auth) { - ss_addr_to_send[0] |= ONETIMEAUTH_FLAG; - ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx->evp.iv); - addr_len += ONETIMEAUTH_BYTES; + abuf->array[0] |= ONETIMEAUTH_FLAG; + ss_onetimeauth(abuf, server->e_ctx->evp.iv); } - memcpy(remote->buf, ss_addr_to_send, addr_len); + brealloc(remote->buf, buf->len + abuf->len, BUF_SIZE); + memcpy(remote->buf->array, abuf->array, abuf->len); + remote->buf->len = buf->len + abuf->len; - if (r > 0) { + if (buf->len > 0) { if (auth) { - buf = ss_gen_hash(buf, &r, &remote->counter, server->e_ctx, BUF_SIZE); + ss_gen_hash(buf, &remote->counter, server->e_ctx); } - memcpy(remote->buf + addr_len, buf, r); + memcpy(remote->buf->array + abuf->len, buf->array, buf->len); } - r += addr_len; } else { - if (r > 0) { - memcpy(remote->buf, buf, r); + if (buf->len > 0) { + memcpy(remote->buf->array, buf->array, buf->len); + remote->buf->len = buf->len; } } server->remote = remote; remote->server = server; + + bfree(abuf); } // Fake reply @@ -518,17 +523,17 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) response.rsv = 0; response.atyp = 1; - memcpy(server->buf, &response, sizeof(struct socks5_response)); - memcpy(server->buf + sizeof(struct socks5_response), + memcpy(server->buf->array, &response, sizeof(struct socks5_response)); + memcpy(server->buf->array + sizeof(struct socks5_response), &sock_addr.sin_addr, sizeof(sock_addr.sin_addr)); - memcpy(server->buf + sizeof(struct socks5_response) + + memcpy(server->buf->array + sizeof(struct socks5_response) + sizeof(sock_addr.sin_addr), &sock_addr.sin_port, sizeof(sock_addr.sin_port)); int reply_size = sizeof(struct socks5_response) + sizeof(sock_addr.sin_addr) + sizeof(sock_addr.sin_port); - int s = send(server->fd, server->buf, reply_size, 0); + int s = send(server->fd, server->buf->array, reply_size, 0); if (s < reply_size) { LOGE("failed to send fake reply"); close_and_free_remote(EV_A_ remote); @@ -547,18 +552,18 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) static void server_send_cb(EV_P_ ev_io *w, int revents) { - struct server_ctx *server_send_ctx = (struct server_ctx *)w; - struct server *server = server_send_ctx->server; - struct remote *remote = server->remote; - if (server->buf_len == 0) { + server_ctx_t *server_send_ctx = (server_ctx_t *)w; + server_t *server = server_send_ctx->server; + remote_t *remote = server->remote; + if (server->buf->len == 0) { // close and free close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } else { // has data to send - ssize_t s = send(server->fd, server->buf + server->buf_idx, - server->buf_len, 0); + ssize_t s = send(server->fd, server->buf->array + server->buf->idx, + server->buf->len, 0); if (s < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK) { ERROR("server_send_cb_send"); @@ -566,15 +571,15 @@ static void server_send_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); } return; - } else if (s < server->buf_len) { + } else if (s < server->buf->len) { // partly sent, move memory, wait for the next time to send - server->buf_len -= s; - server->buf_idx += s; + server->buf->len -= s; + server->buf->idx += s; return; } else { // all sent out, wait for reading - server->buf_len = 0; - server->buf_idx = 0; + server->buf->len = 0; + server->buf->idx = 0; ev_io_stop(EV_A_ & server_send_ctx->io); ev_io_start(EV_A_ & remote->recv_ctx->io); return; @@ -596,10 +601,10 @@ static void stat_update_cb(struct ev_loop *loop) static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) { - struct remote_ctx *remote_ctx = (struct remote_ctx *)(((void *)watcher) + remote_ctx_t *remote_ctx = (remote_ctx_t *)(((void *)watcher) - sizeof(ev_io)); - struct remote *remote = remote_ctx->remote; - struct server *server = remote->server; + remote_t *remote = remote_ctx->remote; + server_t *server = remote->server; if (verbose) { LOGI("TCP connection timeout"); @@ -611,9 +616,9 @@ static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) static void remote_recv_cb(EV_P_ ev_io *w, int revents) { - struct remote_ctx *remote_recv_ctx = (struct remote_ctx *)w; - struct remote *remote = remote_recv_ctx->remote; - struct server *server = remote->server; + remote_ctx_t *remote_recv_ctx = (remote_ctx_t *)w; + remote_t *remote = remote_recv_ctx->remote; + server_t *server = remote->server; ev_timer_again(EV_A_ & remote->recv_ctx->watcher); @@ -621,7 +626,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) stat_update_cb(loop); #endif - ssize_t r = recv(remote->fd, server->buf, BUF_SIZE, 0); + ssize_t r = recv(remote->fd, server->buf->array, BUF_SIZE, 0); if (r == 0) { // connection closed @@ -641,12 +646,14 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) } } + server->buf->len = r; + if (!remote->direct) { #ifdef ANDROID - rx += r; + rx += server->buf->len; #endif - server->buf = ss_decrypt(BUF_SIZE, server->buf, &r, server->d_ctx); - if (server->buf == NULL) { + int err = ss_decrypt(server->buf, server->d_ctx); + if (err) { LOGE("invalid password or cipher"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); @@ -654,13 +661,12 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) } } - int s = send(server->fd, server->buf, r, 0); + int s = send(server->fd, server->buf->array, server->buf->len, 0); if (s == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // no data, wait for send - server->buf_len = r; - server->buf_idx = 0; + server->buf->idx = 0; ev_io_stop(EV_A_ & remote_recv_ctx->io); ev_io_start(EV_A_ & server->send_ctx->io); return; @@ -670,9 +676,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); return; } - } else if (s < r) { - server->buf_len = r - s; - server->buf_idx = s; + } else if (s < server->buf->len) { + server->buf->len -= s; + server->buf->idx = s; ev_io_stop(EV_A_ & remote_recv_ctx->io); ev_io_start(EV_A_ & server->send_ctx->io); return; @@ -681,9 +687,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) static void remote_send_cb(EV_P_ ev_io *w, int revents) { - struct remote_ctx *remote_send_ctx = (struct remote_ctx *)w; - struct remote *remote = remote_send_ctx->remote; - struct server *server = remote->server; + remote_ctx_t *remote_send_ctx = (remote_ctx_t *)w; + remote_t *remote = remote_send_ctx->remote; + server_t *server = remote->server; if (!remote_send_ctx->connected) { struct sockaddr_storage addr; @@ -696,7 +702,7 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) ev_io_start(EV_A_ & remote->recv_ctx->io); // no need to send any data - if (remote->buf_len == 0) { + if (remote->buf->len == 0) { ev_io_stop(EV_A_ & remote_send_ctx->io); ev_io_start(EV_A_ & server->recv_ctx->io); return; @@ -710,15 +716,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) } } - if (remote->buf_len == 0) { + if (remote->buf->len == 0) { // close and free close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } else { // has data to send - ssize_t s = send(remote->fd, remote->buf + remote->buf_idx, - remote->buf_len, 0); + ssize_t s = send(remote->fd, remote->buf->array + remote->buf->idx, + remote->buf->len, 0); if (s < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK) { ERROR("remote_send_cb_send"); @@ -727,31 +733,33 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); } return; - } else if (s < remote->buf_len) { + } else if (s < remote->buf->len) { // partly sent, move memory, wait for the next time to send - remote->buf_len -= s; - remote->buf_idx += s; + remote->buf->len -= s; + remote->buf->idx += s; return; } else { // all sent out, wait for reading - remote->buf_len = 0; - remote->buf_idx = 0; + remote->buf->len = 0; + remote->buf->idx = 0; ev_io_stop(EV_A_ & remote_send_ctx->io); ev_io_start(EV_A_ & server->recv_ctx->io); } } } -static struct remote * new_remote(int fd, int timeout) +static remote_t * new_remote(int fd, int timeout) { - struct remote *remote; - remote = malloc(sizeof(struct remote)); + remote_t *remote; + remote = malloc(sizeof(remote_t)); - memset(remote, 0, sizeof(struct remote)); + memset(remote, 0, sizeof(remote_t)); - remote->buf = malloc(BUF_SIZE); - remote->recv_ctx = malloc(sizeof(struct remote_ctx)); - remote->send_ctx = malloc(sizeof(struct remote_ctx)); + remote->buf = malloc(sizeof(buffer_t)); + balloc(remote->buf, BUF_SIZE); + + remote->recv_ctx = malloc(sizeof(remote_ctx_t)); + remote->send_ctx = malloc(sizeof(remote_ctx_t)); remote->recv_ctx->connected = 0; remote->send_ctx->connected = 0; remote->fd = fd; @@ -766,12 +774,13 @@ static struct remote * new_remote(int fd, int timeout) return remote; } -static void free_remote(struct remote *remote) +static void free_remote(remote_t *remote) { if (remote->server != NULL) { remote->server->remote = NULL; } if (remote->buf != NULL) { + bfree(remote->buf); free(remote->buf); } free(remote->recv_ctx); @@ -779,7 +788,7 @@ static void free_remote(struct remote *remote) free(remote); } -static void close_and_free_remote(EV_P_ struct remote *remote) +static void close_and_free_remote(EV_P_ remote_t *remote) { if (remote != NULL) { ev_timer_stop(EV_A_ & remote->send_ctx->watcher); @@ -791,16 +800,18 @@ static void close_and_free_remote(EV_P_ struct remote *remote) } } -static struct server * new_server(int fd, int method) +static server_t * new_server(int fd, int method) { - struct server *server; - server = malloc(sizeof(struct server)); + server_t *server; + server = malloc(sizeof(server_t)); + + memset(server, 0, sizeof(server_t)); - memset(server, 0, sizeof(struct server)); + server->buf = malloc(sizeof(buffer_t)); + balloc(server->buf, BUF_SIZE); - server->buf = malloc(BUF_SIZE); - server->recv_ctx = malloc(sizeof(struct server_ctx)); - server->send_ctx = malloc(sizeof(struct server_ctx)); + server->recv_ctx = malloc(sizeof(server_ctx_t)); + server->send_ctx = malloc(sizeof(server_ctx_t)); server->recv_ctx->connected = 0; server->send_ctx->connected = 0; server->fd = fd; @@ -823,7 +834,7 @@ static struct server * new_server(int fd, int method) return server; } -static void free_server(struct server *server) +static void free_server(server_t *server) { cork_dllist_remove(&server->entries); @@ -839,6 +850,7 @@ static void free_server(struct server *server) free(server->d_ctx); } if (server->buf != NULL) { + bfree(server->buf); free(server->buf); } free(server->recv_ctx); @@ -846,7 +858,7 @@ static void free_server(struct server *server) free(server); } -static void close_and_free_server(EV_P_ struct server *server) +static void close_and_free_server(EV_P_ server_t *server) { if (server != NULL) { ev_io_stop(EV_A_ & server->send_ctx->io); @@ -856,7 +868,7 @@ static void close_and_free_server(EV_P_ struct server *server) } } -static struct remote * create_remote(struct listen_ctx *listener, +static remote_t * create_remote(listen_ctx_t *listener, struct sockaddr *addr) { struct sockaddr *remote_addr; @@ -889,7 +901,7 @@ static struct remote * create_remote(struct listen_ctx *listener, } #endif - struct remote *remote = new_remote(remotefd, listener->timeout); + remote_t *remote = new_remote(remotefd, listener->timeout); remote->addr_len = get_sockaddr_len(remote_addr); memcpy(&(remote->addr), remote_addr, remote->addr_len); @@ -909,7 +921,7 @@ static void signal_cb(EV_P_ ev_signal *w, int revents) void accept_cb(EV_P_ ev_io *w, int revents) { - struct listen_ctx *listener = (struct listen_ctx *)w; + listen_ctx_t *listener = (listen_ctx_t *)w; int serverfd = accept(listener->fd, NULL, NULL); if (serverfd == -1) { ERROR("accept"); @@ -922,7 +934,7 @@ void accept_cb(EV_P_ ev_io *w, int revents) setsockopt(serverfd, SOL_SOCKET, SO_NOSIGPIPE, &opt, sizeof(opt)); #endif - struct server *server = new_server(serverfd, listener->method); + server_t *server = new_server(serverfd, listener->method); server->listener = listener; ev_io_start(EV_A_ & server->recv_ctx->io); @@ -1142,7 +1154,7 @@ int main(int argc, char **argv) int m = enc_init(password, method); // Setup proxy context - struct listen_ctx listen_ctx; + listen_ctx_t listen_ctx; listen_ctx.remote_num = remote_num; listen_ctx.remote_addr = malloc(sizeof(struct sockaddr *) * remote_num); for (i = 0; i < remote_num; i++) { @@ -1287,7 +1299,7 @@ int start_ss_local_server(profile_t profile) // Setup proxy context struct ev_loop *loop = EV_DEFAULT; - struct listen_ctx listen_ctx; + listen_ctx_t listen_ctx; listen_ctx.remote_num = 1; listen_ctx.remote_addr = malloc(sizeof(struct sockaddr *)); diff --git a/src/local.h b/src/local.h index edbfa9c5..b248de7c 100644 --- a/src/local.h +++ b/src/local.h @@ -31,7 +31,7 @@ #include "common.h" -struct listen_ctx { +typedef struct listen_ctx { ev_io io; char *iface; int remote_num; @@ -39,19 +39,17 @@ struct listen_ctx { int timeout; int fd; struct sockaddr **remote_addr; -}; +} listen_ctx_t; -struct server_ctx { +typedef struct server_ctx { ev_io io; int connected; struct server *server; -}; +} server_ctx_t; -struct server { +typedef struct server { int fd; - ssize_t buf_len; - ssize_t buf_idx; - char *buf; // server send from, remote recv into + buffer_t *buf; char stage; struct enc_ctx *e_ctx; struct enc_ctx *d_ctx; @@ -61,27 +59,25 @@ struct server { struct remote *remote; struct cork_dllist_item entries; -}; +} server_t; -struct remote_ctx { +typedef struct remote_ctx { ev_io io; ev_timer watcher; int connected; struct remote *remote; -}; +} remote_ctx_t; -struct remote { +typedef struct remote { int fd; - ssize_t buf_len; - ssize_t buf_idx; + buffer_t *buf; int direct; - char *buf; // remote send from, server recv into struct remote_ctx *recv_ctx; struct remote_ctx *send_ctx; struct server *server; struct sockaddr_storage addr; int addr_len; uint32_t counter; -}; +} remote_t; #endif // _LOCAL_H diff --git a/src/redir.c b/src/redir.c index b317e6db..a52c4f74 100644 --- a/src/redir.c +++ b/src/redir.c @@ -72,13 +72,13 @@ static void server_send_cb(EV_P_ ev_io *w, int revents); static void remote_recv_cb(EV_P_ ev_io *w, int revents); static void remote_send_cb(EV_P_ ev_io *w, int revents); -static struct remote * new_remote(int fd, int timeout); -static struct server * new_server(int fd, int method); +static remote_t * new_remote(int fd, int timeout); +static server_t * new_server(int fd, int method); -static void free_remote(struct remote *remote); -static void close_and_free_remote(EV_P_ struct remote *remote); -static void free_server(struct server *server); -static void close_and_free_server(EV_P_ struct server *server); +static void free_remote(remote_t *remote); +static void close_and_free_remote(EV_P_ remote_t *remote); +static void free_server(server_t *server); +static void close_and_free_server(EV_P_ server_t *server); int verbose = 0; @@ -160,11 +160,11 @@ int create_and_bind(const char *addr, const char *port) static void server_recv_cb(EV_P_ ev_io *w, int revents) { - struct server_ctx *server_recv_ctx = (struct server_ctx *)w; - struct server *server = server_recv_ctx->server; - struct remote *remote = server->remote; + server_ctx_t *server_recv_ctx = (server_ctx_t *)w; + server_t *server = server_recv_ctx->server; + remote_t *remote = server->remote; - ssize_t r = recv(server->fd, remote->buf, BUF_SIZE, 0); + ssize_t r = recv(server->fd, remote->buf->array, BUF_SIZE, 0); if (r == 0) { // connection closed @@ -184,26 +184,27 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } } + remote->buf->len = r; + if (auth) { - remote->buf = ss_gen_hash(remote->buf, &r, &remote->counter, server->e_ctx, BUF_SIZE); + ss_gen_hash(remote->buf, &remote->counter, server->e_ctx); } - remote->buf = ss_encrypt(BUF_SIZE, remote->buf, &r, server->e_ctx); + int err = ss_encrypt(remote->buf, server->e_ctx); - if (remote->buf == NULL) { + if (err) { LOGE("invalid password or cipher"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } - int s = send(remote->fd, remote->buf, r, 0); + int s = send(remote->fd, remote->buf->array, remote->buf->len, 0); if (s == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // no data, wait for send - remote->buf_len = r; - remote->buf_idx = 0; + remote->buf->idx = 0; ev_io_stop(EV_A_ & server_recv_ctx->io); ev_io_start(EV_A_ & remote->send_ctx->io); return; @@ -213,9 +214,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); return; } - } else if (s < r) { - remote->buf_len = r - s; - remote->buf_idx = s; + } else if (s < remote->buf->len) { + remote->buf->len -= s; + remote->buf->idx = s; ev_io_stop(EV_A_ & server_recv_ctx->io); ev_io_start(EV_A_ & remote->send_ctx->io); return; @@ -225,18 +226,18 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) static void server_send_cb(EV_P_ ev_io *w, int revents) { - struct server_ctx *server_send_ctx = (struct server_ctx *)w; - struct server *server = server_send_ctx->server; - struct remote *remote = server->remote; - if (server->buf_len == 0) { + server_ctx_t *server_send_ctx = (server_ctx_t *)w; + server_t *server = server_send_ctx->server; + remote_t *remote = server->remote; + if (server->buf->len == 0) { // close and free close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } else { // has data to send - ssize_t s = send(server->fd, server->buf + server->buf_idx, - server->buf_len, 0); + ssize_t s = send(server->fd, server->buf->array + server->buf->idx, + server->buf->len, 0); if (s < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK) { ERROR("send"); @@ -244,15 +245,15 @@ static void server_send_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); } return; - } else if (s < server->buf_len) { + } else if (s < server->buf->len) { // partly sent, move memory, wait for the next time to send - server->buf_len -= s; - server->buf_idx += s; + server->buf->len -= s; + server->buf->idx += s; return; } else { // all sent out, wait for reading - server->buf_len = 0; - server->buf_idx = 0; + server->buf->len = 0; + server->buf->idx = 0; ev_io_stop(EV_A_ & server_send_ctx->io); ev_io_start(EV_A_ & remote->recv_ctx->io); } @@ -262,10 +263,10 @@ static void server_send_cb(EV_P_ ev_io *w, int revents) static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) { - struct remote_ctx *remote_ctx = (struct remote_ctx *)(((void *)watcher) + remote_ctx_t *remote_ctx = (remote_ctx_t *)(((void *)watcher) - sizeof(ev_io)); - struct remote *remote = remote_ctx->remote; - struct server *server = remote->server; + remote_t *remote = remote_ctx->remote; + server_t *server = remote->server; ev_timer_stop(EV_A_ watcher); @@ -275,11 +276,11 @@ static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) static void remote_recv_cb(EV_P_ ev_io *w, int revents) { - struct remote_ctx *remote_recv_ctx = (struct remote_ctx *)w; - struct remote *remote = remote_recv_ctx->remote; - struct server *server = remote->server; + remote_ctx_t *remote_recv_ctx = (remote_ctx_t *)w; + remote_t *remote = remote_recv_ctx->remote; + server_t *server = remote->server; - ssize_t r = recv(remote->fd, server->buf, BUF_SIZE, 0); + ssize_t r = recv(remote->fd, server->buf->array, BUF_SIZE, 0); if (r == 0) { // connection closed @@ -299,20 +300,21 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) } } - server->buf = ss_decrypt(BUF_SIZE, server->buf, &r, server->d_ctx); - if (server->buf == NULL) { + server->buf->len = r; + + int err = ss_decrypt(server->buf, server->d_ctx); + if (err) { LOGE("invalid password or cipher"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } - int s = send(server->fd, server->buf, r, 0); + int s = send(server->fd, server->buf->array, server->buf->len, 0); if (s == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // no data, wait for send - server->buf_len = r; - server->buf_idx = 0; + server->buf->idx = 0; ev_io_stop(EV_A_ & remote_recv_ctx->io); ev_io_start(EV_A_ & server->send_ctx->io); return; @@ -322,9 +324,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); return; } - } else if (s < r) { - server->buf_len = r - s; - server->buf_idx = s; + } else if (s < server->buf->len) { + server->buf->len -= s; + server->buf->idx = s; ev_io_stop(EV_A_ & remote_recv_ctx->io); ev_io_start(EV_A_ & server->send_ctx->io); return; @@ -333,9 +335,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) static void remote_send_cb(EV_P_ ev_io *w, int revents) { - struct remote_ctx *remote_send_ctx = (struct remote_ctx *)w; - struct remote *remote = remote_send_ctx->remote; - struct server *server = remote->server; + remote_ctx_t *remote_send_ctx = (remote_ctx_t *)w; + remote_t *remote = remote_send_ctx->remote; + server_t *server = remote->server; if (!remote_send_ctx->connected) { @@ -348,52 +350,52 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) ev_timer_stop(EV_A_ & remote_send_ctx->watcher); // send destaddr - char *ss_addr_to_send = malloc(BUF_SIZE); - ssize_t addr_len = 0; + buffer_t ss_addr_to_send; + buffer_t *abuf = &ss_addr_to_send; + balloc(abuf, BUF_SIZE); + if (AF_INET6 == server->destaddr.ss_family) { // IPv6 - ss_addr_to_send[addr_len++] = 4; //Type 4 is IPv6 address + abuf->array[abuf->len++] = 4; //Type 4 is IPv6 address - size_t in_addr_len = sizeof(struct in6_addr); - memcpy(ss_addr_to_send + addr_len, + size_t in6_addr_len = sizeof(struct in6_addr); + memcpy(abuf->array + abuf->len, &(((struct sockaddr_in6 *)&(server->destaddr))->sin6_addr), - in_addr_len); - addr_len += in_addr_len; - memcpy(ss_addr_to_send + addr_len, + in6_addr_len); + abuf->len += in6_addr_len; + memcpy(abuf->array + abuf->len, &(((struct sockaddr_in6 *)&(server->destaddr))->sin6_port), 2); } else { //IPv4 - ss_addr_to_send[addr_len++] = 1; //Type 1 is IPv4 address + abuf->array[abuf->len++] = 1; //Type 1 is IPv4 address size_t in_addr_len = sizeof(struct in_addr); - memcpy(ss_addr_to_send + addr_len, - &((struct sockaddr_in *)&(server->destaddr))->sin_addr, - in_addr_len); - addr_len += in_addr_len; - memcpy(ss_addr_to_send + addr_len, - &((struct sockaddr_in *)&(server->destaddr))->sin_port, - 2); + memcpy(abuf->array + abuf->len, + &((struct sockaddr_in *)&(server->destaddr))->sin_addr, in_addr_len); + abuf->len += in_addr_len; + memcpy(abuf->array + abuf->len, + &((struct sockaddr_in *)&(server->destaddr))->sin_port, 2); } - addr_len += 2; + abuf->len += 2; if (auth) { - ss_addr_to_send[0] |= ONETIMEAUTH_FLAG; - ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx->evp.iv); - addr_len += ONETIMEAUTH_BYTES; + abuf->array[0] |= ONETIMEAUTH_FLAG; + ss_onetimeauth(abuf, server->e_ctx->evp.iv); } - ss_addr_to_send = ss_encrypt(BUF_SIZE, ss_addr_to_send, &addr_len, - server->e_ctx); - if (ss_addr_to_send == NULL) { + int err = ss_encrypt(abuf, server->e_ctx); + if (err) { + bfree(abuf); LOGE("invalid password or cipher"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } - int s = send(remote->fd, ss_addr_to_send, addr_len, 0); - free(ss_addr_to_send); + int s = send(remote->fd, abuf->array, abuf->len, 0); - if (s < addr_len) { + bfree(abuf); + + if (s < abuf->len) { LOGE("failed to send addr"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); @@ -412,15 +414,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) return; } } else { - if (remote->buf_len == 0) { + if (remote->buf->len == 0) { // close and free close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } else { // has data to send - ssize_t s = send(remote->fd, remote->buf + remote->buf_idx, - remote->buf_len, 0); + ssize_t s = send(remote->fd, remote->buf->array + remote->buf->idx, + remote->buf->len, 0); if (s < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK) { ERROR("send"); @@ -429,15 +431,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); } return; - } else if (s < remote->buf_len) { + } else if (s < remote->buf->len) { // partly sent, move memory, wait for the next time to send - remote->buf_len -= s; - remote->buf_idx += s; + remote->buf->len -= s; + remote->buf->idx += s; return; } else { // all sent out, wait for reading - remote->buf_len = 0; - remote->buf_idx = 0; + remote->buf->len = 0; + remote->buf->idx = 0; ev_io_stop(EV_A_ & remote_send_ctx->io); ev_io_start(EV_A_ & server->recv_ctx->io); } @@ -446,16 +448,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) } } -static struct remote * new_remote(int fd, int timeout) +static remote_t * new_remote(int fd, int timeout) { - struct remote *remote; - remote = malloc(sizeof(struct remote)); + remote_t *remote; + remote = malloc(sizeof(remote_t)); - memset(remote, 0, sizeof(struct remote)); + memset(remote, 0, sizeof(remote_t)); - remote->buf = malloc(BUF_SIZE); - remote->recv_ctx = malloc(sizeof(struct remote_ctx)); - remote->send_ctx = malloc(sizeof(struct remote_ctx)); + remote->recv_ctx = malloc(sizeof(remote_ctx_t)); + remote->send_ctx = malloc(sizeof(remote_ctx_t)); remote->fd = fd; ev_io_init(&remote->recv_ctx->io, remote_recv_cb, fd, EV_READ); ev_io_init(&remote->send_ctx->io, remote_send_cb, fd, EV_WRITE); @@ -465,18 +466,21 @@ static struct remote * new_remote(int fd, int timeout) remote->recv_ctx->connected = 0; remote->send_ctx->remote = remote; remote->send_ctx->connected = 0; - remote->buf_len = 0; - remote->buf_idx = 0; + + remote->buf = malloc(sizeof(buffer_t)); + balloc(remote->buf, BUF_SIZE); + return remote; } -static void free_remote(struct remote *remote) +static void free_remote(remote_t *remote) { if (remote != NULL) { if (remote->server != NULL) { remote->server->remote = NULL; } if (remote->buf != NULL) { + bfree(remote->buf); free(remote->buf); } free(remote->recv_ctx); @@ -485,7 +489,7 @@ static void free_remote(struct remote *remote) } } -static void close_and_free_remote(EV_P_ struct remote *remote) +static void close_and_free_remote(EV_P_ remote_t *remote) { if (remote != NULL) { ev_timer_stop(EV_A_ & remote->send_ctx->watcher); @@ -496,13 +500,13 @@ static void close_and_free_remote(EV_P_ struct remote *remote) } } -static struct server * new_server(int fd, int method) +static server_t * new_server(int fd, int method) { - struct server *server; - server = malloc(sizeof(struct server)); - server->buf = malloc(BUF_SIZE); - server->recv_ctx = malloc(sizeof(struct server_ctx)); - server->send_ctx = malloc(sizeof(struct server_ctx)); + server_t *server; + server = malloc(sizeof(server_t)); + + server->recv_ctx = malloc(sizeof(server_ctx_t)); + server->send_ctx = malloc(sizeof(server_ctx_t)); server->fd = fd; ev_io_init(&server->recv_ctx->io, server_recv_cb, fd, EV_READ); ev_io_init(&server->send_ctx->io, server_send_cb, fd, EV_WRITE); @@ -511,20 +515,22 @@ static struct server * new_server(int fd, int method) server->send_ctx->server = server; server->send_ctx->connected = 0; if (method) { - server->e_ctx = malloc(sizeof(struct enc_ctx)); - server->d_ctx = malloc(sizeof(struct enc_ctx)); + server->e_ctx = malloc(sizeof(enc_ctx_t)); + server->d_ctx = malloc(sizeof(enc_ctx_t)); enc_ctx_init(method, server->e_ctx, 1); enc_ctx_init(method, server->d_ctx, 0); } else { server->e_ctx = NULL; server->d_ctx = NULL; } - server->buf_len = 0; - server->buf_idx = 0; + + server->buf = malloc(sizeof(buffer_t)); + balloc(server->buf, BUF_SIZE); + return server; } -static void free_server(struct server *server) +static void free_server(server_t *server) { if (server != NULL) { if (server->remote != NULL) { @@ -539,6 +545,7 @@ static void free_server(struct server *server) free(server->d_ctx); } if (server->buf != NULL) { + bfree(server->buf); free(server->buf); } free(server->recv_ctx); @@ -547,7 +554,7 @@ static void free_server(struct server *server) } } -static void close_and_free_server(EV_P_ struct server *server) +static void close_and_free_server(EV_P_ server_t *server) { if (server != NULL) { ev_io_stop(EV_A_ & server->send_ctx->io); @@ -559,7 +566,7 @@ static void close_and_free_server(EV_P_ struct server *server) static void accept_cb(EV_P_ ev_io *w, int revents) { - struct listen_ctx *listener = (struct listen_ctx *)w; + listen_ctx_t *listener = (listen_ctx_t *)w; struct sockaddr_storage destaddr; int err; @@ -599,8 +606,8 @@ static void accept_cb(EV_P_ ev_io *w, int revents) // Setup setnonblocking(remotefd); - struct server *server = new_server(serverfd, listener->method); - struct remote *remote = new_remote(remotefd, listener->timeout); + server_t *server = new_server(serverfd, listener->method); + remote_t *remote = new_remote(remotefd, listener->timeout); server->remote = remote; remote->server = server; server->destaddr = destaddr; @@ -756,7 +763,7 @@ int main(int argc, char **argv) int m = enc_init(password, method); // Setup proxy context - struct listen_ctx listen_ctx; + listen_ctx_t listen_ctx; listen_ctx.remote_num = remote_num; listen_ctx.remote_addr = malloc(sizeof(struct sockaddr *) * remote_num); for (int i = 0; i < remote_num; i++) { diff --git a/src/redir.h b/src/redir.h index 398f5700..30d7d603 100644 --- a/src/redir.h +++ b/src/redir.h @@ -27,50 +27,46 @@ #include "encrypt.h" #include "jconf.h" -struct listen_ctx { +typedef struct listen_ctx { ev_io io; int remote_num; int timeout; int fd; int method; struct sockaddr **remote_addr; -}; +} listen_ctx_t; -struct server_ctx { +typedef struct server_ctx { ev_io io; int connected; struct server *server; -}; +} server_ctx_t; -struct server { +typedef struct server { int fd; - ssize_t buf_len; - ssize_t buf_idx; - char *buf; // server send from, remote recv into + buffer_t *buf; struct sockaddr_storage destaddr; struct enc_ctx *e_ctx; struct enc_ctx *d_ctx; struct server_ctx *recv_ctx; struct server_ctx *send_ctx; struct remote *remote; -}; +} server_t; -struct remote_ctx { +typedef struct remote_ctx { ev_io io; ev_timer watcher; int connected; struct remote *remote; -}; +} remote_ctx_t; -struct remote { +typedef struct remote { int fd; - ssize_t buf_len; - ssize_t buf_idx; - char *buf; // remote send from, server recv into + buffer_t *buf; struct remote_ctx *recv_ctx; struct remote_ctx *send_ctx; struct server *server; uint32_t counter; -}; +} remote_t; #endif // _LOCAL_H diff --git a/src/server.c b/src/server.c index 30eedfd4..293f644f 100644 --- a/src/server.c +++ b/src/server.c @@ -93,15 +93,15 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents); static void remote_send_cb(EV_P_ ev_io *w, int revents); static void server_timeout_cb(EV_P_ ev_timer *watcher, int revents); -static struct remote * new_remote(int fd); -static struct server * new_server(int fd, struct listen_ctx *listener); -static struct remote *connect_to_remote(struct addrinfo *res, - struct server *server); +static remote_t * new_remote(int fd); +static server_t * new_server(int fd, listen_ctx_t *listener); +static remote_t *connect_to_remote(struct addrinfo *res, + server_t *server); -static void free_remote(struct remote *remote); -static void close_and_free_remote(EV_P_ struct remote *remote); -static void free_server(struct server *server); -static void close_and_free_server(EV_P_ struct server *server); +static void free_remote(remote_t *remote); +static void close_and_free_remote(EV_P_ remote_t *remote); +static void free_server(server_t *server); +static void close_and_free_server(EV_P_ server_t *server); static void server_resolve_cb(struct sockaddr *addr, void *data); @@ -208,8 +208,8 @@ static void free_connections(struct ev_loop *loop) for (curr = cork_dllist_start(&connections); !cork_dllist_is_end(&connections, curr); curr = curr->next) { - struct server *server = cork_container_of(curr, struct server, entries); - struct remote *remote = server->remote; + server_t *server = cork_container_of(curr, server_t, entries); + remote_t *remote = server->remote; close_and_free_server(loop, server); close_and_free_remote(loop, remote); } @@ -357,8 +357,8 @@ int create_and_bind(const char *host, const char *port) return listen_sock; } -static struct remote *connect_to_remote(struct addrinfo *res, - struct server *server) +static remote_t *connect_to_remote(struct addrinfo *res, + server_t *server) { int sockfd; #ifdef SET_INTERFACE @@ -379,7 +379,7 @@ static struct remote *connect_to_remote(struct addrinfo *res, setsockopt(sockfd, SOL_SOCKET, SO_NOSIGPIPE, &opt, sizeof(opt)); #endif - struct remote *remote = new_remote(sockfd); + remote_t *remote = new_remote(sockfd); // setup remote socks setnonblocking(sockfd); @@ -391,8 +391,8 @@ static struct remote *connect_to_remote(struct addrinfo *res, #ifdef TCP_FASTOPEN if (fast_open) { - ssize_t s = sendto(sockfd, server->buf + server->buf_idx, - server->buf_len, MSG_FASTOPEN, res->ai_addr, + ssize_t s = sendto(sockfd, server->buf->array + server->buf->idx, + server->buf->len, MSG_FASTOPEN, res->ai_addr, res->ai_addrlen); if (s == -1) { if (errno == EINPROGRESS || errno == EAGAIN @@ -408,12 +408,12 @@ static struct remote *connect_to_remote(struct addrinfo *res, } else { ERROR("sendto"); } - } else if (s < server->buf_len) { - server->buf_idx += s; - server->buf_len -= s; + } else if (s < server->buf->len) { + server->buf->idx += s; + server->buf->len -= s; } else { - server->buf_idx = 0; - server->buf_len = 0; + server->buf->idx = 0; + server->buf->len = 0; } } else #endif @@ -424,22 +424,22 @@ static struct remote *connect_to_remote(struct addrinfo *res, static void server_recv_cb(EV_P_ ev_io *w, int revents) { - struct server_ctx *server_recv_ctx = (struct server_ctx *)w; - struct server *server = server_recv_ctx->server; - struct remote *remote = NULL; + server_ctx_t *server_recv_ctx = (server_ctx_t *)w; + server_t *server = server_recv_ctx->server; + remote_t *remote = NULL; - int len = server->buf_len; - char **buf = &server->buf; + int len = server->buf->len; + buffer_t *buf = server->buf; ev_timer_again(EV_A_ & server->recv_ctx->watcher); if (server->stage != 0) { remote = server->remote; - buf = &remote->buf; + buf = remote->buf; len = 0; } - ssize_t r = recv(server->fd, *buf + len, BUF_SIZE - len, 0); + ssize_t r = recv(server->fd, buf->array + len, BUF_SIZE - len, 0); if (r == 0) { // connection closed @@ -466,8 +466,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) // handle incomplete header if (server->stage == 0) { - r += server->buf_len; - if (r <= enc_get_iv_len()) { + buf->len += r; + if (buf->len <= enc_get_iv_len()) { // wait for more if (verbose) { #ifdef __MINGW32__ @@ -476,16 +476,15 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) LOGI("imcomplete header: %zu", r); #endif } - server->buf_len = r; return; - } else { - server->buf_len = 0; } + } else { + buf->len = r; } - *buf = ss_decrypt(BUF_SIZE, *buf, &r, server->d_ctx); + int err = ss_decrypt(buf, server->d_ctx); - if (*buf == NULL) { + if (err) { LOGE("invalid password or cipher"); report_addr(server->fd); close_and_free_remote(EV_A_ remote); @@ -495,19 +494,18 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) // handshake and transmit data if (server->stage == 5) { - if (server->auth && !ss_check_hash(&remote->buf, &r, server->chunk, server->d_ctx, BUF_SIZE)) { + if (server->auth && !ss_check_hash(remote->buf, server->chunk, server->d_ctx)) { LOGE("hash error"); report_addr(server->fd); close_and_free_server(EV_A_ server); close_and_free_remote(EV_A_ remote); return; } - int s = send(remote->fd, remote->buf, r, 0); + int s = send(remote->fd, remote->buf->array, remote->buf->len, 0); if (s == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // no data, wait for send - remote->buf_len = r; - remote->buf_idx = 0; + remote->buf->idx = 0; ev_io_stop(EV_A_ & server_recv_ctx->io); ev_io_start(EV_A_ & remote->send_ctx->io); } else { @@ -515,9 +513,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); } - } else if (s < r) { - remote->buf_len = r - s; - remote->buf_idx = s; + } else if (s < remote->buf->len) { + remote->buf->len -= s; + remote->buf->idx = s; ev_io_stop(EV_A_ & server_recv_ctx->io); ev_io_start(EV_A_ & remote->send_ctx->io); } @@ -555,7 +553,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) int offset = 0; int need_query = 0; - char atyp = server->buf[offset++]; + char atyp = server->buf->array[offset++]; char host[256] = { 0 }; uint16_t port = 0; struct addrinfo info; @@ -569,9 +567,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) struct sockaddr_in *addr = (struct sockaddr_in *)&storage; size_t in_addr_len = sizeof(struct in_addr); addr->sin_family = AF_INET; - if (r >= in_addr_len + 3) { - addr->sin_addr = *(struct in_addr *)(server->buf + offset); - dns_ntop(AF_INET, (const void *)(server->buf + offset), + if (server->buf->len >= in_addr_len + 3) { + addr->sin_addr = *(struct in_addr *)(server->buf->array + offset); + dns_ntop(AF_INET, (const void *)(server->buf->array + offset), host, INET_ADDRSTRLEN); offset += in_addr_len; } else { @@ -580,7 +578,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); return; } - addr->sin_port = *(uint16_t *)(server->buf + offset); + addr->sin_port = *(uint16_t *)(server->buf->array + offset); info.ai_family = AF_INET; info.ai_socktype = SOCK_STREAM; info.ai_protocol = IPPROTO_TCP; @@ -588,9 +586,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) info.ai_addr = (struct sockaddr *)addr; } else if ((atyp & ADDRTYPE_MASK) == 3) { // Domain name - uint8_t name_len = *(uint8_t *)(server->buf + offset); - if (name_len + 4 <= r) { - memcpy(host, server->buf + offset + 1, name_len); + uint8_t name_len = *(uint8_t *)(server->buf->array + offset); + if (name_len + 4 <= server->buf->len) { + memcpy(host, server->buf->array + offset + 1, name_len); offset += name_len + 1; } else { LOGE("invalid name length: %d", name_len); @@ -605,7 +603,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) if (ip.version == 4) { struct sockaddr_in *addr = (struct sockaddr_in *)&storage; dns_pton(AF_INET, host, &(addr->sin_addr)); - addr->sin_port = *(uint16_t *)(server->buf + offset); + addr->sin_port = *(uint16_t *)(server->buf->array + offset); addr->sin_family = AF_INET; info.ai_family = AF_INET; info.ai_addrlen = sizeof(struct sockaddr_in); @@ -613,7 +611,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } else if (ip.version == 6) { struct sockaddr_in6 *addr = (struct sockaddr_in6 *)&storage; dns_pton(AF_INET6, host, &(addr->sin6_addr)); - addr->sin6_port = *(uint16_t *)(server->buf + offset); + addr->sin6_port = *(uint16_t *)(server->buf->array + offset); addr->sin6_family = AF_INET6; info.ai_family = AF_INET6; info.ai_addrlen = sizeof(struct sockaddr_in6); @@ -627,9 +625,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) struct sockaddr_in6 *addr = (struct sockaddr_in6 *)&storage; size_t in6_addr_len = sizeof(struct in6_addr); addr->sin6_family = AF_INET6; - if (r >= in6_addr_len + 3) { - addr->sin6_addr = *(struct in6_addr *)(server->buf + offset); - dns_ntop(AF_INET6, (const void *)(server->buf + offset), + if (server->buf->len >= in6_addr_len + 3) { + addr->sin6_addr = *(struct in6_addr *)(server->buf->array + offset); + dns_ntop(AF_INET6, (const void *)(server->buf->array + offset), host, INET6_ADDRSTRLEN); offset += in6_addr_len; } else { @@ -638,7 +636,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); return; } - addr->sin6_port = *(uint16_t *)(server->buf + offset); + addr->sin6_port = *(uint16_t *)(server->buf->array + offset); info.ai_family = AF_INET6; info.ai_socktype = SOCK_STREAM; info.ai_protocol = IPPROTO_TCP; @@ -661,12 +659,12 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) return; } - port = (*(uint16_t *)(server->buf + offset)); + port = (*(uint16_t *)(server->buf->array + offset)); offset += 2; if (auth || (atyp & ONETIMEAUTH_FLAG)) { - if (ss_onetimeauth_verify(server->buf + offset, server->buf, offset, server->d_ctx->evp.iv)) { + if (ss_onetimeauth_verify(server->buf, server->d_ctx->evp.iv)) { LOGE("authentication error %d", atyp); report_addr(server->fd); close_and_free_server(EV_A_ server); @@ -682,12 +680,12 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } // XXX: should handle buffer carefully - if (r > offset) { - server->buf_len = r - offset; - memmove(server->buf, server->buf + offset, server->buf_len); + if (server->buf->len > offset) { + server->buf->len -= offset; + memmove(server->buf->array, server->buf->array + offset, server->buf->len); } - if (server->auth && !ss_check_hash(&server->buf, &server->buf_len, server->chunk, server->d_ctx, BUF_SIZE)) { + if (server->auth && !ss_check_hash(server->buf, server->chunk, server->d_ctx)) { LOGE("hash error"); report_addr(server->fd); close_and_free_server(EV_A_ server); @@ -695,7 +693,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } if (!need_query) { - struct remote *remote = connect_to_remote(&info, server); + remote_t *remote = connect_to_remote(&info, server); if (remote == NULL) { LOGE("connect error"); @@ -706,12 +704,12 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) remote->server = server; // XXX: should handle buffer carefully - if (server->buf_len > 0) { - memcpy(remote->buf, server->buf + server->buf_idx, server->buf_len); - remote->buf_len = server->buf_len; - remote->buf_idx = 0; - server->buf_len = 0; - server->buf_idx = 0; + if (server->buf->len > 0) { + memcpy(remote->buf->array, server->buf->array, server->buf->len); + remote->buf->len = server->buf->len; + remote->buf->idx = 0; + server->buf->len = 0; + server->buf->idx = 0; } server->stage = 4; @@ -736,9 +734,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) static void server_send_cb(EV_P_ ev_io *w, int revents) { - struct server_ctx *server_send_ctx = (struct server_ctx *)w; - struct server *server = server_send_ctx->server; - struct remote *remote = server->remote; + server_ctx_t *server_send_ctx = (server_ctx_t *)w; + server_t *server = server_send_ctx->server; + remote_t *remote = server->remote; if (remote == NULL) { LOGE("invalid server"); @@ -746,7 +744,7 @@ static void server_send_cb(EV_P_ ev_io *w, int revents) return; } - if (server->buf_len == 0) { + if (server->buf->len == 0) { // close and free if (verbose) { LOGI("server_send close the connection"); @@ -756,8 +754,8 @@ static void server_send_cb(EV_P_ ev_io *w, int revents) return; } else { // has data to send - ssize_t s = send(server->fd, server->buf + server->buf_idx, - server->buf_len, 0); + ssize_t s = send(server->fd, server->buf->array + server->buf->idx, + server->buf->len, 0); if (s < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK) { ERROR("server_send_send"); @@ -765,15 +763,15 @@ static void server_send_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); } return; - } else if (s < server->buf_len) { + } else if (s < server->buf->len) { // partly sent, move memory, wait for the next time to send - server->buf_len -= s; - server->buf_idx += s; + server->buf->len -= s; + server->buf->idx += s; return; } else { // all sent out, wait for reading - server->buf_len = 0; - server->buf_idx = 0; + server->buf->len = 0; + server->buf->idx = 0; ev_io_stop(EV_A_ & server_send_ctx->io); if (remote != NULL) { ev_io_start(EV_A_ & remote->recv_ctx->io); @@ -790,10 +788,10 @@ static void server_send_cb(EV_P_ ev_io *w, int revents) static void server_timeout_cb(EV_P_ ev_timer *watcher, int revents) { - struct server_ctx *server_ctx = (struct server_ctx *)(((void *)watcher) + server_ctx_t *server_ctx = (server_ctx_t *)(((void *)watcher) - sizeof(ev_io)); - struct server *server = server_ctx->server; - struct remote *remote = server->remote; + server_t *server = server_ctx->server; + remote_t *remote = server->remote; if (verbose) { LOGI("TCP connection timeout"); @@ -805,7 +803,7 @@ static void server_timeout_cb(EV_P_ ev_timer *watcher, int revents) static void server_resolve_cb(struct sockaddr *addr, void *data) { - struct server *server = (struct server *)data; + server_t *server = (server_t *)data; struct ev_loop *loop = server->listen_ctx->loop; server->query = NULL; @@ -851,7 +849,7 @@ static void server_resolve_cb(struct sockaddr *addr, void *data) info.ai_addrlen = sizeof(struct sockaddr_in6); } - struct remote *remote = connect_to_remote(&info, server); + remote_t *remote = connect_to_remote(&info, server); if (remote == NULL) { LOGE("connect error"); @@ -861,13 +859,13 @@ static void server_resolve_cb(struct sockaddr *addr, void *data) remote->server = server; // XXX: should handle buffer carefully - if (server->buf_len > 0) { - memcpy(remote->buf, server->buf + server->buf_idx, - server->buf_len); - remote->buf_len = server->buf_len; - remote->buf_idx = 0; - server->buf_len = 0; - server->buf_idx = 0; + if (server->buf->len > 0) { + memcpy(remote->buf->array, server->buf->array + server->buf->idx, + server->buf->len); + remote->buf->len = server->buf->len; + remote->buf->idx = 0; + server->buf->len = 0; + server->buf->idx = 0; } // listen to remote connected event @@ -878,9 +876,9 @@ static void server_resolve_cb(struct sockaddr *addr, void *data) static void remote_recv_cb(EV_P_ ev_io *w, int revents) { - struct remote_ctx *remote_recv_ctx = (struct remote_ctx *)w; - struct remote *remote = remote_recv_ctx->remote; - struct server *server = remote->server; + remote_ctx_t *remote_recv_ctx = (remote_ctx_t *)w; + remote_t *remote = remote_recv_ctx->remote; + server_t *server = remote->server; if (server == NULL) { LOGE("invalid server"); @@ -890,7 +888,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) ev_timer_again(EV_A_ & server->recv_ctx->watcher); - ssize_t r = recv(remote->fd, server->buf, BUF_SIZE, 0); + ssize_t r = recv(remote->fd, server->buf->array, BUF_SIZE, 0); if (r == 0) { // connection closed @@ -915,22 +913,23 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) rx += r; - server->buf = ss_encrypt(BUF_SIZE, server->buf, &r, server->e_ctx); - if (server->buf == NULL) { + server->buf->len = r; + int err = ss_encrypt(server->buf, server->e_ctx); + + if (err) { LOGE("invalid password or cipher"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } - int s = send(server->fd, server->buf, r, 0); + int s = send(server->fd, server->buf->array, server->buf->len, 0); if (s == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // no data, wait for send - server->buf_len = r; - server->buf_idx = 0; + server->buf->idx = 0; ev_io_stop(EV_A_ & remote_recv_ctx->io); ev_io_start(EV_A_ & server->send_ctx->io); } else { @@ -939,9 +938,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); } return; - } else if (s < r) { - server->buf_len = r - s; - server->buf_idx = s; + } else if (s < server->buf->len) { + server->buf->len -= s; + server->buf->idx = s; ev_io_stop(EV_A_ & remote_recv_ctx->io); ev_io_start(EV_A_ & server->send_ctx->io); return; @@ -950,9 +949,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) static void remote_send_cb(EV_P_ ev_io *w, int revents) { - struct remote_ctx *remote_send_ctx = (struct remote_ctx *)w; - struct remote *remote = remote_send_ctx->remote; - struct server *server = remote->server; + remote_ctx_t *remote_send_ctx = (remote_ctx_t *)w; + remote_t *remote = remote_send_ctx->remote; + server_t *server = remote->server; if (server == NULL) { LOGE("invalid server"); @@ -972,7 +971,7 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) } remote_send_ctx->connected = 1; - if (remote->buf_len == 0) { + if (remote->buf->len == 0) { server->stage = 5; ev_io_stop(EV_A_ & remote_send_ctx->io); ev_io_start(EV_A_ & server->recv_ctx->io); @@ -989,7 +988,7 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) } } - if (remote->buf_len == 0) { + if (remote->buf->len == 0) { // close and free if (verbose) { LOGI("remote_send close the connection"); @@ -999,8 +998,8 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) return; } else { // has data to send - ssize_t s = send(remote->fd, remote->buf + remote->buf_idx, - remote->buf_len, 0); + ssize_t s = send(remote->fd, remote->buf->array + remote->buf->idx, + remote->buf->len, 0); if (s == -1) { if (errno != EAGAIN && errno != EWOULDBLOCK) { ERROR("remote_send_send"); @@ -1009,15 +1008,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); } return; - } else if (s < remote->buf_len) { + } else if (s < remote->buf->len) { // partly sent, move memory, wait for the next time to send - remote->buf_len -= s; - remote->buf_idx += s; + remote->buf->len -= s; + remote->buf->idx += s; return; } else { // all sent out, wait for reading - remote->buf_len = 0; - remote->buf_idx = 0; + remote->buf->len = 0; + remote->buf->idx = 0; ev_io_stop(EV_A_ & remote_send_ctx->io); if (server != NULL) { ev_io_start(EV_A_ & server->recv_ctx->io); @@ -1035,17 +1034,20 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) } } -static struct remote * new_remote(int fd) +static remote_t * new_remote(int fd) { if (verbose) { remote_conn++; } - struct remote *remote; - remote = malloc(sizeof(struct remote)); - remote->buf = malloc(BUF_SIZE); - remote->recv_ctx = malloc(sizeof(struct remote_ctx)); - remote->send_ctx = malloc(sizeof(struct remote_ctx)); + remote_t *remote; + remote = malloc(sizeof(remote_t)); + remote->recv_ctx = malloc(sizeof(remote_ctx_t)); + remote->send_ctx = malloc(sizeof(remote_ctx_t)); + + remote->buf = malloc(sizeof(buffer_t)); + balloc(remote->buf, BUF_SIZE); + remote->fd = fd; ev_io_init(&remote->recv_ctx->io, remote_recv_cb, fd, EV_READ); ev_io_init(&remote->send_ctx->io, remote_send_cb, fd, EV_WRITE); @@ -1053,18 +1055,17 @@ static struct remote * new_remote(int fd) remote->recv_ctx->connected = 0; remote->send_ctx->remote = remote; remote->send_ctx->connected = 0; - remote->buf_len = 0; - remote->buf_idx = 0; remote->server = NULL; return remote; } -static void free_remote(struct remote *remote) +static void free_remote(remote_t *remote) { if (remote->server != NULL) { remote->server->remote = NULL; } if (remote->buf != NULL) { + bfree(remote->buf); free(remote->buf); } free(remote->recv_ctx); @@ -1072,7 +1073,7 @@ static void free_remote(struct remote *remote) free(remote); } -static void close_and_free_remote(EV_P_ struct remote *remote) +static void close_and_free_remote(EV_P_ remote_t *remote) { if (remote != NULL) { ev_io_stop(EV_A_ & remote->send_ctx->io); @@ -1086,20 +1087,22 @@ static void close_and_free_remote(EV_P_ struct remote *remote) } } -static struct server * new_server(int fd, struct listen_ctx *listener) +static server_t * new_server(int fd, listen_ctx_t *listener) { if (verbose) { server_conn++; } - struct server *server; - server = malloc(sizeof(struct server)); + server_t *server; + server = malloc(sizeof(server_t)); + + memset(server, 0, sizeof(server_t)); - memset(server, 0, sizeof(struct server)); + server->buf = malloc(sizeof(buffer_t)); + balloc(server->buf, BUF_SIZE); - server->buf = malloc(BUF_SIZE); - server->recv_ctx = malloc(sizeof(struct server_ctx)); - server->send_ctx = malloc(sizeof(struct server_ctx)); + server->recv_ctx = malloc(sizeof(server_ctx_t)); + server->send_ctx = malloc(sizeof(server_ctx_t)); server->fd = fd; ev_io_init(&server->recv_ctx->io, server_recv_cb, fd, EV_READ); ev_io_init(&server->send_ctx->io, server_send_cb, fd, EV_WRITE); @@ -1113,34 +1116,32 @@ static struct server * new_server(int fd, struct listen_ctx *listener) server->query = NULL; server->listen_ctx = listener; if (listener->method) { - server->e_ctx = malloc(sizeof(struct enc_ctx)); - server->d_ctx = malloc(sizeof(struct enc_ctx)); + server->e_ctx = malloc(sizeof(enc_ctx_t)); + server->d_ctx = malloc(sizeof(enc_ctx_t)); enc_ctx_init(listener->method, server->e_ctx, 1); enc_ctx_init(listener->method, server->d_ctx, 0); } else { server->e_ctx = NULL; server->d_ctx = NULL; } - server->buf_len = 0; - server->buf_idx = 0; server->remote = NULL; - server->chunk = (struct chunk *)malloc(sizeof(struct chunk)); - memset(server->chunk, 0, sizeof(struct chunk)); + server->chunk = (chunk_t *)malloc(sizeof(chunk_t)); + memset(server->chunk, 0, sizeof(chunk_t)); + server->chunk->buf = malloc(sizeof(buffer_t)); cork_dllist_add(&connections, &server->entries); return server; } -static void free_server(struct server *server) +static void free_server(server_t *server) { cork_dllist_remove(&server->entries); if (server->chunk != NULL) { - if (server->chunk->buf != NULL) { - free(server->chunk->buf); - } + bfree(server->chunk->buf); + free(server->chunk->buf); free(server->chunk); server->chunk = NULL; } @@ -1156,14 +1157,16 @@ static void free_server(struct server *server) free(server->d_ctx); } if (server->buf != NULL) { + bfree(server->buf); free(server->buf); } + free(server->recv_ctx); free(server->send_ctx); free(server); } -static void close_and_free_server(EV_P_ struct server *server) +static void close_and_free_server(EV_P_ server_t *server) { if (server != NULL) { if (server->query != NULL) { @@ -1195,7 +1198,7 @@ static void signal_cb(EV_P_ ev_signal *w, int revents) static void accept_cb(EV_P_ ev_io *w, int revents) { - struct listen_ctx *listener = (struct listen_ctx *)w; + listen_ctx_t *listener = (listen_ctx_t *)w; int serverfd = accept(listener->fd, NULL, NULL); if (serverfd == -1) { ERROR("accept"); @@ -1213,7 +1216,7 @@ static void accept_cb(EV_P_ ev_io *w, int revents) LOGI("accept a connection"); } - struct server *server = new_server(serverfd, listener); + server_t *server = new_server(serverfd, listener); ev_io_start(EV_A_ & server->recv_ctx->io); ev_timer_start(EV_A_ & server->recv_ctx->watcher); } @@ -1446,7 +1449,7 @@ int main(int argc, char **argv) } // inilitialize listen context - struct listen_ctx listen_ctx_list[server_num]; + listen_ctx_t listen_ctx_list[server_num]; // bind to each interface while (server_num > 0) { @@ -1464,7 +1467,7 @@ int main(int argc, char **argv) FATAL("listen() error"); } setnonblocking(listenfd); - struct listen_ctx *listen_ctx = &listen_ctx_list[index]; + listen_ctx_t *listen_ctx = &listen_ctx_list[index]; // Setup proxy context listen_ctx->timeout = atoi(timeout); @@ -1521,7 +1524,7 @@ int main(int argc, char **argv) // Clean up for (int i = 0; i <= server_num; i++) { - struct listen_ctx *listen_ctx = &listen_ctx_list[i]; + listen_ctx_t *listen_ctx = &listen_ctx_list[i]; if (mode != UDP_ONLY) { ev_io_stop(loop, &listen_ctx->io); close(listen_ctx->fd); diff --git a/src/server.h b/src/server.h index a2fcd64c..3febdf5f 100644 --- a/src/server.h +++ b/src/server.h @@ -33,28 +33,26 @@ #include "common.h" -struct listen_ctx { +typedef struct listen_ctx { ev_io io; int fd; int timeout; int method; char *iface; struct ev_loop *loop; -}; +} listen_ctx_t; -struct server_ctx { +typedef struct server_ctx { ev_io io; ev_timer watcher; int connected; struct server *server; -}; +} server_ctx_t; -struct server { +typedef struct server { int fd; int stage; - ssize_t buf_len; - ssize_t buf_idx; - char *buf; // server send from, remote recv into + buffer_t *buf; int auth; struct chunk *chunk; @@ -69,22 +67,20 @@ struct server { struct ResolvQuery *query; struct cork_dllist_item entries; -}; +} server_t; -struct remote_ctx { +typedef struct remote_ctx { ev_io io; int connected; struct remote *remote; -}; +} remote_ctx_t; -struct remote { +typedef struct remote { int fd; - ssize_t buf_len; - ssize_t buf_idx; - char *buf; // remote send from, server recv into + buffer_t *buf; struct remote_ctx *recv_ctx; struct remote_ctx *send_ctx; struct server *server; -}; +} remote_t; #endif // _SERVER_H diff --git a/src/tunnel.c b/src/tunnel.c index 3e027005..21106435 100644 --- a/src/tunnel.c +++ b/src/tunnel.c @@ -77,13 +77,13 @@ static void server_send_cb(EV_P_ ev_io *w, int revents); static void remote_recv_cb(EV_P_ ev_io *w, int revents); static void remote_send_cb(EV_P_ ev_io *w, int revents); -static struct remote * new_remote(int fd, int timeout); -static struct server * new_server(int fd, int method); +static remote_t * new_remote(int fd, int timeout); +static server_t * new_server(int fd, int method); -static void free_remote(struct remote *remote); -static void close_and_free_remote(EV_P_ struct remote *remote); -static void free_server(struct server *server); -static void close_and_free_server(EV_P_ struct server *server); +static void free_remote(remote_t *remote); +static void close_and_free_remote(EV_P_ remote_t *remote); +static void free_server(server_t *server); +static void close_and_free_server(EV_P_ server_t *server); #ifdef ANDROID int vpn = 0; @@ -167,16 +167,16 @@ int create_and_bind(const char *addr, const char *port) static void server_recv_cb(EV_P_ ev_io *w, int revents) { - struct server_ctx *server_recv_ctx = (struct server_ctx *)w; - struct server *server = server_recv_ctx->server; - struct remote *remote = server->remote; + server_ctx_t *server_recv_ctx = (server_ctx_t *)w; + server_t *server = server_recv_ctx->server; + remote_t *remote = server->remote; if (remote == NULL) { close_and_free_server(EV_A_ server); return; } - ssize_t r = recv(server->fd, remote->buf, BUF_SIZE, 0); + ssize_t r = recv(server->fd, remote->buf->array, BUF_SIZE, 0); if (r == 0) { // connection closed @@ -196,26 +196,27 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } } + remote->buf->len = r; + if (auth) { - remote->buf = ss_gen_hash(remote->buf, &r, &remote->counter, server->e_ctx, BUF_SIZE); + ss_gen_hash(remote->buf, &remote->counter, server->e_ctx); } - remote->buf = ss_encrypt(BUF_SIZE, remote->buf, &r, server->e_ctx); + int err = ss_encrypt(remote->buf, server->e_ctx); - if (remote->buf == NULL) { + if (err) { LOGE("invalid password or cipher"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } - int s = send(remote->fd, remote->buf, r, 0); + int s = send(remote->fd, remote->buf->array, remote->buf->len, 0); if (s == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // no data, wait for send - remote->buf_len = r; - remote->buf_idx = 0; + remote->buf->idx = 0; ev_io_stop(EV_A_ & server_recv_ctx->io); ev_io_start(EV_A_ & remote->send_ctx->io); return; @@ -225,9 +226,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); return; } - } else if (s < r) { - remote->buf_len = r - s; - remote->buf_idx = s; + } else if (s < remote->buf->len) { + remote->buf->len -= s; + remote->buf->idx = s; ev_io_stop(EV_A_ & server_recv_ctx->io); ev_io_start(EV_A_ & remote->send_ctx->io); return; @@ -236,18 +237,18 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) static void server_send_cb(EV_P_ ev_io *w, int revents) { - struct server_ctx *server_send_ctx = (struct server_ctx *)w; - struct server *server = server_send_ctx->server; - struct remote *remote = server->remote; - if (server->buf_len == 0) { + server_ctx_t *server_send_ctx = (server_ctx_t *)w; + server_t *server = server_send_ctx->server; + remote_t *remote = server->remote; + if (server->buf->len == 0) { // close and free close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } else { // has data to send - ssize_t s = send(server->fd, server->buf + server->buf_idx, - server->buf_len, 0); + ssize_t s = send(server->fd, server->buf->array + server->buf->idx, + server->buf->len, 0); if (s < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK) { ERROR("send"); @@ -255,15 +256,15 @@ static void server_send_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); } return; - } else if (s < server->buf_len) { + } else if (s < server->buf->len) { // partly sent, move memory, wait for the next time to send - server->buf_len -= s; - server->buf_idx += s; + server->buf->len -= s; + server->buf->idx += s; return; } else { // all sent out, wait for reading - server->buf_len = 0; - server->buf_idx = 0; + server->buf->len = 0; + server->buf->idx = 0; ev_io_stop(EV_A_ & server_send_ctx->io); if (remote != NULL) { ev_io_start(EV_A_ & remote->recv_ctx->io); @@ -279,10 +280,10 @@ static void server_send_cb(EV_P_ ev_io *w, int revents) static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) { - struct remote_ctx *remote_ctx = (struct remote_ctx *)(((void *)watcher) + remote_ctx_t *remote_ctx = (remote_ctx_t *)(((void *)watcher) - sizeof(ev_io)); - struct remote *remote = remote_ctx->remote; - struct server *server = remote->server; + remote_t *remote = remote_ctx->remote; + server_t *server = remote->server; if (verbose) { LOGI("TCP connection timeout"); @@ -296,11 +297,11 @@ static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) static void remote_recv_cb(EV_P_ ev_io *w, int revents) { - struct remote_ctx *remote_recv_ctx = (struct remote_ctx *)w; - struct remote *remote = remote_recv_ctx->remote; - struct server *server = remote->server; + remote_ctx_t *remote_recv_ctx = (remote_ctx_t *)w; + remote_t *remote = remote_recv_ctx->remote; + server_t *server = remote->server; - ssize_t r = recv(remote->fd, server->buf, BUF_SIZE, 0); + ssize_t r = recv(remote->fd, server->buf->array, BUF_SIZE, 0); if (r == 0) { // connection closed @@ -320,22 +321,23 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) } } - server->buf = ss_decrypt(BUF_SIZE, server->buf, &r, server->d_ctx); + server->buf->len = r; + + int err = ss_decrypt(server->buf, server->d_ctx); - if (server->buf == NULL) { + if (err) { LOGE("invalid password or cipher"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } - int s = send(server->fd, server->buf, r, 0); + int s = send(server->fd, server->buf->array, server->buf->len, 0); if (s == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // no data, wait for send - server->buf_len = r; - server->buf_idx = 0; + server->buf->idx = 0; ev_io_stop(EV_A_ & remote_recv_ctx->io); ev_io_start(EV_A_ & server->send_ctx->io); return; @@ -345,9 +347,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); return; } - } else if (s < r) { - server->buf_len = r - s; - server->buf_idx = s; + } else if (s < server->buf->len) { + server->buf->len -= s; + server->buf->idx = s; ev_io_stop(EV_A_ & remote_recv_ctx->io); ev_io_start(EV_A_ & server->send_ctx->io); return; @@ -356,9 +358,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) static void remote_send_cb(EV_P_ ev_io *w, int revents) { - struct remote_ctx *remote_send_ctx = (struct remote_ctx *)w; - struct remote *remote = remote_send_ctx->remote; - struct server *server = remote->server; + remote_ctx_t *remote_send_ctx = (remote_ctx_t *)w; + remote_t *remote = remote_send_ctx->remote; + server_t *server = remote->server; if (!remote_send_ctx->connected) { struct sockaddr_storage addr; @@ -369,8 +371,10 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) remote_send_ctx->connected = 1; ev_io_stop(EV_A_ & remote_send_ctx->io); ev_timer_stop(EV_A_ & remote_send_ctx->watcher); - char *ss_addr_to_send = malloc(BUF_SIZE); - ssize_t addr_len = 0; + + buffer_t ss_addr_to_send; + buffer_t *abuf = &ss_addr_to_send; + balloc(abuf, BUF_SIZE); ss_addr_t *sa = &server->destaddr; struct cork_ip ip; @@ -383,9 +387,9 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) if (dns_pton(AF_INET, sa->host, &host) == -1) { FATAL("IP parser error"); } - ss_addr_to_send[addr_len++] = 1; - memcpy(ss_addr_to_send + addr_len, &host, host_len); - addr_len += host_len; + abuf->array[abuf->len++] = 1; + memcpy(abuf->array + abuf->len, &host, host_len); + abuf->len += host_len; } else if (ip.version == 6) { // send as IPv6 struct in6_addr host; @@ -394,9 +398,9 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) if (dns_pton(AF_INET6, sa->host, &host) == -1) { FATAL("IP parser error"); } - ss_addr_to_send[addr_len++] = 4; - memcpy(ss_addr_to_send + addr_len, &host, host_len); - addr_len += host_len; + abuf->array[abuf->len++] = 4; + memcpy(abuf->array + abuf->len, &host, host_len); + abuf->len += host_len; } else { FATAL("IP parser error"); } @@ -404,35 +408,35 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) // send as domain int host_len = strlen(sa->host); - ss_addr_to_send[addr_len++] = 3; - ss_addr_to_send[addr_len++] = host_len; - memcpy(ss_addr_to_send + addr_len, sa->host, host_len); - addr_len += host_len; + abuf->array[abuf->len++] = 3; + abuf->array[abuf->len++] = host_len; + memcpy(abuf->array + abuf->len, sa->host, host_len); + abuf->len += host_len; } uint16_t port = htons(atoi(sa->port)); - memcpy(ss_addr_to_send + addr_len, &port, 2); - addr_len += 2; + memcpy(abuf->array + abuf->len, &port, 2); + abuf->len += 2; if (auth) { - ss_addr_to_send[0] |= ONETIMEAUTH_FLAG; - ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx->evp.iv); - addr_len += ONETIMEAUTH_BYTES; + abuf->array[0] |= ONETIMEAUTH_FLAG; + ss_onetimeauth(abuf, server->e_ctx->evp.iv); } - ss_addr_to_send = ss_encrypt(BUF_SIZE, ss_addr_to_send, &addr_len, - server->e_ctx); - if (ss_addr_to_send == NULL) { + int err = ss_encrypt(abuf, server->e_ctx); + if (err) { + bfree(abuf); LOGE("invalid password or cipher"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } - int s = send(remote->fd, ss_addr_to_send, addr_len, 0); - free(ss_addr_to_send); + int s = send(remote->fd, abuf->array, abuf->len, 0); + + bfree(abuf); - if (s < addr_len) { + if (s < abuf->len) { LOGE("failed to send addr"); close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); @@ -451,15 +455,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) return; } } else { - if (remote->buf_len == 0) { + if (remote->buf->len == 0) { // close and free close_and_free_remote(EV_A_ remote); close_and_free_server(EV_A_ server); return; } else { // has data to send - ssize_t s = send(remote->fd, remote->buf + remote->buf_idx, - remote->buf_len, 0); + ssize_t s = send(remote->fd, remote->buf->array + remote->buf->idx, + remote->buf->len, 0); if (s < 0) { if (errno != EAGAIN && errno != EWOULDBLOCK) { ERROR("send"); @@ -468,15 +472,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) close_and_free_server(EV_A_ server); } return; - } else if (s < remote->buf_len) { + } else if (s < remote->buf->len) { // partly sent, move memory, wait for the next time to send - remote->buf_len -= s; - remote->buf_idx += s; + remote->buf->len -= s; + remote->buf->idx += s; return; } else { // all sent out, wait for reading - remote->buf_len = 0; - remote->buf_idx = 0; + remote->buf->len = 0; + remote->buf->idx = 0; ev_io_stop(EV_A_ & remote_send_ctx->io); ev_io_start(EV_A_ & server->recv_ctx->io); } @@ -485,16 +489,16 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents) } } -static struct remote * new_remote(int fd, int timeout) +static remote_t * new_remote(int fd, int timeout) { - struct remote *remote; - remote = malloc(sizeof(struct remote)); + remote_t *remote; + remote = malloc(sizeof(remote_t)); - memset(remote, 0, sizeof(struct remote)); + memset(remote, 0, sizeof(remote_t)); remote->buf = malloc(BUF_SIZE); - remote->recv_ctx = malloc(sizeof(struct remote_ctx)); - remote->send_ctx = malloc(sizeof(struct remote_ctx)); + remote->recv_ctx = malloc(sizeof(remote_ctx_t)); + remote->send_ctx = malloc(sizeof(remote_ctx_t)); remote->fd = fd; ev_io_init(&remote->recv_ctx->io, remote_recv_cb, fd, EV_READ); ev_io_init(&remote->send_ctx->io, remote_send_cb, fd, EV_WRITE); @@ -504,18 +508,19 @@ static struct remote * new_remote(int fd, int timeout) remote->recv_ctx->connected = 0; remote->send_ctx->remote = remote; remote->send_ctx->connected = 0; - remote->buf_len = 0; - remote->buf_idx = 0; + remote->buf->len = 0; + remote->buf->idx = 0; return remote; } -static void free_remote(struct remote *remote) +static void free_remote(remote_t *remote) { if (remote != NULL) { if (remote->server != NULL) { remote->server->remote = NULL; } if (remote->buf) { + bfree(remote->buf); free(remote->buf); } free(remote->recv_ctx); @@ -524,7 +529,7 @@ static void free_remote(struct remote *remote) } } -static void close_and_free_remote(EV_P_ struct remote *remote) +static void close_and_free_remote(EV_P_ remote_t *remote) { if (remote != NULL) { ev_timer_stop(EV_A_ & remote->send_ctx->watcher); @@ -535,13 +540,13 @@ static void close_and_free_remote(EV_P_ struct remote *remote) } } -static struct server * new_server(int fd, int method) +static server_t * new_server(int fd, int method) { - struct server *server; - server = malloc(sizeof(struct server)); + server_t *server; + server = malloc(sizeof(server_t)); server->buf = malloc(BUF_SIZE); - server->recv_ctx = malloc(sizeof(struct server_ctx)); - server->send_ctx = malloc(sizeof(struct server_ctx)); + server->recv_ctx = malloc(sizeof(server_ctx_t)); + server->send_ctx = malloc(sizeof(server_ctx_t)); server->fd = fd; ev_io_init(&server->recv_ctx->io, server_recv_cb, fd, EV_READ); ev_io_init(&server->send_ctx->io, server_send_cb, fd, EV_WRITE); @@ -558,12 +563,12 @@ static struct server * new_server(int fd, int method) server->e_ctx = NULL; server->d_ctx = NULL; } - server->buf_len = 0; - server->buf_idx = 0; + server->buf->len = 0; + server->buf->idx = 0; return server; } -static void free_server(struct server *server) +static void free_server(server_t *server) { if (server != NULL) { if (server->remote != NULL) { @@ -578,6 +583,7 @@ static void free_server(struct server *server) free(server->d_ctx); } if (server->buf) { + bfree(server->buf); free(server->buf); } free(server->recv_ctx); @@ -586,7 +592,7 @@ static void free_server(struct server *server) } } -static void close_and_free_server(EV_P_ struct server *server) +static void close_and_free_server(EV_P_ server_t *server) { if (server != NULL) { ev_io_stop(EV_A_ & server->send_ctx->io); @@ -643,8 +649,8 @@ static void accept_cb(EV_P_ ev_io *w, int revents) } #endif - struct server *server = new_server(serverfd, listener->method); - struct remote *remote = new_remote(remotefd, listener->timeout); + server_t *server = new_server(serverfd, listener->method); + remote_t *remote = new_remote(remotefd, listener->timeout); server->destaddr = listener->tunnel_addr; server->remote = remote; remote->server = server; diff --git a/src/tunnel.h b/src/tunnel.h index 16f2b0b3..c33277fd 100644 --- a/src/tunnel.h +++ b/src/tunnel.h @@ -29,7 +29,7 @@ #include "common.h" -struct listen_ctx { +typedef struct listen_ctx { ev_io io; ss_addr_t tunnel_addr; char *iface; @@ -38,43 +38,39 @@ struct listen_ctx { int timeout; int fd; struct sockaddr **remote_addr; -}; +} listen_ctx_t; -struct server_ctx { +typedef struct server_ctx { ev_io io; int connected; struct server *server; -}; +} server_ctx_t; -struct server { +typedef struct server { int fd; - ssize_t buf_len; - ssize_t buf_idx; - char *buf; // server send from, remote recv into + buffer_t *buf; struct enc_ctx *e_ctx; struct enc_ctx *d_ctx; struct server_ctx *recv_ctx; struct server_ctx *send_ctx; struct remote *remote; ss_addr_t destaddr; -}; +} server_t; -struct remote_ctx { +typedef struct remote_ctx { ev_io io; ev_timer watcher; int connected; struct remote *remote; -}; +} remote_ctx_t; -struct remote { +typedef struct remote { int fd; - ssize_t buf_len; - ssize_t buf_idx; - char *buf; // remote send from, server recv into + buffer_t *buf; struct remote_ctx *recv_ctx; struct remote_ctx *send_ctx; struct server *server; uint32_t counter; -}; +} remote_t; #endif // _TUNNEL_H diff --git a/src/udprelay.c b/src/udprelay.c index da3b8f15..c25dc7e6 100644 --- a/src/udprelay.c +++ b/src/udprelay.c @@ -91,8 +91,8 @@ static char *hash_key(const int af, const struct sockaddr_storage *addr); #ifdef UDPRELAY_REMOTE static void query_resolve_cb(struct sockaddr *addr, void *data); #endif -static void close_and_free_remote(EV_P_ struct remote_ctx *ctx); -static struct remote_ctx * new_remote(int fd, struct server_ctx * server_ctx); +static void close_and_free_remote(EV_P_ remote_ctx_t *ctx); +static remote_ctx_t * new_remote(int fd, server_ctx_t * server_ctx); extern int verbose; extern int vpn; @@ -102,7 +102,7 @@ extern uint64_t rx; #endif static int server_num = 0; -static struct server_ctx *server_ctx_list[MAX_REMOTE_NUM] = { NULL }; +static server_ctx_t *server_ctx_list[MAX_REMOTE_NUM] = { NULL }; #ifndef __MINGW32__ static int setnonblocking(int fd) @@ -467,10 +467,10 @@ int create_server_socket(const char *host, const char *port) return server_sock; } -struct remote_ctx *new_remote(int fd, struct server_ctx *server_ctx) +remote_ctx_t *new_remote(int fd, server_ctx_t *server_ctx) { - struct remote_ctx *ctx = malloc(sizeof(struct remote_ctx)); - memset(ctx, 0, sizeof(struct remote_ctx)); + remote_ctx_t *ctx = malloc(sizeof(remote_ctx_t)); + memset(ctx, 0, sizeof(remote_ctx_t)); ctx->fd = fd; ctx->server_ctx = server_ctx; @@ -480,23 +480,23 @@ struct remote_ctx *new_remote(int fd, struct server_ctx *server_ctx) return ctx; } -struct server_ctx * new_server_ctx(int fd) +server_ctx_t * new_server_ctx(int fd) { - struct server_ctx *ctx = malloc(sizeof(struct server_ctx)); - memset(ctx, 0, sizeof(struct server_ctx)); + server_ctx_t *ctx = malloc(sizeof(server_ctx_t)); + memset(ctx, 0, sizeof(server_ctx_t)); ctx->fd = fd; ev_io_init(&ctx->io, server_recv_cb, fd, EV_READ); return ctx; } #ifdef UDPRELAY_REMOTE -struct query_ctx *new_query_ctx(const char *buf, const int buf_len) +struct query_ctx *new_query_ctx(char *buf, size_t len) { struct query_ctx *ctx = malloc(sizeof(struct query_ctx)); memset(ctx, 0, sizeof(struct query_ctx)); - ctx->buf = malloc(buf_len); - ctx->buf_len = buf_len; - memcpy(ctx->buf, buf, buf_len); + ctx->buf = malloc(sizeof(buffer_t)); + balloc(ctx->buf, len); + memcpy(ctx->buf->array, buf, len); return ctx; } @@ -508,6 +508,7 @@ void close_and_free_query(EV_P_ struct query_ctx *ctx) ctx->query = NULL; } if (ctx->buf != NULL) { + bfree(ctx->buf); free(ctx->buf); } free(ctx); @@ -516,7 +517,7 @@ void close_and_free_query(EV_P_ struct query_ctx *ctx) #endif -void close_and_free_remote(EV_P_ struct remote_ctx *ctx) +void close_and_free_remote(EV_P_ remote_ctx_t *ctx) { if (ctx != NULL) { ev_timer_stop(EV_A_ & ctx->watcher); @@ -528,7 +529,7 @@ void close_and_free_remote(EV_P_ struct remote_ctx *ctx) static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) { - struct remote_ctx *remote_ctx = (struct remote_ctx *)(((void *)watcher) + remote_ctx_t *remote_ctx = (remote_ctx_t *)(((void *)watcher) - sizeof(ev_io)); if (verbose) { @@ -554,7 +555,7 @@ static void query_resolve_cb(struct sockaddr *addr, void *data) if (addr == NULL) { LOGE("[udp] udns returned an error"); } else { - struct remote_ctx *remote_ctx = query_ctx->remote_ctx; + remote_ctx_t *remote_ctx = query_ctx->remote_ctx; int cache_hit = 0; // Lookup in the conn cache @@ -593,7 +594,7 @@ static void query_resolve_cb(struct sockaddr *addr, void *data) if (remote_ctx != NULL) { size_t addr_len = get_sockaddr_len(addr); - int s = sendto(remote_ctx->fd, query_ctx->buf, query_ctx->buf_len, + int s = sendto(remote_ctx->fd, query_ctx->buf->array, query_ctx->buf->len, 0, addr, addr_len); if (s == -1) { @@ -621,8 +622,8 @@ static void query_resolve_cb(struct sockaddr *addr, void *data) static void remote_recv_cb(EV_P_ ev_io *w, int revents) { - struct remote_ctx *remote_ctx = (struct remote_ctx *)w; - struct server_ctx *server_ctx = remote_ctx->server_ctx; + remote_ctx_t *remote_ctx = (remote_ctx_t *)w; + server_ctx_t *server_ctx = remote_ctx->server_ctx; // server has been closed if (server_ctx == NULL) { @@ -638,12 +639,14 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) struct sockaddr_storage src_addr; socklen_t src_addr_len = sizeof(src_addr); memset(&src_addr, 0, src_addr_len); - char *buf = malloc(BUF_SIZE); + + buffer_t *buf = malloc(sizeof(buffer_t)); + balloc(buf, BUF_SIZE); // recv - ssize_t buf_len = recvfrom(remote_ctx->fd, buf, BUF_SIZE, 0, (struct sockaddr *)&src_addr, &src_addr_len); + buf->len = recvfrom(remote_ctx->fd, buf->array, BUF_SIZE, 0, (struct sockaddr *)&src_addr, &src_addr_len); - if (buf_len == -1) { + if (buf->len == -1) { // error on recv // simply drop that packet ERROR("[udp] remote_recvfrom"); @@ -651,13 +654,13 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) } // packet size > default MTU - if (verbose && buf_len > MTU) { - LOGE("[udp] possible ip fragment, size: %d", (int)buf_len); + if (verbose && buf->len > MTU) { + LOGE("[udp] possible ip fragment, size: %d", (int)buf->len); } #ifdef UDPRELAY_LOCAL - buf = ss_decrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, 0); - if (buf == NULL) { + int err = ss_decrypt_all(buf, server_ctx->method, 0); + if (err) { // drop the packet silently goto CLEAN_UP; } @@ -665,14 +668,14 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) #ifdef UDPRELAY_REDIR struct sockaddr_storage dst_addr; memset(&dst_addr, 0, sizeof(struct sockaddr_storage)); - int len = parse_udprealy_header(buf, buf_len, NULL, NULL, NULL, &dst_addr); + int len = parse_udprealy_header(buf->array, buf->len, NULL, NULL, NULL, &dst_addr); if (dst_addr.ss_family != AF_INET && dst_addr.ss_family != AF_INET6) { LOGI("[udp] ss-redir does not support domain name"); goto CLEAN_UP; } #else - int len = parse_udprealy_header(buf, buf_len, NULL, NULL, NULL, NULL); + int len = parse_udprealy_header(buf->array, buf->len, NULL, NULL, NULL, NULL); #endif if (len == 0) { @@ -686,22 +689,20 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) #if defined(UDPRELAY_TUNNEL) || defined(UDPRELAY_REDIR) // Construct packet - buf_len -= len; - memmove(buf, buf + len, buf_len); + buf->len -= len; + memmove(buf->array, buf->array + len, buf->len); #else // Construct packet - if (BUF_SIZE < buf_len + 3) { - buf = realloc(buf, buf_len + 3); - } - memmove(buf + 3, buf, buf_len); - memset(buf, 0, 3); - buf_len += 3; + brealloc(buf, buf->len + 3, BUF_SIZE); + memmove(buf->array + 3, buf->array, buf->len); + memset(buf->array, 0, 3); + buf->len += 3; #endif #endif #ifdef UDPRELAY_REMOTE - rx += buf_len; + rx += buf->len; char addr_header_buf[256]; char *addr_header = remote_ctx->addr_header; @@ -713,14 +714,16 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) } // Construct packet - if (BUF_SIZE < buf_len + addr_header_len) { - buf = realloc(buf, buf_len + addr_header_len); - } - memmove(buf + addr_header_len, buf, buf_len); - memcpy(buf, addr_header, addr_header_len); - buf_len += addr_header_len; + brealloc(buf, buf->len + addr_header_len, BUF_SIZE); + memmove(buf->array + addr_header_len, buf->array, buf->len); + memcpy(buf->array, addr_header, addr_header_len); + buf->len += addr_header_len; - buf = ss_encrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, 0); + int err = ss_encrypt_all(buf, server_ctx->method, 0); + if (err) { + // drop the packet silently + goto CLEAN_UP; + } #endif size_t remote_src_addr_len = get_sockaddr_len((struct sockaddr *)&remote_ctx->src_addr); @@ -750,7 +753,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) goto CLEAN_UP; } - int s = sendto(src_fd, buf, buf_len, 0, + int s = sendto(src_fd, buf->array, buf->len, 0, (struct sockaddr *)&remote_ctx->src_addr, remote_src_addr_len); if (s == -1) { ERROR("[udp] remote_recv_sendto"); @@ -760,7 +763,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) close(src_fd); #else - int s = sendto(server_ctx->fd, buf, buf_len, 0, + int s = sendto(server_ctx->fd, buf->array, buf->len, 0, (struct sockaddr *)&remote_ctx->src_addr, remote_src_addr_len); if (s == -1) { ERROR("[udp] remote_recv_sendto"); @@ -774,16 +777,18 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents) CLEAN_UP: + bfree(buf); free(buf); - } static void server_recv_cb(EV_P_ ev_io *w, int revents) { - struct server_ctx *server_ctx = (struct server_ctx *)w; + server_ctx_t *server_ctx = (server_ctx_t *)w; struct sockaddr_storage src_addr; memset(&src_addr, 0, sizeof(struct sockaddr_storage)); - char *buf = malloc(BUF_SIZE); + + buffer_t *buf = malloc(sizeof(buffer_t)); + balloc(buf, BUF_SIZE); socklen_t src_addr_len = sizeof(struct sockaddr_storage); unsigned int offset = 0; @@ -800,7 +805,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) msg.msg_control = control_buffer; msg.msg_controllen = sizeof(control_buffer); - iov[0].iov_base = buf; + iov[0].iov_base = buf->array; iov[0].iov_len = BUF_SIZE; msg.msg_iov = iov; msg.msg_iovlen = 1; @@ -818,11 +823,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) src_addr_len = msg.msg_namelen; #else - ssize_t buf_len = - recvfrom(server_ctx->fd, buf, BUF_SIZE, 0, (struct sockaddr *)&src_addr, - &src_addr_len); + buf->len = recvfrom(server_ctx->fd, buf->array, BUF_SIZE, + 0, (struct sockaddr *)&src_addr, &src_addr_len); - if (buf_len == -1) { + if (buf->len == -1) { // error on recv // simply drop that packet ERROR("[udp] server_recvfrom"); @@ -836,10 +840,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) #ifdef UDPRELAY_REMOTE - tx += buf_len; + tx += buf->len; - buf = ss_decrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, server_ctx->auth); - if (buf == NULL) { + int err = ss_decrypt_all(buf, server_ctx->method, server_ctx->auth); + if (err) { // drop the packet silently goto CLEAN_UP; } @@ -847,14 +851,14 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) #ifdef UDPRELAY_LOCAL #if !defined(UDPRELAY_TUNNEL) && !defined(UDPRELAY_REDIR) - uint8_t frag = *(uint8_t *)(buf + 2); + uint8_t frag = *(uint8_t *)(buf->array + 2); offset += 3; #endif #endif // packet size > default MTU - if (verbose && buf_len > MTU) { - LOGE("[udp] possible ip fragment, size: %d", (int)buf_len); + if (verbose && buf->len > MTU) { + LOGE("[udp] possible ip fragment, size: %d", (int)buf->len); } /* @@ -911,12 +915,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } // reconstruct the buffer - if (BUF_SIZE < buf_len + addr_header_len) { - buf = realloc(buf, buf_len + addr_header_len); - } - memmove(buf + addr_header_len, buf, buf_len); - memcpy(buf, addr_header, addr_header_len); - buf_len += addr_header_len; + brealloc(buf, buf->len + addr_header_len, BUF_SIZE); + memmove(buf->array + addr_header_len, buf->array, buf->len); + memcpy(buf->array, addr_header, addr_header_len); + buf->len += addr_header_len; char *key = hash_key(dst_addr.ss_family, &src_addr); @@ -969,12 +971,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) addr_header_len += 2; // reconstruct the buffer - if (BUF_SIZE < buf_len + addr_header_len) { - buf = realloc(buf, buf_len + addr_header_len); - } - memmove(buf + addr_header_len, buf, buf_len); - memcpy(buf, addr_header, addr_header_len); - buf_len += addr_header_len; + brealloc(buf, buf->len + addr_header_len, BUF_SIZE); + memmove(buf->array + addr_header_len, buf->array, buf->len); + memcpy(buf->array, addr_header, addr_header_len); + buf->len += addr_header_len; char *key = hash_key(ip.version == 4 ? AF_INET : AF_INET6, &src_addr); @@ -985,21 +985,21 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) struct sockaddr_storage dst_addr; memset(&dst_addr, 0, sizeof(struct sockaddr_storage)); - int addr_header_len = parse_udprealy_header(buf + offset, buf_len - offset, + int addr_header_len = parse_udprealy_header(buf->array + offset, buf->len - offset, &server_ctx->auth, host, port, &dst_addr); if (addr_header_len == 0) { // error in parse header goto CLEAN_UP; } - char *addr_header = buf + offset; + char *addr_header = buf->array + offset; char *key = hash_key(dst_addr.ss_family, &src_addr); #endif struct cache *conn_cache = server_ctx->conn_cache; - struct remote_ctx *remote_ctx = NULL; + remote_ctx_t *remote_ctx = NULL; cache_lookup(conn_cache, key, HASH_KEY_LEN, (void *)&remote_ctx); if (remote_ctx != NULL) { @@ -1098,17 +1098,22 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) } if (offset > 0) { - buf_len -= offset; - memmove(buf, buf + offset, buf_len); + buf->len -= offset; + memmove(buf->array, buf->array + offset, buf->len); } if (server_ctx->auth) { - buf[0] |= ONETIMEAUTH_FLAG; + buf->array[0] |= ONETIMEAUTH_FLAG; } - buf = ss_encrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, server_ctx->auth); + int err = ss_encrypt_all(buf, server_ctx->method, server_ctx->auth); + + if (err) { + // drop the packet silently + goto CLEAN_UP; + } - int s = sendto(remote_ctx->fd, buf, buf_len, 0, remote_addr, remote_addr_len); + int s = sendto(remote_ctx->fd, buf->array, buf->len, 0, remote_addr, remote_addr_len); if (s == -1) { ERROR("[udp] sendto_remote"); @@ -1158,8 +1163,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) if (remote_ctx != NULL && !need_query) { size_t addr_len = get_sockaddr_len((struct sockaddr *)&dst_addr); - int s = sendto(remote_ctx->fd, buf + addr_header_len, - buf_len - addr_header_len, 0, + int s = sendto(remote_ctx->fd, buf->array + addr_header_len, + buf->len - addr_header_len, 0, (struct sockaddr *)&dst_addr, addr_len); if (s == -1) { @@ -1185,9 +1190,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) hints.ai_socktype = SOCK_DGRAM; hints.ai_protocol = IPPROTO_UDP; - struct query_ctx *query_ctx = new_query_ctx(buf + addr_header_len, - buf_len - - addr_header_len); + struct query_ctx *query_ctx = new_query_ctx(buf->array + addr_header_len, + buf->len - addr_header_len); query_ctx->server_ctx = server_ctx; query_ctx->addr_header_len = addr_header_len; query_ctx->src_addr = src_addr; @@ -1209,12 +1213,13 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents) #endif CLEAN_UP: + bfree(buf); free(buf); } void free_cb(void *element) { - struct remote_ctx *remote_ctx = (struct remote_ctx *)element; + remote_ctx_t *remote_ctx = (remote_ctx_t *)element; if (verbose) { LOGI("[udp] one connection freed"); @@ -1249,7 +1254,7 @@ int init_udprelay(const char *server_host, const char *server_port, } setnonblocking(serverfd); - struct server_ctx *server_ctx = new_server_ctx(serverfd); + server_ctx_t *server_ctx = new_server_ctx(serverfd); #ifdef UDPRELAY_REMOTE server_ctx->loop = loop; #endif @@ -1277,7 +1282,7 @@ void free_udprelay() { struct ev_loop *loop = EV_DEFAULT; while (server_num-- > 0) { - struct server_ctx *server_ctx = server_ctx_list[server_num]; + server_ctx_t *server_ctx = server_ctx_list[server_num]; ev_io_stop(loop, &server_ctx->io); close(server_ctx->fd); cache_delete(server_ctx->conn_cache, 0); diff --git a/src/udprelay.h b/src/udprelay.h index 5d5e1b1a..6f8bcc2f 100644 --- a/src/udprelay.h +++ b/src/udprelay.h @@ -41,7 +41,7 @@ #define MTU 1397 // 1492 - 1 - 28 - 2 - 64 = 1397, the default MTU for UDP relay -struct server_ctx { +typedef struct server_ctx { ev_io io; int fd; int method; @@ -59,22 +59,21 @@ struct server_ctx { #ifdef UDPRELAY_REMOTE struct ev_loop *loop; #endif -}; +} server_ctx_t; #ifdef UDPRELAY_REMOTE -struct query_ctx { +typedef struct query_ctx { struct ResolvQuery *query; struct sockaddr_storage src_addr; - int buf_len; - char *buf; // server send from, remote recv into + buffer_t *buf; int addr_header_len; char addr_header[384]; struct server_ctx *server_ctx; struct remote_ctx *remote_ctx; -}; +} query_ctx_t; #endif -struct remote_ctx { +typedef struct remote_ctx { ev_io io; ev_timer watcher; int af; @@ -83,16 +82,6 @@ struct remote_ctx { char addr_header[384]; struct sockaddr_storage src_addr; struct server_ctx *server_ctx; -}; - -#ifdef ANDROID -struct protect_ctx { - int buf_len; - char *buf; - struct sockaddr_storage addr; - int addr_len; - struct remote_ctx *remote_ctx; -}; -#endif +} remote_ctx_t; #endif // _UDPRELAY_H