/* RTSP extension for TCP NAT alteration.
 * (C) 2008 by Arthur Tang <tangjingbiao@gmail.com>
 *
 * based on nf_nat_ftp.c 
 * 	  & ip_nat_rtsp.c (Tom Marshall)
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 *
 */

#include <linux/module.h>
#include <linux/moduleparam.h>
#include <linux/ip.h>
#include <linux/tcp.h>
#include <linux/kernel.h>
#include <linux/netfilter_ipv4.h>
#include <net/netfilter/nf_nat.h>
#include <net/netfilter/nf_nat_helper.h>
#include <net/netfilter/nf_nat_rule.h>
#include <net/netfilter/nf_conntrack_helper.h>
#include <net/netfilter/nf_conntrack_expect.h>
#include <linux/netfilter/nf_conntrack_rtsp.h>

#include <linux/netfilter_helpers.h>
#include <linux/netfilter_mime.h>

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Arthur Tang <tangjingbiao@gmail.com>");
MODULE_DESCRIPTION("rtsp NAT helper");
MODULE_ALIAS("ip_nat_rtsp");

#if 0
#define DEBUGP printk
#else
#define DEBUGP(format, args...)
#endif

static unsigned int nf_nat_rtsp(struct sk_buff **pskb,
				enum ip_conntrack_info ctinfo,
				unsigned int matchoff,
				unsigned int matchlen,
				struct nf_conntrack_expect *exp,
				struct nf_conntrack_expect *exp2,
				struct nf_ct_rtsp_expect *expinfo,
				unsigned int destioff,
				unsigned int destilen)
{
	__be32 newip;
	u_int16_t loport, hiport;
	char buf_1port[16];
	uint buf_1portlen;
	char buf_allport[16];
	uint buf_allportlen;
	int dir = CTINFO2DIR(ctinfo);
	struct nf_conn *ct = exp->master;
	char newipbuff[sizeof("nnn.nnn.nnn.nnn")];
	uint newiplen;

	loport = hiport = 0;
	newip = ct->tuplehash[!dir].tuple.dst.u3.ip;
	exp->saved_proto.udp.port = exp->tuple.dst.u.udp.port;
	exp->dir = !dir;

	if (exp2) {
		exp2->saved_proto.udp.port = exp2->tuple.dst.u.udp.port;
		exp2->dir = !dir;
		exp2->expectfn = nf_nat_follow_master;
	}

	exp->expectfn = nf_nat_follow_master;

	DEBUGP("RTSP_NAT: exp = %p, exp2 = %p, newip = " NIPQUAD_FMT ", port = %u\n", exp, exp2,
			NIPQUAD(newip), ntohs(exp->saved_proto.udp.port));

	buf_1portlen = buf_allportlen = 0;

	DEBUGP("nf_nat_rtsp: expinfo = %p, ports=(%d,%hu,%hu)\n", expinfo,
			(int)expinfo->pbtype,
			expinfo->loport,
			expinfo->hiport);

	switch(expinfo->pbtype)
	{
		case pb_single:
			for (loport = ntohs(exp->saved_proto.udp.port); loport != 0; loport++)
		       	{
				exp->tuple.dst.u.udp.port = htons(loport);
				if (nf_conntrack_expect_related(exp) == 0)
				{
					DEBUGP("nf_nat_rstp: use single port %u\n", loport);
					break;
				}
			}
			if (loport)
			{
				buf_1portlen = sprintf(buf_1port, "%hu", loport);
				buf_allportlen = sprintf(buf_allport, "%hu", loport);
			}
			break;
		case pb_range:
			for (loport = expinfo->loport; loport != 0; loport += 2)
			{
				exp->tuple.dst.u.udp.port = htons(loport);
				if (nf_conntrack_expect_related(exp) == 0)
				{
					hiport = loport + 1;
					exp2->tuple.dst.u.udp.port = htons(hiport);
					nf_conntrack_expect_related(exp2);
					DEBUGP("nf_nat_rtsp: using port range %hu-%hu\n",
						       loport, hiport);
					break;
				}
			}
			if (loport)
			{
				buf_1portlen = sprintf(buf_1port, "%hu", loport);
				buf_allportlen = sprintf(buf_allport, "%hu-%hu", loport, hiport);
			}
			break;
		case pb_discon:
			for (loport = expinfo->loport; loport != 0; loport++)
			{
				exp->tuple.dst.u.udp.port = htons(loport);
				if (nf_conntrack_expect_related(exp) == 0)
				{
					DEBUGP("nf_nat_rtsp: using port %hu (1 of 2)\n", loport);
					break;
				}
			}
			for (hiport = expinfo->hiport; hiport != 0; hiport++)
			{
				exp2->tuple.dst.u.udp.port = htons(hiport);
				if (nf_conntrack_expect_related(exp2) == 0)
				{
					DEBUGP("nf_nat_rtsp: using port %hu (2 of 2)\n", hiport);
					break;
				}
			}
			if (loport != 0 && hiport != 0)
			{
				buf_1portlen = sprintf(buf_1port, "%hu", loport);
				if (hiport == loport+1)
					buf_allportlen = sprintf(buf_allport, "%hu-%hu", 
								loport, hiport);
				else
					buf_allportlen = sprintf(buf_allport, "%hu/%hu",
								loport, hiport);
			}
			break;
	}

	if (buf_1portlen == 0)
		return 0;

	/* Now, Need mangle the packet */

	DEBUGP("mangle client port - Arthur\n");
	if ( !nf_nat_mangle_tcp_packet(pskb, ct, ctinfo, matchoff, matchlen,								  buf_allport, buf_allportlen))
	{
		DEBUGP("nf_nat_rtsp: mangle client port error?\n");
		nf_conntrack_unexpect_related(exp);
		if (exp2) nf_conntrack_unexpect_related(exp2);
		return NF_DROP;
	}

	DEBUGP("destioff = %u, destilen = %u\n", destioff, destilen);
	if (destioff != 0 && destilen != 0)
	{
		DEBUGP("mangle destination - Arthur\n");
		newiplen = sprintf(newipbuff, NIPQUAD_FMT, NIPQUAD(newip));
		DEBUGP("newip = %s   newiplen = %d destilen = %d\n", 
				newipbuff, newiplen, destilen);
		if ( !nf_nat_mangle_tcp_packet(pskb, ct, ctinfo, destioff, destilen,
					newipbuff, newiplen))
		{
			DEBUGP("nf_nat_rtsp: mangle destination error?\n");
			return NF_ACCEPT;
		}
	}
	return NF_ACCEPT;
	
}

static void __exit nf_nat_rtsp_fini(void)
{
	rcu_assign_pointer(nf_nat_rtsp_hook, NULL);
	synchronize_rcu();
}

static int __init nf_nat_rtsp_init(void)
{
	BUG_ON(rcu_dereference(nf_nat_rtsp_hook));
	rcu_assign_pointer(nf_nat_rtsp_hook, nf_nat_rtsp);
	return 0;
}

module_init(nf_nat_rtsp_init);
module_exit(nf_nat_rtsp_fini);
