#!/usr/bin/python
#
# This script is run by a VM host (eg "west") to prepare itself for testing
# It should be passed a testname as its onyl argument

import os,sys,socket,shutil,distutils.dir_util,commands
import pexpect, glob

try:
	import argparse
except ImportError , e:
	module = str(e)[16:]
	sys.exit("we requires the python argparse module")

# prevent double installs from going unnoticed
if os.path.isfile("/usr/sbin/ipsec") or os.path.isfile("/sbin/ipsec"):
	if os.path.isfile("/usr/local/sbin/ipsec"):
		sys.exit("\n\n---------------------------------------------------------------------\n"
                             "ABORT: found a swan userland in the base system as well as /usr/local\n"
			     "---------------------------------------------------------------------\n")

os.environ['NSS_DEFAULT_DB_TYPE'] = 'sql:'

# just in case a new copy was installed
if os.path.isfile("/usr/local/sbin/ipsec") and os.path.isfile("/usr/sbin/restorecon"):
	commands.getoutput("restorecon -Rv /usr/local/libexec/ipsec /usr/local/sbin/ipsec")

ipv = 0
ipv6 = 0
parser = argparse.ArgumentParser(description='swan-prep arguments')
parser.add_argument('--testname', '-t', action='store', default='', help='The name of the test to prepapre')
parser.add_argument('--hostname', '-H', action='store', default='', help='The name of the host to prepare as')
# we should get this from the testparams.sh file?
parser.add_argument('--userland', '-u', action='store', default='libreswan', help='which userland to prepapre')
parser.add_argument('--x509', '-x', action='store_true',  help='create X509 NSS file by importing test certs')
parser.add_argument('--x509name', '-X', action='store', default="",  help='create X509 NSS file by importing test certs')
parser.add_argument('--fips', '-f', action='store_true',  help='prepare /etc/ipsec.d for running in FIPS mode')
parser.add_argument('--revoked', '-r', action='store_true',  help='load a revoked certificate')
parser.add_argument('--signedbyother', '-o', action='store_true',  help='load the signedbyother certificate')
parser.add_argument('--eccert', '-e', action='store_true', help='enable an EC cert for this host')
parser.add_argument('--nsspw', '-P', action='store_true', help='enable an nss passworded db')
parser.add_argument('--certchain', '-c', action='store_true', help='import the ca-chain test certs')
parser.add_argument('--46', '--64', action='store_true', help='Do not disable IPv6. Default is disable IPv6 ', dest = 'ipv', default = False)
parser.add_argument('--6', action='store_true', help='Enable IPv6 and run - /etc/init.d/network restart', dest = 'ipv6', default = False)
args = parser.parse_args()

if args.hostname:
	hostname = args.hostname
	if hostname == "nic":
		# nothing to do, just stop
		sys.exit()
else:
	hostname = socket.gethostname()
if "." in hostname:
	hostname = hostname.split(".")[0]

if args.testname:
        # Should this instead try to guess /testing/pluto prefix, or
        # allow testname to specify a directory path?
	testname = args.testname
        testpath = "/testing/pluto/" + testname
	if not os.path.isdir(testpath):
		sys.exit("Unknown or bad testname '%s'"%args.testname)
else:
        # Validate this is sane?
	testpath = os.getcwd()
	testname = os.path.basename(testpath)

# Setup pluto.log softlink
if hostname != "nic":
	if os.path.isfile("/tmp/pluto.log") or os.path.islink("/tmp/pluto.log"):
		os.unlink("/tmp/pluto.log")

	outputdir="%s/OUTPUT/"%testpath
	if  not os.path.isdir(outputdir):
		os.mkdir(outputdir)
		os.chmod(outputdir, 0777)

	if args.userland in ("libreswan", "openswan"):
		dname = "pluto"
	elif args.userland == "strongswan":
		dname = "charon"
	elif args.userland == "racoon":
		dname = "racoon"
	else:
		dname = "iked"

	daemonlogfile = "%s/%s.%s.log"%(outputdir,hostname,dname)
	tmplink = "/tmp/%s.log"%dname
	if os.path.islink(tmplink) or os.path.isfile(tmplink):
		os.unlink(tmplink)
	os.symlink(daemonlogfile, tmplink)
	f = open(daemonlogfile, 'w')
	f.close()
	os.chmod(daemonlogfile, 0777)

