#!/usr/bin/python

# Copyright 2015 Red Hat Inc., Durham, North Carolina.
# All Rights Reserved.
#
# openscap-daemon is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 2.1 of the License, or
# (at your option) any later version.
#
# openscap-daemon is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.

# You should have received a copy of the GNU Lesser General Public License
# along with openscap-daemon.  If not, see <http://www.gnu.org/licenses/>.
#
# Authors:
#   Martin Preisler <mpreisle@redhat.com>

from openscap_daemon import dbus_utils
from openscap_daemon import oscap_helpers
from openscap_daemon import cli_helpers
from openscap_daemon import version

import dbus
import sys
import argparse
from datetime import datetime
import os.path
import json
import time
import io

if sys.version_info < (3,):
    import gobject
else:
    from gi.repository import GObject as gobject

try:
    import Atomic.util
    atomic_support = True
except:
    atomic_support = False


class TaskAccessor(object):
    def __init__(self):
        self._attributes = dict()

    @staticmethod
    def get_bool(values):
        TaskAccessor._expect_param_len(values, 1)
        bool_val = get_bool(values[0])
        return [bool_val,]

    @staticmethod
    def get_int(values):
        TaskAccessor._expect_param_len(values, 1)
        return [int(values[0])]

    @staticmethod
    def _expect_param_len(values, expected):
        length = len(values)
        if length != expected:
            raise ValueError(
                "Expected %d parameters, but %d were provided." %
                (expected, length)
            )

    @staticmethod
    def get_string(values):
        TaskAccessor._expect_param_len(values, 1)
        return [values[0]]

    def add_accessor(self, key, dbus_getter, dbus_setter, check=None, result_processor=None):
        self.add_getter(key, dbus_getter, result_processor)
        self.add_setter(key, dbus_setter, check)

    def add_getter(self, key, dbus_getter, result_processor=None):
        self._attributes["get-%s" % key] = (dbus_getter, None, result_processor)

    def add_setter(self, key, dbus_setter, check=None):
        self._attributes["set-%s" % key] = (dbus_setter, check, None)

    def eval(self, dbus_iface, key, task_id, args):
        record = self._attributes[key]
        method_name = record[0]
        cast_func = record[1]
        result_func = record[2]
        if cast_func:
            casted_args = cast_func(args)
        else:
            casted_args = TaskAccessor.get_string(args)

        casted_args.insert(0, task_id)

        dbus_method = getattr(dbus_iface, method_name)
        res = dbus_method(*casted_args)
        if result_func:
            result_func(res)

    def __contains__(self, key):
        return key in self._attributes

    def get_allowed(self):
        return self._attributes.keys()


def get_dbus_interface():
    bus = dbus_utils.get_dbus()

    if bus is None:
        return None

    obj = bus.get_object(
        dbus_utils.BUS_NAME,
        dbus_utils.OBJECT_PATH
    )

    if obj is None:
        return None

    return dbus.Interface(obj, dbus_utils.DBUS_INTERFACE)


def cli_status(dbus_iface, args):
    async_status = dbus_iface.GetAsyncActionsStatus()
    print(async_status)


def cli_eval(dbus_iface, args):
    eval_spec = cli_helpers.cli_create_evaluation_spec(dbus_iface)
    if eval_spec is not None:
        token = dbus_iface.EvaluateSpecXMLAsync(eval_spec.to_xml_source())
        try:
            print("Evaluating...")
            while True:
                success, arf, stdout, stderr, exit_code = \
                    dbus_iface.GetEvaluateSpecXMLAsyncResults(token)

                if success:
                    if args.results_arf:
                        args.results_arf.write(arf)
                        args.results_arf.close()

                    if args.stdout:
                        args.stdout.write(stdout)
                        args.stdout.close()

                    if args.stderr:
                        args.stderr.write(stderr)
                        args.stderr.close()

                    # TODO: show the results
                    break
                time.sleep(1)
        except:
            dbus_iface.CancelEvaluateSpecXMLAsync(token)
            raise


