#!/usr/bin/python

from __future__ import print_function

import argparse
import base64
import copy
from lxml import etree
import getpass
import grp
import jinja2
from keycloak_httpd_client import keycloak_cli
import logging
import logging.handlers
from collections import namedtuple
import os
import pwd
import re
from keycloak_httpd_client.keycloak_cli import RESTError
import requests
import shutil
import six
import socket
import subprocess
import sys
import tempfile
import traceback


from six.moves.urllib.parse import quote as urlquote
from six.moves.urllib.parse import urlsplit, urlunsplit

# ----------------------------- Global Variables ------------------------------

logger = None
template_env = None
prog_name = os.path.basename(sys.argv[0])
LOG_FILE_ROTATION_COUNT = 3
BIN_TIMEOUT = "/usr/bin/timeout"

# -------------------------------- Constants ----------------------------------

DEV_NULL = '/dev/null'
BIN_TIMEOUT = '/usr/bin/timeout'

HTTPD_SAML_DIR = 'saml2'
HTTPD_CONF_DIR = 'conf.d'

MELLON_METADATA_TEMPLATE = 'sp_metadata.tpl'
MELLON_METADATA = 'sp_metadata.xml'

MELLON_HTTPD_CONFIG_TEMPLATE = 'mellon_httpd.conf'

STATUS_SUCCESS = 0
STATUS_OPERATION_ERROR = 1
STATUS_CONFIGURATION_ERROR = 2  # Must be 2 to match argparse exit status
STATUS_INSUFFICIENT_PRIVILEGE = 3
STATUS_COMMUNICATION_ERROR = 4
STATUS_ALREADY_EXISTS_ERROR = 5

SAML_PAOS_BINDING='urn:oasis:names:tc:SAML:2.0:bindings:PAOS'

# --------------------------- Exception Definitions ---------------------------


class AlreadyExistsError(ValueError):
    pass

# --------------------------- Logging Configuration ---------------------------


def configure_logging(options):
    global logger  # pylint: disable=W0603

    STEP = logging.INFO + 1

    class StepLogger(logging.Logger):

        def __init__(self, name):
            self.step_number = 1
            super(StepLogger, self).__init__(name)

        def step(self, msg, *args, **kwargs):
            if self.isEnabledFor(STEP):
                self._log(STEP, ('[Step %2d] ' % self.step_number) + msg,
                          args, **kwargs)
                self.step_number += 1

    logging.addLevelName(STEP, 'STEP')
    logging.setLoggerClass(StepLogger)

    log_dir = os.path.dirname(options.log_file)
    if os.path.exists(log_dir):
        if not os.path.isdir(log_dir):
            raise ValueError('logging directory "{log_dir}" exists but is not '
                             'directory'.format(log_dir=log_dir))
    else:
        os.makedirs(log_dir)

    # Check if log exists and should therefore be rolled
    need_roll = os.path.isfile(options.log_file)

    log_level = STEP
    if options.verbose:
        log_level = logging.INFO
    if options.debug:
        log_level = logging.DEBUG

        # These two lines enable debugging at httplib level
        # (requests->urllib3->http.client) You will see the REQUEST,
        # including HEADERS and DATA, and RESPONSE with HEADERS but
        # without DATA.  The only thing missing will be the
        # response.body which is not logged.
        try:
            import http.client as http_client  # Python 3
        except ImportError:
            import httplib as http_client      # Python 2

        http_client.HTTPConnection.debuglevel = 1

        # Turn on cookielib debugging
        if False:
            try:
                import http.cookiejar as cookiejar
            except ImportError:
                import cookielib as cookiejar  # Python 2
            cookiejar.debug = True

    root_logger = logging.getLogger()
    logger = logging.getLogger(prog_name)

    try:
        file_handler = logging.handlers.RotatingFileHandler(
            options.log_file, mode='w', backupCount=LOG_FILE_ROTATION_COUNT)
    except IOError as e:
        print('Unable to open log file %s (%s)' % (options.log_file, e),
              file=sys.stderr)

    else:
        formatter = logging.Formatter(
            '%(asctime)s %(name)s %(levelname)s: %(message)s')
        file_handler.setFormatter(formatter)
        file_handler.setLevel(logging.DEBUG)
        root_logger.addHandler(file_handler)

        if need_roll:
            file_handler.doRollover()

    console_handler = logging.StreamHandler(sys.stdout)
    formatter = logging.Formatter('%(message)s')
    console_handler.setFormatter(formatter)
    console_handler.setLevel(log_level)
    root_logger.addHandler(console_handler)

    # Set the log level on the logger to the lowest level
    # possible. This allows the message to be emitted from the logger
    # to it's handlers where the level will be filtered on a per
    # handler basis.
    root_logger.setLevel(1)

# ----------------------------- General Utilities -----------------------------

def join_path(*args):
    '''Join each argument into a final path assuring there is
    exactly one slash separating all components in the final path
    and there are no leading or trailing spaces between path components.
    Initial or final slashes are preserved but are collapsed into a
    single slash.

    Why not use posixpath.join and posixpath.normpath? Because they do not
    handle multiple slashes, leading and trailing slashes the way we want'''

    if not args:
        return ''

    components = []

    for item in args:
        components.extend(item.split('/'))
        
    if components[0]:
        leading_slash = False
    else:
        leading_slash = True 

    if components[-1]:
        trailing_slash = False
    else:
        trailing_slash = True

    components = [x.strip() for x in components if x]

    path = '/'.join(components)

    if leading_slash:
        path = '/' + path

    if trailing_slash and components:
        path = path + '/'

    return path