if args.userland:
	if not args.userland in ( "libreswan","strongswan","racoon","shrew", "openswan"):
		sys.exit("swan-prep: unknown userland type '%s'"%args.userland)
	userland = args.userland
else:
	userland = "libreswan"

#print "swan-prep running on %s for test %s with userland %s"%(hostname,testname,userland)

# wipe any old configs in /etc/ipsec.*
if os.path.isfile("/etc/ipsec.conf"):
	os.unlink("/etc/ipsec.conf")
if os.path.isfile("/etc/ipsec.secrets"):
	os.unlink("/etc/ipsec.secrets")
if os.path.isdir("/etc/ipsec.d"):
	shutil.rmtree("/etc/ipsec.d")
	os.mkdir("/etc/ipsec.d")

# if using systemd, ensure we don't restart pluto on crash
if os.path.isfile("/lib/systemd/system/ipsec.service"):
	service = "".join(open("/lib/systemd/system/ipsec.service").readlines())
	if "Restart=always" in service:
		fp = open("/lib/systemd/system/ipsec.service","w")
		fp.write("".join(service).replace("Restart=always","Restart=no"))
	# always reload to avoid "service is masked" errors
	commands.getoutput("/usr/bin/systemctl daemon-reload")

# we have to cleanup the audit log or we could get entries from previous test
if os.path.isfile("/var/log/audit/audit.log"):
	fp = open("/var/log/audit/audit.log","w")
	fp.close()
	if os.path.isfile("/lib/systemd/system/auditd.service"):
		commands.getoutput("/usr/bin/systemctl auditd restart")

# ensure cores are just dropped, not sent to aobrt-hook-ccpp or systemd.
# setting /proc/sys/kernel/core_pattern to "core" or a pattern does not
# work. And you cannot use shell redirection ">". So we hack it with "tee"
pattern = "|/usr/bin/tee /tmp/core.%h.%e.%p"
fp = open("/proc/sys/kernel/core_pattern","w")
fp.write(pattern)
fp.close()

