/* SIP extension for UDP NAT alteration.
 *
 * (C) 2005 by Christian Hentschel <chentschel@arnet.com.ar>
 * based on RR's ip_nat_ftp.c and other modules.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 */

#include <linux/types.h>
#include <linux/netfilter.h>
#include <linux/module.h>
#include <linux/skbuff.h>
#include <linux/ip.h>
#include <linux/udp.h>
#include <linux/ctype.h>
#include <net/netfilter/nf_nat.h>
#include <net/netfilter/nf_nat_helper.h>
#include <net/netfilter/nf_nat_rule.h>
#include <net/netfilter/nf_conntrack_helper.h>
#include <net/netfilter/nf_conntrack_expect.h>
#include <net/netfilter/nf_conntrack_ecache.h>
#include <linux/netfilter/nf_conntrack_sip.h>

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Christian Hentschel <chentschel@arnet.com.ar>");
MODULE_DESCRIPTION("SIP NAT helper");
MODULE_ALIAS("ip_nat_sip");

#if 0
#define DEBUGP printk
#else
#define DEBUGP(format, args...)
#endif

LIST_HEAD(rtp_list);

static unsigned int sip_timeout __read_mostly = SIP_TIMEOUT;

static unsigned int ip_nat_sip(struct sk_buff **pskb,
			       enum ip_conntrack_info ctinfo,
			       struct nf_conn *ct,
			       const char **dptr);


struct addr_map {
	struct {
		char		src[sizeof("nnn.nnn.nnn.nnn:nnnnn")];
		char		dst[sizeof("nnn.nnn.nnn.nnn:nnnnn")];
		unsigned int	srclen, srciplen;
		unsigned int	dstlen, dstiplen;
	} addr[IP_CT_DIR_MAX];
};

void rtp_follow_nat(struct nf_conn *ct, struct nf_conntrack_expect *exp)
{
	struct nf_conn *master = ct->master;
	struct nf_ct_sip_master *info = &nfct_help(master)->help.ct_sip_info;
	struct nf_nat_range range;

	BUG_ON(ct->status & IPS_NAT_DONE_MASK);

	range.flags = IP_NAT_RANGE_MAP_IPS;
	if (exp->dir == IP_CT_DIR_ORIGINAL &&
	   (info->rtp_port != info->rtp_natport)) {
		DEBUGP("info: port=%u, natport=%u\n",
			ntohs(info->rtp_port), ntohs(info->rtp_natport));
		range.min.udp.port = info->rtp_natport;
		range.max =range.min;
		range.flags = (IP_NAT_RANGE_MAP_IPS | IP_NAT_RANGE_PROTO_SPECIFIED);
	}
	range.min_ip = range.max_ip
		= ct->master->tuplehash[!exp->dir].tuple.dst.u3.ip;
	DEBUGP("rtp_follow_nat: POST_ROUTING ip = "NIPQUAD_FMT" port=%u\n",
			NIPQUAD(range.min_ip), ntohs(range.min.udp.port));
	nf_nat_setup_info(ct, &range, NF_IP_POST_ROUTING);

	range.flags = (IP_NAT_RANGE_MAP_IPS | IP_NAT_RANGE_PROTO_SPECIFIED);
	range.min = range.max = exp->saved_proto;
	range.min_ip = range.max_ip
		= ct->master->tuplehash[!exp->dir].tuple.src.u3.ip;
	DEBUGP("rtp_follow_nat: PRE_ROUTING ip = "NIPQUAD_FMT" port=%u\n",
			NIPQUAD(range.min_ip), ntohs(range.min.udp.port));
	nf_nat_setup_info(ct, &range, NF_IP_PRE_ROUTING);
}