# -------------------------- Shell Command Utilities --------------------------


def nolog_replace(string, nolog):
    """Replace occurences of strings given in `nolog` with XXXXXXXX"""
    for value in nolog:
        if not isinstance(value, six.string_types):
            continue

        quoted = urlquote(value)
        shquoted = shell_quote(value)
        for nolog_value in (shquoted, value, quoted):
            string = string.replace(nolog_value, 'XXXXXXXX')
    return string


def shell_quote(string):
    return "'" + string.replace("'", "'\\''") + "'"


def run_cmd(args, stdin=None, raiseonerr=True,
            nolog=(), env=None, capture_output=True, skip_output=False,
            cwd=None, runas=None, timeout=None, suplementary_groups=[]):
    """
    Execute a command and return stdin, stdout and the process return code.

    :param args: List of arguments for the command
    :param stdin: Optional input to the command
    :param raiseonerr: If True, raises an exception if the return code is
        not zero
    :param nolog: Tuple of strings that shouldn't be logged, like passwords.
        Each tuple consists of a string to be replaced by XXXXXXXX.

        Example:
        We have a command
            [paths.SETPASSWD, '--password', 'Secret123', 'someuser']
        and we don't want to log the password so nolog would be set to:
        ('Secret123',)
        The resulting log output would be:

        /usr/bin/setpasswd --password XXXXXXXX someuser

        If a value isn't found in the list it is silently ignored.
    :param env: Dictionary of environment variables passed to the command.
        When None, current environment is copied
    :param capture_output: Capture stderr and stdout
    :param skip_output: Redirect the output to /dev/null and do not capture it
    :param cwd: Current working directory
    :param runas: Name of a user that the command should be run as. The spawned
        process will have both real and effective UID and GID set.
    :param timeout: Timeout if the command hasn't returned within the specified
        number of seconds.
    :param suplementary_groups: List of group names that will be used as
        suplementary groups for subporcess.
        The option runas must be specified together with this option.
    """
    assert isinstance(suplementary_groups, list)
    p_in = None
    p_out = None
    p_err = None

    if isinstance(nolog, six.string_types):
        # We expect a tuple (or list, or other iterable) of nolog strings.
        # Passing just a single string is bad: strings are also, so this
        # would result in every individual character of that string being
        # replaced by XXXXXXXX.
        # This is a sanity check to prevent that.
        raise ValueError('nolog must be a tuple of strings.')

    if env is None:
        # copy default env
        env = copy.deepcopy(os.environ)
        env["PATH"] = (
            "/bin:/sbin:/usr/kerberos/bin:"
            "/usr/kerberos/sbin:/usr/bin:/usr/sbin")
    if stdin:
        p_in = subprocess.PIPE
    if skip_output:
        p_out = p_err = open(DEV_NULL, 'w')
    elif capture_output:
        p_out = subprocess.PIPE
        p_err = subprocess.PIPE

    if timeout:
        # If a timeout was provided, use the timeout command
        # to execute the requested command.
        args[0:0] = [BIN_TIMEOUT, str(timeout)]

    arg_string = nolog_replace(' '.join(shell_quote(a) for a in args), nolog)
    logger.debug('Starting external process')
    logger.debug('args=%s' % arg_string)

    preexec_fn = None
    if runas is not None:
        pent = pwd.getpwnam(runas)

        suplementary_gids = [
            grp.getgrnam(group).gr_gid for group in suplementary_groups
        ]

        logger.debug('runas=%s (UID %d, GID %s)', runas,
                     pent.pw_uid, pent.pw_gid)
        if suplementary_groups:
            for group, gid in zip(suplementary_groups, suplementary_gids):
                logger.debug('suplementary_group=%s (GID %d)', group, gid)

        preexec_fn = lambda: (
            os.setgroups(suplementary_gids),
            os.setregid(pent.pw_gid, pent.pw_gid),
            os.setreuid(pent.pw_uid, pent.pw_uid),
        )

    try:
        p = subprocess.Popen(args, stdin=p_in, stdout=p_out, stderr=p_err,
                             close_fds=True, env=env, cwd=cwd,
                             preexec_fn=preexec_fn)
        stdout, stderr = p.communicate(stdin)
        stdout, stderr = str(stdout), str(stderr)    # Make pylint happy
    except KeyboardInterrupt:
        logger.debug('Process interrupted')
        p.wait()
        raise
    except:
        logger.debug('Process execution failed')
        raise
    finally:
        if skip_output:
            p_out.close()   # pylint: disable=E1103

    if timeout and p.returncode == 124:
        logger.debug('Process did not complete before timeout')

    logger.debug('Process finished, return code=%s', p.returncode)

    # The command and its output may include passwords that we don't want
    # to log. Replace those.
    if capture_output and not skip_output:
        stdout = nolog_replace(stdout, nolog)
        stderr = nolog_replace(stderr, nolog)
        logger.debug('stdout=%s' % stdout)
        logger.debug('stderr=%s' % stderr)

    if p.returncode != 0 and raiseonerr:
        raise subprocess.CalledProcessError(p.returncode, arg_string, stdout)

    return (stdout, stderr, p.returncode)


def install_file(src_file, dst_file):
    logger.debug('install_file dst_file="%s"', dst_file)
    if os.path.exists(dst_file):
        if not os.path.isfile(dst_file):
            raise ValueError('install file "{dst_file}" exists but is not '
                             'plain file'.format(dst_file=dst_file))
        dst_backup_file = dst_file + ".orig"
        if not os.path.exists(dst_backup_file):
            os.rename(dst_file, dst_backup_file)
    shutil.copy(src_file, dst_file)


