#!/usr/bin/python
""" Testing for pluto's certificate verification
	
 Copyright (C) 2015 Matt Rogers <mrogers@libreswan.org>

 This program is free software; you can redistribute it and/or modify it
 under the terms of the GNU General Public License as published by the
 Free Software Foundation; either version 2 of the License, or (at your
 option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.

 This program is distributed in the hope that it will be useful, but
 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 for more details.
 """

import os
import sys
import ssl
import subprocess
import time
from datetime import datetime, timedelta

try:
	import argparse
	from OpenSSL import crypto
except ImportError as e:
	sys.exit("module %s required" % str(e)[16:])

FILE_DIR  = 'testfiles/'
EXT_BASIC = 'basicConstraints'
EXT_KU    = 'keyUsage'
EXT_EKU   = 'extendedKeyUsage'
EXT_CRL   = 'crlDistributionPoints'
EXT_OCSP  = 'authorityInfoAccess'

CRL_URI   = 'URI:http://nic.testing.libreswan.org/revoked.crl'
OCSP_URI  = 'OCSP;URI:http://nic.testing.libreswan.org:2560'

# man x509v3_config
KU_DIGSIG = 'digitalSignature'
KU_NONREP = 'nonRepudiation'
KU_KEYENC = 'keyEncipherment'
KU_DATENC = 'dataEncipherment'
KU_KEYAGR = 'keyAgreement'
KU_KEYSIG = 'keyCertSign'
KU_CRLSIG = 'cRLSign'
KU_ENONLY = 'encipherOnly'
KU_DEONLY = 'decipherOnly'

all_kus = ( KU_DIGSIG,
			KU_NONREP,
			KU_KEYENC,
			KU_DATENC,
			KU_KEYAGR,
			KU_KEYSIG,
			KU_CRLSIG,
			KU_ENONLY,
			KU_DEONLY,)
	
EKU_SRVAUTH = 'serverAuth'
EKU_CLIAUTH = 'clientAuth'
EKU_CODESIG = 'codeSigning'
EKU_OCSP    = 'OCSPSigning'
EKU_EMAILPR = 'emailProtection'

all_ekus = ( EKU_SRVAUTH,
			 EKU_CLIAUTH,
			 EKU_CODESIG,
			 EKU_OCSP,
			 EKU_EMAILPR, )
# more...

parser = argparse.ArgumentParser(description='create and verify a cert chain with KU and EKU settings')
parser.add_argument('--verbose', '-v', action='store_true', help='verbose')
parser.add_argument('--usage', '-U', action='store', default='certUsageAnyCA', help='usage string')
parser.add_argument('--use_sub', '-u', action='store_true', default=False, help='use an intermediate cert')
parser.add_argument('--rootku', '-r', action='store', default='', help='root CA keyUsage string')
parser.add_argument('--rooteku', '-R', action='store', default='', help='root extendedKeyUsage string')
parser.add_argument('--rootcrit', '-c', action='store_true', default=False, help='set root CA basicConstraints flag to Critical')
parser.add_argument('--roottrust', '-t', action='store', default='', help='root NSS db trust string')
parser.add_argument('--subku', '-s', action='store', default='', help='sub CA keyUsage string')
parser.add_argument('--subeku', '-S', action='store', default='', help='sub CA extendedKeyUsage string')
parser.add_argument('--subtrust', '-T', action='store', default='', help='sub CA NSS db trust string')
parser.add_argument('--subcrit', '-C', action='store_true', default=False, help='sub CA basicConstraints flag to Critical')
parser.add_argument('--endku', '-e', action='store', default='', help='end cert keyUsage string')
parser.add_argument('--endeku', '-E', action='store', default='', help='end extendedKeyUsage string')
parser.add_argument('--pkix', '-p', action='store_true', default='', help='use the pkix validator')
parser.add_argument('--retry', '-x', action='store_true', default='', help='use the usage failover retry')
parser.add_argument('--ocsp', '-O', action='store_true', default='', help='ocsp checking')
parser.add_argument('--crl', '-L', action='store_true', default='', help='ocsp checking')
parser.add_argument('--strict', '-z', action='store_true', default='', help='strict checks')
args = parser.parse_args()

