Browse Source

fix encryption context

pull/4/merge
Max Lv 11 years ago
parent
commit
68f31592d9
4 changed files with 54 additions and 45 deletions
  1. 24
      jconf.c
  2. 50
      server.c
  3. 24
      utils.c
  4. 1
      utils.h

24
jconf.c

@ -8,30 +8,6 @@
#include "json.h"
#include "string.h"
#define INT_DIGITS 19 /* enough for 64 bit integer */
static char *itoa(int i)
{
/* Room for INT_DIGITS digits, - and '\0' */
static char buf[INT_DIGITS + 2];
char *p = buf + INT_DIGITS + 1; /* points to terminating '\0' */
if (i >= 0) {
do {
*--p = '0' + (i % 10);
i /= 10;
} while (i != 0);
return p;
}
else { /* i < 0 */
do {
*--p = '0' - (i % 10);
i /= 10;
} while (i != 0);
*--p = '-';
}
return p;
}
static char *to_string(const json_value *value) {
if (value->type == json_string) {
return strndup(value->u.string.ptr, value->u.string.length);

50
server.c

@ -35,6 +35,8 @@
#define min(a,b) (((a)<(b))?(a):(b))
static int verbose = 0;
int setnonblocking(int fd) {
int flags;
if (-1 ==(flags = fcntl(fd, F_GETFL, 0)))
@ -53,7 +55,7 @@ int create_and_bind(const char *host, const char *port) {
s = getaddrinfo(host, port, &hints, &result);
if (s != 0) {
LOGD("getaddrinfo: %s\n", gai_strerror(s));
LOGE("getaddrinfo: %s\n", gai_strerror(s));
return -1;
}
@ -178,7 +180,7 @@ static void server_recv_cb (EV_P_ ev_io *w, int revents) {
ev_io_start(EV_A_ &remote->send_ctx->io);
return;
} else {
perror("send");
perror("server_recv_send");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
@ -211,14 +213,15 @@ static void server_recv_cb (EV_P_ ev_io *w, int revents) {
int offset = 0;
char atyp = server->buf[offset++];
char host[256];
int port;
memset(host, 0, 256);
int port = 0;
// get remote addr and port
if (atyp == 1) {
// IP V4
size_t in_addr_len = sizeof(struct in_addr);
char *a = inet_ntoa(*(struct in_addr*)(server->buf + offset));
memcpy(host, a, sizeof(a));
memcpy(host, a, strlen(a));
offset += in_addr_len;
} else if (atyp == 3) {
@ -237,14 +240,18 @@ static void server_recv_cb (EV_P_ ev_io *w, int revents) {
port += *(uint8_t *)(server->buf + offset++) << 8;
port += *(uint8_t *)(server->buf + offset);
struct remote *remote = connect_to_remote(host, port, server->timeout);
if (verbose) {
LOGD("connect to: %s:%s\n", host, itoa(port));
}
struct remote *remote = connect_to_remote(host, itoa(port), server->timeout);
if (remote == NULL) {
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
}
// listen to remote connected event
ev_io_stop(EV_A_ &server->recv_ctx->io);
ev_io_start(EV_A_ &remote->send_ctx->io);
ev_timer_start(EV_A_ &remote->send_ctx->watcher);
@ -270,7 +277,7 @@ static void server_send_cb (EV_P_ ev_io *w, int revents) {
server->buf_len, 0);
if (r < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
perror("send");
perror("server_send_send");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
}
@ -308,7 +315,7 @@ static void remote_timeout_cb(EV_P_ ev_timer *watcher, int revents) {
struct remote *remote = remote_ctx->remote;
struct server *server = remote->server;
LOGD("remote timeout\n");
LOGE("remote timeout\n");
ev_timer_stop(EV_A_ watcher);
@ -349,7 +356,7 @@ static void remote_recv_cb (EV_P_ ev_io *w, int revents) {
return;
}
}
decrypt_ctx(server->buf, r, server->d_ctx);
encrypt_ctx(server->buf, r, server->e_ctx);
int w = send(server->fd, server->buf, r, 0);
if(w == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
@ -359,7 +366,7 @@ static void remote_recv_cb (EV_P_ ev_io *w, int revents) {
ev_io_start(EV_A_ &server->send_ctx->io);
break;
} else {
perror("send");
perror("remote_recv_send");
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
return;
@ -416,7 +423,7 @@ static void remote_send_cb (EV_P_ ev_io *w, int revents) {
remote->buf_len, 0);
if (r < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK) {
perror("send");
perror("remote_send_send");
// close and free
close_and_free_remote(EV_A_ remote);
close_and_free_server(EV_A_ server);
@ -547,15 +554,19 @@ static void accept_cb (EV_P_ ev_io *w, int revents) {
}
setnonblocking(serverfd);
if (verbose) {
LOGD("Accept a connection.\n");
}
struct server *server = new_server(serverfd);
server->timeout = listener->timeout;
ev_io_start(EV_A_ &server->recv_ctx->io);
}
int main (int argc, char **argv) {
int i, c;
int pid_flags = 0;
char *local_port = NULL;
char *password = NULL;
char *timeout = NULL;
char *method = NULL;
@ -568,7 +579,7 @@ int main (int argc, char **argv) {
opterr = 0;
while ((c = getopt (argc, argv, "f:s:p:l:k:t:m:c:")) != -1) {
while ((c = getopt (argc, argv, "f:s:p:l:k:t:m:c:v")) != -1) {
switch (c) {
case 's':
server_host[server_num++] = optarg;
@ -576,9 +587,6 @@ int main (int argc, char **argv) {
case 'p':
server_port = optarg;
break;
case 'l':
local_port = optarg;
break;
case 'k':
password = optarg;
break;
@ -595,6 +603,8 @@ int main (int argc, char **argv) {
case 'c':
conf_path = optarg;
break;
case 'v':
verbose = 1;
}
}
@ -612,14 +622,12 @@ int main (int argc, char **argv) {
}
}
if (server_port == NULL) server_port = conf->remote_port;
if (local_port == NULL) local_port = conf->local_port;
if (password == NULL) password = conf->password;
if (method == NULL) method = conf->method;
if (timeout == NULL) timeout = conf->timeout;
}
if (server_num == 0 || server_port == NULL ||
local_port == NULL || password == NULL) {
if (server_num == 0 || server_port == NULL || password == NULL) {
usage();
exit(EXIT_FAILURE);
}
@ -650,7 +658,7 @@ int main (int argc, char **argv) {
// Bind to port
int listenfd;
listenfd = create_and_bind(host, local_port);
listenfd = create_and_bind(host, server_port);
if (listenfd < 0) {
FATAL("bind() error..\n");
}
@ -658,7 +666,7 @@ int main (int argc, char **argv) {
FATAL("listen() error.\n");
}
setnonblocking(listenfd);
LOGD("server listening at port %s.\n", local_port);
LOGD("server listening at port %s.\n", server_port);
// Setup proxy context
struct listen_ctx listen_ctx;

24
utils.c

@ -4,6 +4,30 @@
#include "utils.h"
#define INT_DIGITS 19 /* enough for 64 bit integer */
char *itoa(int i) {
/* Room for INT_DIGITS digits, - and '\0' */
static char buf[INT_DIGITS + 2];
char *p = buf + INT_DIGITS + 1; /* points to terminating '\0' */
if (i >= 0) {
do {
*--p = '0' + (i % 10);
i /= 10;
} while (i != 0);
return p;
}
else { /* i < 0 */
do {
*--p = '0' - (i % 10);
i /= 10;
} while (i != 0);
*--p = '-';
}
return p;
}
void FATAL(const char *msg) {
fprintf(stderr, "%s", msg);
exit(-1);

1
utils.h

@ -7,5 +7,6 @@
void FATAL(const char *msg);
void usage(void);
void demonize(const char* path);
char *itoa(int i);
#endif // _UTILS_H
Loading…
Cancel
Save