def install_file_from_data(data, dst_file):
    logger.debug('install_file_from_data dst_file="%s"', dst_file)
    if os.path.exists(dst_file):
        if not os.path.isfile(dst_file):
            raise ValueError('install file "{dst_file}" exists but is not '
                             'plain file'.format(dst_file=dst_file))
        dst_backup_file = dst_file + ".orig"
        if not os.path.exists(dst_backup_file):
            os.rename(dst_file, dst_backup_file)
    with open(dst_file, 'w') as f:
        f.write(data)


def mkdir(pathname, mode=0o775):
    logger.debug('mkdir pathname="%s" mode=%#o', pathname, mode)
    if os.path.exists(pathname):
        if not os.path.isdir(pathname):
            raise ValueError('mkdir "{pathname}" exists but is not '
                             'directory'.format(pathname=pathname))
    else:
        os.makedirs(pathname, mode)


def httpd_restart():
    cmd = ['/usr/bin/systemctl', 'restart', 'httpd.service']
    run_cmd(cmd)

# ----------------------------- HTTP Utilities --------------------------------

def normalize_url(url, default_scheme='https'):
    '''Assure scheme and port are canonical.

    SAML requires a scheme for URL's, if a scheme is not present add a
    default scheme.

    Strip the port from the URL if it matches the scheme (e.g. 80 for
    http and 443 for https)

    Explicitly specifying a default port (e.g. http://example.com:80
    or https://example.com:443) will cause Mellon to fail. This occurs
    because the port gets embedded into the location URL for each
    endpoint in the SP metadata (e.g the Assertion Consumer
    Service). The IdP sets the Destination attribute in the SAML
    response by looking it up in the SP metadata, thus the Destination
    will have the default port in it (e.g. 443). Upon receiving the
    SAML response the SP compares the URL of the request to the
    Destination attribute in the SAML response, they must match for
    the response to be considered valid. However when Mellon asks
    Apache what the request URL was it won't have the port in it thus
    the URL comparison fails. So why is the port absent? It turns out
    that most (all?) browsers will strip the port from a URL if it
    matches the port for the scheme (e.g. 80 for http and 443 for
    https). Thus even if you include the port in the URL it will never
    be included in the URL the browser emits. This also includes
    stripping the port from the HTTP host header (which Apache uses to
    reconstruct the URL).
    '''

    s = urlsplit(url)
    scheme = s.scheme
    netloc = s.netloc
    path = s.path
    query = s.query
    fragment = s.fragment
    hostname = s.hostname
    port = s.port
    
    if not scheme:
        scheme = default_scheme

    if port is not None:
        if scheme == 'http' and port == 80:
            port = None
        elif scheme == 'https' and port == 443:
            port = None

    if port is None:
        netloc = hostname
    else:
        netloc = "%s:%d" % (hostname, port)

    return urlunsplit((scheme, netloc, path, query, fragment))

# ------------------------------ PEM Utilities --------------------------------


class InvalidBase64Error(ValueError):
    pass

pem_headers = {
    'csr': 'NEW CERTIFICATE REQUEST',
    'cert': 'CERTIFICATE',
    'crl': 'CRL',
    'cms': 'CMS',
    'key': 'PRIVATE KEY',
}

PEMParseResult = namedtuple('PEMParseResult',
                            ['pem_type',
                             'pem_start', 'pem_end',
                             'base64_start', 'base64_end', 'base64_text',
                             'binary_data'])

pem_begin_re = re.compile(r'^-{5}BEGIN\s+([^-]+)-{5}\s*$', re.MULTILINE)
pem_end_re = re.compile(r'^-{5}END\s+([^-]+)-{5}\s*$', re.MULTILINE)


def pem_search(text, start=0):
    '''Search for a block of PEM formatted data

    Search for a PEM block in a text string. The search begins at
    start. If a PEM block is found a PEMParseResult named tuple is
    returned, otherwise if no PEM block is found None is returned.

    The PEMParseResult named tuple is:
    (pem_type, pem_start, pem_end, base64_start, base64_end)

    pem_type
        The text following '-----BEGIN ' in the PEM header.
        Common examples are 'CERTIFICATE', 'CRL', 'CMS'.
    pem_start, pem_end
        The beginning and ending positions of the PEM block
        including the PEM header and footer.
    base64_start, base64_end
        The beginning and ending positions of the base64 text
        contained inside the PEM header and footer.
    base64_text
        The base64 text (e.g. text[b.base64_start : b.base64_end])
    binary_data
        The decoded base64 text. None if not decoded.

    If the pem_type is not the same in both the header and footer
    a ValueError is raised.

    The start and end positions are suitable for use as slices into
    the text. To search for multiple PEM blocks pass pem_end as the
    start position for the next iteration. Terminate the iteration
    when None is returned. Example:

        start = 0
        while True:
            b = pem_search(text, start)
            if b is None:
                break
            start = b.pem_end

    :param string text: the text to search for PEM blocks
    :param int start: the position in text to start searching from
    :returns: PEMParseResult named tuple or None if not found
    '''

    match = pem_begin_re.search(text, pos=start)
    if match:
        pem_start = match.start()
        begin_text = match.group(0)
        base64_start = min(len(text), match.end() + 1)
        begin_pem_type = match.group(1).strip()

        match = pem_end_re.search(text, pos=base64_start)
        if match:
            pem_end = min(len(text), match.end() + 1)
            base64_end = match.start() - 1
            end_pem_type = match.group(1).strip()
        else:
            raise ValueError("failed to find end matching '%s'" % begin_text)

        if begin_pem_type != end_pem_type:
            raise ValueError("beginning & end PEM types do not match "
                             "(%s != %s)",
                             begin_pem_type, end_pem_type)
    else:
        return None

    pem_type = begin_pem_type
    base64_text = text[base64_start:base64_end]
    try:
        binary_data = base64.b64decode(base64_text)
    except Exception as e:
        binary_data = None
        raise InvalidBase64Error('failed to base64 decode %s PEM '
                                 'at position %d: %s' %
                                 (pem_type, pem_start, e))

    result = PEMParseResult(pem_type=pem_type,
                            pem_start=pem_start, pem_end=pem_end,
                            base64_start=base64_start, base64_end=base64_end,
                            base64_text=base64_text,
                            binary_data=binary_data)
    return result