ADD_CERT_FMT = "%s -A -d sql:%s -n '%s' -t '%s' -i %s"
CREATE_DB_FMT = "%s -N -d sql:%s --empty-password"
CU_PATH = "/usr/bin/certutil"
VERIFIER_FMT = "%s -u %s"
VERIFIER = "./verify"

def create_keypair(algo=crypto.TYPE_RSA, bits=1024):
	""" Create an OpenSSL keypair
	"""
	pkey = crypto.PKey()
	pkey.generate_key(algo, bits)
	return pkey

def create_csr(pkey, **kwargs):
	""" Create the certreq
	"""
	req = crypto.X509Req()
	subject = req.get_subject()
	subject.CN = kwargs['CN']
	subject.L = str(kwargs['L'])
	subject.O = str(kwargs['O'])
	req.set_pubkey(pkey)
	req.sign(pkey, 'sha1')
	return req

def writeout_cert(filename, item, type=crypto.FILETYPE_ASN1):
	try:
		with open(filename, "w") as f:
			f.write(crypto.dump_certificate(type, item))
	except Exception, e:
		print "writeout_cert %s, quitting: %s" % (filename, e)
		sys.exit(1)

def add_usage(ku_str, ku_add):
	if ku_add is None or ku_add is '':
		return ku_str

	if ku_str == '':
		sep = ''
	else:
		sep = ','
	return ku_str + sep + ku_add

def run_cmd(cmd, perr=False, fatal=True):
	ret = 0
	if args.verbose:
		print "run_cmd: %s" % cmd
	try:
		subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT)
	except subprocess.CalledProcessError as e:
		if perr:
			print "%s" % (e.output)
		if fatal:
			raise
		else:
			ret = e.returncode

	return ret

def clear_path(path):
	for the_file in os.listdir(path):
		file_path = os.path.join(path, the_file)
		try:
			if os.path.isfile(file_path):
				os.unlink(file_path)
		except Exception, e:
			print e
			sys.exit(1)