if userland == "libreswan" or userland == "openswan" or userland == "strongswan":
	# copy in base configs

	# this brings in the nss *.db files that are path-specific - they have pathnames hardcoded inside the file
	#shutil.copytree("/testing/baseconfigs/%s/etc/ipsec.d"%hostname, "/etc/")
	distutils.dir_util.copy_tree("/testing/baseconfigs/%s/etc/ipsec.d"%hostname, "/etc/ipsec.d/")

	# fill in any missing dirs
	if userland == "strongswan":
		prefix = "strongswan/"
	else:
		prefix = ""
	for dir in ( "/etc/%sipsec.d"%prefix, "/etc/%sipsec.d/policies"%prefix, "/etc/%sipsec.d/cacerts"%prefix, "/etc/%sipsec.d/crls"%prefix , "/etc/%sipsec.d/certs"%prefix, "/etc/%sipsec.d/private"%prefix):
		if not os.path.isdir(dir):
			os.mkdir(dir)
			if "private" in dir:
				os.chmod(dir,0700)

	# test specific files
	ipsecconf = "%s/%s.conf"%(testpath,hostname)
	ipsecsecrets = "%s/%s.secrets"%(testpath,hostname)
	xl2tpdconf = "%s/%s.xl2tpd.conf"%(testpath,hostname)
	pppoptions = "%s/%s.ppp-options.xl2tpd"%(testpath,hostname)
	chapfile = "%s/chap-secrets"%testpath

	if not os.path.isfile(ipsecconf):
		ipsecconf = "/testing/baseconfigs/%s/etc/ipsec.conf"%hostname
	if not os.path.isfile(ipsecsecrets):
		ipsecsecrets = "/testing/baseconfigs/%s/etc/ipsec.secrets"%hostname

	if args.userland == "strongswan":
		strongswanconf = "%s/%sstrongswan.conf"%(testpath,hostname)
		shutil.copy(strongswanconf, "/etc/strongswan/strongswan.conf")
		for dir in ( "/etc/strongswan/ipsec.d/aacerts", "/etc/strongswan/ipsec.d/ocspcerts"):
			if not os.path.isdir(dir):
				os.mkdir(dir)

	dstconf = "/etc/%sipsec.conf"%(prefix)
	dstsecrets = "/etc/%sipsec.secrets"%(prefix)
	shutil.copy(ipsecconf, dstconf)
	shutil.copy(ipsecsecrets,dstsecrets);
	os.chmod(dstsecrets, 0600)

	if os.path.isfile(xl2tpdconf):
		shutil.copyfile(xl2tpdconf, "/etc/xl2tpd/xl2tpd.conf")
	if os.path.isfile(pppoptions):
		shutil.copyfile(pppoptions, "/etc/ppp/options.xl2tpd")
	if os.path.isfile(chapfile):
		shutil.copyfile(chapfile, "/etc/ppp/chap-secrets")

	dbfiles = glob.glob("/etc/ipsec.d/*db")
	for dbfile in dbfiles:
		os.chown(dbfile,0,0)

	# restore /etc/hosts to original - some tests make changes
	shutil.copyfile("/testing/baseconfigs/all/etc/hosts", "/etc/hosts")
	resolv = "/testing/baseconfigs/all/etc/resolv.conf"
	if os.path.isfile("/testing/baseconfigs/%s/etc/resolv.conf"%hostname):
		resolv = "/testing/baseconfigs/%s/etc/resolv.conf"%hostname
	else:
		resolv = "/testing/baseconfigs/all/etc/resolv.conf"
	shutil.copyfile(resolv, "/etc/resolv.conf")

	if args.nsspw or args.fips:
		dbpassword = 's3cret'
		util_pw = ' -f /etc/ipsec.d/nsspassword'
		p12cmd_pw = ' -k /etc/ipsec.d/nsspassword'
		if not os.path.isfile('/etc/ipsec.d/nsspassword'):
			with open('/etc/ipsec.d/nsspassword', 'w') as f:
				f.write("NSS FIPS 140-2 Certificate DB:" + dbpassword + '\n')
				f.write("NSS Certificate DB:" + dbpassword + '\n')

	if args.fips:
		fp = open("/etc/system-fips","w")
		fp.close()
		commands.getoutput("/testing/guestbin/fipson")
		# the test also requires using a modutil cmd which we cannot run here
		shutil.copyfile("/testing/baseconfigs/all/etc/sysconfig/pluto.fips", "/etc/sysconfig/pluto")
		# passwd required
		fp = open("/tmp/pw","w")
		fp.write("\n")
		fp.close()
		out = commands.getoutput("certutil -W -f /tmp/pw -@ /etc/ipsec.d/nsspassword -d sql:/etc/ipsec.d")
		print "certutil out:%s"%out
	else:
		shutil.copyfile("/testing/baseconfigs/all/etc/sysconfig/pluto", "/etc/sysconfig/pluto")
		if os.path.isfile("/etc/system-fips"):
			os.unlink("/etc/system-fips")


	# this section is getting rough. could use a nice refactoring
	if args.x509:
		if (userland == "libreswan" or userland == "openswan"):
			print "Preparing X.509 files"
			dbpassword = ''
			util_pw = ''
			p12cmd_pw = ''



			# clean out old db files and generate new ones from scratch
			# certutil accepts --empty-password, however the following way
			# is useful for both empty and passworded tests.

			oldfiles = glob.glob("/etc/ipsec.d/*db")
			for oldfile in oldfiles:
				os.unlink(oldfile)
			commands.getoutput("certutil -N --empty-password -d sql:/etc/ipsec.d")

			if not os.path.isfile("/testing/x509/keys/mainca.key"):
				print("\n\n---------------------------------------------------------------------\n"
					  "WARNING: no mainca.key file, did you run testing/x509/dist_certs.py?\n"
					  "---------------------------------------------------------------------\n")

			if args.eccert:
				p12 = hostname + "-ec"
				ca = "curveca"
				pw = "-W \"\""
			else:
				if args.x509name:
					p12 = args.x509name
				else:
					p12 = hostname
				ca = "mainca"
				pw = "-w /testing/x509/nss-pw"

			if args.certchain:
				icanum = 2
				pw = "-w /testing/x509/nss-pw"
				root = "mainca"
				ica_p = hostname + "_chain_int_"
				ee = hostname + "_chain_endcert"

				# a note about DB trusts
				# 'CT,,' is our root's trust. T is important!! it is for verifying "SSLClient" x509 KU
				# ',,' is an intermediate's trust
				# 'P,,' (trusted peer) is nic's for OCSP
				# 'u,u,u' will be end cert trusts that are p12 imported (with privkey)

				#mainca and nic import
				commands.getoutput("certutil -A -n %s -t 'CT,,' -d sql:/etc/ipsec.d/ -a -i /testing/x509/cacerts/%s.crt%s" % (root, root, util_pw))
				commands.getoutput("certutil -A -n nic -t 'P,,' -d sql:/etc/ipsec.d/ -a -i /testing/x509/certs/nic.crt%s" % util_pw)

				# install ee
				commands.getoutput("pk12util -i /testing/x509/pkcs12/%s.p12 -d sql:/etc/ipsec.d %s%s" % (ee, pw, p12cmd_pw))

				# install intermediates
				for i in range(1, icanum+1):
					acrt = ica_p + str(i)
					commands.getoutput("certutil -A -n %s -t ',,' -d sql:/etc/ipsec.d/ \
										-a -i /testing/x509/certs/%s.crt%s" % (acrt,acrt, util_pw))
				if args.revoked:
					commands.getoutput("pk12util -i /testing/x509/pkcs12/%s_revoked.p12 -d /etc/ipsec.d %s%s" % (hostname + "_chain", pw, p12cmd_pw))

			else:
				commands.getoutput("pk12util -i /testing/x509/pkcs12/%s/%s.p12 -K '' -d sql:/etc/ipsec.d %s%s"%(ca, p12, pw, p12cmd_pw))
				# install all other public certs
				# libreswanhost = os.getenv("LIBRESWANHOSTS") #kvmsetu.sh is not sourced

				if args.revoked:
					commands.getoutput("pk12util -i /testing/x509/pkcs12/mainca/revoked.p12 -d sql:/etc/ipsec.d %s%s" % (pw, p12cmd_pw))

				if args.signedbyother:
					commands.getoutput("pk12util -i /testing/x509/pkcs12/otherca/signedbyother.p12 -d sql:/etc/ipsec.d %s%s" % (pw, p12cmd_pw))

				# fix trust import from p12
				# is pw needed?
				commands.getoutput("certutil -M -n 'Libreswan test CA for mainca - Libreswan' -d sql:/etc/ipsec.d/ -t 'CT,,'")

				for certname in ( "west", "east", "road", "north", "hashsha2", "west-ec", "east-ec", "nic"):
					if not hostname in certname:
						commands.getoutput("certutil -A -n %s -t 'P,,' -d sql:/etc/ipsec.d/ -a -i /testing/x509/certs/%s.crt%s"%(certname,certname,util_pw))
		elif userland == "strongswan":
			if args.eccert:
				ca = "curveca"
				key = hostname + "-ec"
			else:
				ca = "mainca"
				key = hostname

			shutil.copy("/testing/x509/cacerts/%s.crt"%ca,"/etc/strongswan/ipsec.d/cacerts/")
			shutil.copy("/testing/x509/keys/%s.key"%key,"/etc/strongswan/ipsec.d/private/")

			for certname in ( "west", "east", "road", "north","hashsha2","west-ec","east-ec"):
				shutil.copy("/testing/x509/certs/%s.crt"%certname,"/etc/strongswan/ipsec.d/certs/")