def parse_pem(text, pem_type=None, max_items=None):
    '''Scan text for PEM data, return list of PEMParseResult

    pem_type operates as a filter on the type of PEM desired. If
    pem_type is specified only those PEM blocks which match will be
    included. The pem_type is a logical name, not the actual text in
    the pem header (e.g. 'cert'). If the pem_type is None all PEM
    blocks are returned.

    If max_items is specified the result is limited to that number of
    items.

    The return value is a list of PEMParseResult named tuples.  The
    PEMParseResult provides complete information about the PEM block
    including the decoded binary data for the PEM block.  The list is
    ordered in the same order as found in the text.

    Examples:

        # Get all certs
        certs = parse_pem(text, 'cert')

        # Get the first cert
        try:
            binary_cert = parse_pem(text, 'cert', 1)[0].binary_data
        except IndexError:
            raise ValueError('no cert found')

    :param string text: The text to search for PEM blocks
    :param string pem_type: Only return data for this pem_type.
                            Valid types are: csr, cert, crl, cms, key.
                            If pem_type is None no filtering is performed.
    :param int max_items: Limit the number of blocks returned.
    :returns: List of PEMParseResult, one for each PEM block found
    '''

    pem_blocks = []
    result = []
    start = 0

    while True:
        b = pem_search(text, start)
        if b is None:
            break
        start = b.pem_end
        if pem_type is None:
            pem_blocks.append(b)
        else:
            try:
                if pem_headers[pem_type] == b.pem_type:
                    pem_blocks.append(b)
            except KeyError:
                raise ValueError('unknown pem_type: %s' % (pem_type))

        if max_items is not None and len(pem_blocks) >= max_items:
            break

    return pem_blocks

# ------------------------- SAML Metadata Utilities ---------------------------

def get_sp_assertion_consumer_url(metadata_file, entity_id=None,
                                  binding=None):
    '''Retrieve AssertionConsumerURL(s) from SP metadata

    Read and parse the SAML metadata contained in metadata_file.

    If the entity_id is supplied then select the SP matching it,
    this is useful when the metadata contains multiple SP's. If the
    entity_id is not supplied then there must be exactly 1 SP in the
    metadata, that one will be selected.

    If the SAML endpoint binding is supplied then only
    AssertionConsumerServiceURL's matching that binding will be returned,
    otherwise all AssertionConsumerURL's will be returned.

    The return value is a list of AssertionConsumerServiceURL's in the order
    found in the metadata.

    :param metadata_file:        Pathname of SAML Metadata file
    :param entity_id (optional): EntityID of SP
    :param binding (optional):   Filter matching this binding
    :return:                     List of AssertionConsumerServiceURL's
    '''

    namespaces = dict(md='urn:oasis:names:tc:SAML:2.0:metadata',
                      saml='urn:oasis:names:tc:SAML:2.0:assertion',
                      ds='http://www.w3.org/2000/09/xmldsig#')

    root = etree.parse(metadata_file).getroot()

    if True or not entity_id:
        # If entity_id was not supplied locate a unique SPSSODescriptor
        xpath = ('//md:EntityDescriptor/md:SPSSODescriptor')
        sp = root.xpath(xpath, namespaces=namespaces)
        if len(sp) == 0:
            raise ValueError('entity_id not supplied and no '
                             'SPSSODescriptor was found')
        elif len(sp) > 1:
            raise ValueError('entity_id not supplied and multiple '
                             'SPSSODescriptor elements were found')

        xpath = ('ancestor::md:EntityDescriptor')
        ed = sp[0].xpath(xpath, namespaces=namespaces)

        entity_id =  ed[0].attrib['entityID']

    else:
        xpath = ('//md:EntityDescriptor[@entityID="{entity_id}"]'
                 '/md:SPSSODescriptor'.format(entity_id=entity_id))

        sp = root.xpath(xpath, namespaces=namespaces)
        if len(sp) == 0:
            raise IndexError('SPSSODescriptor with EntityID="{entity_id}" '
                             'not found'.format(entity_id=entity_id))
        elif len(sp) > 1:
            raise ValueError('multiple SPSSODescriptor with '
                             'EntityID="{entity_id}" found'.format(
                                 entity_id=entity_id))
    sp = sp[0]

    if not binding:
        xpath = 'md:AssertionConsumerService'
        acs = sp.xpath(xpath, namespaces=namespaces)
        urls = [x.attrib['Location'] for x in acs]
    else:
        xpath = 'md:AssertionConsumerService[@Binding="{binding}"]'.format(
            binding=SAML_PAOS_BINDING)
        acs = sp.xpath(xpath, namespaces=namespaces)
        urls = [x.attrib['Location'] for x in acs]

    return urls