class usageTest:
	def __init__(self, start, end, subca=False):
		clear_path(FILE_DIR)
		self.use_subca = subca
		self.START = start
		self.END = end
		self.final_result = 0

		self.rootdn = dict(CN='ca.libreswan.org', C='CA', L='Toronto', O='Libreswan')
		self.rootsn = 0
		self.root_use_bc = True
		self.root_bc_crit = False
		self.root_ku_str = ''
		self.root_eku_str = ''

		self.subdn = dict(CN='sub.libreswan.org', C='SubCA', L='Toronto', O='Libreswan')
		self.subsn = 1
		self.sub_use_bc = True
		self.sub_bc_crit = False
		self.sub_ku_str = ''
		self.sub_eku_str = ''

		self.enddn = dict(CN='end.libreswan.org', C='End', L='Toronto', O='Libreswan')
		self.endsn = 2
		self.end_ku_str = ''
		self.end_eku_str = ''

	def generate_certs(self):
		self._create_root_key()
		self._create_root_cert()
		if self.use_subca:
			self._create_sub_key()
			self._create_sub_cert()
		self._create_end_key()
		self._create_end_cert()
		self._set_all_exts()
		self._sign_all_certs()
		self._dump_all_certs()

	def _prepare_db(self):
		run_cmd(CREATE_DB_FMT % (CU_PATH, FILE_DIR))

	def _prepare_certs(self):
		if args.roottrust == '':
			rtrust = 'CT,,'
		else:
			rtrust = args.roottrust

		if args.subtrust == '':
			strust = ',,'
		else:
			strust = args.subtrust

		self._add_cert_to_db("root", FILE_DIR + "root.pem", rtrust)
		#if self.use_subca:
		#	self._add_cert_to_db("sub", FILE_DIR + "sub.pem", strust)

	def _add_cert_to_db(self, name, bin_cert, trust=None):
		truststr = ',,'
		if trust is not None:
			truststr = trust
		run_cmd(ADD_CERT_FMT % (CU_PATH, FILE_DIR, name, trust, bin_cert))

	def _launch_verifier(self):
		cmd = VERIFIER_FMT
		if args.pkix:
			cmd = cmd + " -p"
		if args.ocsp:
			cmd = cmd + " -o"
		if args.crl:
			cmd = cmd + " -c"
		if args.strict:
			cmd = cmd + " -S"
		if args.retry:
			cmd = cmd + " -r"
		if args.use_sub:
			cmd = cmd + " -s " + FILE_DIR + "sub.pem"

		self.final_result = run_cmd(cmd % (VERIFIER, args.usage), perr=True, fatal=False)

	def run_test(self):
		self._prepare_db()
		self._prepare_certs()
		self._launch_verifier()

	def get_result(self):
		return self.final_result

	def set_root_bc(self, crit=False):
		self.root_use_bc = True
		self.root_bc_crit = crit

	def set_sub_bc(self, crit=False):
		self.sub_use_bc = True
		self.sub_bc_crit = crit

	def add_root_ku(self, ku_add):
		self.root_ku_str = add_usage(self.root_ku_str, ku_add)

	def add_sub_ku(self, ku_add):
		self.sub_ku_str = add_usage(self.sub_ku_str, ku_add)

	def add_end_ku(self, ku_add):
		self.end_ku_str = add_usage(self.end_ku_str, ku_add)

	def add_root_eku(self, eku_add):
		self.root_eku_str = add_usage(self.root_eku_str, eku_add)

	def add_sub_eku(self, eku_add):
		self.sub_eku_str = add_usage(self.sub_eku_str, eku_add)

	def add_end_eku(self, eku_add):
		self.end_eku_str = add_usage(self.end_eku_str, eku_add)

    # cert creation
	def _create_root_key(self):
		self.rootkey = create_keypair()

	def _create_sub_key(self):
		self.subkey = create_keypair()

	def _create_end_key(self):
		self.endkey = create_keypair()

	def _create_root_cert(self):
		rootreq = create_csr(self.rootkey, **self.rootdn)
		self.rootcert = self._create_cert_from_req(rootreq, rootreq.get_subject(), self.rootsn,
											 	   self.START, self.END, version=3)
		if self.root_use_bc:
			self._set_basic_constraint(self.rootcert, self.root_bc_crit, isCA=True)

	def _create_sub_cert(self):
		subreq = create_csr(self.subkey, **self.subdn)
		self.subcert = self._create_cert_from_req(subreq, self.rootcert.get_subject(), self.subsn,
											 	   self.START, self.END, version=3)
		if self.sub_use_bc:
			self._set_basic_constraint(self.subcert, self.sub_bc_crit, isCA=True)

	def _create_end_cert(self):
		if self.use_subca:
			issuer = self.subcert
		else:
			issuer = self.rootcert

		endreq = create_csr(self.endkey, **self.enddn)
		self.endcert = self._create_cert_from_req(endreq, issuer.get_subject(), self.endsn,
											 	  self.START, self.END, version=3)

	def _create_cert_from_req(self, req, issuer_subj, sn, start, end, version=3):
		cert = crypto.X509()
		cert.set_serial_number(sn)
		cert.set_notBefore(start)
		cert.set_notAfter(end)
		cert.set_issuer(issuer_subj)
		cert.set_subject(req.get_subject())
		cert.set_pubkey(req.get_pubkey())
		cert.set_version(version)
		return cert

	def _sign_all_certs(self):
		self.rootcert.sign(self.rootkey, 'sha1')
		if self.use_subca:
			self.subcert.sign(self.rootkey, 'sha1')
			endsigkey = self.subkey
		else:
			endsigkey = self.rootkey
		self.endcert.sign(endsigkey, 'sha1')

	def _dump_all_certs(self):
		writeout_cert(FILE_DIR + "root.pem", self.rootcert)
		if self.use_subca:
			writeout_cert(FILE_DIR + "sub.pem", self.subcert)
		writeout_cert(FILE_DIR + "end.pem", self.endcert)

	# extension setters
	def _set_all_exts(self):
		self._set_keyusage(self.rootcert, self.root_ku_str)
		if (args.ocsp):
			self._set_ocspext(self.rootcert)
		if (args.crl):
			self._set_crlext(self.rootcert)

		if self.use_subca:
			self._set_keyusage(self.subcert, self.sub_ku_str)
		self._set_keyusage(self.endcert, self.end_ku_str)

		self._set_extendedkeyusage(self.rootcert, self.root_eku_str)
		if self.use_subca:
			self._set_extendedkeyusage(self.subcert, self.sub_eku_str)
		self._set_extendedkeyusage(self.endcert, self.end_eku_str)

	def _set_basic_constraint(self, cert, crit, isCA):
		if isCA:
			bc = 'CA:TRUE'
		else:
			bc = 'CA:FALSE'
		self._set_extension(cert, EXT_BASIC, bc, crit)

	def _set_keyusage(self, cert, ku_str, crit=False):
		self._set_extension(cert, EXT_KU, ku_str, crit)

	def _set_extendedkeyusage(self, cert, eku_str, crit=False):
		self._set_extension(cert, EXT_EKU, eku_str, crit)
	
	def _set_crlext(self, cert):
		self._set_extension(cert, EXT_CRL, CRL_URI, crit=False)

	def _set_ocspext(self, cert):
		self._set_extension(cert, EXT_OCSP, OCSP_URI, crit=False)

	def _set_extension(self, cert, ext, ext_str, crit):
		if ext_str is '' or ext_str is None:
			return
		try:
			if args.verbose and ext is not EXT_BASIC:
				print "setting %s \'%s\' on %s - crit: %s" % (ext, ext_str, cert.get_subject().commonName, crit)
			cert.add_extensions([crypto.X509Extension(ext, crit, ext_str)])
		except Exception, e:
			print "Error adding extension string %s, quitting: %s" % (ext_str, e)
			sys.exit(1)