def cli_print_results_table(dbus_iface, task_id, result_ids,
                            max_items=sys.maxsize):
    table = [["ID", "Timestamp", "Status"]]

    for result_id in result_ids[:max_items]:
        exit_code = dbus_iface.GetExitCodeOfTaskResult(
            task_id, result_id
        )
        status = oscap_helpers.get_status_from_exit_code(exit_code)
        timestamp = dbus_iface.GetResultCreatedTimestamp(task_id, result_id)

        table.append([str(result_id), datetime.fromtimestamp(timestamp), status])

    cli_helpers.print_table(table)

    if max_items < len(result_ids):
        print("... and %i more" % (len(result_ids) - max_items))


def cli_task(dbus_iface, task_accessor, args):
    if args.task_id is None:
        # args.task_action is ignored in this scope

        table = [["ID", "Title", "Target", "Modified", "Enabled"]]
        task_ids = dbus_iface.ListTaskIDs()

        enabled_count = 0
        for task_id in task_ids:
            title = dbus_iface.GetTaskTitle(task_id)
            target = dbus_iface.GetTaskTarget(task_id)

            modified_timestamp = dbus_iface.GetTaskModifiedTimestamp(task_id)
            modified = datetime.fromtimestamp(modified_timestamp)

            enabled = dbus_iface.GetTaskEnabled(task_id)
            if enabled:
                enabled_count += 1

            table.append([
                str(task_id),
                title,
                target,
                modified,
                # TODO: Maybe we can show the disabled state in a better way?
                "enabled" if enabled else "disabled"
            ])

        cli_helpers.print_table(table)
        print("")
        print("Found %i tasks, %i of them enabled." %
              (len(task_ids), enabled_count))

    else:
        if args.task_action == "info":
            title = dbus_iface.GetTaskTitle(args.task_id)
            target = dbus_iface.GetTaskTarget(args.task_id)

            created_timestamp = dbus_iface.GetTaskCreatedTimestamp(args.task_id)
            created = datetime.fromtimestamp(created_timestamp)

            modified_timestamp = dbus_iface.GetTaskModifiedTimestamp(args.task_id)
            modified = datetime.fromtimestamp(modified_timestamp)

            table = []
            table.append(["Title", title])
            table.append(["ID", str(args.task_id)])
            table.append(["Target", target])
            table.append(["Created", created])
            table.append(["Modified", modified])
            cli_helpers.print_table(table, first_row_header=False)
            print("")

            result_ids = dbus_iface.GetTaskResultIDs(args.task_id)
            if len(result_ids) > 0:
                print("Latest results:")
                cli_print_results_table(dbus_iface, args.task_id, result_ids, 5)
                print("")

            if not dbus_iface.GetTaskEnabled(args.task_id):
                print("This task is currently disabled. Enable it by calling:")
                print("$ oscapd-cli task %i enable" % (args.task_id))
            # TODO

        elif args.task_action == "guide":
            guide = dbus_iface.GenerateGuideForTask(args.task_id)
            print(guide)

        elif args.task_action == "run":
            dbus_iface.RunTaskOutsideSchedule(args.task_id)

        elif args.task_action == "enable":
            dbus_iface.SetTaskEnabled(args.task_id, True)

        elif args.task_action == "disable":
            dbus_iface.SetTaskEnabled(args.task_id, False)

        elif args.task_action == "remove":
            if args.force or confirm("Do you really want to delete task with ID %i?" % args.task_id):
                dbus_iface.RemoveTask(args.task_id, args.remove_results)
        elif args.task_action in task_accessor:
            try:
                task_accessor.eval(
                    dbus_iface, args.task_action, args.task_id,
                    args.parameters[0]
                )
            except ValueError as e:
                sys.stderr.write("%s\n" % (e))
                sys.exit(1)
        else:
            # throwing exception here, this code should never be executed if
            # argparse does its job
            raise RuntimeError("Unknown action '%s'." % (args.task_action))