def get_entity_id_from_metadata(metadata_file, role):
    '''Retrieve entityID from metadata

    Read and parse the SAML metadata contained in metadata_file,
    search for one of the following roles and return the entityID
    associated with that role.

    SSO Identity Provider (role='idp')
    SSO Service Provider (role='sp')
    Authentication Authority (role='authn_authority')
    Attribute Authority (role='attr_authority)
    Policy Decision Point (role='pdp')

    :param metadata_file:        Pathname of SAML Metadata file
    :param role:                 one of: idp, sp, authn_authority,
                                 attr_authority, pdp
    :return:                     entityID
    '''

    roles = {'idp':             'IDPSSODescriptor',
             'sp':              'SPSSODescriptor',
             'authn_authority': 'AuthnAuthorityDescriptor',
             'attr_authority':  'AttributeAuthorityDescriptor',
             'pdp':             'PDPDescriptor'}


    role_descriptor = roles.get(role)
    if role_descriptor is None:
        raise ValueError("invalid role '%s', must be one of: %s" %
                         (role, ', '.join(sorted(roles.keys()))))

    namespaces = dict(md='urn:oasis:names:tc:SAML:2.0:metadata',
                      saml='urn:oasis:names:tc:SAML:2.0:assertion',
                      ds='http://www.w3.org/2000/09/xmldsig#')

    root = etree.parse(metadata_file).getroot()
    
    xpath = '//md:EntityDescriptor/md:%s' % role_descriptor
    entity = root.xpath(xpath, namespaces=namespaces)
    if len(entity) == 0:
        raise ValueError('no %s found' % role_descriptor)
    elif len(entity) > 1:
        raise ValueError('multiple EntityDescriptor elements found')

    xpath = ('ancestor::md:EntityDescriptor')
    ed = entity[0].xpath(xpath, namespaces=namespaces)

    entity_id =  ed[0].attrib['entityID']
    return entity_id

# -------------------- Certificate Creation & Installation --------------------


def load_cert_from_file(filename, format='base64_text'):
    '''Load a cert from a file, return as either base64 text or binary.

    :param string filename: The input file to read the cert from.
    :param string format: One of: 'base64_text', 'binary'
    :returns: cert in requested format
    '''
    with open(filename, 'r') as f:
        data = f.read()

    certs = parse_pem(data, 'cert')

    if len(certs) == 0:
        raise ValueError('No cert found in {filename}'.format(
            filename=filename))

    if len(certs) > 1:
        raise ValueError('Multiple certs ({num_certs}) '
                         'found in {filename}'.format(
                             num_certs=len(certs),
                             filename=filename))

    if format == 'base64_text':
        return certs[0].base64_text
    if format == 'binary':
        return certs[0].binary
    else:
        raise ValueError('Uknown format "{format}"'.format(
            format=format))


def generate_cert(subject):
    '''Generate self-signed cert and key.

    A new self-signed cert and key is generated.
    The key and cert are returned as strings in PEM format.

    :param string subject: Certificate subject.
    :returns: key, cert as 2-tuple of PEM formatted strings
    '''

    tmpdir = tempfile.mkdtemp()
    key_file = os.path.join(tmpdir, 'key.pem')
    cert_file = os.path.join(tmpdir, 'cert.pem')
    try:
        openssl_subject = '/CN=%s' % subject
        cmd = ['openssl',
               'req', '-x509', '-batch', '-days', '1825',
               '-newkey', 'rsa:2048', '-nodes', '-subj', openssl_subject,
               '-keyout', key_file, '-out', cert_file]

        run_cmd(cmd)

        with open(key_file, 'r') as f:
            key = f.read()

        with open(cert_file, 'r') as f:
            cert = f.read()

    except Exception:
        raise
    finally:
        shutil.rmtree(tmpdir)

    return key, cert


def install_mellon_cert(options):
    if options.mellon_key_file or options.mellon_cert_file:
        if not (options.mellon_key_file and options.mellon_cert_file):
            raise ValueError('You must specify both a cert and key file, '
                             'not just one.')
        install_file(options.mellon_key_file, options.mellon_dst_key_file)
        install_file(options.mellon_cert_file, options.mellon_dst_cert_file)
    else:
        subject = options.mellon_hostname
        key, cert = generate_cert(subject)
        install_file_from_data(key, options.mellon_dst_key_file)
        install_file_from_data(cert, options.mellon_dst_cert_file)

# ---------------------------- Template Builders ------------------------------


def build_template_params(options):
    params = dict([(x, getattr(options, x)) for x in dir(options)
                   if not x.startswith('_')])
    return params


def build_mellon_httpd_config_file(options):
    template_params = build_template_params(options)
    template = template_env.get_template(MELLON_HTTPD_CONFIG_TEMPLATE)
    return template.render(template_params)


def build_mellon_sp_metadata_file(options):
    template_params = build_template_params(options)
    template = template_env.get_template(MELLON_METADATA_TEMPLATE)
    return template.render(template_params)

# ------------ Argparse Argument Conversion/Validation Functions --------------


def arg_type_mellon_endpoint(value):
    value = value.strip(' /')
    return value


def arg_type_mellon_protected_location(value):
    if not value.startswith('/'):
        raise argparse.ArgumentTypeError('Location must be absolute '
                                         '(arg="%s")' % value)
    return value

# -----------------------------------------------------------------------------