void nf_sdp_expected(struct nf_conn *ct, struct nf_conntrack_expect *exp)
{
	struct nf_conn *master = ct->master;
	struct rtp_node *i;

	i = kmalloc(sizeof(struct rtp_node), GFP_ATOMIC);

	// DEBUGP("sdp_expected: i = %p, me = %p, master = %p\n", i, ct, master);
	if (!i) {
		DEBUGP("nf_sdp_expected: can't alloc memory\n");
		goto out;
	}

	i->me = ct;
	i->master = master;
	i->rtp_conn = INVITE;

	list_add(&i->list, &rtp_list);

	DEBUGP("sdp expected: add list ct=%p for master=%p\n", ct, master);
#if 0
	struct nf_conntrack_expect *pos, *tmp;
	list_for_each_entry_safe(pos, tmp, &nf_conntrack_expect_list, list) {
		if (pos->master == master) {
			DEBUGP("ct had exp = %p\n", pos);
		}
	}
	DEBUGP("ct dump: ");
	NF_CT_DUMP_TUPLE(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple);
	NF_CT_DUMP_TUPLE(&ct->tuplehash[IP_CT_DIR_REPLY].tuple);
#endif

out:
	// don't follow the master
	rtp_follow_nat(ct, exp);
}

char saddr_buff[sizeof("nnn.nnn.nnn.nnn:nnnnn")];

static int poorly_port_help(struct sk_buff **pskb, unsigned int protoff,
		struct nf_conn *ct, enum ip_conntrack_info ctinfo)
{
	const char *dptr;
	int ret = NF_ACCEPT;
	unsigned int dataoff, datalen;

	dataoff = protoff + sizeof(struct udphdr);
	if (dataoff >= (*pskb)->len)
		return ret;

	nf_ct_refresh(ct, *pskb, sip_timeout * HZ);

	if (!skb_is_nonlinear(*pskb))
		dptr = (*pskb)->data + dataoff;
	else {
		DEBUGP("Copy of skbuff not supported yet.\n");
		return ret;
	}

	datalen = (*pskb)->len - dataoff;
	if (datalen < sizeof("SIP/2.0 XXX") - 1)
		return ret;

	ret = ip_nat_sip(pskb, ctinfo, ct, &dptr);

	//DEBUGP("poorly_port_help: dptr = %.250s\n", dptr);

	return ret;
}

static struct nf_conntrack_helper poorly_helper[] __read_mostly = {
	{
		.name			= "pp_helper",
		.me			= THIS_MODULE,
		.max_expected   	= 20,
		.timeout		= 120,
		.tuple.src.l3num	= AF_INET,
		.tuple.src.u.udp.port   = __constant_htons(5060),
		.tuple.dst.protonum	= IPPROTO_UDP,
		.mask.src.l3num		= 0xFFFF,
		.mask.src.u.udp.port	= __constant_htons(0xFFFF),
		.mask.dst.protonum	= 0xFF,
		.help			= poorly_port_help,
	},
	{
		.name			= "pp_helper",
		.me			= THIS_MODULE,
		.max_expected   	= 20,
		.timeout		= 120,
		.tuple.src.l3num	= AF_INET6,
		.tuple.src.u.udp.port   = __constant_htons(5060),
		.tuple.dst.protonum	= IPPROTO_UDP,
		.mask.src.l3num		= 0xFFFF,
		.mask.src.u.udp.port	= __constant_htons(0xFFFF),
		.mask.dst.protonum	= 0xFF,
		.help			= poorly_port_help,
	},
};