def gmc(timestamp):
	return time.strftime("%Y%m%d%H%M%SZ", time.gmtime(timestamp))

def gen_gmtime_dates():
	gmtfmt = "%b %d %H:%M:%S %Y GMT"
	ok_stamp = ssl.cert_time_to_seconds(time.strftime(gmtfmt, time.gmtime())) - (60*60*24)
	future_stamp = ok_stamp + (60*60*24*365*1)
	return gmc(ok_stamp), gmc(future_stamp)

def single_test(start, end, use_sub, **kwargs):
	test = usageTest(start, end, use_sub)
	if kwargs['root_bc']:
		test.set_root_bc(kwargs['root_bc_crit'])
	if use_sub and kwargs['sub_bc']:
		test.set_sub_bc(kwargs['sub_bc_crit'])
	
	for ku in kwargs['root_ku']:
		test.add_root_ku(ku)
	for eku in kwargs['root_eku']:
		test.add_root_eku(eku)

	if use_sub:
		for ku in kwargs['sub_ku']:
			test.add_sub_ku(ku)
		for eku in kwargs['sub_eku']:
			test.add_sub_eku(eku)

	for ku in kwargs['end_ku']:
		test.add_end_ku(ku)
	for eku in kwargs['end_eku']:
		test.add_end_eku(eku)

	test.generate_certs()
	test.run_test()
	if (test.get_result() == 0):
		print "OK"
	return test.get_result()

def get_test_parms():
	return dict(root_ku=tuple(args.rootku.split(',')),
			    root_eku=tuple(args.rooteku.split(',')),
			    root_bc=True, root_bc_crit=args.rootcrit,
			    sub_ku=tuple(args.subku.split(',')),
			    sub_eku=tuple(args.subeku.split(',')),
			    sub_bc=True, sub_bc_crit=args.subcrit,
			    end_ku=tuple(args.endku.split(',')),
			    end_eku=tuple(args.endeku.split(',')))
def main():
	if not os.path.exists(FILE_DIR):
		os.makedirs(FILE_DIR)
	if not os.path.exists(VERIFIER):
		run_cmd("make")

	start, end = gen_gmtime_dates()

	sys.exit(single_test(start, end, args.use_sub, **get_test_parms()))

if __name__ == '__main__':
	main()