def cli_task_create(dbus_iface, args):
    if args.interactive:
        print("Creating new task in interactive mode")

        title = cli_helpers.py2_raw_input("Title: ")
        target = cli_helpers.py2_raw_input("Target (empty for localhost): ")
        if not target:
            target = "localhost"

        input_ssg_choice = ""
        ssg_choices = dbus_iface.GetSSGChoices()
        if ssg_choices:
            print("Found the following SCAP Security Guide content: ")
            i = 0
            for ssg_choice in ssg_choices:
                print("\t%i:  %s" % (i + 1, ssg_choice))
                i += 1

            input_file = None
            input_ssg_choice = cli_helpers.py2_raw_input(
                "Choose SSG content by number (empty for custom content): ")

        if not input_ssg_choice:
            input_file = cli_helpers.py2_raw_input("Input file (absolute path): ")
        else:
            input_file = ssg_choices[int(input_ssg_choice) - 1]

        if not input_file:
            sys.stderr.write(
                "You have to provide an SCAP input file for the task!\n"
            )
            sys.exit(1)

        if not os.path.isabs(input_file):
            sys.stderr.write(
                "'%s' is not an absolute path. Please provide the absolute "
                "path that can be used to access the SCAP content on the "
                "machine running openscap-daemon.\n" % (input_file)
            )
            sys.exit(1)

        tailoring_file = cli_helpers.py2_raw_input(
            "Tailoring file (absolute path, empty for no tailoring): ")
        if tailoring_file in [None, ""]:
            tailoring_file = ""

        else:
            if not os.path.isabs(tailoring_file):
                sys.stderr.write(
                    "'%s' is not an absolute path. Please provide the absolute "
                    "path that can be used to access the tailoring file on the "
                    "machine running openscap-daemon.\n" % (tailoring_file)
                )
                sys.exit(1)

        print("Found the following possible profiles: ")
        profile_choices = dbus_iface.GetProfileChoicesForInput(
            input_file, tailoring_file
        )
        i = 0
        for key, value in profile_choices.items():
            print("\t%i:  %s (id='%s')" % (i + 1, value, key))
            i += 1

        profile_choice = cli_helpers.py2_raw_input(
            "Choose profile by number (empty for (default) profile): ")
        if profile_choice:
            profile = list(profile_choices.keys())[int(profile_choice) - 1]
        else:
            profile = ""

        online_remediation_raw = \
            cli_helpers.py2_raw_input(
                "Online remediation (1, y or Y for yes, else no): "
            )
        try:
            online_remediation = get_bool(online_remediation_raw, default=False)
        except ValueError:
            pass

        print("Schedule: ")
        schedule_not_before = None
        schedule_not_before_str = \
            cli_helpers.py2_raw_input(
                " - not before (YYYY-MM-DD HH:MM in UTC, empty for NOW): "
            )
        if schedule_not_before_str == "":
            schedule_not_before = datetime.now()
        else:
            schedule_not_before = datetime.strptime(
                schedule_not_before_str, "%Y-%m-%d %H:%M"
            )

        schedule_repeat_after = None
        schedule_repeat_after_str = \
            cli_helpers.py2_raw_input(
                " - repeat after (hours or @daily, @weekly, @monthly, "
                "empty or 0 for no repeat): "
            )

        schedule_repeat_after = 0
        if not schedule_repeat_after_str:
            pass  # empty means no repeat
        elif schedule_repeat_after_str == "@daily":
            schedule_repeat_after = 1 * 24
        elif schedule_repeat_after_str == "@weekly":
            schedule_repeat_after = 7 * 24
        elif schedule_repeat_after_str == "@monthly":
            schedule_repeat_after = 30 * 24
        else:
            schedule_repeat_after = int(schedule_repeat_after_str)

        # most users need just drop_missed_aligned, we will not offer the
        # other options here
        # schedule_slip_mode = task.SlipMode.DROP_MISSED_ALIGNED

        task_id = dbus_iface.CreateTask()
        dbus_iface.SetTaskTitle(task_id, title)
        dbus_iface.SetTaskTarget(task_id, target)
        dbus_iface.SetTaskInput(task_id, input_file)
        dbus_iface.SetTaskTailoring(task_id, tailoring_file)
        dbus_iface.SetTaskProfileID(task_id, profile)
        dbus_iface.SetTaskOnlineRemediation(task_id, online_remediation)
        dbus_iface.SetTaskScheduleNotBefore(
            task_id, schedule_not_before.strftime("%Y-%m-%dT%H:%M")
        )
        dbus_iface.SetTaskScheduleRepeatAfter(task_id, schedule_repeat_after)

        print(
            "Task created with ID '%i'. It is currently set as disabled. "
            "You can enable it with `oscapd-cli task %i enable`." %
            (task_id, task_id)
        )
        # TODO: Setting Schedule SlipMode

    else:
        raise NotImplementedError("Not yet!")