def main():
    global logger, template_env

    # ===== Command Line Arguments =====
    parser = argparse.ArgumentParser(
            description='Configure mod_auth_mellon as Keycloak client',
            prog=prog_name,
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # ---- Common Arguments ----

    parser.add_argument('--no-root-check', dest='root_check',
                        action='store_false',
                        help='permit running by non-root')

    parser.add_argument('-v', '--verbose', action='store_true',
                        help='be chatty')

    parser.add_argument('-d', '--debug', action='store_true',
                        help='turn on debug info')

    parser.add_argument('--show-traceback', action='store_true',
                        help='exceptions print traceback in addition to '
                             'error message')

    parser.add_argument('--log-file',
                        default=('/var/log/python-keycloak-httpd-client/'
                                 '{prog_name}.log'.
                                 format(prog_name=prog_name)),
                        help='log file pathname')

    parser.add_argument('--app-name',
                        required=True,
                        help='name of the web app being protected by mellon')

    parser.add_argument('--force', action='store_true',
                        help='forcefully override safety checks')

    parser.add_argument('--permit-insecure-transport',  action='store_true',
                        help='Normally secure transport such as TLS '
                        'is required, defeat this check')

    parser.add_argument('--tls-verify', action=keycloak_cli.TlsVerifyAction,
                        default=True,
                        help='TLS certificate verification for requests to'
                        ' the server. May be one of case insenstive '
                        '[true, yes, on] to enable,'
                        '[false, no, off] to disable.'
                        'Or the pathname to a OpenSSL CA bundle to use.'
                        ' Default is True.')

    # ---- Argument Group "Program Configuration"  ----

    group = parser.add_argument_group('Program Configuration')

    group.add_argument('--template-dir',
                       default=('/usr/share/'
                                'keycloak-httpd-client-install/templates'.
                                format(prog_name=prog_name)),
                       help='Template location')

    group.add_argument('--httpd-dir',
                       default='/etc/httpd',
                       help='Template location')

    # ---- Argument Group "Keycloak IdP"  ----

    group = parser.add_argument_group('Keycloak IdP')

    group.add_argument('-r', '--keycloak-realm',
                       required=True,
                       help='realm name')

    group.add_argument('-s', '--keycloak-server-url',
                       required=True,
                       help='Keycloak server URL')

    group.add_argument('-a', '--keycloak-auth-role',
                       choices=keycloak_cli.AUTH_ROLES,
                       default='root-admin',
                       help='authenticating as what type of user '
                       '(default: root-admin)')

    group.add_argument('-u', '--keycloak-admin-username', default='admin',
                       help='admin user name (default: admin)')

    group.add_argument('-p', '--keycloak-admin-password',
                       help='admin password (use - to read from stdin)')

    group.add_argument('--keycloak-admin-realm',
                       default='master',
                       help='realm admin belongs to')

    group.add_argument('--initial-access-token',
                       help='realm initial access token for '
                       'client registeration')

    group.add_argument('--client-originate-method',
                       choices=['descriptor', 'registration'],
                       default='descriptor',
                       help='select Keycloak method for creating SAML client')

    # ---- Argument Group "Mellon SP"  ----

    group = parser.add_argument_group('Mellon SP')

    group.add_argument('--mellon-key-file',
                       help='certficate key file')

    group.add_argument('--mellon-cert-file',
                       help='certficate file')

    group.add_argument('--mellon-hostname', default=socket.getfqdn(),
                       help="Machine's fully qualified host name")

    group.add_argument('--mellon-https-port', default=443, type=int,
                       help="SSL/TLS port on mellon-hostname")

    group.add_argument('--mellon-root', default='/',
                       help='common root ancestor for all mellon endpoints')

    group.add_argument('--mellon-endpoint', default='mellon',
                       type=arg_type_mellon_endpoint,
                       help='Used to form the MellonEndpointPath, e.g. '
                       '{mellon_root}/{mellon_endpoint}.')

    group.add_argument('--mellon-entity-id',
                       help='SP SAML Entity ID, '
                       'defaults to {mellon_http_url}/{mellon_endpoint_path}/metadata')

    group.add_argument('--mellon-idp-attr-name', default='IDP',
                       help='name of the attribute Mellon adds which will '
                       'contain the IdP entity id')

    group.add_argument('--mellon-organization-name',
                       help='Add SAML OrganizationName to SP metadata')

    group.add_argument('--mellon-organization-display-name',
                       help='Add SAML OrganizationDisplayName to SP metadata')

    group.add_argument('--mellon-organization-url',
                       help='Add SAML OrganizationURL to SP metadata')

    group.add_argument('-l', '--mellon-protected-locations', action='append',
                       type=arg_type_mellon_protected_location, default=[],
                       help='Web location to protect with Mellon. '
                            'May be specified multiple times')

    # ===== Process command line arguments =====

    options = parser.parse_args()

    # ===== Verify process permission =====

    if options.root_check and os.getuid() != 0:
        print("You must be root to run this program",
              file=sys.stderr)
        return STATUS_INSUFFICIENT_PRIVILEGE

    # ===== Configure Logging =====

    configure_logging(options)

    # ===== Options requiring special handling =====

    if options.keycloak_auth_role in ['root-admin', 'realm-admin']:
        if options.keycloak_admin_password is None:
            if (('KEYCLOAK_ADMIN_PASSWORD' in os.environ) and
                (os.environ['KEYCLOAK_ADMIN_PASSWORD'])):
                options.keycloak_admin_password = (
                    os.environ['KEYCLOAK_ADMIN_PASSWORD'])
            else:
                options.keycloak_admin_password = getpass.getpass(
                    '%s password: ' % (options.keycloak_admin_username))
        elif options.keycloak_admin_password == '-':
            options.keycloak_admin_password = sys.stdin.readline().rstrip('\n')

        if not options.keycloak_admin_password:
            parser.error('argument %s is required '
                         'unless passed in the environment '
                         'variable KEYCLOAK_ADMIN_PASSWORD' %
                         ('keycloak-admin-password'))

    # ===== Normalize Options =====

    options.mellon_root = '/' + options.mellon_root.strip(' /')
    options.mellon_endpoint = options.mellon_endpoint.strip(' /')

    # ===== Synthesize Derived Options =====

    options.httpd_saml_dir = os.path.join(options.httpd_dir, HTTPD_SAML_DIR)
    options.httpd_conf_dir = os.path.join(options.httpd_dir, HTTPD_CONF_DIR)
    options.mellon_httpd_config_filename = \
        os.path.join(options.httpd_conf_dir,
                     '{app_name}_mellon_keycloak_{realm}.conf'.format(
                         app_name=options.app_name,
                         realm=options.keycloak_realm))
    options.mellon_sp_metadata_filename = \
        os.path.join(options.httpd_saml_dir,
                     '{app_name}_{mellon_metadata}').format(
                         app_name=options.app_name,
                         mellon_metadata=MELLON_METADATA)
    options.mellon_dst_key_file = \
        os.path.join(options.httpd_saml_dir,
                     '{app_name}.key'.format(
                         app_name=options.app_name))
    options.mellon_dst_cert_file = \
        os.path.join(options.httpd_saml_dir, '{app_name}.cert'.format(
            app_name=options.app_name))
    options.mellon_dst_idp_metadata_file = \
        os.path.join(options.httpd_saml_dir,
                     '{app_name}_keycloak_{realm}_idp_metadata.xml'.format(
                         app_name=options.app_name,
                         realm=options.keycloak_realm))
    options.mellon_http_url = \
        normalize_url('https://{mellon_hostname}:{mellon_https_port}'.format(
            mellon_hostname=options.mellon_hostname,
            mellon_https_port=options.mellon_https_port))
    options.mellon_endpoint_path = join_path(options.mellon_root, options.mellon_endpoint)

    if not options.mellon_entity_id:
        url = urlsplit(options.mellon_http_url)
        options.mellon_entity_id = urlunsplit((url.scheme, url.netloc,
                                               join_path(options.mellon_endpoint_path, 'metadata'),
                                               '', ''))

    # ===== Validate Options =====

    if options.keycloak_auth_role == 'anonymous':
        if options.client_originate_method == 'registration':
            logger.warn("Using client originate method 'registration' with the 'anonymous'\n"
                        "authentication role disables updating the client configuration after\n"
                        "registration. You may need to adjust the client configuration manually\n"
                        "in the Keycloak admin console. Use one of the admin authentication\n"
                        "roles to permit automated client configuration.\n")

        if not options.initial_access_token:
            raise ValueError("You must supply an initial access token "
                             "with anonymous authentication")


    # ===== Establish Keycloak Server Communication =====

    try:
        logger.step('Connect to Keycloak Server')
        logger.info('Connecting to Keycloak server "%s"',
                    options.keycloak_server_url)
        if options.permit_insecure_transport:
            os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = '1'

        anonymous_conn = keycloak_cli.KeycloakAnonymousConnection(
            options.keycloak_server_url,
            options.tls_verify)

        if options.keycloak_auth_role in ['root-admin', 'realm-admin']:
            admin_conn = keycloak_cli.KeycloakAdminConnection(
                options.keycloak_server_url,
                options.keycloak_auth_role,
                options.keycloak_admin_realm,
                keycloak_cli.ADMIN_CLIENT_ID,
                options.keycloak_admin_username,
                options.keycloak_admin_password,
                options.tls_verify)
        else:
            admin_conn = None

    except Exception as e:
        if options.show_traceback:
            traceback.print_exc()
        print('%s: %s' % (e.__class__.__name__, six.text_type(e)),
              file=sys.stderr)
        return STATUS_COMMUNICATION_ERROR

    # ===== Assure required directories are present =====

    try:
        logger.step('Create Directories')
        mkdir(options.httpd_saml_dir)
        mkdir(options.httpd_conf_dir)
    except Exception as e:
        if options.show_traceback:
            traceback.print_exc()
        print('%s: %s' % (e.__class__.__name__, six.text_type(e)),
              file=sys.stderr)
        return STATUS_OPERATION_ERROR

    # ===== Create jinja2 Template Environment =====

    try:
        logger.step('Set up template environment')
        template_env = jinja2.Environment(trim_blocks=True,
                                          lstrip_blocks=True,
                                          keep_trailing_newline=True,
                                          undefined=jinja2.StrictUndefined,
                                          loader=jinja2.FileSystemLoader(
                                              options.template_dir))
    except Exception as e:
        if options.show_traceback:
            traceback.print_exc()
        print('%s: %s' % (e.__class__.__name__, six.text_type(e)),
              file=sys.stderr)
        return STATUS_CONFIGURATION_ERROR

    # ===== Congfigure Mellon  =====

    try:
        logger.step('Set up Service Provider X509 Certificiates')
        install_mellon_cert(options)

        cert_base64 = load_cert_from_file(options.mellon_dst_cert_file)
        options.sp_signing_cert = cert_base64
        options.sp_encryption_cert = cert_base64

        logger.step('Build Mellon httpd config file')
        mellon_httpd_config = build_mellon_httpd_config_file(options)
        install_file_from_data(mellon_httpd_config,
                               options.mellon_httpd_config_filename)

        logger.step('Build Mellon SP metadata file')
        mellon_sp_metadata = build_mellon_sp_metadata_file(options)
        install_file_from_data(mellon_sp_metadata,
                               options.mellon_sp_metadata_filename)
    except Exception as e:
        if options.show_traceback:
            traceback.print_exc()
        print('%s: %s' % (e.__class__.__name__, six.text_type(e)),
              file=sys.stderr)
        return STATUS_OPERATION_ERROR

    # ===== Configure Keycloak  =====

    try:
        if options.keycloak_auth_role == 'root-admin':
            logger.step('Query realms from Keycloak server')
            realms = admin_conn.get_realms()
            realm_names = keycloak_cli.get_realm_names_from_realms(realms)
            logger.info('existing realms [%s]', ', '.join(realm_names))

            if options.keycloak_realm not in realm_names:
                logger.step('Create realm on Keycloak server')
                logger.info('Create realm "%s"', options.keycloak_realm)
                admin_conn.create_realm(options.keycloak_realm)
            else:
                logger.step('Use existing realm on Keycloak server')

        if options.keycloak_auth_role in ['root-admin', 'realm-admin']:
            logger.step('Query realm clients from Keycloak server')
            clients = admin_conn.get_clients(options.keycloak_realm)
            client_ids = keycloak_cli.get_client_client_ids_from_clients(clients)
            logger.info('existing clients in realm %s = [%s]',
                        options.keycloak_realm, ', '.join(client_ids))

            if options.mellon_entity_id in client_ids:
                if options.force:
                    logger.step('Force delete client on Keycloak server')
                    logger.info('Delete client "%s"', options.mellon_entity_id)
                    admin_conn.delete_client_by_name(options.keycloak_realm,
                                                     options.mellon_entity_id)

                else:
                    raise AlreadyExistsError('client "{client_id}" '
                                             'already exists in realm "{realm}". '
                                             'Use --force to replace it.'.format(
                                                 client_id=options.mellon_entity_id,
                                                 realm=options.keycloak_realm))

        if options.client_originate_method == 'descriptor':
            logger.step('Creating new client from descriptor')
            logger.info('Create new client "%s"', options.mellon_entity_id)
            admin_conn.create_client(options.keycloak_realm,
                                     mellon_sp_metadata)
        elif options.client_originate_method == 'registration':

            if options.initial_access_token:
                logger.step('Use provided initial access token')
                initial_access_token = options.initial_access_token
            else:
                if options.keycloak_auth_role in ['root-admin', 'realm-admin']:
                    logger.step('Get new initial access token')
                    client_initial_access = admin_conn.get_initial_access_token(
                        options.keycloak_realm)
                    initial_access_token = client_initial_access['token']
                else:
                    raise ValueError("You must root or realm admin privileges "
                                     "to acquire an initial access token")

            logger.step('Creating new client using registration service')
            logger.info('Register new client "%s"', options.mellon_entity_id)

            try:
                anonymous_conn.register_client(initial_access_token,
                                               options.keycloak_realm,
                                               mellon_sp_metadata)
            except RESTError as e:
                if e.error_description == "Client Identifier in use":
                    raise AlreadyExistsError('client "{client_id}" '
                                             'already exists in realm "{realm}"'.format(
                                                 client_id=options.mellon_entity_id,
                                                 realm=options.keycloak_realm))
                else:
                    raise
        else:
            raise ValueError("Unknown client-originate-method = '%s'" %
                             options.client_originate_method)


        if options.keycloak_auth_role in ['root-admin', 'realm-admin']:
            # Enable Force Post Binding, registration service fails to
            # to enable it (however creation with client descriptor does)
            logger.step('Enable saml.force.post.binding')
            update_attrs = {'saml.force.post.binding': True}
            admin_conn.update_client_by_name_attributes(options.keycloak_realm,
                                                        options.mellon_entity_id,
                                                        update_attrs)

            logger.step('Add group attribute mapper to client')
            mapper = admin_conn.new_saml_group_protocol_mapper(
                'group list', 'groups',
                friendly_name='List of groups user is a member of')
            admin_conn.create_client_by_name_protocol_mapper(options.keycloak_realm,
                                                             options.mellon_entity_id,
                                                             mapper)

            logger.step('Add Redirect URIs to client')
            urls = get_sp_assertion_consumer_url(options.mellon_sp_metadata_filename,
                                                 entity_id=options.mellon_entity_id)
            admin_conn.add_client_by_name_redirect_uris(options.keycloak_realm,
                                                        options.mellon_entity_id,
                                                        urls)

        logger.step('Retrieve IdP metadata from Keycloak server')
        idp_metadata = anonymous_conn.get_realm_metadata(options.keycloak_realm)
        install_file_from_data(idp_metadata,
                               options.mellon_dst_idp_metadata_file)

    except AlreadyExistsError as e:
        if options.show_traceback:
            traceback.print_exc()
        print(six.text_type(e), file=sys.stderr)
        return STATUS_ALREADY_EXISTS_ERROR
    except Exception as e:
        if options.show_traceback:
            traceback.print_exc()
        print('%s: %s' % (e.__class__.__name__, six.text_type(e)),
              file=sys.stderr)
        return STATUS_OPERATION_ERROR

    # ===== Wrap Up =====

    logger.step('Completed Successfully')
    logger.info('mellon entityID="%s"' % options.mellon_entity_id)
    return STATUS_SUCCESS

# -----------------------------------------------------------------------------

if __name__ == '__main__':
    sys.exit(main())