static void addr_map_init(struct nf_conn *ct, struct addr_map *map)
{
	struct nf_conntrack_tuple *t;
	enum ip_conntrack_dir dir;
	unsigned int n;

	for (dir = 0; dir < IP_CT_DIR_MAX; dir++) {
		t = &ct->tuplehash[dir].tuple;

		n = sprintf(map->addr[dir].src, "%u.%u.%u.%u",
			    NIPQUAD(t->src.u3.ip));
		map->addr[dir].srciplen = n;
		n += sprintf(map->addr[dir].src + n, ":%u",
			     ntohs(t->src.u.udp.port));
		map->addr[dir].srclen = n;

		n = sprintf(map->addr[dir].dst, "%u.%u.%u.%u",
			    NIPQUAD(t->dst.u3.ip));
		map->addr[dir].dstiplen = n;
		n += sprintf(map->addr[dir].dst + n, ":%u",
			     ntohs(t->dst.u.udp.port));
		map->addr[dir].dstlen = n;

		//DEBUGP("addr_map_init: dir=%d, srciplen=%d, srclen=%d, src=%s\n dstiplen=%d, dstlen=%d, dst=%s\n",
		//		dir, 
		//		map->addr[dir].srciplen, map->addr[dir].srclen, map->addr[dir].src,
		//		map->addr[dir].dstiplen, map->addr[dir].dstlen, map->addr[dir].dst);
	}
}

static int map_sip_addr(struct sk_buff **pskb, enum ip_conntrack_info ctinfo,
			struct nf_conn *ct, const char **dptr, size_t dlen,
			enum sip_header_pos pos, struct addr_map *map)
{
	enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
	unsigned int matchlen, matchoff, addrlen, n;
	char *addr;
	struct nf_conntrack_expect *exp = NULL;
	u_int16_t poorly_port = 0;
	__be16 port;

	if (ct_sip_get_info(ct, *dptr, dlen, &matchoff, &matchlen, pos) <= 0)
		return 1;

	DEBUGP("###########%.20s#############\n", *dptr);

	if ((matchlen == map->addr[dir].srciplen ||
	     matchlen == map->addr[dir].srclen) &&
	    memcmp(*dptr + matchoff, map->addr[dir].src, matchlen) == 0) {
		addr    = map->addr[!dir].dst;
		addrlen = map->addr[!dir].dstlen;
	} else if ((matchlen == map->addr[dir].dstiplen ||
		    matchlen == map->addr[dir].dstlen) &&
		   memcmp(*dptr + matchoff, map->addr[dir].dst, matchlen) == 0) {
		addr    = map->addr[!dir].src;
		addrlen = map->addr[!dir].srclen;
	} else if ((pos == POS_VIA) &&
		    memcmp(*dptr + matchoff, map->addr[dir].src, map->addr[dir].srciplen) == 0) {
		DEBUGP("via port not equal to src port, poor prog -- Arthur\n");
		addr     = map->addr[!dir].dst;
		addrlen  = map->addr[!dir].dstiplen;
		matchlen = map->addr[dir].srciplen;
	} else if ((pos == POS_CONTACT) &&
		    memcmp(*dptr + matchoff, map->addr[dir].src, map->addr[dir].srciplen) == 0) {
		addr     = map->addr[!dir].dst;
		addrlen  = map->addr[!dir].dstiplen;
		if (memcmp(*dptr, "REGISTER", sizeof("REGISTER") - 1) == 0) {

			char *c;

			c = (char *)*dptr + matchoff;
			while (*c != ':' && ((*c >= '0' && *c <= '9') || *c == '.'))
				c++;

			c++;
			while (*c >= '0' && *c <= '9')
			{
				poorly_port = poorly_port * 10 + *c - '0';
				c++;
			}

			//DEBUGP("########CONTACT poorly_port? dptr=%.250s, src=%s, poorly_port = %u\n", 
			//	*dptr, map->addr[dir].src, poorly_port);

			exp = nf_conntrack_expect_alloc(ct);
			if (exp == NULL) return 0;
			port = htons(poorly_port);

			nf_conntrack_expect_init(exp, ct->tuplehash[!dir].tuple.src.l3num,
						 &ct->tuplehash[!dir].tuple.src.u3,
						 &ct->tuplehash[!dir].tuple.dst.u3,
						 IPPROTO_UDP,
						 &ct->tuplehash[!dir].tuple.src.u.udp.port,
						 &port);

			exp->saved_proto.udp.port = exp->tuple.dst.u.udp.port;
			exp->dir = !dir;
			exp->expectfn = nf_nat_follow_master;
			exp->helper = poorly_helper;

			for (poorly_port = ntohs(exp->saved_proto.udp.port); poorly_port != 0; poorly_port++)
			{
				exp->tuple.dst.u.udp.port = htons(poorly_port);
				if (nf_conntrack_expect_related(exp) == 0)
				{
					DEBUGP("nf_nat_sip: poorly port expected, port = %u\n", poorly_port);
					break;
				}
			}

			if (poorly_port != 0)
			{
				c = addr;
				n = 0;
				while (*c != ':' && *c != '\0' ) {
					saddr_buff[n] = *c;
					c++;
					n++;
				}
				n += sprintf(saddr_buff+n, ":%u", poorly_port);
				addr = saddr_buff;
				addrlen = n;
			} else
				matchlen = map->addr[dir].srciplen;


			//DEBUGP("\nnf_nat_sip: after addr = %s, addrlen = %u, matchlen = %u\n", addr, addrlen, matchlen);
		}
	} else
		return 1;