def cli_result(dbus_iface, args):
    if args.result_id is None:
        task_title = dbus_iface.GetTaskTitle(args.task_id)

        print("Results of Task \"%s\", ID = %i" % (task_title, args.task_id))
        print("")

        result_ids = dbus_iface.GetTaskResultIDs(args.task_id)
        cli_print_results_table(dbus_iface, args.task_id, result_ids)

    elif args.result_id == "remove":
        if args.force or confirm("Do you really want to remove all results of task %d"
                                 % args.task_id):
            dbus_iface.RemoveTaskResults(args.task_id)
    else:
        if args.result_action == "arf":
            arf = dbus_iface.GetARFOfTaskResult(args.task_id, args.result_id)
            print(arf)

        elif args.result_action == "stdout":
            stdout = dbus_iface.GetStdOutOfTaskResult(
                args.task_id, args.result_id
            )
            print(stdout)

        elif args.result_action == "stderr":
            stderr = dbus_iface.GetStdErrOfTaskResult(
                args.task_id, args.result_id
            )
            print(stderr)

        elif args.result_action == "exit_code":
            exit_code = dbus_iface.GetExitCodeOfTaskResult(
                args.task_id, args.result_id
            )
            print("%i" % (exit_code))

        elif args.result_action == "report":
            report = dbus_iface.GenerateReportForTaskResult(
                args.task_id, args.result_id
            )
            print(report)
        elif args.result_action == "remove":
            if args.force or confirm("Do you really want to remove result %d from task %d"
                                     % (args.result_id, args.task_id)):
                dbus_iface.RemoveTaskResult(args.task_id, args.result_id)
        else:
            raise RuntimeError(
                "Unknown result action '%s'." % (args.result_action)
            )


def cli_scan(dbus_iface, args):
    if args.fetch_cves is None:
        fetch_cve = 2  # use defaults
    elif args.fetch_cves:
        fetch_cve = 1  # disable
    else:
        fetch_cve = 0  # enable

    threads_count = 4

    scan_targets = []

    any_target_specified = False
    if args.all or args.images:
        images = json.loads(dbus_iface.images())
        images_ids = [str(image["Id"]) for image in images]
        scan_targets.extend(images_ids)
        any_target_specified = True

    if args.all or args.containers:
        containers = json.loads(dbus_iface.containers())
        container_ids = [str(container["Id"]) for container in containers]
        scan_targets.extend(container_ids)
        any_target_specified = True

    if args.scan_targets:
        scan_targets.extend(args.scan_targets)  # todo do check if targets are valid
        any_target_specified = True

    if not any_target_specified:
        raise RuntimeError("No scan target")

    token = dbus_iface.CVEScanListAsync(
        scan_targets, threads_count, fetch_cve
    )
    try:
        print("Processing...")
        while True:
            success, scan_results = dbus_iface.GetCVEScanListAsyncResults(token)
            if success:
                break
            time.sleep(1)
    except:
        dbus_iface.CancelCVEScanListAsync(token)
        raise

    if args.json:
        print(scan_results)
    else:
        json_parsed = json.loads(scan_results)
        if args.detail:
            clean = Atomic.util.print_detail_scan_summary(
                json_parsed
            )

        else:
            if args.scan_targets:
                raise NotImplemented(
                    "This type of output is not implemented"
                    "for specified targets.\n"
                )
            clean = Atomic.util.print_scan_summary(
                json_parsed, scan_targets
            )

        if not clean:
            sys.exit(1)


def get_bool(val, default=False):
    val = val.lower()

    if not val:
        return default

    if val in ['n', '0', 'false', 'no']:
        return False

    if val in ['y', '1', 'true', 'yes']:
        return True

    raise ValueError("'%s' is not valid value, use y/n instead." % (val))


