diff --git a/src/server.c b/src/server.c index 23c5bb9d..6310084c 100644 --- a/src/server.c +++ b/src/server.c @@ -78,6 +78,18 @@ #define MAX_FRAG 1 #endif +#ifdef USE_NFCONNTRACK_TOS + +#ifndef MARK_MAX_PACKET +#define MARK_MAX_PACKET 10 +#endif + +#ifndef MARK_MASK_PREFIX +#define MARK_MASK_PREFIX 0xDC00 +#endif + +#endif + static void signal_cb(EV_P_ ev_signal *w, int revents); static void accept_cb(EV_P_ ev_io *w, int revents); static void server_send_cb(EV_P_ ev_io *w, int revents); @@ -551,6 +563,93 @@ connect_to_remote(EV_P_ struct addrinfo *res, return remote; } +#ifdef USE_NFCONNTRACK_TOS +int setMarkDscpCallback(enum nf_conntrack_msg_type type, struct nf_conntrack *ct, void *data) +{ + server_t* server = (server_t*) data; + struct dscptracker* tracker = server->tracker; + + tracker->mark = nfct_get_attr_u32(ct, ATTR_MARK); + if ((tracker->mark & 0xff00) == MARK_MASK_PREFIX) { + // Extract DSCP value from mark value + tracker->dscp = tracker->mark & 0x00ff; + int tos = (tracker->dscp) << 2; + if (setsockopt(server->fd, IPPROTO_IP, IP_TOS, &tos, sizeof(tos)) != 0) { + ERROR("iptable setsockopt IP_TOS"); + }; + } + return NFCT_CB_CONTINUE; +} + +void conntrackQuery(server_t* server) { + struct dscptracker* tracker = server->tracker; + if(tracker && tracker->ct) { + // Trying query mark from nf conntrack + struct nfct_handle *h = nfct_open(CONNTRACK, 0); + if (h) { + nfct_callback_register(h, NFCT_T_ALL, setMarkDscpCallback, (void*) server); + int x = nfct_query(h, NFCT_Q_GET, tracker->ct); + if (x == -1) { + LOGE("QOS: Failed to retrieve connection mark %s", strerror(errno)); + } + nfct_close(h); + } else { + LOGE("QOS: Failed to open conntrack handle for upstream netfilter mark retrieval."); + } + } +} + +void setTosFromConnmark(remote_t* remote, server_t* server) +{ + if(server->tracker && server->tracker->ct) { + if(server->tracker->mark == 0 && server->tracker->packet_count < MARK_MAX_PACKET) { + server->tracker->packet_count++; + conntrackQuery(server); + } + } else { + socklen_t len; + struct sockaddr_storage sin; + len = sizeof(sin); + if (getsockname(remote->fd, (struct sockaddr *)&sin, &len) == 0) { + struct sockaddr_storage from_addr; + len = sizeof from_addr; + if(getpeername(remote->fd, (struct sockaddr*)&from_addr, &len) == 0) { + if((server->tracker = (struct dscptracker*) malloc(sizeof(struct dscptracker)))) + { + if ((server->tracker->ct = nfct_new())) { + // Build conntrack query SELECT + if (from_addr.ss_family == AF_INET) { + struct sockaddr_in *src = (struct sockaddr_in *)&from_addr; + struct sockaddr_in *dst = (struct sockaddr_in *)&sin; + + nfct_set_attr_u8(server->tracker->ct, ATTR_L3PROTO, AF_INET); + nfct_set_attr_u32(server->tracker->ct, ATTR_IPV4_DST, dst->sin_addr.s_addr); + nfct_set_attr_u32(server->tracker->ct, ATTR_IPV4_SRC, src->sin_addr.s_addr); + nfct_set_attr_u16(server->tracker->ct, ATTR_PORT_DST, dst->sin_port); + nfct_set_attr_u16(server->tracker->ct, ATTR_PORT_SRC, src->sin_port); + } else if (from_addr.ss_family == AF_INET6) { + struct sockaddr_in6 *src = (struct sockaddr_in6 *)&from_addr; + struct sockaddr_in6 *dst = (struct sockaddr_in6 *)&sin; + + nfct_set_attr_u8(server->tracker->ct, ATTR_L3PROTO, AF_INET6); + nfct_set_attr(server->tracker->ct, ATTR_IPV6_DST, dst->sin6_addr.s6_addr); + nfct_set_attr(server->tracker->ct, ATTR_IPV6_SRC, src->sin6_addr.s6_addr); + nfct_set_attr_u16(server->tracker->ct, ATTR_PORT_DST, dst->sin6_port); + nfct_set_attr_u16(server->tracker->ct, ATTR_PORT_SRC, src->sin6_port); + } + nfct_set_attr_u8(server->tracker->ct, ATTR_L4PROTO, IPPROTO_TCP); + conntrackQuery(server); + } else { + LOGE("Failed to allocate new conntrack for upstream netfilter mark retrieval."); + server->tracker->ct=NULL; + }; + } + } + } + } +} +#endif + static void server_recv_cb(EV_P_ ev_io *w, int revents) { @@ -1004,6 +1103,9 @@ remote_recv_cb(EV_P_ ev_io *w, int revents) return; } +#ifdef USE_NFCONNTRACK_TOS + setTosFromConnmark(remote, server); +#endif int s = send(server->fd, server->buf->data, server->buf->len, 0); if (s == -1) { @@ -1232,6 +1334,17 @@ new_server(int fd, listen_ctx_t *listener) static void free_server(server_t *server) { +#ifdef USE_NFCONNTRACK_TOS + if(server->tracker) { + struct dscptracker* tracker = server->tracker; + struct nf_conntrack* ct = server->tracker->ct; + server->tracker = NULL; + if (ct) { + nfct_destroy(ct); + } + free(tracker); + }; +#endif cork_dllist_remove(&server->entries); if (server->remote != NULL) { diff --git a/src/server.h b/src/server.h index a1d74995..90cf051d 100644 --- a/src/server.h +++ b/src/server.h @@ -53,6 +53,20 @@ typedef struct server_ctx { struct server *server; } server_ctx_t; +#ifdef USE_NFCONNTRACK_TOS + +#include +#include + +struct dscptracker { + struct nf_conntrack *ct; + long unsigned int mark; + unsigned int dscp; + unsigned int packet_count; +}; + +#endif + typedef struct server { int fd; int stage; @@ -70,6 +84,9 @@ typedef struct server { struct ResolvQuery *query; struct cork_dllist_item entries; +#ifdef USE_NFCONNTRACK_TOS + struct dscptracker* tracker; +#endif } server_t; typedef struct query {