if userland == "racoon2" or userland == "racoon":
	# Racoon uses x509 files straight from /testing/x509/*
	# shutil.copytree("%s/%s-racoon/"%(testpath,hostname), "/etc/racoon2/")
	distutils.dir_util.copy_tree("%s/%s-racoon"%(testpath,hostname), "/etc/racoon2/")
	os.chmod("/etc/racoon2/psk/test.psk", 0600)
	os.chmod("/etc/racoon2/spmd.pwd", 0600)
	if not os.path.isdir("/var/run/racoon2"):
		os.mkdir("/var/run/racoon2")
		os.chmod("/var/run/racoon2", 0600)
	# craete stub ipsec.conf for stackmanager to load netkey
	ipsecconf = open("/etc/ipsec.conf","w")
	ipsecconf.write("version 2\n");
	ipsecconf.write("config setup\n");
	ipsecconf.write("\tprotostack=netkey\n");
	ipsecconf.close()

if userland == "shrew":
	print "shrew not yet tested/integrated"

if not args.ipv:
	commands.getoutput("sysctl net.ipv6.conf.all.disable_ipv6=1")
	commands.getoutput("sysctl net.ipv6.conf.default.disable_ipv6=1")

if args.ipv6:
	commands.getoutput("sysctl net.ipv6.conf.all.disable_ipv6=0")
	commands.getoutput("sysctl net.ipv6.conf.default.disable_ipv6=0")
	commands.getoutput("/etc/init.d/network restart")

