Browse Source

Refine buffer handling

pull/524/head
Max Lv 9 years ago
parent
commit
f4fa30da5c
12 changed files with 876 additions and 883 deletions
  1. 393
      src/encrypt.c
  2. 44
      src/encrypt.h
  3. 274
      src/local.c
  4. 28
      src/local.h
  5. 225
      src/redir.c
  6. 28
      src/redir.h
  7. 299
      src/server.c
  8. 28
      src/server.h
  9. 206
      src/tunnel.c
  10. 28
      src/tunnel.h
  11. 181
      src/udprelay.c
  12. 25
      src/udprelay.h

393
src/encrypt.c

@ -201,6 +201,35 @@ static const int supported_ciphers_key_size[CIPHER_NUM] =
0, 16, 16, 16, 24, 32, 16, 16, 24, 32, 16, 8, 16, 16, 16, 32, 32
};
int balloc(buffer_t *ptr, size_t capacity)
{
memset(ptr, 0 , sizeof(buffer_t));
ptr->array = malloc(capacity);
ptr->capacity = capacity;
return capacity;
}
int brealloc(buffer_t *ptr, size_t len, size_t capacity)
{
int real_capacity = max(len, capacity);
if (ptr->capacity < real_capacity) {
ptr->array = realloc(ptr->array, real_capacity);
}
ptr->capacity = real_capacity;
return real_capacity;
}
void bfree(buffer_t *ptr)
{
ptr->idx = 0;
ptr->len = 0;
ptr->capacity = 0;
if (ptr->array != NULL) {
free(ptr->array);
ptr->array = NULL;
}
}
static int crypto_stream_xor_ic(uint8_t *c, const uint8_t *m, uint64_t mlen,
const uint8_t *n, uint64_t ic, const uint8_t *k,
int method)
@ -1030,25 +1059,28 @@ static int cipher_context_update(cipher_ctx_t *ctx, uint8_t *output, size_t *ole
#endif
}
int ss_onetimeauth(char *auth, char *msg, int msg_len, uint8_t *iv)
int ss_onetimeauth(buffer_t *buf, uint8_t *iv)
{
uint8_t hash[ONETIMEAUTH_BYTES * 2];
uint8_t auth_key[MAX_IV_LENGTH + MAX_KEY_LENGTH];
memcpy(auth_key, iv, enc_iv_len);
memcpy(auth_key + enc_iv_len, enc_key, enc_key_len);
brealloc(buf, ONETIMEAUTH_BYTES + buf->len, buf->capacity);
#if defined(USE_CRYPTO_OPENSSL)
HMAC(EVP_sha1(), auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, (uint8_t *)hash, NULL);
HMAC(EVP_sha1(), auth_key, enc_iv_len + enc_key_len, (uint8_t *)buf->array, buf->len, (uint8_t *)hash, NULL);
#else
ss_sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, (uint8_t *)hash);
ss_sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)buf->array, buf->len, (uint8_t *)hash);
#endif
memcpy(auth, hash, ONETIMEAUTH_BYTES);
memcpy(buf->array + buf->len, hash, ONETIMEAUTH_BYTES);
buf->len += ONETIMEAUTH_BYTES;
return 0;
}
int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, uint8_t *iv)
int ss_onetimeauth_verify(buffer_t *buf, uint8_t *iv)
{
uint8_t hash[ONETIMEAUTH_BYTES * 2];
uint8_t auth_key[MAX_IV_LENGTH + MAX_KEY_LENGTH];
@ -1056,219 +1088,190 @@ int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, uint8_t *iv)
memcpy(auth_key + enc_iv_len, enc_key, enc_key_len);
#if defined(USE_CRYPTO_OPENSSL)
HMAC(EVP_sha1(), auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, hash, NULL);
HMAC(EVP_sha1(), auth_key, enc_iv_len + enc_key_len, (uint8_t *)buf->array, buf->len, hash, NULL);
#else
ss_sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)msg, msg_len, hash);
ss_sha1_hmac(auth_key, enc_iv_len + enc_key_len, (uint8_t *)buf->array, buf->len, hash);
#endif
return memcmp(auth, hash, ONETIMEAUTH_BYTES);
return memcmp(buf->array + buf->len - ONETIMEAUTH_BYTES, hash, ONETIMEAUTH_BYTES);
}
char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method, int auth)
int ss_encrypt_all(buffer_t *plain, int method, int auth)
{
if (method > TABLE) {
cipher_ctx_t evp;
cipher_context_init(&evp, method, 1);
size_t p_len = *len, c_len = *len;
size_t p_len = plain->len, c_len = plain->len;
size_t iv_len = enc_iv_len;
int err = 1;
static int tmp_len = 0;
static char *tmp_buf = NULL;
int buf_len = max(iv_len + c_len, buf_size);
if (tmp_len < buf_len) {
tmp_len = buf_len;
tmp_buf = realloc(tmp_buf, buf_len);
}
char *ciphertext = tmp_buf;
static buffer_t tmp = {0};
brealloc(&tmp, iv_len + c_len, plain->capacity);
buffer_t *cipher = &tmp;
uint8_t iv[MAX_IV_LENGTH];
rand_bytes(iv, iv_len);
cipher_context_set_iv(&evp, iv, iv_len, 1);
memcpy(ciphertext, iv, iv_len);
memcpy(cipher, iv, iv_len);
if (auth) {
char hash[ONETIMEAUTH_BYTES * 2];
ss_onetimeauth(hash, plaintext, p_len, iv);
if (buf_size < ONETIMEAUTH_BYTES + p_len) {
plaintext = realloc(plaintext, ONETIMEAUTH_BYTES + p_len);
}
memcpy(plaintext + p_len, hash, ONETIMEAUTH_BYTES);
ss_onetimeauth(plain, iv);
p_len = c_len = p_len + ONETIMEAUTH_BYTES;
}
if (method >= SALSA20) {
crypto_stream_xor_ic((uint8_t *)(ciphertext + iv_len),
(const uint8_t *)plaintext, (uint64_t)(p_len),
crypto_stream_xor_ic((uint8_t *)(cipher + iv_len),
(const uint8_t *)plain->array, (uint64_t)(p_len),
(const uint8_t *)iv,
0, enc_key, method);
} else {
err = cipher_context_update(&evp, (uint8_t *)(ciphertext + iv_len),
&c_len, (const uint8_t *)plaintext,
err = cipher_context_update(&evp, (uint8_t *)(cipher + iv_len),
&c_len, (const uint8_t *)plain->array,
p_len);
}
if (!err) {
free(plaintext);
bfree(plain);
cipher_context_release(&evp);
return NULL;
return -1;
}
#ifdef DEBUG
dump("PLAIN", plaintext, *len);
dump("CIPHER", ciphertext + iv_len, c_len);
dump("PLAIN", plain->array, *len);
dump("CIPHER", cipher->array + iv_len, c_len);
#endif
cipher_context_release(&evp);
if (buf_size < iv_len + c_len) {
plaintext = realloc(plaintext, iv_len + c_len);
}
*len = iv_len + c_len;
memcpy(plaintext, ciphertext, *len);
brealloc(plain, iv_len + c_len, plain->capacity);
memcpy(plain, cipher, iv_len + c_len);
plain->len = iv_len + c_len;
return plaintext;
return 0;
} else {
char *begin = plaintext;
while (plaintext < begin + *len) {
*plaintext = (char)enc_table[(uint8_t)*plaintext];
plaintext++;
char *begin = plain->array;
char *ptr = plain->array;
while (ptr < begin + plain->len) {
*ptr = (char)enc_table[(uint8_t)*ptr];
ptr++;
}
return begin;
return 0;
}
}
char * ss_encrypt(int buf_size, char *plaintext, ssize_t *len,
struct enc_ctx *ctx)
int ss_encrypt(buffer_t *plain, enc_ctx_t *ctx)
{
if (ctx != NULL) {
static int tmp_len = 0;
static char *tmp_buf = NULL;
static buffer_t tmp = {0};
int err = 1;
size_t iv_len = 0;
size_t p_len = *len, c_len = *len;
size_t p_len = plain->len, c_len = plain->len;
if (!ctx->init) {
iv_len = enc_iv_len;
}
int buf_len = max(iv_len + c_len, buf_size);
if (tmp_len < buf_len) {
tmp_len = buf_len;
tmp_buf = realloc(tmp_buf, buf_len);
}
char *ciphertext = tmp_buf;
brealloc(&tmp, iv_len + c_len, plain->capacity);
buffer_t *cipher = &tmp;
if (!ctx->init) {
cipher_context_set_iv(&ctx->evp, ctx->evp.iv, iv_len, 1);
memcpy(ciphertext, ctx->evp.iv, iv_len);
memcpy(cipher->array, ctx->evp.iv, iv_len);
ctx->counter = 0;
ctx->init = 1;
}
if (enc_method >= SALSA20) {
int padding = ctx->counter % SODIUM_BLOCK_SIZE;
if (buf_len < iv_len + padding + c_len) {
buf_len = max(iv_len + (padding + c_len) * 2, buf_size);
ciphertext = realloc(ciphertext, buf_len);
tmp_len = buf_len;
tmp_buf = ciphertext;
}
brealloc(cipher, iv_len + (padding + c_len) * 2, cipher->capacity);
if (padding) {
plaintext = realloc(plaintext, max(p_len + padding, buf_size));
memmove(plaintext + padding, plaintext, p_len);
memset(plaintext, 0, padding);
brealloc(plain, p_len + padding, plain->capacity);
memmove(plain->array + padding, plain->array, p_len);
memset(plain->array, 0, padding);
}
crypto_stream_xor_ic((uint8_t *)(ciphertext + iv_len),
(const uint8_t *)plaintext,
crypto_stream_xor_ic((uint8_t *)(cipher->array + iv_len),
(const uint8_t *)plain->array,
(uint64_t)(p_len + padding),
(const uint8_t *)ctx->evp.iv,
ctx->counter / SODIUM_BLOCK_SIZE, enc_key,
enc_method);
ctx->counter += p_len;
if (padding) {
memmove(ciphertext + iv_len, ciphertext + iv_len + padding,
c_len);
memmove(cipher->array + iv_len,
cipher->array + iv_len + padding, c_len);
}
} else {
err =
cipher_context_update(&ctx->evp,
(uint8_t *)(ciphertext + iv_len),
&c_len, (const uint8_t *)plaintext,
(uint8_t *)(cipher->array + iv_len),
&c_len, (const uint8_t *)plain->array,
p_len);
if (!err) {
free(plaintext);
return NULL;
return -1;
}
}
#ifdef DEBUG
dump("PLAIN", plaintext, p_len);
dump("CIPHER", ciphertext + iv_len, c_len);
dump("PLAIN", plain->array, p_len);
dump("CIPHER", cipher->array + iv_len, c_len);
#endif
if (buf_size < iv_len + c_len) {
plaintext = realloc(plaintext, iv_len + c_len);
}
*len = iv_len + c_len;
memcpy(plaintext, ciphertext, *len);
brealloc(plain, iv_len + c_len, plain->capacity);
memcpy(plain->array, cipher->array, iv_len + c_len);
plain->len = iv_len + c_len;
return plaintext;
return 0;
} else {
char *begin = plaintext;
while (plaintext < begin + *len) {
*plaintext = (char)enc_table[(uint8_t)*plaintext];
plaintext++;
char *begin = plain->array;
char *ptr = plain->array;
while (ptr < begin + plain->len) {
*ptr = (char)enc_table[(uint8_t)*ptr];
ptr++;
}
return begin;
return 0;
}
}
char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method, int auth)
int ss_decrypt_all(buffer_t *cipher, int method, int auth)
{
if (method > TABLE) {
size_t iv_len = enc_iv_len;
size_t c_len = *len, p_len = *len - iv_len;
size_t c_len = cipher->len;
size_t p_len = cipher->len - iv_len;
int ret = 1;
if (*len <= iv_len) {
return NULL;
if (c_len <= iv_len) {
return -1;
}
cipher_ctx_t evp;
cipher_context_init(&evp, method, 0);
static int tmp_len = 0;
static char *tmp_buf = NULL;
int buf_len = max(p_len, buf_size);
if (tmp_len < buf_len) {
tmp_len = buf_len;
tmp_buf = realloc(tmp_buf, buf_len);
}
char *plaintext = tmp_buf;
static buffer_t tmp = {0};
brealloc(&tmp, cipher->len, cipher->capacity);
buffer_t *plain = &tmp;
uint8_t iv[MAX_IV_LENGTH];
memcpy(iv, ciphertext, iv_len);
memcpy(iv, cipher, iv_len);
cipher_context_set_iv(&evp, iv, iv_len, 0);
if (method >= SALSA20) {
crypto_stream_xor_ic((uint8_t *)plaintext,
(const uint8_t *)(ciphertext + iv_len),
crypto_stream_xor_ic((uint8_t *)plain->array,
(const uint8_t *)(cipher->array + iv_len),
(uint64_t)(c_len - iv_len),
(const uint8_t *)iv, 0, enc_key, method);
} else {
ret = cipher_context_update(&evp, (uint8_t *)plaintext, &p_len,
(const uint8_t *)(ciphertext + iv_len),
ret = cipher_context_update(&evp, (uint8_t *)plain->array, &p_len,
(const uint8_t *)(cipher->array + iv_len),
c_len - iv_len);
}
if (auth || (plaintext[0] & ONETIMEAUTH_FLAG)) {
if (auth || (plain->array[0] & ONETIMEAUTH_FLAG)) {
if (p_len > ONETIMEAUTH_BYTES) {
char hash[ONETIMEAUTH_BYTES];
memcpy(hash, plaintext + p_len - ONETIMEAUTH_BYTES, ONETIMEAUTH_BYTES);
ret = !ss_onetimeauth_verify(hash, plaintext, p_len - ONETIMEAUTH_BYTES, iv);
ret = !ss_onetimeauth_verify(plain, iv);
if (ret) {
p_len -= ONETIMEAUTH_BYTES;
}
@ -1278,65 +1281,61 @@ char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method,
}
if (!ret) {
free(ciphertext);
bfree(cipher);
cipher_context_release(&evp);
return NULL;
return -1;
}
#ifdef DEBUG
dump("PLAIN", plaintext, p_len);
dump("CIPHER", ciphertext + iv_len, c_len - iv_len);
dump("PLAIN", plain->array, p_len);
dump("CIPHER", cipher->array + iv_len, c_len - iv_len);
#endif
cipher_context_release(&evp);
if (buf_size < p_len) {
ciphertext = realloc(ciphertext, p_len);
}
*len = p_len;
memcpy(ciphertext, plaintext, *len);
brealloc(cipher, p_len, plain->capacity);
memcpy(cipher, plain, p_len);
cipher->len = p_len;
return ciphertext;
return 0;
} else {
char *begin = ciphertext;
while (ciphertext < begin + *len) {
*ciphertext = (char)dec_table[(uint8_t)*ciphertext];
ciphertext++;
char *begin = cipher->array;
char *ptr = cipher->array;
while (ptr < begin + cipher->len) {
*ptr = (char)dec_table[(uint8_t)*ptr];
ptr++;
}
return begin;
return 0;
}
}
char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len, struct enc_ctx *ctx)
int ss_decrypt(buffer_t *cipher, enc_ctx_t *ctx)
{
if (ctx != NULL) {
static int tmp_len = 0;
static char *tmp_buf = NULL;
static buffer_t tmp = {0};
size_t c_len = *len, p_len = *len;
size_t c_len = cipher->len;
size_t p_len = cipher->len;
size_t iv_len = 0;
int err = 1;
int buf_len = max(p_len, buf_size);
if (tmp_len < buf_len) {
tmp_len = buf_len;
tmp_buf = realloc(tmp_buf, buf_len);
}
char *plaintext = tmp_buf;
brealloc(&tmp, p_len, cipher->capacity);
buffer_t *plain = &tmp;
if (!ctx->init) {
uint8_t iv[MAX_IV_LENGTH];
iv_len = enc_iv_len;
p_len -= iv_len;
memcpy(iv, ciphertext, iv_len);
memcpy(iv, cipher->array, iv_len);
cipher_context_set_iv(&ctx->evp, iv, iv_len, 0);
ctx->counter = 0;
ctx->init = 1;
if (enc_method >= RC4_MD5) {
if (cache_key_exist(iv_cache, (char *)iv, iv_len)) {
free(ciphertext);
return NULL;
bfree(cipher);
return -1;
} else {
cache_insert(iv_cache, (char *)iv, iv_len, NULL);
}
@ -1345,64 +1344,59 @@ char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len, struct enc_ctx *
if (enc_method >= SALSA20) {
int padding = ctx->counter % SODIUM_BLOCK_SIZE;
if (buf_len < (p_len + padding) * 2) {
buf_len = max((p_len + padding) * 2, buf_size);
plaintext = realloc(plaintext, buf_len);
tmp_len = buf_len;
tmp_buf = plaintext;
}
brealloc(plain, (p_len + padding) * 2, plain->capacity);
if (padding) {
ciphertext = realloc(ciphertext, max(c_len + padding, buf_size));
memmove(ciphertext + iv_len + padding, ciphertext + iv_len,
brealloc(cipher, c_len + padding, cipher->capacity);
memmove(cipher->array + iv_len + padding, cipher->array + iv_len,
c_len - iv_len);
memset(ciphertext + iv_len, 0, padding);
memset(cipher->array + iv_len, 0, padding);
}
crypto_stream_xor_ic((uint8_t *)plaintext,
(const uint8_t *)(ciphertext + iv_len),
crypto_stream_xor_ic((uint8_t *)plain->array,
(const uint8_t *)(cipher->array + iv_len),
(uint64_t)(c_len - iv_len + padding),
(const uint8_t *)ctx->evp.iv,
ctx->counter / SODIUM_BLOCK_SIZE, enc_key,
enc_method);
ctx->counter += c_len - iv_len;
if (padding) {
memmove(plaintext, plaintext + padding, p_len);
memmove(plain->array, plain->array + padding, p_len);
}
} else {
err = cipher_context_update(&ctx->evp, (uint8_t *)plaintext, &p_len,
(const uint8_t *)(ciphertext + iv_len),
err = cipher_context_update(&ctx->evp, (uint8_t *)plain->array, &p_len,
(const uint8_t *)(cipher + iv_len),
c_len - iv_len);
}
if (!err) {
free(ciphertext);
return NULL;
bfree(cipher);
return -1;
}
#ifdef DEBUG
dump("PLAIN", plaintext, p_len);
dump("CIPHER", ciphertext + iv_len, c_len - iv_len);
dump("PLAIN", plain->array, p_len);
dump("CIPHER", cipher->array + iv_len, c_len - iv_len);
#endif
if (buf_size < p_len) {
ciphertext = realloc(ciphertext, p_len);
}
*len = p_len;
memcpy(ciphertext, plaintext, *len);
brealloc(cipher, p_len, cipher->capacity);
memcpy(cipher->array, plain->array, p_len);
cipher->len = p_len;
return ciphertext;
return 0;
} else {
char *begin = ciphertext;
while (ciphertext < begin + *len) {
*ciphertext = (char)dec_table[(uint8_t)*ciphertext];
ciphertext++;
char *begin = cipher->array;
char *ptr = cipher->array;
while (ptr < begin + cipher->len) {
*ptr = (char)dec_table[(uint8_t)*ptr];
ptr++;
}
return begin;
return 0;
}
}
void enc_ctx_init(int method, struct enc_ctx *ctx, int enc)
void enc_ctx_init(int method, enc_ctx_t *ctx, int enc)
{
memset(ctx, 0, sizeof(struct enc_ctx));
memset(ctx, 0, sizeof(enc_ctx_t));
cipher_context_init(&ctx->evp, method, enc);
if (enc) {
@ -1520,33 +1514,22 @@ int enc_init(const char *pass, const char *method)
return m;
}
int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct enc_ctx *ctx, int buf_size)
int ss_check_hash(buffer_t *buf, chunk_t *chunk, enc_ctx_t *ctx)
{
int i, j, k;
char *buf = *buf_ptr;
ssize_t blen = *buf_len;
ssize_t blen = buf->len;
uint32_t cidx = chunk->idx;
if (chunk->buf == NULL) {
chunk->buf = (char *)malloc(buf_size);
chunk->len = buf_size - AUTH_BYTES;
}
int size = max(chunk->len + blen, buf_size);
if (buf_size < size) {
buf = realloc(buf, size);
}
brealloc(chunk->buf, blen, buf->capacity);
brealloc(buf, chunk->len + blen, buf->capacity);
for (i = 0, j = 0, k = 0; i < blen; i++) {
chunk->buf[cidx++] = buf[k++];
chunk->buf->array[cidx++] = buf->array[k++];
if (cidx == CLEN_BYTES) {
uint16_t clen = ntohs(*((uint16_t *)chunk->buf));
if (buf_size < clen + AUTH_BYTES) {
chunk->buf = realloc(chunk->buf, clen + AUTH_BYTES);
}
uint16_t clen = ntohs(*((uint16_t *)chunk->buf->array));
brealloc(chunk->buf, clen + AUTH_BYTES, buf->capacity);
chunk->len = clen;
}
@ -1560,20 +1543,19 @@ int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct
memcpy(key + enc_iv_len, &c, sizeof(uint32_t));
#if defined(USE_CRYPTO_OPENSSL)
HMAC(EVP_sha1(), key, enc_iv_len + sizeof(uint32_t),
(uint8_t *)chunk->buf + AUTH_BYTES, chunk->len, hash, NULL);
(uint8_t *)chunk->buf->array + AUTH_BYTES, chunk->len, hash, NULL);
#else
ss_sha1_hmac(key, enc_iv_len + sizeof(uint32_t),
(uint8_t *)chunk->buf + AUTH_BYTES, chunk->len, hash);
(uint8_t *)chunk->buf->array + AUTH_BYTES, chunk->len, hash);
#endif
if (memcmp(hash, chunk->buf + CLEN_BYTES, ONETIMEAUTH_BYTES) != 0) {
*buf_ptr = buf;
if (memcmp(hash, chunk->buf->array + CLEN_BYTES, ONETIMEAUTH_BYTES) != 0) {
return 0;
}
// Copy chunk back to buffer
memmove(buf + j + chunk->len, buf + k, blen - i - 1);
memcpy(buf + j, chunk->buf + AUTH_BYTES, chunk->len);
memmove(buf->array + j + chunk->len, buf->array + k, blen - i - 1);
memcpy(buf->array + j, chunk->buf->array + AUTH_BYTES, chunk->len);
// Reset the base offset
j += chunk->len;
@ -1583,39 +1565,34 @@ int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct
}
}
*buf_ptr = buf;
*buf_len = j;
buf->len = j;
chunk->idx = cidx;
return 1;
}
char *ss_gen_hash(char *buf, ssize_t *buf_len, uint32_t *counter, struct enc_ctx *ctx, int buf_size)
int ss_gen_hash(buffer_t *buf, uint32_t *counter, enc_ctx_t *ctx)
{
ssize_t blen = *buf_len;
int size = max(AUTH_BYTES + blen, buf_size);
if (buf_size < size) {
buf = realloc(buf, size);
}
ssize_t blen = buf->len;
uint16_t chunk_len = htons((uint16_t)blen);
uint8_t hash[ONETIMEAUTH_BYTES * 2];
uint8_t key[MAX_IV_LENGTH + sizeof(uint32_t)];
uint32_t c = htonl(*counter);
brealloc(buf, AUTH_BYTES + blen, buf->capacity);
memcpy(key, ctx->evp.iv, enc_iv_len);
memcpy(key + enc_iv_len, &c, sizeof(uint32_t));
#if defined(USE_CRYPTO_OPENSSL)
HMAC(EVP_sha1(), key, enc_iv_len + sizeof(uint32_t), (uint8_t *)buf, blen, hash, NULL);
HMAC(EVP_sha1(), key, enc_iv_len + sizeof(uint32_t), (uint8_t *)buf->array, blen, hash, NULL);
#else
ss_sha1_hmac(key, enc_iv_len + sizeof(uint32_t), (uint8_t *)buf, blen, hash);
ss_sha1_hmac(key, enc_iv_len + sizeof(uint32_t), (uint8_t *)buf->array, blen, hash);
#endif
memmove(buf + AUTH_BYTES, buf, blen);
memcpy(buf + CLEN_BYTES, hash, ONETIMEAUTH_BYTES);
memcpy(buf, &chunk_len, CLEN_BYTES);
memmove(buf->array + AUTH_BYTES, buf->array, blen);
memcpy(buf->array + CLEN_BYTES, hash, ONETIMEAUTH_BYTES);
memcpy(buf->array, &chunk_len, CLEN_BYTES);
*counter = *counter + 1;
*buf_len = blen + AUTH_BYTES;
return buf;
buf->len = blen + AUTH_BYTES;
return 0;
}

44
src/encrypt.h

@ -151,35 +151,45 @@ typedef struct {
#define min(a, b) (((a) < (b)) ? (a) : (b))
#define max(a, b) (((a) > (b)) ? (a) : (b))
struct chunk {
typedef struct buffer {
size_t idx;
size_t len;
size_t capacity;
char *array;
} buffer_t;
typedef struct chunk {
uint32_t idx;
uint32_t len;
uint32_t counter;
char *buf;
};
buffer_t *buf;
} chunk_t;
struct enc_ctx {
typedef struct enc_ctx {
uint8_t init;
uint64_t counter;
cipher_ctx_t evp;
};
char * ss_encrypt_all(int buf_size, char *plaintext, ssize_t *len, int method, int auth);
char * ss_decrypt_all(int buf_size, char *ciphertext, ssize_t *len, int method, int auth);
char * ss_encrypt(int buf_size, char *plaintext, ssize_t *len,
struct enc_ctx *ctx);
char * ss_decrypt(int buf_size, char *ciphertext, ssize_t *len,
struct enc_ctx *ctx);
void enc_ctx_init(int method, struct enc_ctx *ctx, int enc);
} enc_ctx_t;
int ss_encrypt_all(buffer_t *plaintext, int method, int auth);
int ss_decrypt_all(buffer_t *ciphertext, int method, int auth);
int ss_encrypt(buffer_t *plaintext, enc_ctx_t *ctx);
int ss_decrypt(buffer_t *ciphertext, enc_ctx_t *ctx);
void enc_ctx_init(int method, enc_ctx_t *ctx, int enc);
int enc_init(const char *pass, const char *method);
int enc_get_iv_len(void);
void cipher_context_release(cipher_ctx_t *evp);
unsigned char *enc_md5(const unsigned char *d, size_t n, unsigned char *md);
int ss_onetimeauth(char *auth, char *msg, int msg_len, uint8_t *iv);
int ss_onetimeauth_verify(char *auth, char *msg, int msg_len, uint8_t *iv);
int ss_onetimeauth(buffer_t *buf, uint8_t *iv);
int ss_onetimeauth_verify(buffer_t *buf, uint8_t *iv);
int ss_check_hash(buffer_t *buf, chunk_t *chunk, enc_ctx_t *ctx);
int ss_gen_hash(buffer_t *buf, uint32_t *counter, enc_ctx_t *ctx);
int ss_check_hash(char **buf_ptr, ssize_t *buf_len, struct chunk *chunk, struct enc_ctx *ctx, int buf_size);
char *ss_gen_hash(char *buf, ssize_t *buf_len, uint32_t *counter, struct enc_ctx *ctx, int buf_size);
int balloc(buffer_t *ptr, size_t capacity);
int brealloc(buffer_t *ptr, size_t len, size_t capacity);
void bfree(buffer_t *ptr);
#endif // _ENCRYPT_H

274
src/local.c

@ -107,14 +107,14 @@ static void accept_cb(EV_P_ ev_io *w, int revents);
static void signal_cb(EV_P_ ev_signal *w, int revents);
static int create_and_bind(const char *addr, const char *port);
static struct remote * create_remote(struct listen_ctx *listener, struct sockaddr *addr);
static void free_remote(struct remote *remote);
static void close_and_free_remote(EV_P_ struct remote *remote);
static void free_server(struct server *server);
static void close_and_free_server(EV_P_ struct server *server);
static remote_t * create_remote(listen_ctx_t *listener, struct sockaddr *addr);
static void free_remote(remote_t *remote);
static void close_and_free_remote(EV_P_ remote_t *remote);
static void free_server(server_t *server);
static void close_and_free_server(EV_P_ server_t *server);
static struct remote * new_remote(int fd, int timeout);
static struct server * new_server(int fd, int method);
static remote_t * new_remote(int fd, int timeout);
static server_t * new_server(int fd, int method);
static struct cork_dllist connections;
@ -196,8 +196,8 @@ static void free_connections(struct ev_loop *loop)
for (curr = cork_dllist_start(&connections);
!cork_dllist_is_end(&connections, curr);
curr = curr->next) {
struct server *server = cork_container_of(curr, struct server, entries);
struct remote *remote = server->remote;
server_t *server = cork_container_of(curr, server_t, entries);
remote_t *remote = server->remote;
close_and_free_server(loop, server);
close_and_free_remote(loop, remote);
}
@ -205,10 +205,10 @@ static void free_connections(struct ev_loop *loop)
static void server_recv_cb(EV_P_ ev_io *w, int revents)
{
struct server_ctx *server_recv_ctx = (struct server_ctx *)w;
struct server *server = server_recv_ctx->server;
struct remote *remote = server->remote;
char *buf;
server_ctx_t *server_recv_ctx = (server_ctx_t *)w;
server_t *server = server_recv_ctx->server;
remote_t *remote = server->remote;
buffer_t *buf;
if (remote == NULL) {
buf = server->buf;
@ -218,7 +218,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
ssize_t r;
r = recv(server->fd, buf, BUF_SIZE, 0);
r = recv(server->fd, buf->array, BUF_SIZE, 0);
if (r == 0) {
// connection closed
@ -238,6 +238,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
}
buf->len = r;
while (1) {
// local socks5 server
if (server->stage == 5) {
@ -248,18 +250,17 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
if (!remote->direct && remote->send_ctx->connected && auth) {
remote->buf = ss_gen_hash(remote->buf, &r, &remote->counter, server->e_ctx, BUF_SIZE);
ss_gen_hash(remote->buf, &remote->counter, server->e_ctx);
}
// insert shadowsocks header
if (!remote->direct) {
#ifdef ANDROID
tx += r;
tx += remote->buf_len;
#endif
remote->buf = ss_encrypt(BUF_SIZE, remote->buf, &r,
server->e_ctx);
int err = ss_encrypt(remote->buf, server->e_ctx);
if (remote->buf == NULL) {
if (err) {
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
@ -280,8 +281,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
#endif
remote->buf_idx = 0;
remote->buf_len = r;
remote->buf->idx = 0;
if (!fast_open || remote->direct) {
// connecting, wait until connected
@ -293,13 +293,12 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
ev_timer_start(EV_A_ & remote->send_ctx->watcher);
} else {
#ifdef TCP_FASTOPEN
int s = sendto(remote->fd, remote->buf, r, MSG_FASTOPEN,
int s = sendto(remote->fd, remote->buf->array, remote->buf->len, MSG_FASTOPEN,
(struct sockaddr *)&(remote->addr), remote->addr_len);
if (s == -1) {
if (errno == EINPROGRESS) {
// in progress, wait until connected
remote->buf_idx = 0;
remote->buf_len = r;
remote->buf->idx = 0;
ev_io_stop(EV_A_ & server_recv_ctx->io);
ev_io_start(EV_A_ & remote->send_ctx->io);
return;
@ -315,9 +314,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
return;
}
} else if (s < r) {
remote->buf_len = r - s;
remote->buf_idx = s;
} else if (s < remote->buf->len) {
remote->buf->len -= s;
remote->buf->idx = s;
}
// Just connected
@ -331,12 +330,11 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
#endif
}
} else {
int s = send(remote->fd, remote->buf, r, 0);
int s = send(remote->fd, remote->buf->array, remote->buf->len, 0);
if (s == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// no data, wait for send
remote->buf_idx = 0;
remote->buf_len = r;
remote->buf->idx = 0;
ev_io_stop(EV_A_ & server_recv_ctx->io);
ev_io_start(EV_A_ & remote->send_ctx->io);
return;
@ -346,9 +344,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
return;
}
} else if (s < r) {
remote->buf_len = r - s;
remote->buf_idx = s;
} else if (s < remote->buf->len) {
remote->buf->len -= s;
remote->buf->idx = s;
ev_io_stop(EV_A_ & server_recv_ctx->io);
ev_io_start(EV_A_ & remote->send_ctx->io);
return;
@ -365,16 +363,16 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
send(server->fd, send_buf, sizeof(response), 0);
server->stage = 1;
int off = (buf[1] & 0xff) + 2;
if (buf[0] == 0x05 && off < r) {
memmove(buf, buf + off, r - off);
r -= off;
int off = (buf->array[1] & 0xff) + 2;
if (buf->array[0] == 0x05 && off < buf->len) {
memmove(buf->array, buf->array + off, buf->len - off);
buf->len -= off;
continue;
}
return;
} else if (server->stage == 1) {
struct socks5_request *request = (struct socks5_request *)buf;
struct socks5_request *request = (struct socks5_request *)buf->array;
struct sockaddr_in sock_addr;
memset(&sock_addr, 0, sizeof(sock_addr));
@ -403,21 +401,22 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
return;
} else {
char host[256], port[16];
char ss_addr_to_send[320];
ssize_t addr_len = 0;
ss_addr_to_send[addr_len++] = request->atyp;
buffer_t ss_addr_to_send;
buffer_t *abuf = &ss_addr_to_send;
balloc(abuf, BUF_SIZE);
abuf->array[abuf->len++] = request->atyp;
// get remote addr and port
if (request->atyp == 1) {
// IP V4
size_t in_addr_len = sizeof(struct in_addr);
memcpy(ss_addr_to_send + addr_len, buf + 4, in_addr_len + 2);
addr_len += in_addr_len + 2;
memcpy(abuf->array + abuf->len, buf->array + 4, in_addr_len + 2);
abuf->len += in_addr_len + 2;
if (acl || verbose) {
uint16_t p =
ntohs(*(uint16_t *)(buf + 4 + in_addr_len));
uint16_t p = ntohs(*(uint16_t *)(buf + 4 + in_addr_len));
dns_ntop(AF_INET, (const void *)(buf + 4),
host, INET_ADDRSTRLEN);
sprintf(port, "%d", p);
@ -425,9 +424,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
} else if (request->atyp == 3) {
// Domain name
uint8_t name_len = *(uint8_t *)(buf + 4);
ss_addr_to_send[addr_len++] = name_len;
memcpy(ss_addr_to_send + addr_len, buf + 4 + 1, name_len + 2);
addr_len += name_len + 2;
abuf->array[abuf->len++] = name_len;
memcpy(abuf->array + abuf->len, buf + 4 + 1, name_len + 2);
abuf->len += name_len + 2;
if (acl || verbose) {
uint16_t p =
@ -439,17 +438,17 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
} else if (request->atyp == 4) {
// IP V6
size_t in6_addr_len = sizeof(struct in6_addr);
memcpy(ss_addr_to_send + addr_len, buf + 4, in6_addr_len + 2);
addr_len += in6_addr_len + 2;
memcpy(abuf->array + abuf->len, buf + 4, in6_addr_len + 2);
abuf->len += in6_addr_len + 2;
if (acl || verbose) {
uint16_t p =
ntohs(*(uint16_t *)(buf + 4 + in6_addr_len));
uint16_t p = ntohs(*(uint16_t *)(buf + 4 + in6_addr_len));
dns_ntop(AF_INET6, (const void *)(buf + 4),
host, INET6_ADDRSTRLEN);
sprintf(port, "%d", p);
}
} else {
bfree(abuf);
LOGE("unsupported addrtype: %d", request->atyp);
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
@ -458,8 +457,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
server->stage = 5;
r -= (3 + addr_len);
buf += (3 + addr_len);
buf->len -= (3 + abuf->len);
if (buf->len > 0) {
memmove(buf->array, buf->array + 3 + abuf->len, buf->len);
}
if (verbose) {
LOGI("connect to %s:%s", host, port);
@ -480,6 +481,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
if (remote == NULL) {
bfree(abuf);
LOGE("invalid remote addr");
close_and_free_server(EV_A_ server);
return;
@ -487,28 +489,31 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
if (!remote->direct) {
if (auth) {
ss_addr_to_send[0] |= ONETIMEAUTH_FLAG;
ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx->evp.iv);
addr_len += ONETIMEAUTH_BYTES;
abuf->array[0] |= ONETIMEAUTH_FLAG;
ss_onetimeauth(abuf, server->e_ctx->evp.iv);
}
memcpy(remote->buf, ss_addr_to_send, addr_len);
brealloc(remote->buf, buf->len + abuf->len, BUF_SIZE);
memcpy(remote->buf->array, abuf->array, abuf->len);
remote->buf->len = buf->len + abuf->len;
if (r > 0) {
if (buf->len > 0) {
if (auth) {
buf = ss_gen_hash(buf, &r, &remote->counter, server->e_ctx, BUF_SIZE);
ss_gen_hash(buf, &remote->counter, server->e_ctx);
}
memcpy(remote->buf + addr_len, buf, r);
memcpy(remote->buf->array + abuf->len, buf->array, buf->len);
}
r += addr_len;
} else {
if (r > 0) {
memcpy(remote->buf, buf, r);
if (buf->len > 0) {
memcpy(remote->buf->array, buf->array, buf->len);
remote->buf->len = buf->len;
}
}
server->remote = remote;
remote->server = server;
bfree(abuf);
}
// Fake reply
@ -518,17 +523,17 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
response.rsv = 0;
response.atyp = 1;
memcpy(server->buf, &response, sizeof(struct socks5_response));
memcpy(server->buf + sizeof(struct socks5_response),
memcpy(server->buf->array, &response, sizeof(struct socks5_response));
memcpy(server->buf->array + sizeof(struct socks5_response),
&sock_addr.sin_addr, sizeof(sock_addr.sin_addr));
memcpy(server->buf + sizeof(struct socks5_response) +
memcpy(server->buf->array + sizeof(struct socks5_response) +
sizeof(sock_addr.sin_addr),
&sock_addr.sin_port, sizeof(sock_addr.sin_port));
int reply_size = sizeof(struct socks5_response) +
sizeof(sock_addr.sin_addr) +
sizeof(sock_addr.sin_port);
int s = send(server->fd, server->buf, reply_size, 0);
int s = send(server->fd, server->buf->array, reply_size, 0);
if (s < reply_size) {
LOGE("failed to send fake reply");
close_and_free_remote(EV_A_ remote);
@ -547,18 +552,18 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
static void server_send_cb(EV_P_ ev_io *w, int revents)
{
struct server_ctx *server_send_ctx = (struct server_ctx *)w;
struct server *server = server_send_ctx->server;
struct remote *remote = server->remote;
if (server->buf_len == 0) {
server_ctx_t *server_send_ctx = (server_ctx_t *)w;
server_t *server = server_send_ctx->server;
remote_t *remote = server->remote;
if (server->buf->len == 0) {
// close and free
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
} else {
// has data to send
ssize_t s = send(server->fd, server->buf + server->buf_idx,
server->buf_len, 0);
ssize_t s = send(server->fd, server->buf->array + server->buf->idx,
server->buf->len, 0);
if (s < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ERROR("server_send_cb_send");
@ -566,15 +571,15 @@ static void server_send_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
}
return;
} else if (s < server->buf_len) {
} else if (s < server->buf->len) {
// partly sent, move memory, wait for the next time to send
server->buf_len -= s;
server->buf_idx += s;
server->buf->len -= s;
server->buf->idx += s;
return;
} else {
// all sent out, wait for reading
server->buf_len = 0;
server->buf_idx = 0;
server->buf->len = 0;
server->buf->idx = 0;
ev_io_stop(EV_A_ & server_send_ctx->io);
ev_io_start(EV_A_ & remote->recv_ctx->io);
return;
@ -596,10 +601,10 @@ static void stat_update_cb(struct ev_loop *loop)
static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents)
{
struct remote_ctx *remote_ctx = (struct remote_ctx *)(((void *)watcher)
remote_ctx_t *remote_ctx = (remote_ctx_t *)(((void *)watcher)
- sizeof(ev_io));
struct remote *remote = remote_ctx->remote;
struct server *server = remote->server;
remote_t *remote = remote_ctx->remote;
server_t *server = remote->server;
if (verbose) {
LOGI("TCP connection timeout");
@ -611,9 +616,9 @@ static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents)
static void remote_recv_cb(EV_P_ ev_io *w, int revents)
{
struct remote_ctx *remote_recv_ctx = (struct remote_ctx *)w;
struct remote *remote = remote_recv_ctx->remote;
struct server *server = remote->server;
remote_ctx_t *remote_recv_ctx = (remote_ctx_t *)w;
remote_t *remote = remote_recv_ctx->remote;
server_t *server = remote->server;
ev_timer_again(EV_A_ & remote->recv_ctx->watcher);
@ -621,7 +626,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
stat_update_cb(loop);
#endif
ssize_t r = recv(remote->fd, server->buf, BUF_SIZE, 0);
ssize_t r = recv(remote->fd, server->buf->array, BUF_SIZE, 0);
if (r == 0) {
// connection closed
@ -641,12 +646,14 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
}
}
server->buf->len = r;
if (!remote->direct) {
#ifdef ANDROID
rx += r;
rx += server->buf->len;
#endif
server->buf = ss_decrypt(BUF_SIZE, server->buf, &r, server->d_ctx);
if (server->buf == NULL) {
int err = ss_decrypt(server->buf, server->d_ctx);
if (err) {
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
@ -654,13 +661,12 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
}
}
int s = send(server->fd, server->buf, r, 0);
int s = send(server->fd, server->buf->array, server->buf->len, 0);
if (s == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// no data, wait for send
server->buf_len = r;
server->buf_idx = 0;
server->buf->idx = 0;
ev_io_stop(EV_A_ & remote_recv_ctx->io);
ev_io_start(EV_A_ & server->send_ctx->io);
return;
@ -670,9 +676,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
return;
}
} else if (s < r) {
server->buf_len = r - s;
server->buf_idx = s;
} else if (s < server->buf->len) {
server->buf->len -= s;
server->buf->idx = s;
ev_io_stop(EV_A_ & remote_recv_ctx->io);
ev_io_start(EV_A_ & server->send_ctx->io);
return;
@ -681,9 +687,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
static void remote_send_cb(EV_P_ ev_io *w, int revents)
{
struct remote_ctx *remote_send_ctx = (struct remote_ctx *)w;
struct remote *remote = remote_send_ctx->remote;
struct server *server = remote->server;
remote_ctx_t *remote_send_ctx = (remote_ctx_t *)w;
remote_t *remote = remote_send_ctx->remote;
server_t *server = remote->server;
if (!remote_send_ctx->connected) {
struct sockaddr_storage addr;
@ -696,7 +702,7 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
ev_io_start(EV_A_ & remote->recv_ctx->io);
// no need to send any data
if (remote->buf_len == 0) {
if (remote->buf->len == 0) {
ev_io_stop(EV_A_ & remote_send_ctx->io);
ev_io_start(EV_A_ & server->recv_ctx->io);
return;
@ -710,15 +716,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
}
}
if (remote->buf_len == 0) {
if (remote->buf->len == 0) {
// close and free
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
} else {
// has data to send
ssize_t s = send(remote->fd, remote->buf + remote->buf_idx,
remote->buf_len, 0);
ssize_t s = send(remote->fd, remote->buf->array + remote->buf->idx,
remote->buf->len, 0);
if (s < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ERROR("remote_send_cb_send");
@ -727,31 +733,33 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
}
return;
} else if (s < remote->buf_len) {
} else if (s < remote->buf->len) {
// partly sent, move memory, wait for the next time to send
remote->buf_len -= s;
remote->buf_idx += s;
remote->buf->len -= s;
remote->buf->idx += s;
return;
} else {
// all sent out, wait for reading
remote->buf_len = 0;
remote->buf_idx = 0;
remote->buf->len = 0;
remote->buf->idx = 0;
ev_io_stop(EV_A_ & remote_send_ctx->io);
ev_io_start(EV_A_ & server->recv_ctx->io);
}
}
}
static struct remote * new_remote(int fd, int timeout)
static remote_t * new_remote(int fd, int timeout)
{
struct remote *remote;
remote = malloc(sizeof(struct remote));
remote_t *remote;
remote = malloc(sizeof(remote_t));
memset(remote, 0, sizeof(struct remote));
memset(remote, 0, sizeof(remote_t));
remote->buf = malloc(BUF_SIZE);
remote->recv_ctx = malloc(sizeof(struct remote_ctx));
remote->send_ctx = malloc(sizeof(struct remote_ctx));
remote->buf = malloc(sizeof(buffer_t));
balloc(remote->buf, BUF_SIZE);
remote->recv_ctx = malloc(sizeof(remote_ctx_t));
remote->send_ctx = malloc(sizeof(remote_ctx_t));
remote->recv_ctx->connected = 0;
remote->send_ctx->connected = 0;
remote->fd = fd;
@ -766,12 +774,13 @@ static struct remote * new_remote(int fd, int timeout)
return remote;
}
static void free_remote(struct remote *remote)
static void free_remote(remote_t *remote)
{
if (remote->server != NULL) {
remote->server->remote = NULL;
}
if (remote->buf != NULL) {
bfree(remote->buf);
free(remote->buf);
}
free(remote->recv_ctx);
@ -779,7 +788,7 @@ static void free_remote(struct remote *remote)
free(remote);
}
static void close_and_free_remote(EV_P_ struct remote *remote)
static void close_and_free_remote(EV_P_ remote_t *remote)
{
if (remote != NULL) {
ev_timer_stop(EV_A_ & remote->send_ctx->watcher);
@ -791,16 +800,18 @@ static void close_and_free_remote(EV_P_ struct remote *remote)
}
}
static struct server * new_server(int fd, int method)
static server_t * new_server(int fd, int method)
{
struct server *server;
server = malloc(sizeof(struct server));
server_t *server;
server = malloc(sizeof(server_t));
memset(server, 0, sizeof(server_t));
memset(server, 0, sizeof(struct server));
server->buf = malloc(sizeof(buffer_t));
balloc(server->buf, BUF_SIZE);
server->buf = malloc(BUF_SIZE);
server->recv_ctx = malloc(sizeof(struct server_ctx));
server->send_ctx = malloc(sizeof(struct server_ctx));
server->recv_ctx = malloc(sizeof(server_ctx_t));
server->send_ctx = malloc(sizeof(server_ctx_t));
server->recv_ctx->connected = 0;
server->send_ctx->connected = 0;
server->fd = fd;
@ -823,7 +834,7 @@ static struct server * new_server(int fd, int method)
return server;
}
static void free_server(struct server *server)
static void free_server(server_t *server)
{
cork_dllist_remove(&server->entries);
@ -839,6 +850,7 @@ static void free_server(struct server *server)
free(server->d_ctx);
}
if (server->buf != NULL) {
bfree(server->buf);
free(server->buf);
}
free(server->recv_ctx);
@ -846,7 +858,7 @@ static void free_server(struct server *server)
free(server);
}
static void close_and_free_server(EV_P_ struct server *server)
static void close_and_free_server(EV_P_ server_t *server)
{
if (server != NULL) {
ev_io_stop(EV_A_ & server->send_ctx->io);
@ -856,7 +868,7 @@ static void close_and_free_server(EV_P_ struct server *server)
}
}
static struct remote * create_remote(struct listen_ctx *listener,
static remote_t * create_remote(listen_ctx_t *listener,
struct sockaddr *addr)
{
struct sockaddr *remote_addr;
@ -889,7 +901,7 @@ static struct remote * create_remote(struct listen_ctx *listener,
}
#endif
struct remote *remote = new_remote(remotefd, listener->timeout);
remote_t *remote = new_remote(remotefd, listener->timeout);
remote->addr_len = get_sockaddr_len(remote_addr);
memcpy(&(remote->addr), remote_addr, remote->addr_len);
@ -909,7 +921,7 @@ static void signal_cb(EV_P_ ev_signal *w, int revents)
void accept_cb(EV_P_ ev_io *w, int revents)
{
struct listen_ctx *listener = (struct listen_ctx *)w;
listen_ctx_t *listener = (listen_ctx_t *)w;
int serverfd = accept(listener->fd, NULL, NULL);
if (serverfd == -1) {
ERROR("accept");
@ -922,7 +934,7 @@ void accept_cb(EV_P_ ev_io *w, int revents)
setsockopt(serverfd, SOL_SOCKET, SO_NOSIGPIPE, &opt, sizeof(opt));
#endif
struct server *server = new_server(serverfd, listener->method);
server_t *server = new_server(serverfd, listener->method);
server->listener = listener;
ev_io_start(EV_A_ & server->recv_ctx->io);
@ -1142,7 +1154,7 @@ int main(int argc, char **argv)
int m = enc_init(password, method);
// Setup proxy context
struct listen_ctx listen_ctx;
listen_ctx_t listen_ctx;
listen_ctx.remote_num = remote_num;
listen_ctx.remote_addr = malloc(sizeof(struct sockaddr *) * remote_num);
for (i = 0; i < remote_num; i++) {
@ -1287,7 +1299,7 @@ int start_ss_local_server(profile_t profile)
// Setup proxy context
struct ev_loop *loop = EV_DEFAULT;
struct listen_ctx listen_ctx;
listen_ctx_t listen_ctx;
listen_ctx.remote_num = 1;
listen_ctx.remote_addr = malloc(sizeof(struct sockaddr *));

28
src/local.h

@ -31,7 +31,7 @@
#include "common.h"
struct listen_ctx {
typedef struct listen_ctx {
ev_io io;
char *iface;
int remote_num;
@ -39,19 +39,17 @@ struct listen_ctx {
int timeout;
int fd;
struct sockaddr **remote_addr;
};
} listen_ctx_t;
struct server_ctx {
typedef struct server_ctx {
ev_io io;
int connected;
struct server *server;
};
} server_ctx_t;
struct server {
typedef struct server {
int fd;
ssize_t buf_len;
ssize_t buf_idx;
char *buf; // server send from, remote recv into
buffer_t *buf;
char stage;
struct enc_ctx *e_ctx;
struct enc_ctx *d_ctx;
@ -61,27 +59,25 @@ struct server {
struct remote *remote;
struct cork_dllist_item entries;
};
} server_t;
struct remote_ctx {
typedef struct remote_ctx {
ev_io io;
ev_timer watcher;
int connected;
struct remote *remote;
};
} remote_ctx_t;
struct remote {
typedef struct remote {
int fd;
ssize_t buf_len;
ssize_t buf_idx;
buffer_t *buf;
int direct;
char *buf; // remote send from, server recv into
struct remote_ctx *recv_ctx;
struct remote_ctx *send_ctx;
struct server *server;
struct sockaddr_storage addr;
int addr_len;
uint32_t counter;
};
} remote_t;
#endif // _LOCAL_H

225
src/redir.c

@ -72,13 +72,13 @@ static void server_send_cb(EV_P_ ev_io *w, int revents);
static void remote_recv_cb(EV_P_ ev_io *w, int revents);
static void remote_send_cb(EV_P_ ev_io *w, int revents);
static struct remote * new_remote(int fd, int timeout);
static struct server * new_server(int fd, int method);
static remote_t * new_remote(int fd, int timeout);
static server_t * new_server(int fd, int method);
static void free_remote(struct remote *remote);
static void close_and_free_remote(EV_P_ struct remote *remote);
static void free_server(struct server *server);
static void close_and_free_server(EV_P_ struct server *server);
static void free_remote(remote_t *remote);
static void close_and_free_remote(EV_P_ remote_t *remote);
static void free_server(server_t *server);
static void close_and_free_server(EV_P_ server_t *server);
int verbose = 0;
@ -160,11 +160,11 @@ int create_and_bind(const char *addr, const char *port)
static void server_recv_cb(EV_P_ ev_io *w, int revents)
{
struct server_ctx *server_recv_ctx = (struct server_ctx *)w;
struct server *server = server_recv_ctx->server;
struct remote *remote = server->remote;
server_ctx_t *server_recv_ctx = (server_ctx_t *)w;
server_t *server = server_recv_ctx->server;
remote_t *remote = server->remote;
ssize_t r = recv(server->fd, remote->buf, BUF_SIZE, 0);
ssize_t r = recv(server->fd, remote->buf->array, BUF_SIZE, 0);
if (r == 0) {
// connection closed
@ -184,26 +184,27 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
}
remote->buf->len = r;
if (auth) {
remote->buf = ss_gen_hash(remote->buf, &r, &remote->counter, server->e_ctx, BUF_SIZE);
ss_gen_hash(remote->buf, &remote->counter, server->e_ctx);
}
remote->buf = ss_encrypt(BUF_SIZE, remote->buf, &r, server->e_ctx);
int err = ss_encrypt(remote->buf, server->e_ctx);
if (remote->buf == NULL) {
if (err) {
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
}
int s = send(remote->fd, remote->buf, r, 0);
int s = send(remote->fd, remote->buf->array, remote->buf->len, 0);
if (s == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// no data, wait for send
remote->buf_len = r;
remote->buf_idx = 0;
remote->buf->idx = 0;
ev_io_stop(EV_A_ & server_recv_ctx->io);
ev_io_start(EV_A_ & remote->send_ctx->io);
return;
@ -213,9 +214,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
return;
}
} else if (s < r) {
remote->buf_len = r - s;
remote->buf_idx = s;
} else if (s < remote->buf->len) {
remote->buf->len -= s;
remote->buf->idx = s;
ev_io_stop(EV_A_ & server_recv_ctx->io);
ev_io_start(EV_A_ & remote->send_ctx->io);
return;
@ -225,18 +226,18 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
static void server_send_cb(EV_P_ ev_io *w, int revents)
{
struct server_ctx *server_send_ctx = (struct server_ctx *)w;
struct server *server = server_send_ctx->server;
struct remote *remote = server->remote;
if (server->buf_len == 0) {
server_ctx_t *server_send_ctx = (server_ctx_t *)w;
server_t *server = server_send_ctx->server;
remote_t *remote = server->remote;
if (server->buf->len == 0) {
// close and free
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
} else {
// has data to send
ssize_t s = send(server->fd, server->buf + server->buf_idx,
server->buf_len, 0);
ssize_t s = send(server->fd, server->buf->array + server->buf->idx,
server->buf->len, 0);
if (s < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ERROR("send");
@ -244,15 +245,15 @@ static void server_send_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
}
return;
} else if (s < server->buf_len) {
} else if (s < server->buf->len) {
// partly sent, move memory, wait for the next time to send
server->buf_len -= s;
server->buf_idx += s;
server->buf->len -= s;
server->buf->idx += s;
return;
} else {
// all sent out, wait for reading
server->buf_len = 0;
server->buf_idx = 0;
server->buf->len = 0;
server->buf->idx = 0;
ev_io_stop(EV_A_ & server_send_ctx->io);
ev_io_start(EV_A_ & remote->recv_ctx->io);
}
@ -262,10 +263,10 @@ static void server_send_cb(EV_P_ ev_io *w, int revents)
static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents)
{
struct remote_ctx *remote_ctx = (struct remote_ctx *)(((void *)watcher)
remote_ctx_t *remote_ctx = (remote_ctx_t *)(((void *)watcher)
- sizeof(ev_io));
struct remote *remote = remote_ctx->remote;
struct server *server = remote->server;
remote_t *remote = remote_ctx->remote;
server_t *server = remote->server;
ev_timer_stop(EV_A_ watcher);
@ -275,11 +276,11 @@ static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents)
static void remote_recv_cb(EV_P_ ev_io *w, int revents)
{
struct remote_ctx *remote_recv_ctx = (struct remote_ctx *)w;
struct remote *remote = remote_recv_ctx->remote;
struct server *server = remote->server;
remote_ctx_t *remote_recv_ctx = (remote_ctx_t *)w;
remote_t *remote = remote_recv_ctx->remote;
server_t *server = remote->server;
ssize_t r = recv(remote->fd, server->buf, BUF_SIZE, 0);
ssize_t r = recv(remote->fd, server->buf->array, BUF_SIZE, 0);
if (r == 0) {
// connection closed
@ -299,20 +300,21 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
}
}
server->buf = ss_decrypt(BUF_SIZE, server->buf, &r, server->d_ctx);
if (server->buf == NULL) {
server->buf->len = r;
int err = ss_decrypt(server->buf, server->d_ctx);
if (err) {
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
}
int s = send(server->fd, server->buf, r, 0);
int s = send(server->fd, server->buf->array, server->buf->len, 0);
if (s == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// no data, wait for send
server->buf_len = r;
server->buf_idx = 0;
server->buf->idx = 0;
ev_io_stop(EV_A_ & remote_recv_ctx->io);
ev_io_start(EV_A_ & server->send_ctx->io);
return;
@ -322,9 +324,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
return;
}
} else if (s < r) {
server->buf_len = r - s;
server->buf_idx = s;
} else if (s < server->buf->len) {
server->buf->len -= s;
server->buf->idx = s;
ev_io_stop(EV_A_ & remote_recv_ctx->io);
ev_io_start(EV_A_ & server->send_ctx->io);
return;
@ -333,9 +335,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
static void remote_send_cb(EV_P_ ev_io *w, int revents)
{
struct remote_ctx *remote_send_ctx = (struct remote_ctx *)w;
struct remote *remote = remote_send_ctx->remote;
struct server *server = remote->server;
remote_ctx_t *remote_send_ctx = (remote_ctx_t *)w;
remote_t *remote = remote_send_ctx->remote;
server_t *server = remote->server;
if (!remote_send_ctx->connected) {
@ -348,52 +350,52 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
ev_timer_stop(EV_A_ & remote_send_ctx->watcher);
// send destaddr
char *ss_addr_to_send = malloc(BUF_SIZE);
ssize_t addr_len = 0;
buffer_t ss_addr_to_send;
buffer_t *abuf = &ss_addr_to_send;
balloc(abuf, BUF_SIZE);
if (AF_INET6 == server->destaddr.ss_family) { // IPv6
ss_addr_to_send[addr_len++] = 4; //Type 4 is IPv6 address
abuf->array[abuf->len++] = 4; //Type 4 is IPv6 address
size_t in_addr_len = sizeof(struct in6_addr);
memcpy(ss_addr_to_send + addr_len,
size_t in6_addr_len = sizeof(struct in6_addr);
memcpy(abuf->array + abuf->len,
&(((struct sockaddr_in6 *)&(server->destaddr))->sin6_addr),
in_addr_len);
addr_len += in_addr_len;
memcpy(ss_addr_to_send + addr_len,
in6_addr_len);
abuf->len += in6_addr_len;
memcpy(abuf->array + abuf->len,
&(((struct sockaddr_in6 *)&(server->destaddr))->sin6_port),
2);
} else { //IPv4
ss_addr_to_send[addr_len++] = 1; //Type 1 is IPv4 address
abuf->array[abuf->len++] = 1; //Type 1 is IPv4 address
size_t in_addr_len = sizeof(struct in_addr);
memcpy(ss_addr_to_send + addr_len,
&((struct sockaddr_in *)&(server->destaddr))->sin_addr,
in_addr_len);
addr_len += in_addr_len;
memcpy(ss_addr_to_send + addr_len,
&((struct sockaddr_in *)&(server->destaddr))->sin_port,
2);
memcpy(abuf->array + abuf->len,
&((struct sockaddr_in *)&(server->destaddr))->sin_addr, in_addr_len);
abuf->len += in_addr_len;
memcpy(abuf->array + abuf->len,
&((struct sockaddr_in *)&(server->destaddr))->sin_port, 2);
}
addr_len += 2;
abuf->len += 2;
if (auth) {
ss_addr_to_send[0] |= ONETIMEAUTH_FLAG;
ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx->evp.iv);
addr_len += ONETIMEAUTH_BYTES;
abuf->array[0] |= ONETIMEAUTH_FLAG;
ss_onetimeauth(abuf, server->e_ctx->evp.iv);
}
ss_addr_to_send = ss_encrypt(BUF_SIZE, ss_addr_to_send, &addr_len,
server->e_ctx);
if (ss_addr_to_send == NULL) {
int err = ss_encrypt(abuf, server->e_ctx);
if (err) {
bfree(abuf);
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
}
int s = send(remote->fd, ss_addr_to_send, addr_len, 0);
free(ss_addr_to_send);
int s = send(remote->fd, abuf->array, abuf->len, 0);
if (s < addr_len) {
bfree(abuf);
if (s < abuf->len) {
LOGE("failed to send addr");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
@ -412,15 +414,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
return;
}
} else {
if (remote->buf_len == 0) {
if (remote->buf->len == 0) {
// close and free
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
} else {
// has data to send
ssize_t s = send(remote->fd, remote->buf + remote->buf_idx,
remote->buf_len, 0);
ssize_t s = send(remote->fd, remote->buf->array + remote->buf->idx,
remote->buf->len, 0);
if (s < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ERROR("send");
@ -429,15 +431,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
}
return;
} else if (s < remote->buf_len) {
} else if (s < remote->buf->len) {
// partly sent, move memory, wait for the next time to send
remote->buf_len -= s;
remote->buf_idx += s;
remote->buf->len -= s;
remote->buf->idx += s;
return;
} else {
// all sent out, wait for reading
remote->buf_len = 0;
remote->buf_idx = 0;
remote->buf->len = 0;
remote->buf->idx = 0;
ev_io_stop(EV_A_ & remote_send_ctx->io);
ev_io_start(EV_A_ & server->recv_ctx->io);
}
@ -446,16 +448,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
}
}
static struct remote * new_remote(int fd, int timeout)
static remote_t * new_remote(int fd, int timeout)
{
struct remote *remote;
remote = malloc(sizeof(struct remote));
remote_t *remote;
remote = malloc(sizeof(remote_t));
memset(remote, 0, sizeof(struct remote));
memset(remote, 0, sizeof(remote_t));
remote->buf = malloc(BUF_SIZE);
remote->recv_ctx = malloc(sizeof(struct remote_ctx));
remote->send_ctx = malloc(sizeof(struct remote_ctx));
remote->recv_ctx = malloc(sizeof(remote_ctx_t));
remote->send_ctx = malloc(sizeof(remote_ctx_t));
remote->fd = fd;
ev_io_init(&remote->recv_ctx->io, remote_recv_cb, fd, EV_READ);
ev_io_init(&remote->send_ctx->io, remote_send_cb, fd, EV_WRITE);
@ -465,18 +466,21 @@ static struct remote * new_remote(int fd, int timeout)
remote->recv_ctx->connected = 0;
remote->send_ctx->remote = remote;
remote->send_ctx->connected = 0;
remote->buf_len = 0;
remote->buf_idx = 0;
remote->buf = malloc(sizeof(buffer_t));
balloc(remote->buf, BUF_SIZE);
return remote;
}
static void free_remote(struct remote *remote)
static void free_remote(remote_t *remote)
{
if (remote != NULL) {
if (remote->server != NULL) {
remote->server->remote = NULL;
}
if (remote->buf != NULL) {
bfree(remote->buf);
free(remote->buf);
}
free(remote->recv_ctx);
@ -485,7 +489,7 @@ static void free_remote(struct remote *remote)
}
}
static void close_and_free_remote(EV_P_ struct remote *remote)
static void close_and_free_remote(EV_P_ remote_t *remote)
{
if (remote != NULL) {
ev_timer_stop(EV_A_ & remote->send_ctx->watcher);
@ -496,13 +500,13 @@ static void close_and_free_remote(EV_P_ struct remote *remote)
}
}
static struct server * new_server(int fd, int method)
static server_t * new_server(int fd, int method)
{
struct server *server;
server = malloc(sizeof(struct server));
server->buf = malloc(BUF_SIZE);
server->recv_ctx = malloc(sizeof(struct server_ctx));
server->send_ctx = malloc(sizeof(struct server_ctx));
server_t *server;
server = malloc(sizeof(server_t));
server->recv_ctx = malloc(sizeof(server_ctx_t));
server->send_ctx = malloc(sizeof(server_ctx_t));
server->fd = fd;
ev_io_init(&server->recv_ctx->io, server_recv_cb, fd, EV_READ);
ev_io_init(&server->send_ctx->io, server_send_cb, fd, EV_WRITE);
@ -511,20 +515,22 @@ static struct server * new_server(int fd, int method)
server->send_ctx->server = server;
server->send_ctx->connected = 0;
if (method) {
server->e_ctx = malloc(sizeof(struct enc_ctx));
server->d_ctx = malloc(sizeof(struct enc_ctx));
server->e_ctx = malloc(sizeof(enc_ctx_t));
server->d_ctx = malloc(sizeof(enc_ctx_t));
enc_ctx_init(method, server->e_ctx, 1);
enc_ctx_init(method, server->d_ctx, 0);
} else {
server->e_ctx = NULL;
server->d_ctx = NULL;
}
server->buf_len = 0;
server->buf_idx = 0;
server->buf = malloc(sizeof(buffer_t));
balloc(server->buf, BUF_SIZE);
return server;
}
static void free_server(struct server *server)
static void free_server(server_t *server)
{
if (server != NULL) {
if (server->remote != NULL) {
@ -539,6 +545,7 @@ static void free_server(struct server *server)
free(server->d_ctx);
}
if (server->buf != NULL) {
bfree(server->buf);
free(server->buf);
}
free(server->recv_ctx);
@ -547,7 +554,7 @@ static void free_server(struct server *server)
}
}
static void close_and_free_server(EV_P_ struct server *server)
static void close_and_free_server(EV_P_ server_t *server)
{
if (server != NULL) {
ev_io_stop(EV_A_ & server->send_ctx->io);
@ -559,7 +566,7 @@ static void close_and_free_server(EV_P_ struct server *server)
static void accept_cb(EV_P_ ev_io *w, int revents)
{
struct listen_ctx *listener = (struct listen_ctx *)w;
listen_ctx_t *listener = (listen_ctx_t *)w;
struct sockaddr_storage destaddr;
int err;
@ -599,8 +606,8 @@ static void accept_cb(EV_P_ ev_io *w, int revents)
// Setup
setnonblocking(remotefd);
struct server *server = new_server(serverfd, listener->method);
struct remote *remote = new_remote(remotefd, listener->timeout);
server_t *server = new_server(serverfd, listener->method);
remote_t *remote = new_remote(remotefd, listener->timeout);
server->remote = remote;
remote->server = server;
server->destaddr = destaddr;
@ -756,7 +763,7 @@ int main(int argc, char **argv)
int m = enc_init(password, method);
// Setup proxy context
struct listen_ctx listen_ctx;
listen_ctx_t listen_ctx;
listen_ctx.remote_num = remote_num;
listen_ctx.remote_addr = malloc(sizeof(struct sockaddr *) * remote_num);
for (int i = 0; i < remote_num; i++) {

28
src/redir.h

@ -27,50 +27,46 @@
#include "encrypt.h"
#include "jconf.h"
struct listen_ctx {
typedef struct listen_ctx {
ev_io io;
int remote_num;
int timeout;
int fd;
int method;
struct sockaddr **remote_addr;
};
} listen_ctx_t;
struct server_ctx {
typedef struct server_ctx {
ev_io io;
int connected;
struct server *server;
};
} server_ctx_t;
struct server {
typedef struct server {
int fd;
ssize_t buf_len;
ssize_t buf_idx;
char *buf; // server send from, remote recv into
buffer_t *buf;
struct sockaddr_storage destaddr;
struct enc_ctx *e_ctx;
struct enc_ctx *d_ctx;
struct server_ctx *recv_ctx;
struct server_ctx *send_ctx;
struct remote *remote;
};
} server_t;
struct remote_ctx {
typedef struct remote_ctx {
ev_io io;
ev_timer watcher;
int connected;
struct remote *remote;
};
} remote_ctx_t;
struct remote {
typedef struct remote {
int fd;
ssize_t buf_len;
ssize_t buf_idx;
char *buf; // remote send from, server recv into
buffer_t *buf;
struct remote_ctx *recv_ctx;
struct remote_ctx *send_ctx;
struct server *server;
uint32_t counter;
};
} remote_t;
#endif // _LOCAL_H

299
src/server.c

@ -93,15 +93,15 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents);
static void remote_send_cb(EV_P_ ev_io *w, int revents);
static void server_timeout_cb(EV_P_ ev_timer *watcher, int revents);
static struct remote * new_remote(int fd);
static struct server * new_server(int fd, struct listen_ctx *listener);
static struct remote *connect_to_remote(struct addrinfo *res,
struct server *server);
static remote_t * new_remote(int fd);
static server_t * new_server(int fd, listen_ctx_t *listener);
static remote_t *connect_to_remote(struct addrinfo *res,
server_t *server);
static void free_remote(struct remote *remote);
static void close_and_free_remote(EV_P_ struct remote *remote);
static void free_server(struct server *server);
static void close_and_free_server(EV_P_ struct server *server);
static void free_remote(remote_t *remote);
static void close_and_free_remote(EV_P_ remote_t *remote);
static void free_server(server_t *server);
static void close_and_free_server(EV_P_ server_t *server);
static void server_resolve_cb(struct sockaddr *addr, void *data);
@ -208,8 +208,8 @@ static void free_connections(struct ev_loop *loop)
for (curr = cork_dllist_start(&connections);
!cork_dllist_is_end(&connections, curr);
curr = curr->next) {
struct server *server = cork_container_of(curr, struct server, entries);
struct remote *remote = server->remote;
server_t *server = cork_container_of(curr, server_t, entries);
remote_t *remote = server->remote;
close_and_free_server(loop, server);
close_and_free_remote(loop, remote);
}
@ -357,8 +357,8 @@ int create_and_bind(const char *host, const char *port)
return listen_sock;
}
static struct remote *connect_to_remote(struct addrinfo *res,
struct server *server)
static remote_t *connect_to_remote(struct addrinfo *res,
server_t *server)
{
int sockfd;
#ifdef SET_INTERFACE
@ -379,7 +379,7 @@ static struct remote *connect_to_remote(struct addrinfo *res,
setsockopt(sockfd, SOL_SOCKET, SO_NOSIGPIPE, &opt, sizeof(opt));
#endif
struct remote *remote = new_remote(sockfd);
remote_t *remote = new_remote(sockfd);
// setup remote socks
setnonblocking(sockfd);
@ -391,8 +391,8 @@ static struct remote *connect_to_remote(struct addrinfo *res,
#ifdef TCP_FASTOPEN
if (fast_open) {
ssize_t s = sendto(sockfd, server->buf + server->buf_idx,
server->buf_len, MSG_FASTOPEN, res->ai_addr,
ssize_t s = sendto(sockfd, server->buf->array + server->buf->idx,
server->buf->len, MSG_FASTOPEN, res->ai_addr,
res->ai_addrlen);
if (s == -1) {
if (errno == EINPROGRESS || errno == EAGAIN
@ -408,12 +408,12 @@ static struct remote *connect_to_remote(struct addrinfo *res,
} else {
ERROR("sendto");
}
} else if (s < server->buf_len) {
server->buf_idx += s;
server->buf_len -= s;
} else if (s < server->buf->len) {
server->buf->idx += s;
server->buf->len -= s;
} else {
server->buf_idx = 0;
server->buf_len = 0;
server->buf->idx = 0;
server->buf->len = 0;
}
} else
#endif
@ -424,22 +424,22 @@ static struct remote *connect_to_remote(struct addrinfo *res,
static void server_recv_cb(EV_P_ ev_io *w, int revents)
{
struct server_ctx *server_recv_ctx = (struct server_ctx *)w;
struct server *server = server_recv_ctx->server;
struct remote *remote = NULL;
server_ctx_t *server_recv_ctx = (server_ctx_t *)w;
server_t *server = server_recv_ctx->server;
remote_t *remote = NULL;
int len = server->buf_len;
char **buf = &server->buf;
int len = server->buf->len;
buffer_t *buf = server->buf;
ev_timer_again(EV_A_ & server->recv_ctx->watcher);
if (server->stage != 0) {
remote = server->remote;
buf = &remote->buf;
buf = remote->buf;
len = 0;
}
ssize_t r = recv(server->fd, *buf + len, BUF_SIZE - len, 0);
ssize_t r = recv(server->fd, buf->array + len, BUF_SIZE - len, 0);
if (r == 0) {
// connection closed
@ -466,8 +466,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
// handle incomplete header
if (server->stage == 0) {
r += server->buf_len;
if (r <= enc_get_iv_len()) {
buf->len += r;
if (buf->len <= enc_get_iv_len()) {
// wait for more
if (verbose) {
#ifdef __MINGW32__
@ -476,16 +476,15 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
LOGI("imcomplete header: %zu", r);
#endif
}
server->buf_len = r;
return;
} else {
server->buf_len = 0;
}
} else {
buf->len = r;
}
*buf = ss_decrypt(BUF_SIZE, *buf, &r, server->d_ctx);
int err = ss_decrypt(buf, server->d_ctx);
if (*buf == NULL) {
if (err) {
LOGE("invalid password or cipher");
report_addr(server->fd);
close_and_free_remote(EV_A_ remote);
@ -495,19 +494,18 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
// handshake and transmit data
if (server->stage == 5) {
if (server->auth && !ss_check_hash(&remote->buf, &r, server->chunk, server->d_ctx, BUF_SIZE)) {
if (server->auth && !ss_check_hash(remote->buf, server->chunk, server->d_ctx)) {
LOGE("hash error");
report_addr(server->fd);
close_and_free_server(EV_A_ server);
close_and_free_remote(EV_A_ remote);
return;
}
int s = send(remote->fd, remote->buf, r, 0);
int s = send(remote->fd, remote->buf->array, remote->buf->len, 0);
if (s == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// no data, wait for send
remote->buf_len = r;
remote->buf_idx = 0;
remote->buf->idx = 0;
ev_io_stop(EV_A_ & server_recv_ctx->io);
ev_io_start(EV_A_ & remote->send_ctx->io);
} else {
@ -515,9 +513,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
}
} else if (s < r) {
remote->buf_len = r - s;
remote->buf_idx = s;
} else if (s < remote->buf->len) {
remote->buf->len -= s;
remote->buf->idx = s;
ev_io_stop(EV_A_ & server_recv_ctx->io);
ev_io_start(EV_A_ & remote->send_ctx->io);
}
@ -555,7 +553,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
int offset = 0;
int need_query = 0;
char atyp = server->buf[offset++];
char atyp = server->buf->array[offset++];
char host[256] = { 0 };
uint16_t port = 0;
struct addrinfo info;
@ -569,9 +567,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
struct sockaddr_in *addr = (struct sockaddr_in *)&storage;
size_t in_addr_len = sizeof(struct in_addr);
addr->sin_family = AF_INET;
if (r >= in_addr_len + 3) {
addr->sin_addr = *(struct in_addr *)(server->buf + offset);
dns_ntop(AF_INET, (const void *)(server->buf + offset),
if (server->buf->len >= in_addr_len + 3) {
addr->sin_addr = *(struct in_addr *)(server->buf->array + offset);
dns_ntop(AF_INET, (const void *)(server->buf->array + offset),
host, INET_ADDRSTRLEN);
offset += in_addr_len;
} else {
@ -580,7 +578,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
return;
}
addr->sin_port = *(uint16_t *)(server->buf + offset);
addr->sin_port = *(uint16_t *)(server->buf->array + offset);
info.ai_family = AF_INET;
info.ai_socktype = SOCK_STREAM;
info.ai_protocol = IPPROTO_TCP;
@ -588,9 +586,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
info.ai_addr = (struct sockaddr *)addr;
} else if ((atyp & ADDRTYPE_MASK) == 3) {
// Domain name
uint8_t name_len = *(uint8_t *)(server->buf + offset);
if (name_len + 4 <= r) {
memcpy(host, server->buf + offset + 1, name_len);
uint8_t name_len = *(uint8_t *)(server->buf->array + offset);
if (name_len + 4 <= server->buf->len) {
memcpy(host, server->buf->array + offset + 1, name_len);
offset += name_len + 1;
} else {
LOGE("invalid name length: %d", name_len);
@ -605,7 +603,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
if (ip.version == 4) {
struct sockaddr_in *addr = (struct sockaddr_in *)&storage;
dns_pton(AF_INET, host, &(addr->sin_addr));
addr->sin_port = *(uint16_t *)(server->buf + offset);
addr->sin_port = *(uint16_t *)(server->buf->array + offset);
addr->sin_family = AF_INET;
info.ai_family = AF_INET;
info.ai_addrlen = sizeof(struct sockaddr_in);
@ -613,7 +611,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
} else if (ip.version == 6) {
struct sockaddr_in6 *addr = (struct sockaddr_in6 *)&storage;
dns_pton(AF_INET6, host, &(addr->sin6_addr));
addr->sin6_port = *(uint16_t *)(server->buf + offset);
addr->sin6_port = *(uint16_t *)(server->buf->array + offset);
addr->sin6_family = AF_INET6;
info.ai_family = AF_INET6;
info.ai_addrlen = sizeof(struct sockaddr_in6);
@ -627,9 +625,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
struct sockaddr_in6 *addr = (struct sockaddr_in6 *)&storage;
size_t in6_addr_len = sizeof(struct in6_addr);
addr->sin6_family = AF_INET6;
if (r >= in6_addr_len + 3) {
addr->sin6_addr = *(struct in6_addr *)(server->buf + offset);
dns_ntop(AF_INET6, (const void *)(server->buf + offset),
if (server->buf->len >= in6_addr_len + 3) {
addr->sin6_addr = *(struct in6_addr *)(server->buf->array + offset);
dns_ntop(AF_INET6, (const void *)(server->buf->array + offset),
host, INET6_ADDRSTRLEN);
offset += in6_addr_len;
} else {
@ -638,7 +636,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
return;
}
addr->sin6_port = *(uint16_t *)(server->buf + offset);
addr->sin6_port = *(uint16_t *)(server->buf->array + offset);
info.ai_family = AF_INET6;
info.ai_socktype = SOCK_STREAM;
info.ai_protocol = IPPROTO_TCP;
@ -661,12 +659,12 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
return;
}
port = (*(uint16_t *)(server->buf + offset));
port = (*(uint16_t *)(server->buf->array + offset));
offset += 2;
if (auth || (atyp & ONETIMEAUTH_FLAG)) {
if (ss_onetimeauth_verify(server->buf + offset, server->buf, offset, server->d_ctx->evp.iv)) {
if (ss_onetimeauth_verify(server->buf, server->d_ctx->evp.iv)) {
LOGE("authentication error %d", atyp);
report_addr(server->fd);
close_and_free_server(EV_A_ server);
@ -682,12 +680,12 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
// XXX: should handle buffer carefully
if (r > offset) {
server->buf_len = r - offset;
memmove(server->buf, server->buf + offset, server->buf_len);
if (server->buf->len > offset) {
server->buf->len -= offset;
memmove(server->buf->array, server->buf->array + offset, server->buf->len);
}
if (server->auth && !ss_check_hash(&server->buf, &server->buf_len, server->chunk, server->d_ctx, BUF_SIZE)) {
if (server->auth && !ss_check_hash(server->buf, server->chunk, server->d_ctx)) {
LOGE("hash error");
report_addr(server->fd);
close_and_free_server(EV_A_ server);
@ -695,7 +693,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
if (!need_query) {
struct remote *remote = connect_to_remote(&info, server);
remote_t *remote = connect_to_remote(&info, server);
if (remote == NULL) {
LOGE("connect error");
@ -706,12 +704,12 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
remote->server = server;
// XXX: should handle buffer carefully
if (server->buf_len > 0) {
memcpy(remote->buf, server->buf + server->buf_idx, server->buf_len);
remote->buf_len = server->buf_len;
remote->buf_idx = 0;
server->buf_len = 0;
server->buf_idx = 0;
if (server->buf->len > 0) {
memcpy(remote->buf->array, server->buf->array, server->buf->len);
remote->buf->len = server->buf->len;
remote->buf->idx = 0;
server->buf->len = 0;
server->buf->idx = 0;
}
server->stage = 4;
@ -736,9 +734,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
static void server_send_cb(EV_P_ ev_io *w, int revents)
{
struct server_ctx *server_send_ctx = (struct server_ctx *)w;
struct server *server = server_send_ctx->server;
struct remote *remote = server->remote;
server_ctx_t *server_send_ctx = (server_ctx_t *)w;
server_t *server = server_send_ctx->server;
remote_t *remote = server->remote;
if (remote == NULL) {
LOGE("invalid server");
@ -746,7 +744,7 @@ static void server_send_cb(EV_P_ ev_io *w, int revents)
return;
}
if (server->buf_len == 0) {
if (server->buf->len == 0) {
// close and free
if (verbose) {
LOGI("server_send close the connection");
@ -756,8 +754,8 @@ static void server_send_cb(EV_P_ ev_io *w, int revents)
return;
} else {
// has data to send
ssize_t s = send(server->fd, server->buf + server->buf_idx,
server->buf_len, 0);
ssize_t s = send(server->fd, server->buf->array + server->buf->idx,
server->buf->len, 0);
if (s < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ERROR("server_send_send");
@ -765,15 +763,15 @@ static void server_send_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
}
return;
} else if (s < server->buf_len) {
} else if (s < server->buf->len) {
// partly sent, move memory, wait for the next time to send
server->buf_len -= s;
server->buf_idx += s;
server->buf->len -= s;
server->buf->idx += s;
return;
} else {
// all sent out, wait for reading
server->buf_len = 0;
server->buf_idx = 0;
server->buf->len = 0;
server->buf->idx = 0;
ev_io_stop(EV_A_ & server_send_ctx->io);
if (remote != NULL) {
ev_io_start(EV_A_ & remote->recv_ctx->io);
@ -790,10 +788,10 @@ static void server_send_cb(EV_P_ ev_io *w, int revents)
static void server_timeout_cb(EV_P_ ev_timer *watcher, int revents)
{
struct server_ctx *server_ctx = (struct server_ctx *)(((void *)watcher)
server_ctx_t *server_ctx = (server_ctx_t *)(((void *)watcher)
- sizeof(ev_io));
struct server *server = server_ctx->server;
struct remote *remote = server->remote;
server_t *server = server_ctx->server;
remote_t *remote = server->remote;
if (verbose) {
LOGI("TCP connection timeout");
@ -805,7 +803,7 @@ static void server_timeout_cb(EV_P_ ev_timer *watcher, int revents)
static void server_resolve_cb(struct sockaddr *addr, void *data)
{
struct server *server = (struct server *)data;
server_t *server = (server_t *)data;
struct ev_loop *loop = server->listen_ctx->loop;
server->query = NULL;
@ -851,7 +849,7 @@ static void server_resolve_cb(struct sockaddr *addr, void *data)
info.ai_addrlen = sizeof(struct sockaddr_in6);
}
struct remote *remote = connect_to_remote(&info, server);
remote_t *remote = connect_to_remote(&info, server);
if (remote == NULL) {
LOGE("connect error");
@ -861,13 +859,13 @@ static void server_resolve_cb(struct sockaddr *addr, void *data)
remote->server = server;
// XXX: should handle buffer carefully
if (server->buf_len > 0) {
memcpy(remote->buf, server->buf + server->buf_idx,
server->buf_len);
remote->buf_len = server->buf_len;
remote->buf_idx = 0;
server->buf_len = 0;
server->buf_idx = 0;
if (server->buf->len > 0) {
memcpy(remote->buf->array, server->buf->array + server->buf->idx,
server->buf->len);
remote->buf->len = server->buf->len;
remote->buf->idx = 0;
server->buf->len = 0;
server->buf->idx = 0;
}
// listen to remote connected event
@ -878,9 +876,9 @@ static void server_resolve_cb(struct sockaddr *addr, void *data)
static void remote_recv_cb(EV_P_ ev_io *w, int revents)
{
struct remote_ctx *remote_recv_ctx = (struct remote_ctx *)w;
struct remote *remote = remote_recv_ctx->remote;
struct server *server = remote->server;
remote_ctx_t *remote_recv_ctx = (remote_ctx_t *)w;
remote_t *remote = remote_recv_ctx->remote;
server_t *server = remote->server;
if (server == NULL) {
LOGE("invalid server");
@ -890,7 +888,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
ev_timer_again(EV_A_ & server->recv_ctx->watcher);
ssize_t r = recv(remote->fd, server->buf, BUF_SIZE, 0);
ssize_t r = recv(remote->fd, server->buf->array, BUF_SIZE, 0);
if (r == 0) {
// connection closed
@ -915,22 +913,23 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
rx += r;
server->buf = ss_encrypt(BUF_SIZE, server->buf, &r, server->e_ctx);
if (server->buf == NULL) {
server->buf->len = r;
int err = ss_encrypt(server->buf, server->e_ctx);
if (err) {
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
}
int s = send(server->fd, server->buf, r, 0);
int s = send(server->fd, server->buf->array, server->buf->len, 0);
if (s == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// no data, wait for send
server->buf_len = r;
server->buf_idx = 0;
server->buf->idx = 0;
ev_io_stop(EV_A_ & remote_recv_ctx->io);
ev_io_start(EV_A_ & server->send_ctx->io);
} else {
@ -939,9 +938,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
}
return;
} else if (s < r) {
server->buf_len = r - s;
server->buf_idx = s;
} else if (s < server->buf->len) {
server->buf->len -= s;
server->buf->idx = s;
ev_io_stop(EV_A_ & remote_recv_ctx->io);
ev_io_start(EV_A_ & server->send_ctx->io);
return;
@ -950,9 +949,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
static void remote_send_cb(EV_P_ ev_io *w, int revents)
{
struct remote_ctx *remote_send_ctx = (struct remote_ctx *)w;
struct remote *remote = remote_send_ctx->remote;
struct server *server = remote->server;
remote_ctx_t *remote_send_ctx = (remote_ctx_t *)w;
remote_t *remote = remote_send_ctx->remote;
server_t *server = remote->server;
if (server == NULL) {
LOGE("invalid server");
@ -972,7 +971,7 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
}
remote_send_ctx->connected = 1;
if (remote->buf_len == 0) {
if (remote->buf->len == 0) {
server->stage = 5;
ev_io_stop(EV_A_ & remote_send_ctx->io);
ev_io_start(EV_A_ & server->recv_ctx->io);
@ -989,7 +988,7 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
}
}
if (remote->buf_len == 0) {
if (remote->buf->len == 0) {
// close and free
if (verbose) {
LOGI("remote_send close the connection");
@ -999,8 +998,8 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
return;
} else {
// has data to send
ssize_t s = send(remote->fd, remote->buf + remote->buf_idx,
remote->buf_len, 0);
ssize_t s = send(remote->fd, remote->buf->array + remote->buf->idx,
remote->buf->len, 0);
if (s == -1) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ERROR("remote_send_send");
@ -1009,15 +1008,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
}
return;
} else if (s < remote->buf_len) {
} else if (s < remote->buf->len) {
// partly sent, move memory, wait for the next time to send
remote->buf_len -= s;
remote->buf_idx += s;
remote->buf->len -= s;
remote->buf->idx += s;
return;
} else {
// all sent out, wait for reading
remote->buf_len = 0;
remote->buf_idx = 0;
remote->buf->len = 0;
remote->buf->idx = 0;
ev_io_stop(EV_A_ & remote_send_ctx->io);
if (server != NULL) {
ev_io_start(EV_A_ & server->recv_ctx->io);
@ -1035,17 +1034,20 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
}
}
static struct remote * new_remote(int fd)
static remote_t * new_remote(int fd)
{
if (verbose) {
remote_conn++;
}
struct remote *remote;
remote = malloc(sizeof(struct remote));
remote->buf = malloc(BUF_SIZE);
remote->recv_ctx = malloc(sizeof(struct remote_ctx));
remote->send_ctx = malloc(sizeof(struct remote_ctx));
remote_t *remote;
remote = malloc(sizeof(remote_t));
remote->recv_ctx = malloc(sizeof(remote_ctx_t));
remote->send_ctx = malloc(sizeof(remote_ctx_t));
remote->buf = malloc(sizeof(buffer_t));
balloc(remote->buf, BUF_SIZE);
remote->fd = fd;
ev_io_init(&remote->recv_ctx->io, remote_recv_cb, fd, EV_READ);
ev_io_init(&remote->send_ctx->io, remote_send_cb, fd, EV_WRITE);
@ -1053,18 +1055,17 @@ static struct remote * new_remote(int fd)
remote->recv_ctx->connected = 0;
remote->send_ctx->remote = remote;
remote->send_ctx->connected = 0;
remote->buf_len = 0;
remote->buf_idx = 0;
remote->server = NULL;
return remote;
}
static void free_remote(struct remote *remote)
static void free_remote(remote_t *remote)
{
if (remote->server != NULL) {
remote->server->remote = NULL;
}
if (remote->buf != NULL) {
bfree(remote->buf);
free(remote->buf);
}
free(remote->recv_ctx);
@ -1072,7 +1073,7 @@ static void free_remote(struct remote *remote)
free(remote);
}
static void close_and_free_remote(EV_P_ struct remote *remote)
static void close_and_free_remote(EV_P_ remote_t *remote)
{
if (remote != NULL) {
ev_io_stop(EV_A_ & remote->send_ctx->io);
@ -1086,20 +1087,22 @@ static void close_and_free_remote(EV_P_ struct remote *remote)
}
}
static struct server * new_server(int fd, struct listen_ctx *listener)
static server_t * new_server(int fd, listen_ctx_t *listener)
{
if (verbose) {
server_conn++;
}
struct server *server;
server = malloc(sizeof(struct server));
server_t *server;
server = malloc(sizeof(server_t));
memset(server, 0, sizeof(server_t));
memset(server, 0, sizeof(struct server));
server->buf = malloc(sizeof(buffer_t));
balloc(server->buf, BUF_SIZE);
server->buf = malloc(BUF_SIZE);
server->recv_ctx = malloc(sizeof(struct server_ctx));
server->send_ctx = malloc(sizeof(struct server_ctx));
server->recv_ctx = malloc(sizeof(server_ctx_t));
server->send_ctx = malloc(sizeof(server_ctx_t));
server->fd = fd;
ev_io_init(&server->recv_ctx->io, server_recv_cb, fd, EV_READ);
ev_io_init(&server->send_ctx->io, server_send_cb, fd, EV_WRITE);
@ -1113,34 +1116,32 @@ static struct server * new_server(int fd, struct listen_ctx *listener)
server->query = NULL;
server->listen_ctx = listener;
if (listener->method) {
server->e_ctx = malloc(sizeof(struct enc_ctx));
server->d_ctx = malloc(sizeof(struct enc_ctx));
server->e_ctx = malloc(sizeof(enc_ctx_t));
server->d_ctx = malloc(sizeof(enc_ctx_t));
enc_ctx_init(listener->method, server->e_ctx, 1);
enc_ctx_init(listener->method, server->d_ctx, 0);
} else {
server->e_ctx = NULL;
server->d_ctx = NULL;
}
server->buf_len = 0;
server->buf_idx = 0;
server->remote = NULL;
server->chunk = (struct chunk *)malloc(sizeof(struct chunk));
memset(server->chunk, 0, sizeof(struct chunk));
server->chunk = (chunk_t *)malloc(sizeof(chunk_t));
memset(server->chunk, 0, sizeof(chunk_t));
server->chunk->buf = malloc(sizeof(buffer_t));
cork_dllist_add(&connections, &server->entries);
return server;
}
static void free_server(struct server *server)
static void free_server(server_t *server)
{
cork_dllist_remove(&server->entries);
if (server->chunk != NULL) {
if (server->chunk->buf != NULL) {
free(server->chunk->buf);
}
bfree(server->chunk->buf);
free(server->chunk->buf);
free(server->chunk);
server->chunk = NULL;
}
@ -1156,14 +1157,16 @@ static void free_server(struct server *server)
free(server->d_ctx);
}
if (server->buf != NULL) {
bfree(server->buf);
free(server->buf);
}
free(server->recv_ctx);
free(server->send_ctx);
free(server);
}
static void close_and_free_server(EV_P_ struct server *server)
static void close_and_free_server(EV_P_ server_t *server)
{
if (server != NULL) {
if (server->query != NULL) {
@ -1195,7 +1198,7 @@ static void signal_cb(EV_P_ ev_signal *w, int revents)
static void accept_cb(EV_P_ ev_io *w, int revents)
{
struct listen_ctx *listener = (struct listen_ctx *)w;
listen_ctx_t *listener = (listen_ctx_t *)w;
int serverfd = accept(listener->fd, NULL, NULL);
if (serverfd == -1) {
ERROR("accept");
@ -1213,7 +1216,7 @@ static void accept_cb(EV_P_ ev_io *w, int revents)
LOGI("accept a connection");
}
struct server *server = new_server(serverfd, listener);
server_t *server = new_server(serverfd, listener);
ev_io_start(EV_A_ & server->recv_ctx->io);
ev_timer_start(EV_A_ & server->recv_ctx->watcher);
}
@ -1446,7 +1449,7 @@ int main(int argc, char **argv)
}
// inilitialize listen context
struct listen_ctx listen_ctx_list[server_num];
listen_ctx_t listen_ctx_list[server_num];
// bind to each interface
while (server_num > 0) {
@ -1464,7 +1467,7 @@ int main(int argc, char **argv)
FATAL("listen() error");
}
setnonblocking(listenfd);
struct listen_ctx *listen_ctx = &listen_ctx_list[index];
listen_ctx_t *listen_ctx = &listen_ctx_list[index];
// Setup proxy context
listen_ctx->timeout = atoi(timeout);
@ -1521,7 +1524,7 @@ int main(int argc, char **argv)
// Clean up
for (int i = 0; i <= server_num; i++) {
struct listen_ctx *listen_ctx = &listen_ctx_list[i];
listen_ctx_t *listen_ctx = &listen_ctx_list[i];
if (mode != UDP_ONLY) {
ev_io_stop(loop, &listen_ctx->io);
close(listen_ctx->fd);

28
src/server.h

@ -33,28 +33,26 @@
#include "common.h"
struct listen_ctx {
typedef struct listen_ctx {
ev_io io;
int fd;
int timeout;
int method;
char *iface;
struct ev_loop *loop;
};
} listen_ctx_t;
struct server_ctx {
typedef struct server_ctx {
ev_io io;
ev_timer watcher;
int connected;
struct server *server;
};
} server_ctx_t;
struct server {
typedef struct server {
int fd;
int stage;
ssize_t buf_len;
ssize_t buf_idx;
char *buf; // server send from, remote recv into
buffer_t *buf;
int auth;
struct chunk *chunk;
@ -69,22 +67,20 @@ struct server {
struct ResolvQuery *query;
struct cork_dllist_item entries;
};
} server_t;
struct remote_ctx {
typedef struct remote_ctx {
ev_io io;
int connected;
struct remote *remote;
};
} remote_ctx_t;
struct remote {
typedef struct remote {
int fd;
ssize_t buf_len;
ssize_t buf_idx;
char *buf; // remote send from, server recv into
buffer_t *buf;
struct remote_ctx *recv_ctx;
struct remote_ctx *send_ctx;
struct server *server;
};
} remote_t;
#endif // _SERVER_H

206
src/tunnel.c

@ -77,13 +77,13 @@ static void server_send_cb(EV_P_ ev_io *w, int revents);
static void remote_recv_cb(EV_P_ ev_io *w, int revents);
static void remote_send_cb(EV_P_ ev_io *w, int revents);
static struct remote * new_remote(int fd, int timeout);
static struct server * new_server(int fd, int method);
static remote_t * new_remote(int fd, int timeout);
static server_t * new_server(int fd, int method);
static void free_remote(struct remote *remote);
static void close_and_free_remote(EV_P_ struct remote *remote);
static void free_server(struct server *server);
static void close_and_free_server(EV_P_ struct server *server);
static void free_remote(remote_t *remote);
static void close_and_free_remote(EV_P_ remote_t *remote);
static void free_server(server_t *server);
static void close_and_free_server(EV_P_ server_t *server);
#ifdef ANDROID
int vpn = 0;
@ -167,16 +167,16 @@ int create_and_bind(const char *addr, const char *port)
static void server_recv_cb(EV_P_ ev_io *w, int revents)
{
struct server_ctx *server_recv_ctx = (struct server_ctx *)w;
struct server *server = server_recv_ctx->server;
struct remote *remote = server->remote;
server_ctx_t *server_recv_ctx = (server_ctx_t *)w;
server_t *server = server_recv_ctx->server;
remote_t *remote = server->remote;
if (remote == NULL) {
close_and_free_server(EV_A_ server);
return;
}
ssize_t r = recv(server->fd, remote->buf, BUF_SIZE, 0);
ssize_t r = recv(server->fd, remote->buf->array, BUF_SIZE, 0);
if (r == 0) {
// connection closed
@ -196,26 +196,27 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
}
remote->buf->len = r;
if (auth) {
remote->buf = ss_gen_hash(remote->buf, &r, &remote->counter, server->e_ctx, BUF_SIZE);
ss_gen_hash(remote->buf, &remote->counter, server->e_ctx);
}
remote->buf = ss_encrypt(BUF_SIZE, remote->buf, &r, server->e_ctx);
int err = ss_encrypt(remote->buf, server->e_ctx);
if (remote->buf == NULL) {
if (err) {
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
}
int s = send(remote->fd, remote->buf, r, 0);
int s = send(remote->fd, remote->buf->array, remote->buf->len, 0);
if (s == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// no data, wait for send
remote->buf_len = r;
remote->buf_idx = 0;
remote->buf->idx = 0;
ev_io_stop(EV_A_ & server_recv_ctx->io);
ev_io_start(EV_A_ & remote->send_ctx->io);
return;
@ -225,9 +226,9 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
return;
}
} else if (s < r) {
remote->buf_len = r - s;
remote->buf_idx = s;
} else if (s < remote->buf->len) {
remote->buf->len -= s;
remote->buf->idx = s;
ev_io_stop(EV_A_ & server_recv_ctx->io);
ev_io_start(EV_A_ & remote->send_ctx->io);
return;
@ -236,18 +237,18 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
static void server_send_cb(EV_P_ ev_io *w, int revents)
{
struct server_ctx *server_send_ctx = (struct server_ctx *)w;
struct server *server = server_send_ctx->server;
struct remote *remote = server->remote;
if (server->buf_len == 0) {
server_ctx_t *server_send_ctx = (server_ctx_t *)w;
server_t *server = server_send_ctx->server;
remote_t *remote = server->remote;
if (server->buf->len == 0) {
// close and free
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
} else {
// has data to send
ssize_t s = send(server->fd, server->buf + server->buf_idx,
server->buf_len, 0);
ssize_t s = send(server->fd, server->buf->array + server->buf->idx,
server->buf->len, 0);
if (s < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ERROR("send");
@ -255,15 +256,15 @@ static void server_send_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
}
return;
} else if (s < server->buf_len) {
} else if (s < server->buf->len) {
// partly sent, move memory, wait for the next time to send
server->buf_len -= s;
server->buf_idx += s;
server->buf->len -= s;
server->buf->idx += s;
return;
} else {
// all sent out, wait for reading
server->buf_len = 0;
server->buf_idx = 0;
server->buf->len = 0;
server->buf->idx = 0;
ev_io_stop(EV_A_ & server_send_ctx->io);
if (remote != NULL) {
ev_io_start(EV_A_ & remote->recv_ctx->io);
@ -279,10 +280,10 @@ static void server_send_cb(EV_P_ ev_io *w, int revents)
static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents)
{
struct remote_ctx *remote_ctx = (struct remote_ctx *)(((void *)watcher)
remote_ctx_t *remote_ctx = (remote_ctx_t *)(((void *)watcher)
- sizeof(ev_io));
struct remote *remote = remote_ctx->remote;
struct server *server = remote->server;
remote_t *remote = remote_ctx->remote;
server_t *server = remote->server;
if (verbose) {
LOGI("TCP connection timeout");
@ -296,11 +297,11 @@ static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents)
static void remote_recv_cb(EV_P_ ev_io *w, int revents)
{
struct remote_ctx *remote_recv_ctx = (struct remote_ctx *)w;
struct remote *remote = remote_recv_ctx->remote;
struct server *server = remote->server;
remote_ctx_t *remote_recv_ctx = (remote_ctx_t *)w;
remote_t *remote = remote_recv_ctx->remote;
server_t *server = remote->server;
ssize_t r = recv(remote->fd, server->buf, BUF_SIZE, 0);
ssize_t r = recv(remote->fd, server->buf->array, BUF_SIZE, 0);
if (r == 0) {
// connection closed
@ -320,22 +321,23 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
}
}
server->buf = ss_decrypt(BUF_SIZE, server->buf, &r, server->d_ctx);
server->buf->len = r;
int err = ss_decrypt(server->buf, server->d_ctx);
if (server->buf == NULL) {
if (err) {
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
}
int s = send(server->fd, server->buf, r, 0);
int s = send(server->fd, server->buf->array, server->buf->len, 0);
if (s == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// no data, wait for send
server->buf_len = r;
server->buf_idx = 0;
server->buf->idx = 0;
ev_io_stop(EV_A_ & remote_recv_ctx->io);
ev_io_start(EV_A_ & server->send_ctx->io);
return;
@ -345,9 +347,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
return;
}
} else if (s < r) {
server->buf_len = r - s;
server->buf_idx = s;
} else if (s < server->buf->len) {
server->buf->len -= s;
server->buf->idx = s;
ev_io_stop(EV_A_ & remote_recv_ctx->io);
ev_io_start(EV_A_ & server->send_ctx->io);
return;
@ -356,9 +358,9 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
static void remote_send_cb(EV_P_ ev_io *w, int revents)
{
struct remote_ctx *remote_send_ctx = (struct remote_ctx *)w;
struct remote *remote = remote_send_ctx->remote;
struct server *server = remote->server;
remote_ctx_t *remote_send_ctx = (remote_ctx_t *)w;
remote_t *remote = remote_send_ctx->remote;
server_t *server = remote->server;
if (!remote_send_ctx->connected) {
struct sockaddr_storage addr;
@ -369,8 +371,10 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
remote_send_ctx->connected = 1;
ev_io_stop(EV_A_ & remote_send_ctx->io);
ev_timer_stop(EV_A_ & remote_send_ctx->watcher);
char *ss_addr_to_send = malloc(BUF_SIZE);
ssize_t addr_len = 0;
buffer_t ss_addr_to_send;
buffer_t *abuf = &ss_addr_to_send;
balloc(abuf, BUF_SIZE);
ss_addr_t *sa = &server->destaddr;
struct cork_ip ip;
@ -383,9 +387,9 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
if (dns_pton(AF_INET, sa->host, &host) == -1) {
FATAL("IP parser error");
}
ss_addr_to_send[addr_len++] = 1;
memcpy(ss_addr_to_send + addr_len, &host, host_len);
addr_len += host_len;
abuf->array[abuf->len++] = 1;
memcpy(abuf->array + abuf->len, &host, host_len);
abuf->len += host_len;
} else if (ip.version == 6) {
// send as IPv6
struct in6_addr host;
@ -394,9 +398,9 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
if (dns_pton(AF_INET6, sa->host, &host) == -1) {
FATAL("IP parser error");
}
ss_addr_to_send[addr_len++] = 4;
memcpy(ss_addr_to_send + addr_len, &host, host_len);
addr_len += host_len;
abuf->array[abuf->len++] = 4;
memcpy(abuf->array + abuf->len, &host, host_len);
abuf->len += host_len;
} else {
FATAL("IP parser error");
}
@ -404,35 +408,35 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
// send as domain
int host_len = strlen(sa->host);
ss_addr_to_send[addr_len++] = 3;
ss_addr_to_send[addr_len++] = host_len;
memcpy(ss_addr_to_send + addr_len, sa->host, host_len);
addr_len += host_len;
abuf->array[abuf->len++] = 3;
abuf->array[abuf->len++] = host_len;
memcpy(abuf->array + abuf->len, sa->host, host_len);
abuf->len += host_len;
}
uint16_t port = htons(atoi(sa->port));
memcpy(ss_addr_to_send + addr_len, &port, 2);
addr_len += 2;
memcpy(abuf->array + abuf->len, &port, 2);
abuf->len += 2;
if (auth) {
ss_addr_to_send[0] |= ONETIMEAUTH_FLAG;
ss_onetimeauth(ss_addr_to_send + addr_len, ss_addr_to_send, addr_len, server->e_ctx->evp.iv);
addr_len += ONETIMEAUTH_BYTES;
abuf->array[0] |= ONETIMEAUTH_FLAG;
ss_onetimeauth(abuf, server->e_ctx->evp.iv);
}
ss_addr_to_send = ss_encrypt(BUF_SIZE, ss_addr_to_send, &addr_len,
server->e_ctx);
if (ss_addr_to_send == NULL) {
int err = ss_encrypt(abuf, server->e_ctx);
if (err) {
bfree(abuf);
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
}
int s = send(remote->fd, ss_addr_to_send, addr_len, 0);
free(ss_addr_to_send);
int s = send(remote->fd, abuf->array, abuf->len, 0);
bfree(abuf);
if (s < addr_len) {
if (s < abuf->len) {
LOGE("failed to send addr");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
@ -451,15 +455,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
return;
}
} else {
if (remote->buf_len == 0) {
if (remote->buf->len == 0) {
// close and free
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
} else {
// has data to send
ssize_t s = send(remote->fd, remote->buf + remote->buf_idx,
remote->buf_len, 0);
ssize_t s = send(remote->fd, remote->buf->array + remote->buf->idx,
remote->buf->len, 0);
if (s < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
ERROR("send");
@ -468,15 +472,15 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
close_and_free_server(EV_A_ server);
}
return;
} else if (s < remote->buf_len) {
} else if (s < remote->buf->len) {
// partly sent, move memory, wait for the next time to send
remote->buf_len -= s;
remote->buf_idx += s;
remote->buf->len -= s;
remote->buf->idx += s;
return;
} else {
// all sent out, wait for reading
remote->buf_len = 0;
remote->buf_idx = 0;
remote->buf->len = 0;
remote->buf->idx = 0;
ev_io_stop(EV_A_ & remote_send_ctx->io);
ev_io_start(EV_A_ & server->recv_ctx->io);
}
@ -485,16 +489,16 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
}
}
static struct remote * new_remote(int fd, int timeout)
static remote_t * new_remote(int fd, int timeout)
{
struct remote *remote;
remote = malloc(sizeof(struct remote));
remote_t *remote;
remote = malloc(sizeof(remote_t));
memset(remote, 0, sizeof(struct remote));
memset(remote, 0, sizeof(remote_t));
remote->buf = malloc(BUF_SIZE);
remote->recv_ctx = malloc(sizeof(struct remote_ctx));
remote->send_ctx = malloc(sizeof(struct remote_ctx));
remote->recv_ctx = malloc(sizeof(remote_ctx_t));
remote->send_ctx = malloc(sizeof(remote_ctx_t));
remote->fd = fd;
ev_io_init(&remote->recv_ctx->io, remote_recv_cb, fd, EV_READ);
ev_io_init(&remote->send_ctx->io, remote_send_cb, fd, EV_WRITE);
@ -504,18 +508,19 @@ static struct remote * new_remote(int fd, int timeout)
remote->recv_ctx->connected = 0;
remote->send_ctx->remote = remote;
remote->send_ctx->connected = 0;
remote->buf_len = 0;
remote->buf_idx = 0;
remote->buf->len = 0;
remote->buf->idx = 0;
return remote;
}
static void free_remote(struct remote *remote)
static void free_remote(remote_t *remote)
{
if (remote != NULL) {
if (remote->server != NULL) {
remote->server->remote = NULL;
}
if (remote->buf) {
bfree(remote->buf);
free(remote->buf);
}
free(remote->recv_ctx);
@ -524,7 +529,7 @@ static void free_remote(struct remote *remote)
}
}
static void close_and_free_remote(EV_P_ struct remote *remote)
static void close_and_free_remote(EV_P_ remote_t *remote)
{
if (remote != NULL) {
ev_timer_stop(EV_A_ & remote->send_ctx->watcher);
@ -535,13 +540,13 @@ static void close_and_free_remote(EV_P_ struct remote *remote)
}
}
static struct server * new_server(int fd, int method)
static server_t * new_server(int fd, int method)
{
struct server *server;
server = malloc(sizeof(struct server));
server_t *server;
server = malloc(sizeof(server_t));
server->buf = malloc(BUF_SIZE);
server->recv_ctx = malloc(sizeof(struct server_ctx));
server->send_ctx = malloc(sizeof(struct server_ctx));
server->recv_ctx = malloc(sizeof(server_ctx_t));
server->send_ctx = malloc(sizeof(server_ctx_t));
server->fd = fd;
ev_io_init(&server->recv_ctx->io, server_recv_cb, fd, EV_READ);
ev_io_init(&server->send_ctx->io, server_send_cb, fd, EV_WRITE);
@ -558,12 +563,12 @@ static struct server * new_server(int fd, int method)
server->e_ctx = NULL;
server->d_ctx = NULL;
}
server->buf_len = 0;
server->buf_idx = 0;
server->buf->len = 0;
server->buf->idx = 0;
return server;
}
static void free_server(struct server *server)
static void free_server(server_t *server)
{
if (server != NULL) {
if (server->remote != NULL) {
@ -578,6 +583,7 @@ static void free_server(struct server *server)
free(server->d_ctx);
}
if (server->buf) {
bfree(server->buf);
free(server->buf);
}
free(server->recv_ctx);
@ -586,7 +592,7 @@ static void free_server(struct server *server)
}
}
static void close_and_free_server(EV_P_ struct server *server)
static void close_and_free_server(EV_P_ server_t *server)
{
if (server != NULL) {
ev_io_stop(EV_A_ & server->send_ctx->io);
@ -643,8 +649,8 @@ static void accept_cb(EV_P_ ev_io *w, int revents)
}
#endif
struct server *server = new_server(serverfd, listener->method);
struct remote *remote = new_remote(remotefd, listener->timeout);
server_t *server = new_server(serverfd, listener->method);
remote_t *remote = new_remote(remotefd, listener->timeout);
server->destaddr = listener->tunnel_addr;
server->remote = remote;
remote->server = server;

28
src/tunnel.h

@ -29,7 +29,7 @@
#include "common.h"
struct listen_ctx {
typedef struct listen_ctx {
ev_io io;
ss_addr_t tunnel_addr;
char *iface;
@ -38,43 +38,39 @@ struct listen_ctx {
int timeout;
int fd;
struct sockaddr **remote_addr;
};
} listen_ctx_t;
struct server_ctx {
typedef struct server_ctx {
ev_io io;
int connected;
struct server *server;
};
} server_ctx_t;
struct server {
typedef struct server {
int fd;
ssize_t buf_len;
ssize_t buf_idx;
char *buf; // server send from, remote recv into
buffer_t *buf;
struct enc_ctx *e_ctx;
struct enc_ctx *d_ctx;
struct server_ctx *recv_ctx;
struct server_ctx *send_ctx;
struct remote *remote;
ss_addr_t destaddr;
};
} server_t;
struct remote_ctx {
typedef struct remote_ctx {
ev_io io;
ev_timer watcher;
int connected;
struct remote *remote;
};
} remote_ctx_t;
struct remote {
typedef struct remote {
int fd;
ssize_t buf_len;
ssize_t buf_idx;
char *buf; // remote send from, server recv into
buffer_t *buf;
struct remote_ctx *recv_ctx;
struct remote_ctx *send_ctx;
struct server *server;
uint32_t counter;
};
} remote_t;
#endif // _TUNNEL_H

181
src/udprelay.c

@ -91,8 +91,8 @@ static char *hash_key(const int af, const struct sockaddr_storage *addr);
#ifdef UDPRELAY_REMOTE
static void query_resolve_cb(struct sockaddr *addr, void *data);
#endif
static void close_and_free_remote(EV_P_ struct remote_ctx *ctx);
static struct remote_ctx * new_remote(int fd, struct server_ctx * server_ctx);
static void close_and_free_remote(EV_P_ remote_ctx_t *ctx);
static remote_ctx_t * new_remote(int fd, server_ctx_t * server_ctx);
extern int verbose;
extern int vpn;
@ -102,7 +102,7 @@ extern uint64_t rx;
#endif
static int server_num = 0;
static struct server_ctx *server_ctx_list[MAX_REMOTE_NUM] = { NULL };
static server_ctx_t *server_ctx_list[MAX_REMOTE_NUM] = { NULL };
#ifndef __MINGW32__
static int setnonblocking(int fd)
@ -467,10 +467,10 @@ int create_server_socket(const char *host, const char *port)
return server_sock;
}
struct remote_ctx *new_remote(int fd, struct server_ctx *server_ctx)
remote_ctx_t *new_remote(int fd, server_ctx_t *server_ctx)
{
struct remote_ctx *ctx = malloc(sizeof(struct remote_ctx));
memset(ctx, 0, sizeof(struct remote_ctx));
remote_ctx_t *ctx = malloc(sizeof(remote_ctx_t));
memset(ctx, 0, sizeof(remote_ctx_t));
ctx->fd = fd;
ctx->server_ctx = server_ctx;
@ -480,23 +480,23 @@ struct remote_ctx *new_remote(int fd, struct server_ctx *server_ctx)
return ctx;
}
struct server_ctx * new_server_ctx(int fd)
server_ctx_t * new_server_ctx(int fd)
{
struct server_ctx *ctx = malloc(sizeof(struct server_ctx));
memset(ctx, 0, sizeof(struct server_ctx));
server_ctx_t *ctx = malloc(sizeof(server_ctx_t));
memset(ctx, 0, sizeof(server_ctx_t));
ctx->fd = fd;
ev_io_init(&ctx->io, server_recv_cb, fd, EV_READ);
return ctx;
}
#ifdef UDPRELAY_REMOTE
struct query_ctx *new_query_ctx(const char *buf, const int buf_len)
struct query_ctx *new_query_ctx(char *buf, size_t len)
{
struct query_ctx *ctx = malloc(sizeof(struct query_ctx));
memset(ctx, 0, sizeof(struct query_ctx));
ctx->buf = malloc(buf_len);
ctx->buf_len = buf_len;
memcpy(ctx->buf, buf, buf_len);
ctx->buf = malloc(sizeof(buffer_t));
balloc(ctx->buf, len);
memcpy(ctx->buf->array, buf, len);
return ctx;
}
@ -508,6 +508,7 @@ void close_and_free_query(EV_P_ struct query_ctx *ctx)
ctx->query = NULL;
}
if (ctx->buf != NULL) {
bfree(ctx->buf);
free(ctx->buf);
}
free(ctx);
@ -516,7 +517,7 @@ void close_and_free_query(EV_P_ struct query_ctx *ctx)
#endif
void close_and_free_remote(EV_P_ struct remote_ctx *ctx)
void close_and_free_remote(EV_P_ remote_ctx_t *ctx)
{
if (ctx != NULL) {
ev_timer_stop(EV_A_ & ctx->watcher);
@ -528,7 +529,7 @@ void close_and_free_remote(EV_P_ struct remote_ctx *ctx)
static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents)
{
struct remote_ctx *remote_ctx = (struct remote_ctx *)(((void *)watcher)
remote_ctx_t *remote_ctx = (remote_ctx_t *)(((void *)watcher)
- sizeof(ev_io));
if (verbose) {
@ -554,7 +555,7 @@ static void query_resolve_cb(struct sockaddr *addr, void *data)
if (addr == NULL) {
LOGE("[udp] udns returned an error");
} else {
struct remote_ctx *remote_ctx = query_ctx->remote_ctx;
remote_ctx_t *remote_ctx = query_ctx->remote_ctx;
int cache_hit = 0;
// Lookup in the conn cache
@ -593,7 +594,7 @@ static void query_resolve_cb(struct sockaddr *addr, void *data)
if (remote_ctx != NULL) {
size_t addr_len = get_sockaddr_len(addr);
int s = sendto(remote_ctx->fd, query_ctx->buf, query_ctx->buf_len,
int s = sendto(remote_ctx->fd, query_ctx->buf->array, query_ctx->buf->len,
0, addr, addr_len);
if (s == -1) {
@ -621,8 +622,8 @@ static void query_resolve_cb(struct sockaddr *addr, void *data)
static void remote_recv_cb(EV_P_ ev_io *w, int revents)
{
struct remote_ctx *remote_ctx = (struct remote_ctx *)w;
struct server_ctx *server_ctx = remote_ctx->server_ctx;
remote_ctx_t *remote_ctx = (remote_ctx_t *)w;
server_ctx_t *server_ctx = remote_ctx->server_ctx;
// server has been closed
if (server_ctx == NULL) {
@ -638,12 +639,14 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
struct sockaddr_storage src_addr;
socklen_t src_addr_len = sizeof(src_addr);
memset(&src_addr, 0, src_addr_len);
char *buf = malloc(BUF_SIZE);
buffer_t *buf = malloc(sizeof(buffer_t));
balloc(buf, BUF_SIZE);
// recv
ssize_t buf_len = recvfrom(remote_ctx->fd, buf, BUF_SIZE, 0, (struct sockaddr *)&src_addr, &src_addr_len);
buf->len = recvfrom(remote_ctx->fd, buf->array, BUF_SIZE, 0, (struct sockaddr *)&src_addr, &src_addr_len);
if (buf_len == -1) {
if (buf->len == -1) {
// error on recv
// simply drop that packet
ERROR("[udp] remote_recvfrom");
@ -651,13 +654,13 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
}
// packet size > default MTU
if (verbose && buf_len > MTU) {
LOGE("[udp] possible ip fragment, size: %d", (int)buf_len);
if (verbose && buf->len > MTU) {
LOGE("[udp] possible ip fragment, size: %d", (int)buf->len);
}
#ifdef UDPRELAY_LOCAL
buf = ss_decrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, 0);
if (buf == NULL) {
int err = ss_decrypt_all(buf, server_ctx->method, 0);
if (err) {
// drop the packet silently
goto CLEAN_UP;
}
@ -665,14 +668,14 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
#ifdef UDPRELAY_REDIR
struct sockaddr_storage dst_addr;
memset(&dst_addr, 0, sizeof(struct sockaddr_storage));
int len = parse_udprealy_header(buf, buf_len, NULL, NULL, NULL, &dst_addr);
int len = parse_udprealy_header(buf->array, buf->len, NULL, NULL, NULL, &dst_addr);
if (dst_addr.ss_family != AF_INET && dst_addr.ss_family != AF_INET6) {
LOGI("[udp] ss-redir does not support domain name");
goto CLEAN_UP;
}
#else
int len = parse_udprealy_header(buf, buf_len, NULL, NULL, NULL, NULL);
int len = parse_udprealy_header(buf->array, buf->len, NULL, NULL, NULL, NULL);
#endif
if (len == 0) {
@ -686,22 +689,20 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
#if defined(UDPRELAY_TUNNEL) || defined(UDPRELAY_REDIR)
// Construct packet
buf_len -= len;
memmove(buf, buf + len, buf_len);
buf->len -= len;
memmove(buf->array, buf->array + len, buf->len);
#else
// Construct packet
if (BUF_SIZE < buf_len + 3) {
buf = realloc(buf, buf_len + 3);
}
memmove(buf + 3, buf, buf_len);
memset(buf, 0, 3);
buf_len += 3;
brealloc(buf, buf->len + 3, BUF_SIZE);
memmove(buf->array + 3, buf->array, buf->len);
memset(buf->array, 0, 3);
buf->len += 3;
#endif
#endif
#ifdef UDPRELAY_REMOTE
rx += buf_len;
rx += buf->len;
char addr_header_buf[256];
char *addr_header = remote_ctx->addr_header;
@ -713,14 +714,16 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
}
// Construct packet
if (BUF_SIZE < buf_len + addr_header_len) {
buf = realloc(buf, buf_len + addr_header_len);
}
memmove(buf + addr_header_len, buf, buf_len);
memcpy(buf, addr_header, addr_header_len);
buf_len += addr_header_len;
brealloc(buf, buf->len + addr_header_len, BUF_SIZE);
memmove(buf->array + addr_header_len, buf->array, buf->len);
memcpy(buf->array, addr_header, addr_header_len);
buf->len += addr_header_len;
buf = ss_encrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, 0);
int err = ss_encrypt_all(buf, server_ctx->method, 0);
if (err) {
// drop the packet silently
goto CLEAN_UP;
}
#endif
size_t remote_src_addr_len = get_sockaddr_len((struct sockaddr *)&remote_ctx->src_addr);
@ -750,7 +753,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
goto CLEAN_UP;
}
int s = sendto(src_fd, buf, buf_len, 0,
int s = sendto(src_fd, buf->array, buf->len, 0,
(struct sockaddr *)&remote_ctx->src_addr, remote_src_addr_len);
if (s == -1) {
ERROR("[udp] remote_recv_sendto");
@ -760,7 +763,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
close(src_fd);
#else
int s = sendto(server_ctx->fd, buf, buf_len, 0,
int s = sendto(server_ctx->fd, buf->array, buf->len, 0,
(struct sockaddr *)&remote_ctx->src_addr, remote_src_addr_len);
if (s == -1) {
ERROR("[udp] remote_recv_sendto");
@ -774,16 +777,18 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
CLEAN_UP:
bfree(buf);
free(buf);
}
static void server_recv_cb(EV_P_ ev_io *w, int revents)
{
struct server_ctx *server_ctx = (struct server_ctx *)w;
server_ctx_t *server_ctx = (server_ctx_t *)w;
struct sockaddr_storage src_addr;
memset(&src_addr, 0, sizeof(struct sockaddr_storage));
char *buf = malloc(BUF_SIZE);
buffer_t *buf = malloc(sizeof(buffer_t));
balloc(buf, BUF_SIZE);
socklen_t src_addr_len = sizeof(struct sockaddr_storage);
unsigned int offset = 0;
@ -800,7 +805,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
msg.msg_control = control_buffer;
msg.msg_controllen = sizeof(control_buffer);
iov[0].iov_base = buf;
iov[0].iov_base = buf->array;
iov[0].iov_len = BUF_SIZE;
msg.msg_iov = iov;
msg.msg_iovlen = 1;
@ -818,11 +823,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
src_addr_len = msg.msg_namelen;
#else
ssize_t buf_len =
recvfrom(server_ctx->fd, buf, BUF_SIZE, 0, (struct sockaddr *)&src_addr,
&src_addr_len);
buf->len = recvfrom(server_ctx->fd, buf->array, BUF_SIZE,
0, (struct sockaddr *)&src_addr, &src_addr_len);
if (buf_len == -1) {
if (buf->len == -1) {
// error on recv
// simply drop that packet
ERROR("[udp] server_recvfrom");
@ -836,10 +840,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
#ifdef UDPRELAY_REMOTE
tx += buf_len;
tx += buf->len;
buf = ss_decrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, server_ctx->auth);
if (buf == NULL) {
int err = ss_decrypt_all(buf, server_ctx->method, server_ctx->auth);
if (err) {
// drop the packet silently
goto CLEAN_UP;
}
@ -847,14 +851,14 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
#ifdef UDPRELAY_LOCAL
#if !defined(UDPRELAY_TUNNEL) && !defined(UDPRELAY_REDIR)
uint8_t frag = *(uint8_t *)(buf + 2);
uint8_t frag = *(uint8_t *)(buf->array + 2);
offset += 3;
#endif
#endif
// packet size > default MTU
if (verbose && buf_len > MTU) {
LOGE("[udp] possible ip fragment, size: %d", (int)buf_len);
if (verbose && buf->len > MTU) {
LOGE("[udp] possible ip fragment, size: %d", (int)buf->len);
}
/*
@ -911,12 +915,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
// reconstruct the buffer
if (BUF_SIZE < buf_len + addr_header_len) {
buf = realloc(buf, buf_len + addr_header_len);
}
memmove(buf + addr_header_len, buf, buf_len);
memcpy(buf, addr_header, addr_header_len);
buf_len += addr_header_len;
brealloc(buf, buf->len + addr_header_len, BUF_SIZE);
memmove(buf->array + addr_header_len, buf->array, buf->len);
memcpy(buf->array, addr_header, addr_header_len);
buf->len += addr_header_len;
char *key = hash_key(dst_addr.ss_family, &src_addr);
@ -969,12 +971,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
addr_header_len += 2;
// reconstruct the buffer
if (BUF_SIZE < buf_len + addr_header_len) {
buf = realloc(buf, buf_len + addr_header_len);
}
memmove(buf + addr_header_len, buf, buf_len);
memcpy(buf, addr_header, addr_header_len);
buf_len += addr_header_len;
brealloc(buf, buf->len + addr_header_len, BUF_SIZE);
memmove(buf->array + addr_header_len, buf->array, buf->len);
memcpy(buf->array, addr_header, addr_header_len);
buf->len += addr_header_len;
char *key = hash_key(ip.version == 4 ? AF_INET : AF_INET6, &src_addr);
@ -985,21 +985,21 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
struct sockaddr_storage dst_addr;
memset(&dst_addr, 0, sizeof(struct sockaddr_storage));
int addr_header_len = parse_udprealy_header(buf + offset, buf_len - offset,
int addr_header_len = parse_udprealy_header(buf->array + offset, buf->len - offset,
&server_ctx->auth, host, port,
&dst_addr);
if (addr_header_len == 0) {
// error in parse header
goto CLEAN_UP;
}
char *addr_header = buf + offset;
char *addr_header = buf->array + offset;
char *key = hash_key(dst_addr.ss_family, &src_addr);
#endif
struct cache *conn_cache = server_ctx->conn_cache;
struct remote_ctx *remote_ctx = NULL;
remote_ctx_t *remote_ctx = NULL;
cache_lookup(conn_cache, key, HASH_KEY_LEN, (void *)&remote_ctx);
if (remote_ctx != NULL) {
@ -1098,17 +1098,22 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
if (offset > 0) {
buf_len -= offset;
memmove(buf, buf + offset, buf_len);
buf->len -= offset;
memmove(buf->array, buf->array + offset, buf->len);
}
if (server_ctx->auth) {
buf[0] |= ONETIMEAUTH_FLAG;
buf->array[0] |= ONETIMEAUTH_FLAG;
}
buf = ss_encrypt_all(BUF_SIZE, buf, &buf_len, server_ctx->method, server_ctx->auth);
int err = ss_encrypt_all(buf, server_ctx->method, server_ctx->auth);
if (err) {
// drop the packet silently
goto CLEAN_UP;
}
int s = sendto(remote_ctx->fd, buf, buf_len, 0, remote_addr, remote_addr_len);
int s = sendto(remote_ctx->fd, buf->array, buf->len, 0, remote_addr, remote_addr_len);
if (s == -1) {
ERROR("[udp] sendto_remote");
@ -1158,8 +1163,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
if (remote_ctx != NULL && !need_query) {
size_t addr_len = get_sockaddr_len((struct sockaddr *)&dst_addr);
int s = sendto(remote_ctx->fd, buf + addr_header_len,
buf_len - addr_header_len, 0,
int s = sendto(remote_ctx->fd, buf->array + addr_header_len,
buf->len - addr_header_len, 0,
(struct sockaddr *)&dst_addr, addr_len);
if (s == -1) {
@ -1185,9 +1190,8 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
hints.ai_socktype = SOCK_DGRAM;
hints.ai_protocol = IPPROTO_UDP;
struct query_ctx *query_ctx = new_query_ctx(buf + addr_header_len,
buf_len -
addr_header_len);
struct query_ctx *query_ctx = new_query_ctx(buf->array + addr_header_len,
buf->len - addr_header_len);
query_ctx->server_ctx = server_ctx;
query_ctx->addr_header_len = addr_header_len;
query_ctx->src_addr = src_addr;
@ -1209,12 +1213,13 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
#endif
CLEAN_UP:
bfree(buf);
free(buf);
}
void free_cb(void *element)
{
struct remote_ctx *remote_ctx = (struct remote_ctx *)element;
remote_ctx_t *remote_ctx = (remote_ctx_t *)element;
if (verbose) {
LOGI("[udp] one connection freed");
@ -1249,7 +1254,7 @@ int init_udprelay(const char *server_host, const char *server_port,
}
setnonblocking(serverfd);
struct server_ctx *server_ctx = new_server_ctx(serverfd);
server_ctx_t *server_ctx = new_server_ctx(serverfd);
#ifdef UDPRELAY_REMOTE
server_ctx->loop = loop;
#endif
@ -1277,7 +1282,7 @@ void free_udprelay()
{
struct ev_loop *loop = EV_DEFAULT;
while (server_num-- > 0) {
struct server_ctx *server_ctx = server_ctx_list[server_num];
server_ctx_t *server_ctx = server_ctx_list[server_num];
ev_io_stop(loop, &server_ctx->io);
close(server_ctx->fd);
cache_delete(server_ctx->conn_cache, 0);

25
src/udprelay.h

@ -41,7 +41,7 @@
#define MTU 1397 // 1492 - 1 - 28 - 2 - 64 = 1397, the default MTU for UDP relay
struct server_ctx {
typedef struct server_ctx {
ev_io io;
int fd;
int method;
@ -59,22 +59,21 @@ struct server_ctx {
#ifdef UDPRELAY_REMOTE
struct ev_loop *loop;
#endif
};
} server_ctx_t;
#ifdef UDPRELAY_REMOTE
struct query_ctx {
typedef struct query_ctx {
struct ResolvQuery *query;
struct sockaddr_storage src_addr;
int buf_len;
char *buf; // server send from, remote recv into
buffer_t *buf;
int addr_header_len;
char addr_header[384];
struct server_ctx *server_ctx;
struct remote_ctx *remote_ctx;
};
} query_ctx_t;
#endif
struct remote_ctx {
typedef struct remote_ctx {
ev_io io;
ev_timer watcher;
int af;
@ -83,16 +82,6 @@ struct remote_ctx {
char addr_header[384];
struct sockaddr_storage src_addr;
struct server_ctx *server_ctx;
};
#ifdef ANDROID
struct protect_ctx {
int buf_len;
char *buf;
struct sockaddr_storage addr;
int addr_len;
struct remote_ctx *remote_ctx;
};
#endif
} remote_ctx_t;
#endif // _UDPRELAY_H
Loading…
Cancel
Save