	if (!nf_nat_mangle_udp_packet(pskb, ct, ctinfo,
				      matchoff, matchlen, addr, addrlen))
	{
		if (exp)
			nf_conntrack_unexpect_related(exp);
		return 0;
	}

	if (exp)
		nf_conntrack_expect_put(exp);

	*dptr = (*pskb)->data + (*pskb)->nh.iph->ihl*4 + sizeof(struct udphdr);
	return 1;

}

static unsigned int ip_nat_sip(struct sk_buff **pskb,
			       enum ip_conntrack_info ctinfo,
			       struct nf_conn *ct,
			       const char **dptr)
{
	enum sip_header_pos pos;
	struct addr_map map;
	int dataoff, datalen;

	dataoff = (*pskb)->nh.iph->ihl*4 + sizeof(struct udphdr);
	datalen = (*pskb)->len - dataoff;
	if (datalen < sizeof("SIP/2.0") - 1)
		return NF_DROP;

	addr_map_init(ct, &map);

	/* Basic rules: requests and responses. */
	if (strncmp(*dptr, "SIP/2.0", sizeof("SIP/2.0") - 1) != 0) {
		/* 10.2: Constructing the REGISTER Request:
		 *
		 * The "userinfo" and "@" components of the SIP URI MUST NOT
		 * be present.
		 */
		if (datalen >= sizeof("REGISTER") - 1 &&
		    strncmp(*dptr, "REGISTER", sizeof("REGISTER") - 1) == 0)
			pos = POS_REG_REQ_URI;
		else
			pos = POS_REQ_URI;

		if (!map_sip_addr(pskb, ctinfo, ct, dptr, datalen, pos, &map))
			return NF_DROP;

	}

	if (!map_sip_addr(pskb, ctinfo, ct, dptr, datalen, POS_FROM, &map) ||
	    !map_sip_addr(pskb, ctinfo, ct, dptr, datalen, POS_TO, &map) ||
	    !map_sip_addr(pskb, ctinfo, ct, dptr, datalen, POS_VIA, &map) ||
	    !map_sip_addr(pskb, ctinfo, ct, dptr, datalen, POS_CONTACT, &map))
		return NF_DROP;
	
	if ((strncmp(*dptr, "BYE", sizeof("BYE") - 1) == 0) ||
	    (strncmp(*dptr, "CANCEL", sizeof("CANCEL") - 1) == 0) ||
	    (strncmp(*dptr, "SIP/2.0 400 Bad", sizeof("SIP/2.0 400 Bad") - 1) == 0)) {
		DEBUGP("---------------- BYE\\CANCEL\\Bad Request --------------\n");
		struct rtp_node *i, *tmp;

		nf_ct_remove_expectations(ct);
		list_for_each_entry_safe(i, tmp, &rtp_list, list)
		{
			if (list_empty(&rtp_list)) {
				DEBUGP("rtp_list empty\n");
				break;
			}
			if (i->master == ct) {
				if (strncmp(*dptr, "BYE", sizeof("BYE") - 1) == 0)
					i->rtp_conn = BYE;
				else
					i->rtp_conn = CANCEL;
			}
		}
	}

	if (strncmp(*dptr, "SIP/2.0 200 OK", sizeof("SIP/2.0 200 OK") - 1) == 0) {
		struct rtp_node *i, *tmp;
		struct nf_conn *rtp;

		list_for_each_entry_safe(i, tmp, &rtp_list, list)
		{
			if (list_empty(&rtp_list))
			{
				DEBUGP("rtp_list empty\n");
				break;
			}
			if (i->master == ct && 
			    ((i->rtp_conn == BYE) || (i->rtp_conn == CANCEL))) {
				rtp = i->me;
				// destroy the conntrack i->me by using timeout --Arthur
				DEBUGP("Destroy ct = %p by command BYE or (Bad Cancel?)\n", rtp);
				if (del_timer(&rtp->timeout))
				{
					//rtp->timeout.expires = jiffies + HZ;
					//add_timer(&rtp->timeout);
					//nf_conntrack_event(IPCT_REFRESH, rtp);
					rtp->timeout.function((unsigned long)rtp);
				}
				list_del(&i->list);
				kfree(i);
			}
		}
	}

	return NF_ACCEPT;
}