def confirm(prompt, default=False):
    options = "Y/n" if default else "y/N"
    while True:
        try:
            res = cli_helpers.py2_raw_input("%s [%s]: " % (prompt, options))
            return get_bool(res, default)
        except ValueError:
            continue
        except EOFError:
            sys.stderr.write("Operation aborted.\n")
            return default


def main():
    parser = argparse.ArgumentParser(
        description="OpenSCAP-Daemon command line interface."
    )
    parser.add_argument(
        "-v", "--version", action="version",
        version="%(prog)s " + version.VERSION_STRING
    )
    subparsers = parser.add_subparsers(dest="action")
    subparsers.required = True

    task_accessor = TaskAccessor()

    task_accessor.add_setter("enabled",
                             "SetTaskEnabled",
                             TaskAccessor.get_bool)
    task_accessor.add_setter("title", "SetTaskTitle")
    task_accessor.add_setter("target", "SetTaskTarget")
    task_accessor.add_setter("input", "SetTaskInput")
    task_accessor.add_setter("tailoring", "SetTaskTailoring")
    task_accessor.add_setter("profile-id", "SetTaskProfileID")
    task_accessor.add_setter("online-remediation",
                             "SetTaskOnlineRemediation",
                             TaskAccessor.get_bool)
    task_accessor.add_setter("schedule-not-before", "SetTaskScheduleNotBefore")
    task_accessor.add_setter("schedule-repeat-after",
                             "SetTaskScheduleRepeatAfter",
                             TaskAccessor.get_int)

    def add_eval_parser(subparsers):
        eval_parser = subparsers.add_parser(
            "eval",
            help="Interactive one-off evaluation of any target supported by "
                 "OpenSCAP Daemon"
        )
        eval_parser.add_argument(
              "--results-arf", dest="results_arf",
              type=lambda path: io.open(path, "w", encoding="utf-8"),
              help="Write ARF (result data stream) into file on this path."
        )
        eval_parser.add_argument(
              "--stdout",
              type=lambda path: io.open(path, "w", encoding="utf-8"),
              help="Write stdout from oscap into file on this path."
        )
        eval_parser.add_argument(
              "--stderr",
              type=lambda path: io.open(path, "w", encoding="utf-8"),
              help="Write stderr from oscap into file on this path."
        )
        # todo non-interactive
    add_eval_parser(subparsers)

    def add_task_parser(subparsers, task_accessor):
        task_parser = subparsers.add_parser(
            "task",
            help="Show info about tasks that have already been defined. "
            "Perform operations on already defined tasks."
        )
        task_parser.add_argument(
            "task_id", metavar="TASK_ID", type=int, nargs="?",
            help="ID of the task to display, or perform action on. If none is "
            "provided a summary of all tasks is displayed."
        )

        task_actions = ["info", "guide", "run", "enable", "disable", "remove"]
        task_actions += task_accessor.get_allowed()

        task_parser.add_argument(
            "task_action", metavar="ACTION", type=str,
            choices=task_actions,
            help="Which action to perform on selected task. Use one of " +
            ", ".join(task_actions),
            default="info", nargs="?"
        )

        task_parser.add_argument(
            "parameters", metavar="parameter", action="append", nargs="*",
            help="Parameters for the ACTION. For setter actions this is the "
            "string that you want to set. Some actions, such as enable, remove, "
            "... don't require any parameters."
        )

        task_parser.add_argument(
              "-f", "--force", help="remove task without confirmation",
              action="store_true"
        )
        task_parser.add_argument(
              "-r", "--remove-results", help="remove with results",
              action="store_true"
        )
    add_task_parser(subparsers, task_accessor)

    def add_task_create_parser(subparsers):
        task_create_parser = subparsers.add_parser(
            "task-create",
            help="Create new task."
        )
        task_create_parser.add_argument(
            "-i", "--interactive", action="store_true", dest="interactive", required=True
        )
    add_task_create_parser(subparsers)

    def add_status_parser(subparsers):
        status_parser = subparsers.add_parser(
            "status",
            help="Displays status, tasks that are planned and tasks that are "
            "being evaluated."
        )
    add_status_parser(subparsers)

    def result_id_or_action(val):
        if val == "remove":
            return "remove"

        try:
            return int(val)
        except ValueError:
            raise argparse.ArgumentTypeError("'%s' is not \"remove\" or integer"
                                             % (val))

    def add_result_parser(subparsers):
        result_parser = subparsers.add_parser(
            "result",
            help="Displays info about past results"
        )
        result_parser.add_argument(
            "task_id", metavar="TASK_ID", type=int
        )


        result_parser.add_argument(
            "result_id", metavar="RESULT_ID", type=result_id_or_action, nargs="?",
            help="ID of the result we want to interact with, if none is "
            "provided a summary of all results of given task is displayed."
        )

        result_actions = [
            "arf", "stdout", "stderr", "exit_code", "report", "remove"
        ]
        result_parser.add_argument(
            "result_action", metavar="ACTION", type=str,
            choices=result_actions,
            help="Which action to perform on selected result. Use one of " +
            ", ".join(result_actions),
            default="arf", nargs="?",
        )

        result_parser.add_argument(
              "-f", "--force", help="remove results without confirmation",
              action="store_true"
        )

    add_result_parser(subparsers)

    def add_scan_parser(subparsers):
        scan_parser = subparsers.add_parser(
            "scan", help="scan an image or container for CVEs",
            epilog="atomic scan <input> scans a container or image for CVEs"
        )

        scan_parser.add_argument(
            "scan_targets", nargs='*', help="container image"

        )
        scan_parser.add_argument(
            "--fetch_cves", type=get_bool, default=None
        )

        scan_out = scan_parser.add_mutually_exclusive_group()

        scan_out.add_argument(
            "--json", default=False, action='store_true',
            help="output json"
        )
        scan_out.add_argument(
            "--detail", default=False, action='store_true',
            help="output more detail"
        )

        scan_group = scan_parser.add_mutually_exclusive_group()
        scan_group.add_argument(
            "--all", default=False, action='store_true',
            help="scan all images (excluding intermediate layers) and containers"
        )
        scan_group.add_argument(
            "--images", default=False, action='store_true',
            help="scan all images (excluding intermediate layers"
        )
        scan_group.add_argument(
            "--containers", default=False, action='store_true',
            help="scan all containers"
        )

    if atomic_support:
        add_scan_parser(subparsers)

    args = parser.parse_args()

    gobject.threads_init()

    dbus_iface = None
    try:
        dbus_iface = get_dbus_interface()

    except:
        sys.stderr.write(
            "Error: Failed to connect to the SCAP Client DBus interface. "
            "Is the daemon running?\n\n"
        )
        raise

    try:
        oscapd_version_major, oscapd_version_minor, oscapd_version_patch = \
            dbus_iface.GetVersion()

        if (oscapd_version_major, oscapd_version_minor, oscapd_version_patch) \
                != (version.VERSION_MAJOR, version.VERSION_MINOR, version.VERSION_PATCH):
            sys.stderr.write(
                "Warning: Version mismatch between oscapd-cli and oscapd.\n")

    except dbus.exceptions.DBusException as e:
        if e.get_dbus_name() == "org.freedesktop.DBus.Error.UnknownMethod":
            sys.stderr.write(
                "Warning: Can't perform version check, the openscap-daemon dbus"
                " interface doesn't provide the GetVersion method.\n\n"
            )
        elif e.get_dbus_name() == "org.freedesktop.DBus.Error.AccessDenied":
            sys.stderr.write(
                "Error: Access denied on the DBus interface. "
                "Do you have required permissions?\n\n"
            )
            sys.exit(1)
        else:
            raise

    if args.action == "status":
        cli_status(dbus_iface, args)
    elif args.action == "eval":
        cli_eval(dbus_iface, args)
    elif args.action == "task":
        cli_task(dbus_iface, task_accessor, args)
    elif args.action == "task-create":
        cli_task_create(dbus_iface, args)
    elif args.action == "status":
        cli_status(dbus_iface, args)
    elif args.action == "result":
        cli_result(dbus_iface, args)
    elif atomic_support and args.action == "scan":
        cli_scan(dbus_iface, args)
    else:
        raise RuntimeError("Unknown action '%s'." % (args.action))


if __name__ == "__main__":
    main()
