Browse Source

Refine buffer allocation (2)

pull/481/head
Max Lv 9 years ago
parent
commit
d2ef245a05
7 changed files with 57 additions and 57 deletions
  1. 50
      src/encrypt.c
  2. 14
      src/encrypt.h
  3. 12
      src/local.c
  4. 10
      src/redir.c
  5. 10
      src/server.c
  6. 10
      src/tunnel.c
  7. 8
      src/udprelay.c

50
src/encrypt.c

@ -1058,14 +1058,14 @@ static int cipher_context_update(cipher_ctx_t *ctx, uint8_t *output, size_t *ole
#endif
}
int ss_onetimeauth(buffer_t *buf, uint8_t *iv)
int ss_onetimeauth(buffer_t *buf, uint8_t *iv, size_t capacity)
{
uint8_t hash[ONETIMEAUTH_BYTES * 2];
uint8_t auth_key[MAX_IV_LENGTH + MAX_KEY_LENGTH];
memcpy(auth_key, iv, enc_iv_len);
memcpy(auth_key + enc_iv_len, enc_key, enc_key_len);
brealloc(buf, ONETIMEAUTH_BYTES + buf->len, buf->capacity);
brealloc(buf, ONETIMEAUTH_BYTES + buf->len, capacity);
#if defined(USE_CRYPTO_OPENSSL)
HMAC(EVP_sha1(), auth_key, enc_iv_len + enc_key_len, (uint8_t *)buf->array, buf->len, (uint8_t *)hash, NULL);
@ -1100,7 +1100,7 @@ int ss_onetimeauth_verify(buffer_t *buf, uint8_t *iv)
return safe_memcmp(buf->array + len, hash, ONETIMEAUTH_BYTES);
}
int ss_encrypt_all(buffer_t *plain, int method, int auth)
int ss_encrypt_all(buffer_t *plain, int method, int auth, size_t capacity)
{
if (method > TABLE) {
cipher_ctx_t evp;
@ -1110,7 +1110,7 @@ int ss_encrypt_all(buffer_t *plain, int method, int auth)
int err = 1;
static buffer_t tmp = { 0 };
brealloc(&tmp, iv_len + plain->len, plain->capacity);
brealloc(&tmp, iv_len + plain->len, capacity);
buffer_t *cipher = &tmp;
cipher->len = plain->len;
@ -1121,7 +1121,7 @@ int ss_encrypt_all(buffer_t *plain, int method, int auth)
memcpy(cipher->array, iv, iv_len);
if (auth) {
ss_onetimeauth(plain, iv);
ss_onetimeauth(plain, iv, capacity);
cipher->len = plain->len;
}
@ -1149,7 +1149,7 @@ int ss_encrypt_all(buffer_t *plain, int method, int auth)
cipher_context_release(&evp);
brealloc(plain, iv_len + cipher->len, plain->capacity);
brealloc(plain, iv_len + cipher->len, capacity);
memcpy(plain->array, cipher->array, iv_len + cipher->len);
plain->len = iv_len + cipher->len;
@ -1165,7 +1165,7 @@ int ss_encrypt_all(buffer_t *plain, int method, int auth)
}
}
int ss_encrypt(buffer_t *plain, enc_ctx_t *ctx)
int ss_encrypt(buffer_t *plain, enc_ctx_t *ctx, size_t capacity)
{
if (ctx != NULL) {
static buffer_t tmp = { 0 };
@ -1176,7 +1176,7 @@ int ss_encrypt(buffer_t *plain, enc_ctx_t *ctx)
iv_len = enc_iv_len;
}
brealloc(&tmp, iv_len + plain->len, plain->capacity);
brealloc(&tmp, iv_len + plain->len, capacity);
buffer_t *cipher = &tmp;
cipher->len = plain->len;
@ -1189,9 +1189,9 @@ int ss_encrypt(buffer_t *plain, enc_ctx_t *ctx)
if (enc_method >= SALSA20) {
int padding = ctx->counter % SODIUM_BLOCK_SIZE;
brealloc(cipher, iv_len + (padding + cipher->len) * 2, cipher->capacity);
brealloc(cipher, iv_len + (padding + cipher->len) * 2, capacity);
if (padding) {
brealloc(plain, plain->len + padding, plain->capacity);
brealloc(plain, plain->len + padding, capacity);
memmove(plain->array + padding, plain->array, plain->len);
memset(plain->array, 0, padding);
}
@ -1222,7 +1222,7 @@ int ss_encrypt(buffer_t *plain, enc_ctx_t *ctx)
dump("CIPHER", cipher->array + iv_len, cipher->len);
#endif
brealloc(plain, iv_len + cipher->len, plain->capacity);
brealloc(plain, iv_len + cipher->len, capacity);
memcpy(plain->array, cipher->array, iv_len + cipher->len);
plain->len = iv_len + cipher->len;
@ -1238,7 +1238,7 @@ int ss_encrypt(buffer_t *plain, enc_ctx_t *ctx)
}
}
int ss_decrypt_all(buffer_t *cipher, int method, int auth)
int ss_decrypt_all(buffer_t *cipher, int method, int auth, size_t capacity)
{
if (method > TABLE) {
size_t iv_len = enc_iv_len;
@ -1252,7 +1252,7 @@ int ss_decrypt_all(buffer_t *cipher, int method, int auth)
cipher_context_init(&evp, method, 0);
static buffer_t tmp = { 0 };
brealloc(&tmp, cipher->len, cipher->capacity);
brealloc(&tmp, cipher->len, capacity);
buffer_t *plain = &tmp;
plain->len = cipher->len - iv_len;
@ -1295,7 +1295,7 @@ int ss_decrypt_all(buffer_t *cipher, int method, int auth)
cipher_context_release(&evp);
brealloc(cipher, plain->len, plain->capacity);
brealloc(cipher, plain->len, capacity);
memcpy(cipher->array, plain->array, plain->len);
cipher->len = plain->len;
@ -1311,7 +1311,7 @@ int ss_decrypt_all(buffer_t *cipher, int method, int auth)
}
}
int ss_decrypt(buffer_t *cipher, enc_ctx_t *ctx)
int ss_decrypt(buffer_t *cipher, enc_ctx_t *ctx, size_t capacity)
{
if (ctx != NULL) {
static buffer_t tmp = { 0 };
@ -1319,7 +1319,7 @@ int ss_decrypt(buffer_t *cipher, enc_ctx_t *ctx)
size_t iv_len = 0;
int err = 1;
brealloc(&tmp, cipher->len, cipher->capacity);
brealloc(&tmp, cipher->len, capacity);
buffer_t *plain = &tmp;
plain->len = cipher->len;
@ -1345,10 +1345,10 @@ int ss_decrypt(buffer_t *cipher, enc_ctx_t *ctx)
if (enc_method >= SALSA20) {
int padding = ctx->counter % SODIUM_BLOCK_SIZE;
brealloc(plain, (plain->len + padding) * 2, plain->capacity);
brealloc(plain, (plain->len + padding) * 2, capacity);
if (padding) {
brealloc(cipher, cipher->len + padding, cipher->capacity);
brealloc(cipher, cipher->len + padding, capacity);
memmove(cipher->array + iv_len + padding, cipher->array + iv_len,
cipher->len - iv_len);
memset(cipher->array + iv_len, 0, padding);
@ -1379,7 +1379,7 @@ int ss_decrypt(buffer_t *cipher, enc_ctx_t *ctx)
dump("CIPHER", cipher->array + iv_len, cipher->len - iv_len);
#endif
brealloc(cipher, plain->len, cipher->capacity);
brealloc(cipher, plain->len, capacity);
memcpy(cipher->array, plain->array, plain->len);
cipher->len = plain->len;
@ -1514,21 +1514,21 @@ int enc_init(const char *pass, const char *method)
return m;
}
int ss_check_hash(buffer_t *buf, chunk_t *chunk, enc_ctx_t *ctx)
int ss_check_hash(buffer_t *buf, chunk_t *chunk, enc_ctx_t *ctx, size_t capacity)
{
int i, j, k;
ssize_t blen = buf->len;
uint32_t cidx = chunk->idx;
brealloc(chunk->buf, chunk->len + blen, buf->capacity);
brealloc(buf, chunk->len + blen, buf->capacity);
brealloc(chunk->buf, chunk->len + blen, capacity);
brealloc(buf, chunk->len + blen, capacity);
for (i = 0, j = 0, k = 0; i < blen; i++) {
chunk->buf->array[cidx++] = buf->array[k++];
if (cidx == CLEN_BYTES) {
uint16_t clen = ntohs(*((uint16_t *)chunk->buf->array));
brealloc(chunk->buf, clen + AUTH_BYTES, buf->capacity);
brealloc(chunk->buf, clen + AUTH_BYTES, capacity);
chunk->len = clen;
}
@ -1572,7 +1572,7 @@ int ss_check_hash(buffer_t *buf, chunk_t *chunk, enc_ctx_t *ctx)
return 1;
}
int ss_gen_hash(buffer_t *buf, uint32_t *counter, enc_ctx_t *ctx)
int ss_gen_hash(buffer_t *buf, uint32_t *counter, enc_ctx_t *ctx, size_t capacity)
{
ssize_t blen = buf->len;
uint16_t chunk_len = htons((uint16_t)blen);
@ -1580,7 +1580,7 @@ int ss_gen_hash(buffer_t *buf, uint32_t *counter, enc_ctx_t *ctx)
uint8_t key[MAX_IV_LENGTH + sizeof(uint32_t)];
uint32_t c = htonl(*counter);
brealloc(buf, AUTH_BYTES + blen, buf->capacity);
brealloc(buf, AUTH_BYTES + blen, capacity);
memcpy(key, ctx->evp.iv, enc_iv_len);
memcpy(key + enc_iv_len, &c, sizeof(uint32_t));
#if defined(USE_CRYPTO_OPENSSL)

14
src/encrypt.h

@ -170,10 +170,10 @@ typedef struct enc_ctx {
cipher_ctx_t evp;
} enc_ctx_t;
int ss_encrypt_all(buffer_t *plaintext, int method, int auth);
int ss_decrypt_all(buffer_t *ciphertext, int method, int auth);
int ss_encrypt(buffer_t *plaintext, enc_ctx_t *ctx);
int ss_decrypt(buffer_t *ciphertext, enc_ctx_t *ctx);
int ss_encrypt_all(buffer_t *plaintext, int method, int auth, size_t capacity);
int ss_decrypt_all(buffer_t *ciphertext, int method, int auth, size_t capacity);
int ss_encrypt(buffer_t *plaintext, enc_ctx_t *ctx, size_t capacity);
int ss_decrypt(buffer_t *ciphertext, enc_ctx_t *ctx, size_t capacity);
void enc_ctx_init(int method, enc_ctx_t *ctx, int enc);
int enc_init(const char *pass, const char *method);
@ -181,11 +181,11 @@ int enc_get_iv_len(void);
void cipher_context_release(cipher_ctx_t *evp);
unsigned char *enc_md5(const unsigned char *d, size_t n, unsigned char *md);
int ss_onetimeauth(buffer_t *buf, uint8_t *iv);
int ss_onetimeauth(buffer_t *buf, uint8_t *iv, size_t capacity);
int ss_onetimeauth_verify(buffer_t *buf, uint8_t *iv);
int ss_check_hash(buffer_t *buf, chunk_t *chunk, enc_ctx_t *ctx);
int ss_gen_hash(buffer_t *buf, uint32_t *counter, enc_ctx_t *ctx);
int ss_check_hash(buffer_t *buf, chunk_t *chunk, enc_ctx_t *ctx, size_t capacity);
int ss_gen_hash(buffer_t *buf, uint32_t *counter, enc_ctx_t *ctx, size_t capacity);
int balloc(buffer_t *ptr, size_t capacity);
int brealloc(buffer_t *ptr, size_t len, size_t capacity);

12
src/local.c

@ -233,7 +233,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
// continue to wait for recv
return;
} else {
ERROR("server_recv_cb_recv");
if (verbose) ERROR("server_recv_cb_recv");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
@ -252,7 +252,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
}
if (!remote->direct && remote->send_ctx->connected && auth) {
ss_gen_hash(remote->buf, &remote->counter, server->e_ctx);
ss_gen_hash(remote->buf, &remote->counter, server->e_ctx, BUF_SIZE);
}
// insert shadowsocks header
@ -260,7 +260,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
#ifdef ANDROID
tx += remote->buf->len;
#endif
int err = ss_encrypt(remote->buf, server->e_ctx);
int err = ss_encrypt(remote->buf, server->e_ctx, BUF_SIZE);
if (err) {
LOGE("invalid password or cipher");
@ -505,7 +505,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
if (!remote->direct) {
if (auth) {
abuf->array[0] |= ONETIMEAUTH_FLAG;
ss_onetimeauth(abuf, server->e_ctx->evp.iv);
ss_onetimeauth(abuf, server->e_ctx->evp.iv, BUF_SIZE);
}
brealloc(remote->buf, buf->len + abuf->len, BUF_SIZE);
@ -514,7 +514,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
if (buf->len > 0) {
if (auth) {
ss_gen_hash(buf, &remote->counter, server->e_ctx);
ss_gen_hash(buf, &remote->counter, server->e_ctx, BUF_SIZE);
}
memcpy(remote->buf->array + abuf->len, buf->array, buf->len);
}
@ -667,7 +667,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
#ifdef ANDROID
rx += server->buf->len;
#endif
int err = ss_decrypt(server->buf, server->d_ctx);
int err = ss_decrypt(server->buf, server->d_ctx, BUF_SIZE);
if (err) {
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);

10
src/redir.c

@ -187,10 +187,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
remote->buf->len = r;
if (auth) {
ss_gen_hash(remote->buf, &remote->counter, server->e_ctx);
ss_gen_hash(remote->buf, &remote->counter, server->e_ctx, BUF_SIZE);
}
int err = ss_encrypt(remote->buf, server->e_ctx);
int err = ss_encrypt(remote->buf, server->e_ctx, BUF_SIZE);
if (err) {
LOGE("invalid password or cipher");
@ -300,7 +300,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
server->buf->len = r;
int err = ss_decrypt(server->buf, server->d_ctx);
int err = ss_decrypt(server->buf, server->d_ctx, BUF_SIZE);
if (err) {
LOGE("invalid password or cipher");
close_and_free_remote(EV_A_ remote);
@ -376,10 +376,10 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
if (auth) {
abuf->array[0] |= ONETIMEAUTH_FLAG;
ss_onetimeauth(abuf, server->e_ctx->evp.iv);
ss_onetimeauth(abuf, server->e_ctx->evp.iv, BUF_SIZE);
}
int err = ss_encrypt(abuf, server->e_ctx);
int err = ss_encrypt(abuf, server->e_ctx, BUF_SIZE);
if (err) {
bfree(abuf);
LOGE("invalid password or cipher");

10
src/server.c

@ -520,7 +520,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
buf->len = r;
}
int err = ss_decrypt(buf, server->d_ctx);
int err = ss_decrypt(buf, server->d_ctx, BUF_SIZE);
if (err) {
LOGE("invalid password or cipher");
@ -532,7 +532,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
// handshake and transmit data
if (server->stage == 5) {
if (server->auth && !ss_check_hash(remote->buf, server->chunk, server->d_ctx)) {
if (server->auth && !ss_check_hash(remote->buf, server->chunk, server->d_ctx, BUF_SIZE)) {
LOGE("hash error");
report_addr(server->fd);
close_and_free_server(EV_A_ server);
@ -735,7 +735,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
LOGI("connect to: %s:%d", host, ntohs(port));
}
if (server->auth && !ss_check_hash(server->buf, server->chunk, server->d_ctx)) {
if (server->auth && !ss_check_hash(server->buf, server->chunk, server->d_ctx, BUF_SIZE)) {
LOGE("hash error");
report_addr(server->fd);
close_and_free_server(EV_A_ server);
@ -945,7 +945,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
rx += r;
server->buf->len = r;
int err = ss_encrypt(server->buf, server->e_ctx);
int err = ss_encrypt(server->buf, server->e_ctx, BUF_SIZE);
if (err) {
LOGE("invalid password or cipher");
@ -1161,7 +1161,7 @@ static server_t *new_server(int fd, listen_ctx_t *listener)
server->chunk = (chunk_t *)malloc(sizeof(chunk_t));
memset(server->chunk, 0, sizeof(chunk_t));
server->chunk->buf = malloc(sizeof(buffer_t));
balloc(server->chunk->buf, BUF_SIZE);
memset(server->chunk->buf, 0, sizeof(buffer_t));
cork_dllist_add(&connections, &server->entries);

10
src/tunnel.c

@ -201,10 +201,10 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
remote->buf->len = r;
if (auth) {
ss_gen_hash(remote->buf, &remote->counter, server->e_ctx);
ss_gen_hash(remote->buf, &remote->counter, server->e_ctx, BUF_SIZE);
}
int err = ss_encrypt(remote->buf, server->e_ctx);
int err = ss_encrypt(remote->buf, server->e_ctx, BUF_SIZE);
if (err) {
LOGE("invalid password or cipher");
@ -324,7 +324,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
server->buf->len = r;
int err = ss_decrypt(server->buf, server->d_ctx);
int err = ss_decrypt(server->buf, server->d_ctx, BUF_SIZE);
if (err) {
LOGE("invalid password or cipher");
@ -421,10 +421,10 @@ static void remote_send_cb(EV_P_ ev_io *w, int revents)
if (auth) {
abuf->array[0] |= ONETIMEAUTH_FLAG;
ss_onetimeauth(abuf, server->e_ctx->evp.iv);
ss_onetimeauth(abuf, server->e_ctx->evp.iv, BUF_SIZE);
}
int err = ss_encrypt(abuf, server->e_ctx);
int err = ss_encrypt(abuf, server->e_ctx, BUF_SIZE);
if (err) {
bfree(abuf);
LOGE("invalid password or cipher");

8
src/udprelay.c

@ -667,7 +667,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
}
#ifdef MODULE_LOCAL
int err = ss_decrypt_all(buf, server_ctx->method, 0);
int err = ss_decrypt_all(buf, server_ctx->method, 0, BUF_SIZE);
if (err) {
// drop the packet silently
goto CLEAN_UP;
@ -727,7 +727,7 @@ static void remote_recv_cb(EV_P_ ev_io *w, int revents)
memcpy(buf->array, addr_header, addr_header_len);
buf->len += addr_header_len;
int err = ss_encrypt_all(buf, server_ctx->method, 0);
int err = ss_encrypt_all(buf, server_ctx->method, 0, BUF_SIZE);
if (err) {
// drop the packet silently
goto CLEAN_UP;
@ -850,7 +850,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
tx += buf->len;
int err = ss_decrypt_all(buf, server_ctx->method, server_ctx->auth);
int err = ss_decrypt_all(buf, server_ctx->method, server_ctx->auth, BUF_SIZE);
if (err) {
// drop the packet silently
goto CLEAN_UP;
@ -1113,7 +1113,7 @@ static void server_recv_cb(EV_P_ ev_io *w, int revents)
buf->array[0] |= ONETIMEAUTH_FLAG;
}
int err = ss_encrypt_all(buf, server_ctx->method, server_ctx->auth);
int err = ss_encrypt_all(buf, server_ctx->method, server_ctx->auth, BUF_SIZE);
if (err) {
// drop the packet silently

Loading…
Cancel
Save