static unsigned int mangle_sip_packet(struct sk_buff **pskb,
				      enum ip_conntrack_info ctinfo,
				      struct nf_conn *ct,
				      const char **dptr, size_t dlen,
				      char *buffer, int bufflen,
				      enum sip_header_pos pos)
{
	unsigned int matchlen, matchoff;

	if (ct_sip_get_info(ct, *dptr, dlen, &matchoff, &matchlen, pos) <= 0)
		return 0;

	if (!nf_nat_mangle_udp_packet(pskb, ct, ctinfo,
				      matchoff, matchlen, buffer, bufflen))
		return 0;

	/* We need to reload this. Thanks Patrick. */
	*dptr = (*pskb)->data + (*pskb)->nh.iph->ihl*4 + sizeof(struct udphdr);
	return 1;
}

static int mangle_content_len(struct sk_buff **pskb,
			      enum ip_conntrack_info ctinfo,
			      struct nf_conn *ct,
			      const char *dptr)
{
	unsigned int dataoff, matchoff, matchlen;
	char buffer[sizeof("65536")];
	int bufflen;

	dataoff = (*pskb)->nh.iph->ihl*4 + sizeof(struct udphdr);

	/* Get actual SDP lenght */
	if (ct_sip_get_info(ct, dptr, (*pskb)->len - dataoff, &matchoff,
			    &matchlen, POS_SDP_HEADER) > 0) {

		/* since ct_sip_get_info() give us a pointer passing 'v='
		   we need to add 2 bytes in this count. */
		int c_len = (*pskb)->len - dataoff - matchoff + 2;

		/* Now, update SDP length */
		if (ct_sip_get_info(ct, dptr, (*pskb)->len - dataoff, &matchoff,
				    &matchlen, POS_CONTENT) > 0) {

			bufflen = sprintf(buffer, "%u", c_len);
			return nf_nat_mangle_udp_packet(pskb, ct, ctinfo,
							matchoff, matchlen,
							buffer, bufflen);
		}
	}
	return 0;
}

