# Copyright Red Hat 2017, Jake Hunsaker <jhunsake@redhat.com>
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import os
import fnmatch

from pipes import quote
from soscollector.clusters import Cluster

ENGINE_KEY = '/etc/pki/ovirt-engine/keys/engine_id_rsa'


class ovirt(Cluster):

    cluster_name = 'Community oVirt'
    packages = ('ovirt-engine',)
    db_exec = '/usr/share/ovirt-engine/dbscripts/engine-psql.sh -c'

    option_list = [
        ('no-database', False, 'Do not collect a database dump'),
        ('cluster', '', 'Only collect from hosts in this cluster'),
        ('datacenter', '', 'Only collect from hosts in this datacenter'),
        ('no-hypervisors', False, 'Do not collect from hypervisors'),
        ('spm-only', False, 'Only collect from SPM host(s)')
    ]

    def _run_db_query(self, query):
        '''
        Wrapper for running DB queries on the master. Any scrubbing of the
        query should be done _before_ passing the query to this method.
        '''
        cmd = "%s %s" % (self.db_exec, quote(query))
        return self.exec_master_cmd(cmd, need_root=True)

    def _sql_scrub(self, val):
        '''
        Manually sanitize SQL queries since we can't leave this up to the
        driver since we do not have an actual DB connection
        '''
        if not val:
            return '%'

        invalid_chars = ['\x00', '\\', '\n', '\r', '\032', '"', '\'']
        if any(x in invalid_chars for x in val):
            self.log_warn("WARNING: Cluster option \'%s\' contains invalid "
                          "characters. Using '%%' instead." % val)
            return '%'

        return val

    def _check_for_engine_keys(self):
        '''
        Checks for the presence of the VDSM ssh keys the manager uses for
        communication with hypervisors.

        This only runs if we're locally on the RHV-M, *and* if no ssh-keys are
        called out on the command line, *and* no --password option is given.
        '''
        if self.master.local:
            if not any([self.config['ssh_key'], self.config['password'],
                        self.config['password_per_node']]):
                if self.master.file_exists(ENGINE_KEY):
                    self.config['ssh_key'] = ENGINE_KEY
                    self.log_debug("Found engine SSH key. User command line"
                                   " does not specify a key or password, using"
                                   " engine key.")

    def setup(self):
        self.pg_pass = False
        if not self.get_option('no-database'):
            self.conf = self.parse_db_conf()
        self.format_db_cmd()
        self._check_for_engine_keys()

    def format_db_cmd(self):
        cluster = self._sql_scrub(self.get_option('cluster'))
        datacenter = self._sql_scrub(self.get_option('datacenter'))
        self.dbquery = ("SELECT host_name from vds where cluster_id in "
                        "(select cluster_id FROM cluster WHERE name like '%s'"
                        " and storage_pool_id in (SELECT id FROM storage_pool "
                        "WHERE name like '%s'))" % (cluster, datacenter))
        if self.get_option('spm-only'):
            # spm_status is an integer with the following meanings
            # 0 - Normal (not SPM)
            # 1 - Contending (SPM election in progress, but is not SPM)
            # 2 - SPM
            self.dbquery += ' AND spm_status = 2'
        self.log_debug('Query command for ovirt DB set to: %s' % self.dbquery)

    def get_nodes(self):
        if self.get_option('no-hypervisors'):
            return []
        res = self._run_db_query(self.dbquery)
        if res['status'] == 0:
            nodes = res['stdout'].splitlines()[2:-1]
            return [n.split('(')[0].strip() for n in nodes]
        else:
            raise Exception('database query failed, return code: %s'
                            % res['status'])

    def run_extra_cmd(self):
        if not self.get_option('no-database') and self.conf:
            return self.collect_database()
        return False

    def parse_db_conf(self):
        conf = {}
        engconf = '/etc/ovirt-engine/engine.conf.d/10-setup-database.conf'
        res = self.exec_master_cmd('cat %s' % engconf, need_root=True)
        if res['status'] == 0:
            config = res['stdout'].splitlines()
            for line in config:
                try:
                    k = str(line.split('=')[0])
                    v = str(line.split('=')[1].replace('"', ''))
                    conf[k] = v
                except IndexError:
                    pass
            return conf
        return False

    def collect_database(self):
        sos_opt = (
                   '-k {plugin}.dbname={db} '
                   '-k {plugin}.dbhost={dbhost} '
                   '-k {plugin}.dbport={dbport} '
                   '-k {plugin}.username={dbuser} '
                   ).format(plugin='postgresql',
                            db=self.conf['ENGINE_DB_DATABASE'],
                            dbhost=self.conf['ENGINE_DB_HOST'],
                            dbport=self.conf['ENGINE_DB_PORT'],
                            dbuser=self.conf['ENGINE_DB_USER']
                            )
        cmd = ('PGPASSWORD={} /usr/sbin/sosreport --name=postgresql '
               '--batch -o postgresql {}'
               ).format(self.conf['ENGINE_DB_PASSWORD'], sos_opt)
        db_sos = self.exec_master_cmd(cmd, need_root=True)
        for line in db_sos['stdout'].splitlines():
            if fnmatch.fnmatch(line, '*sosreport-*tar*'):
                return line.strip()
        self.log_error('Failed to gather database dump')
        return False


class rhv(ovirt):

    cluster_name = 'Red Hat Virtualization'
    packages = ('rhevm', 'rhvm')
    sos_preset = 'rhv'

    def set_node_label(self, node):
        if node.address == self.master.address:
            return 'manager'
        if node.is_installed('ovirt-node-ng-nodectl'):
            return 'rhvh'
        else:
            return 'rhelh'


class rhhi_virt(rhv):

    cluster_name = 'Red Hat Hyperconverged Infrastructure - Virtualization'
    sos_plugins = ('gluster',)
    sos_plugin_options = {'gluster.dump': 'on'}
    sos_preset = 'rhv'

    def check_enabled(self):
        return (self.master.is_installed('rhvm') and self._check_for_rhhiv())

    def _check_for_rhhiv(self):
        ret = self._run_db_query('SELECT count(server_id) FROM gluster_server')
        if ret['status'] == 0:
            # if there are any entries in this table, RHHI-V is in use
            return ret['stdout'].splitlines()[2].strip() != '0'
        return False