if not os.path.isfile("/root/.gdbinit"):
	fp = open("/root/.gdbinit","w")
	fp.write("set auto-load safe-path /");
	fp.close()

# Create LOGDROP (used to be done in swan-transmogrify but we want it here for docker)
for ipt in ("iptables", "ip6tables"):
	commands.getoutput("%s -N LOGDROP "%ipt)
	commands.getoutput("%s -A LOGDROP -j LOG"%ipt)
	commands.getoutput("%s -A LOGDROP -j DROP"%ipt)

# final prep - this kills any running userland
output = commands.getoutput("ipsec stop")
# python has no pidof :/. getoutput return a string not a list.
# split the string to make a list. Empty strings will discarded

pids = []
pluto = commands.getoutput("pidof pluto")
pids.extend(pluto.split())
charon = commands.getoutput("pidof starter")
pids.extend(charon.split())
charonstarter = commands.getoutput("pidof charon")
pids.extend(charonstarter.split())
shrew = commands.getoutput("pidof iked")
pids.extend(shrew.split())
racoon1 = commands.getoutput("pidof racoon2-spmd")
pids.extend(racoon1.split())
racoon2 = commands.getoutput("pidof racoon2-iked")
pids.extend(racoon2.split())

for pid in pids:
	if pid:
		os.kill(int(pid),9)

# note shrew and racoon2 share a pid file name
for pidfile in ( "/var/run/pluto/pluto.pid", "/var/run/charon.pid", "/var/run/iked.pid","/var/run/spmd.pid", "/var/run/starter.charon.pid"):
	if os.path.isfile(pidfile):
		os.unlink(pidfile)

# remove stacks so test can start the stack it needs.
commands.getoutput("ipsec _stackmanager stop")
