#!/usr/bin/python
#
# This script is run by a VM host (eg "west") to prepare itself for testing
# It should be passed a testname as its onyl argument

import os,sys,socket,shutil,distutils.dir_util,commands
import pexpect, glob

try:
	import argparse
except ImportError , e:
	module = str(e)[16:]
	sys.exit("we requires the python argparse module")

# prevent double installs from going unnoticed
if os.path.isfile("/usr/sbin/ipsec") or os.path.isfile("/sbin/ipsec"):
	if os.path.isfile("/usr/local/sbin/ipsec"):
		sys.exit("ABORT: found a swan userland in the base system as well as /usr/local")

# just in case a new copy was installed
if os.path.isfile("/usr/local/sbin/ipsec") and os.path.isfile("/usr/sbin/restorecon"):
	commands.getoutput("restorecon -Rv /usr/local/libexec/ipsec /usr/local/sbin/ipsec")

ipv = 0
ipv6 = 0
parser = argparse.ArgumentParser(description='swan-prep arguments')
parser.add_argument('--testname', '-t', action='store', default='', help='The name of the test to prepapre')
parser.add_argument('--hostname', '-H', action='store', default='', help='The name of the host to prepare as')
# we should get this from the testparams.sh file?
parser.add_argument('--userland', '-u', action='store', default='libreswan', help='which userland to prepapre')
parser.add_argument('--x509', '-x', action='store_true',  help='create X509 NSS file by importing test certs')
parser.add_argument('--46', '--64', action='store_true', help='Do not disable IPv6. Default is disable IPv6 ', dest = 'ipv', default = False)
parser.add_argument('--6', action='store_true', help='Enable IPv6 and run - /etc/init.d/network restart', dest = 'ipv6', default = False)
args = parser.parse_args()

if args.hostname:
	hostname = args.hostname
	if hostname == "nic":
		# nothing to do, just stop
		sys.exit()
else:
	hostname = socket.gethostname()
if "." in hostname:
	hostname = hostname.split(".")[0]

if args.testname:
	#if os.path.isdir("/testing/pluto/%s"%args.testname):
	#	testname = args.testname
	#else:
	#	sys.exit("Unknown or bad testname '%s'"%args.testname)
	testname = args.testname
else:
	testpath = "/testing/pluto/%s"%os.path.basename(os.getcwd())
	testname = os.path.basename(testpath)
	if not os.path.isdir(testpath):
		sys.exit("testcase /testing/pluto/%s invalid - aborted"%testname)

# Setup pluto.log softlink
if hostname != "nic":
	if os.path.isfile("/tmp/pluto.log") or os.path.islink("/tmp/pluto.log"):
		os.unlink("/tmp/pluto.log")

	outputdir="/testing/pluto/%s/OUTPUT/"%testname
	if  not os.path.isdir(outputdir):
		os.mkdir(outputdir)
		os.chmod(outputdir, 0777)

	if args.userland in ("libreswan", "openswan"):
		dname = "pluto"
	elif args.userland == "strongswan":
		dname = "charon"
	elif args.userland == "racoon":
		dname = "racoon"
	else:
		dname = "iked"

	daemonlogfile = "%s/%s.%s.log"%(outputdir,hostname,dname)
	tmplink = "/tmp/%s.log"%dname
	if os.path.islink(tmplink) or os.path.isfile(tmplink):
		os.unlink(tmplink)
	os.symlink(daemonlogfile, tmplink)
	f = open(daemonlogfile, 'w')
	f.close()
	os.chmod(daemonlogfile, 0777)

if args.userland:
	if not args.userland in ( "libreswan","strongswan","racoon","shrew", "openswan"):
		sys.exit("swan-prep: unknown userland type '%s'"%args.userland)
	userland = args.userland
else:
	userland = "libreswan"

#print "swan-prep running on %s for test %s with userland %s"%(hostname,testname,userland)

# wipe any old configs in /etc/ipsec.*
if os.path.isfile("/etc/ipsec.conf"):
	os.unlink("/etc/ipsec.conf")
if os.path.isfile("/etc/ipsec.secrets"):
	os.unlink("/etc/ipsec.secrets")
if os.path.isdir("/etc/ipsec.d"):
	shutil.rmtree("/etc/ipsec.d")
	os.mkdir("/etc/ipsec.d")

# if using systemd, ensure we don't restart pluto on crash
if os.path.isfile("/lib/systemd/system/ipsec.service"):
	service = "".join(open("/lib/systemd/system/ipsec.service").readlines())
	if "Restart=always" in service:
		fp = open("/lib/systemd/system/ipsec.service","w")
		fp.write("".join(service).replace("Restart=always","Restart=no"))
		commands.getoutput("/usr/bin/systemctl daemon-reload")

# easy way to ensure we have the LOGDROP table for those OSes where we don't start a firewall with config
commands.getoutput("iptables-restore /etc/sysconfig/iptables")