static unsigned int mangle_sdp(struct sk_buff **pskb,
			       enum ip_conntrack_info ctinfo,
			       struct nf_conn *ct,
			       __be32 newip, u_int16_t port,
			       const char *dptr)
{
	char buffer[sizeof("nnn.nnn.nnn.nnn")];
	unsigned int dataoff, bufflen;

	dataoff = (*pskb)->nh.iph->ihl*4 + sizeof(struct udphdr);

	/* Mangle owner and contact info. */
	bufflen = sprintf(buffer, "%u.%u.%u.%u", NIPQUAD(newip));
	if (!mangle_sip_packet(pskb, ctinfo, ct, &dptr, (*pskb)->len - dataoff,
			       buffer, bufflen, POS_OWNER_IP4))
		return 0;

	if (!mangle_sip_packet(pskb, ctinfo, ct, &dptr, (*pskb)->len - dataoff,
			       buffer, bufflen, POS_CONNECTION_IP4))
		return 0;

	/* Mangle media port. */
	bufflen = sprintf(buffer, "%u", port);
	if (!mangle_sip_packet(pskb, ctinfo, ct, &dptr, (*pskb)->len - dataoff,
			       buffer, bufflen, POS_MEDIA))
		return 0;

	return mangle_content_len(pskb, ctinfo, ct, dptr);
}

/* So, this packet has hit the connection tracking matching code.
   Mangle it, and change the expectation to match the new version. */
static unsigned int ip_nat_sdp(struct sk_buff **pskb,
			       enum ip_conntrack_info ctinfo,
			       struct nf_conntrack_expect *exp,
			       const char *dptr)
{
	struct nf_conn *ct = exp->master;
	enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
	__be32 newip;
	u_int16_t port;
	struct nf_ct_sip_master *info = &nfct_help(ct)->help.ct_sip_info;
	

	DEBUGP("ip_nat_sdp():\n");

	/* Connection will come from reply */
	newip = ct->tuplehash[!dir].tuple.dst.u3.ip;

	exp->tuple.dst.u3.ip = newip;
	exp->saved_proto.udp.port = exp->tuple.dst.u.udp.port;
	exp->dir = !dir;

	/* When you see the packet, we need to NAT it the same as the
	   this one. */
	exp->expectfn = nf_sdp_expected;

	/* Try to get same port: if not, try to change it. */
	port = ntohs(exp->saved_proto.udp.port);
	if (port > 65530) port = 2000;
	for (; port != 0; port += 2) {
		exp->tuple.dst.u.udp.port = htons(port);
		if (nf_conntrack_expect_related(exp) == 0) {
			DEBUGP("set_expected_rtp: ct  = %p, exp = %p, dir = %s\n",
				ct, exp,
				dir == IP_CT_DIR_REPLY? "IP_CT_DIR_REPLY":"IP_CT_DIR_ORIGINAL");
			DEBUGP("set exp src = " NIPQUAD_FMT ":%u -> dst = " NIPQUAD_FMT ":%u\n",
				NIPQUAD(exp->tuple.src.u3.ip),
				ntohs(exp->tuple.src.u.all),
				NIPQUAD(exp->tuple.dst.u3.ip),
				ntohs(exp->tuple.dst.u.all));

			break;
		}
	}

	if (port == 0)
		return NF_DROP;

	if (dir == IP_CT_DIR_ORIGINAL) {
		info->rtp_port = exp->saved_proto.udp.port;
		info->rtp_natport = exp->tuple.dst.u.udp.port;
	}

	DEBUGP("nat_sdp: rtp_port=%u  rtp_natport=%u\n",
		ntohs(info->rtp_port), ntohs(info->rtp_natport));

	if (!mangle_sdp(pskb, ctinfo, ct, newip, port, dptr)) {
		nf_conntrack_unexpect_related(exp);
		return NF_DROP;
	}
	return NF_ACCEPT;
}

static void __exit nf_nat_sip_fini(void)
{
	rcu_assign_pointer(nf_nat_sip_hook, NULL);
	rcu_assign_pointer(nf_nat_sdp_hook, NULL);
	synchronize_rcu();
}

static int __init nf_nat_sip_init(void)
{
	BUG_ON(rcu_dereference(nf_nat_sip_hook));
	BUG_ON(rcu_dereference(nf_nat_sdp_hook));
	rcu_assign_pointer(nf_nat_sip_hook, ip_nat_sip);
	rcu_assign_pointer(nf_nat_sdp_hook, ip_nat_sdp);
	return 0;
}

module_init(nf_nat_sip_init);
module_exit(nf_nat_sip_fini);
