diff --git a/encrypt.c b/encrypt.c index 3f0baa97..157d34a7 100644 --- a/encrypt.c +++ b/encrypt.c @@ -125,13 +125,14 @@ static void merge_sort(uint8_t array[], int length) void encrypt_ctx(char *buf, int len, EVP_CIPHER_CTX *ctx) { if (ctx != NULL) { int outlen; - unsigned char mybuf[BUF_SIZE]; + unsigned char *mybuf = malloc(BUF_SIZE); EVP_CipherUpdate(ctx, mybuf, &outlen, (unsigned char*)buf, len); memcpy(buf, mybuf, len); + free(mybuf); } else { char *end = buf + len; while (buf < end) { - *buf = (char)encrypt_table[(uint8_t)*buf]; + *buf = (char)enc_ctx.table.encrypt_table[(uint8_t)*buf]; buf++; } } @@ -140,48 +141,69 @@ void encrypt_ctx(char *buf, int len, EVP_CIPHER_CTX *ctx) { void decrypt_ctx(char *buf, int len, EVP_CIPHER_CTX *ctx) { if (ctx != NULL) { int outlen; - unsigned char mybuf[BUF_SIZE]; + unsigned char *mybuf = malloc(BUF_SIZE); EVP_CipherUpdate(ctx, mybuf, &outlen, (unsigned char*) buf, len); memcpy(buf, mybuf, len); + free(mybuf); } else { char *end = buf + len; while (buf < end) { - *buf = (char)decrypt_table[(uint8_t)*buf]; + *buf = (char)enc_ctx.table.decrypt_table[(uint8_t)*buf]; buf++; } } } -void enc_ctx_init(EVP_CIPHER_CTX *ctx, const char *pass, int enc) { - unsigned char key[EVP_MAX_KEY_LENGTH]; - unsigned char iv[EVP_MAX_IV_LENGTH]; - int key_len = EVP_BytesToKey(EVP_rc4(), EVP_md5(), NULL, (unsigned char*) pass, - strlen(pass), 1, key, iv); +void enc_ctx_init(EVP_CIPHER_CTX *ctx, int enc) { + uint8_t *key = enc_ctx.rc4.key; + int key_len = enc_ctx.rc4.key_len; EVP_CIPHER_CTX_init(ctx); EVP_CipherInit_ex(ctx, EVP_rc4(), NULL, NULL, NULL, enc); if (!EVP_CIPHER_CTX_set_key_length(ctx, key_len)) { - LOGE("Invalid key length: %d", key_len); + LOGE("Invalid key length: %d\n", key_len); EVP_CIPHER_CTX_cleanup(ctx); exit(EXIT_FAILURE); } - EVP_CipherInit_ex(ctx, NULL, NULL, key, iv, enc); + EVP_CipherInit_ex(ctx, NULL, NULL, key, NULL, enc); +} + +void enc_key_init(const char *pass) { + unsigned char key[EVP_MAX_KEY_LENGTH]; + unsigned char iv[EVP_MAX_IV_LENGTH]; + int key_len = EVP_BytesToKey(EVP_rc4(), EVP_md5(), NULL, (unsigned char*) pass, + strlen(pass), 1, key, iv); + if (!key_len) { + LOGE("Invalid key length: %d\n", key_len); + exit(EXIT_FAILURE); + } + enc_ctx.rc4.key_len = key_len; + enc_ctx.rc4.key = malloc(key_len); + memcpy(enc_ctx.rc4.key, key, key_len); } void get_table(const char *pass) { - uint8_t *table = encrypt_table; + uint8_t *enc_table = enc_ctx.table.encrypt_table; + uint8_t *dec_table = enc_ctx.table.decrypt_table; uint8_t *tmp_hash = MD5((unsigned char *) pass, strlen(pass), NULL); - _a = htole64(*(uint64_t *) tmp_hash); uint32_t i; + _a = htole64(*(uint64_t *) tmp_hash); + + enc_table = malloc(256); + dec_table = malloc(256); + for(i = 0; i < 256; ++i) { - table[i] = i; + enc_table[i] = i; } for(i = 1; i < 1024; ++i) { _i = i; - merge_sort(table, 256); + merge_sort(enc_table, 256); } for(i = 0; i < 256; ++i) { // gen decrypt table from encrypt table - decrypt_table[encrypt_table[i]] = i; + dec_table[enc_table[i]] = i; } + + enc_ctx.table.encrypt_table = enc_table; + enc_ctx.table.decrypt_table = dec_table; } diff --git a/encrypt.h b/encrypt.h index 4138af1f..bd02e388 100755 --- a/encrypt.h +++ b/encrypt.h @@ -13,13 +13,23 @@ #define TABLE 0 #define RC4 1 -unsigned char encrypt_table[256]; -unsigned char decrypt_table[256]; +union { + struct { + unsigned char *encrypt_table; + unsigned char *decrypt_table; + } table; + + struct { + unsigned char *key; + int key_len; + } rc4; +} enc_ctx; void get_table(const char* key); void encrypt_ctx(char *buf, int len, EVP_CIPHER_CTX *ctx); void decrypt_ctx(char *buf, int len, EVP_CIPHER_CTX *ctx); -void enc_ctx_init(EVP_CIPHER_CTX *ctx, const char *pass, int enc); +void enc_ctx_init(EVP_CIPHER_CTX *ctx, int enc); +void enc_key_init(const char *pass); unsigned int _i; unsigned long long _a; diff --git a/local.c b/local.c index a475c220..010dbe93 100755 --- a/local.c +++ b/local.c @@ -38,7 +38,6 @@ static char *_server; static char *_remote_port; static int _timeout; -static char *_key; int setnonblocking(int fd) { int flags; @@ -495,8 +494,8 @@ struct server* new_server(int fd) { if (_method == RC4) { server->e_ctx = malloc(sizeof(EVP_CIPHER_CTX)); server->d_ctx = malloc(sizeof(EVP_CIPHER_CTX)); - enc_ctx_init(server->e_ctx, _key, 1); - enc_ctx_init(server->d_ctx, _key, 0); + enc_ctx_init(server->e_ctx, 1); + enc_ctx_init(server->d_ctx, 0); } else { server->e_ctx = NULL; server->d_ctx = NULL; @@ -694,7 +693,6 @@ int main (int argc, char **argv) _server = strdup(server); _remote_port = strdup(remote_port); _timeout = atoi(timeout); - _key = key; _method = TABLE; if (method != NULL) { if (strcmp(method, "rc4") == 0) { @@ -703,7 +701,9 @@ int main (int argc, char **argv) } LOGD("calculating ciphers %d\n", _method); - if (_method != RC4) { + if (_method == RC4) { + enc_key_init(key); + } else { get_table(key); }