if userland == "libreswan" or userland == "openswan" or userland == "strongswan":
	# copy in base configs

	# this brings in the nss *.db files that are path-specific - they have pathnames hardcoded inside the file
	#shutil.copytree("/testing/baseconfigs/%s/etc/ipsec.d"%hostname, "/etc/")
	distutils.dir_util.copy_tree("/testing/baseconfigs/%s/etc/ipsec.d"%hostname, "/etc/ipsec.d/")

	# fill in any missing dirs
	if userland == "strongswan":
		prefix = "strongswan/"
	else:
		prefix = ""
	for dir in ( "/etc/%sipsec.d"%prefix, "/etc/%sipsec.d/policies"%prefix, "/etc/%sipsec.d/cacerts"%prefix, "/etc/%sipsec.d/crls"%prefix , "/etc/%sipsec.d/certs"%prefix, "/etc/%sipsec.d/private"%prefix):
		if not os.path.isdir(dir):
			os.mkdir(dir)
			if "private" in dir:
				os.chmod(dir,0700)

	# test specific files
	ipsecconf = "/testing/pluto/%s/%s.conf"%(testname,hostname)
	ipsecsecrets = "/testing/pluto/%s/%s.secrets"%(testname,hostname)
	if not os.path.isfile(ipsecconf):
		ipsecconf = "/testing/baseconfigs/%s/etc/ipsec.conf"%hostname
	if not os.path.isfile(ipsecsecrets):
		ipsecsecrets = "/testing/baseconfigs/%s/etc/ipsec.secrets"%hostname

	if args.userland == "strongswan":
		strongswanconf = "/testing/pluto/%s/%sstrongswan.conf"%(testname,hostname)
		shutil.copy(strongswanconf, "/etc/strongswan/strongswan.conf")
		for dir in ( "/etc/strongswan/ipsec.d/aacerts", "/etc/strongswan/ipsec.d/ocspcerts"):
			if not os.path.isdir(dir):
				os.mkdir(dir)

	dstconf = "/etc/%sipsec.conf"%(prefix)
	dstsecrets = "/etc/%sipsec.secrets"%(prefix)
	shutil.copy(ipsecconf, dstconf)
	shutil.copy(ipsecsecrets,dstsecrets);
	os.chmod(dstsecrets, 0600)
    
	if args.x509: 
		if (userland == "libreswan" or userland == "openswan"):
			print "Preparing X.509 files"
			# clean out old db files and generate new ones from scratch, does not support -w for password
			oldfiles = glob.glob("/etc/ipsec.d/*db")
			for oldfile in oldfiles:
				os.unlink(oldfile)
			cmd = "certutil -N -d /etc/ipsec.d"
			timer = 10
			child = pexpect.spawn(cmd)
			child.sendline ('')
			child.sendline ('')

			commands.getoutput("pk12util -i /testing/x509/pkcs12/mainca/%s.p12 -d /etc/ipsec.d -w /testing/x509/nss-pw"%hostname)
			# install all other public certs
			# libreswanhost = os.getenv("LIBRESWANHOSTS") #kvmsetu.sh is not sourced
			for certname in ( "west", "east", "road", "north","hashsha2"):
				if certname is not hostname:
					commands.getoutput("certutil -A -n %s -t 'P,u' -d /etc/ipsec.d/ -a -i /testing/x509/certs/%s.crt"%(certname,certname))
		elif userland == "strongswan":
			shutil.copy("/testing/x509/cacerts/mainca.crt","/etc/strongswan/ipsec.d/cacerts/")
			shutil.copy("/testing/x509/keys/%s.key"%hostname,"/etc/strongswan/ipsec.d/private/")
			for certname in ( "west", "east", "road", "north","hashsha2"):
				shutil.copy("/testing/x509/certs/%s.crt"%hostname,"/etc/strongswan/ipsec.d/certs/")

if userland == "racoon2" or userland == "racoon":
	# Racoon uses x509 files straight from /testing/x509/*
	# shutil.copytree("/testing/pluto/%s/%s-racoon/"%(testname,hostname), "/etc/racoon2/")
	distutils.dir_util.copy_tree("/testing/pluto/%s/%s-racoon"%(testname,hostname), "/etc/racoon2/")
	os.chmod("/etc/racoon2/psk/test.psk", 0600)
	os.chmod("/etc/racoon2/spmd.pwd", 0600)
	if not os.path.isdir("/var/run/racoon2"):
		os.mkdir("/var/run/racoon2")
		os.chmod("/var/run/racoon2", 0600)
	# craete stub ipsec.conf for stackmanager to load netkey
	ipsecconf = open("/etc/ipsec.conf","w")
	ipsecconf.write("version 2\n");
	ipsecconf.write("config setup\n");
	ipsecconf.write("\tprotostack=netkey\n");
	ipsecconf.close()

if userland == "shrew":
	print "shrew not yet tested/integrated"

if not args.ipv:
	commands.getoutput("sysctl net.ipv6.conf.all.disable_ipv6=1")
	commands.getoutput("sysctl net.ipv6.conf.default.disable_ipv6=1")

if args.ipv6:
	commands.getoutput("sysctl net.ipv6.conf.all.disable_ipv6=0")
	commands.getoutput("sysctl net.ipv6.conf.default.disable_ipv6=0")
	commands.getoutput("/etc/init.d/network restart")

# final prep - this kills any running userland
output = commands.getoutput("ipsec stop")
# python has no pidof :/
pluto = commands.getoutput("pidof pluto")
charon = commands.getoutput("pidof charon")
shrew = commands.getoutput("pidof iked")
racoon1 = commands.getoutput("pidof racoon2-spmd")
racoon2 = commands.getoutput("pidof racoon2-iked")

for pid in ( pluto, charon, shrew, racoon1, racoon2 ):
	if pid:
		os.kill(int(pid),9)

# note shrew and racoon2 share a pid file name
for pidfile in ( "/var/run/pluto/pluto.pid", "/var/run/charon.pid", "/var/run/iked.pid","/var/run/spmd.pid"):
	if os.path.isfile(pidfile):
		os.unlink(pidfile)

# remove stacks so test can start the stack it needs.
commands.getoutput("ipsec _stackmanager stop")
