#!/usr/bin/python3

import argparse
import sys
import curses
import errno
import json
import signal

from collections import OrderedDict
from datetime import datetime
from enum import Enum, unique
from threading import Event

import rados


class FSTopException(Exception):
    def __init__(self, msg=''):
        self.error_msg = msg

    def get_error_msg(self):
        return self.error_msg


@unique
class MetricType(Enum):
    METRIC_TYPE_NONE = 0
    METRIC_TYPE_PERCENTAGE = 1
    METRIC_TYPE_LATENCY = 2


FS_TOP_PROG_STR = 'cephfs-top'

# version match b/w fstop and stats emitted by mgr/stats
FS_TOP_SUPPORTED_VER = 1

ITEMS_PAD_LEN = 1
ITEMS_PAD = " " * ITEMS_PAD_LEN
DEFAULT_REFRESH_INTERVAL = 1
# min refresh interval allowed
MIN_REFRESH_INTERVAL = 0.5

# metadata provided by mgr/stats
FS_TOP_MAIN_WINDOW_COL_CLIENT_ID = "client_id"
FS_TOP_MAIN_WINDOW_COL_MNT_ROOT = "mount_root"
FS_TOP_MAIN_WINDOW_COL_MNTPT_HOST_ADDR = "mount_point@host/addr"

MAIN_WINDOW_TOP_LINE_ITEMS_START = [ITEMS_PAD,
                                    FS_TOP_MAIN_WINDOW_COL_CLIENT_ID,
                                    FS_TOP_MAIN_WINDOW_COL_MNT_ROOT]
MAIN_WINDOW_TOP_LINE_ITEMS_END = [FS_TOP_MAIN_WINDOW_COL_MNTPT_HOST_ADDR]

# adjust this map according to stats version and maintain order
# as emitted by mgr/stast
MAIN_WINDOW_TOP_LINE_METRICS = OrderedDict([
    ("CAP_HIT", MetricType.METRIC_TYPE_PERCENTAGE),
    ("READ_LATENCY", MetricType.METRIC_TYPE_LATENCY),
    ("WRITE_LATENCY", MetricType.METRIC_TYPE_LATENCY),
    ("METADATA_LATENCY", MetricType.METRIC_TYPE_LATENCY),
    ("DENTRY_LEASE", MetricType.METRIC_TYPE_PERCENTAGE),
    ("OPENED_FILES", MetricType.METRIC_TYPE_NONE),
    ("PINNED_ICAPS", MetricType.METRIC_TYPE_NONE),
    ("OPENED_INODES", MetricType.METRIC_TYPE_NONE),
])
MGR_STATS_COUNTERS = list(MAIN_WINDOW_TOP_LINE_METRICS.keys())

FS_TOP_VERSION_HEADER_FMT = '{prog_name} - {now}'
FS_TOP_CLIENT_HEADER_FMT = 'Client(s): {num_clients} - {num_mounts} FUSE, '\
    '{num_kclients} kclient, {num_libs} libcephfs'

CLIENT_METADATA_KEY = "client_metadata"
CLIENT_METADATA_MOUNT_POINT_KEY = "mount_point"
CLIENT_METADATA_MOUNT_ROOT_KEY = "root"
CLIENT_METADATA_IP_KEY = "IP"
CLIENT_METADATA_HOSTNAME_KEY = "hostname"
CLIENT_METADATA_VALID_METRICS_KEY = "valid_metrics"

GLOBAL_METRICS_KEY = "global_metrics"
GLOBAL_COUNTERS_KEY = "global_counters"


def calc_perc(c):
    if c[0] == 0 and c[1] == 0:
        return 0.0
    return round((c[0] / (c[0] + c[1])) * 100, 2)


def calc_lat(c):
    return round(c[0] + c[1] / 1000000000, 2)


def wrap(s, sl):
    """return a '+' suffixed wrapped string"""
    if len(s) < sl:
        return s
    return f'{s[0:sl-1]}+'


class FSTop(object):
    def __init__(self, args):
        self.rados = None
        self.stdscr = None  # curses instance
        self.client_name = args.id
        self.cluster_name = args.cluster
        self.conffile = args.conffile
        self.refresh_interval_secs = args.delay
        self.exit_ev = Event()

    def refresh_window_size(self):
        self.height, self.width = self.stdscr.getmaxyx()

    def handle_signal(self, signum, _):
        self.exit_ev.set()

    def init(self):
        try:
            if self.conffile:
                r_rados = rados.Rados(rados_id=self.client_name, clustername=self.cluster_name,
                                      conffile=self.conffile)
            else:
                r_rados = rados.Rados(rados_id=self.client_name, clustername=self.cluster_name)
            r_rados.conf_read_file()
            r_rados.connect()
            self.rados = r_rados
        except rados.Error as e:
            if e.errno == errno.ENOENT:
                raise FSTopException(f'cluster {self.cluster_name} does not exist')
            else:
                raise FSTopException(f'error connecting to cluster: {e}')
        self.verify_perf_stats_support()
        signal.signal(signal.SIGTERM, self.handle_signal)
        signal.signal(signal.SIGINT, self.handle_signal)

    def fini(self):
        if self.rados:
            self.rados.shutdown()
            self.rados = None

    def selftest(self):
        stats_json = self.perf_stats_query()
        if not stats_json['version'] == FS_TOP_SUPPORTED_VER:
            raise FSTopException('perf stats version mismatch!')
        missing = [m for m in stats_json["global_counters"] if m.upper() not in MGR_STATS_COUNTERS]
        if missing:
            raise FSTopException('Cannot handle unknown metrics from \'ceph fs perf stats\': '
                                 f'{missing}')

    def setup_curses(self, win):
        self.stdscr = win
        curses.use_default_colors()
        curses.start_color()
        try:
            curses.curs_set(0)
        except curses.error:
            # If the terminal do not support the visibility
            # requested it will raise an exception
            pass
        self.run_display()

    def verify_perf_stats_support(self):
        mon_cmd = {'prefix': 'mgr module ls', 'format': 'json'}
        try:
            ret, buf, out = self.rados.mon_command(json.dumps(mon_cmd), b'')
        except Exception as e:
            raise FSTopException(f'error checking \'stats\' module: {e}')
        if ret != 0:
            raise FSTopException(f'error checking \'stats\' module: {out}')
        if 'stats' not in json.loads(buf.decode('utf-8'))['enabled_modules']:
            raise FSTopException('\'stats\' module not enabled. Use \'ceph mgr module '
                                 'enable stats\' to enable')

    def perf_stats_query(self):
        mgr_cmd = {'prefix': 'fs perf stats', 'format': 'json'}
        try:
            ret, buf, out = self.rados.mgr_command(json.dumps(mgr_cmd), b'')
        except Exception as e:
            raise FSTopException(f'error in \'perf stats\' query: {e}')
        if ret != 0:
            raise FSTopException(f'error in \'perf stats\' query: {out}')
        return json.loads(buf.decode('utf-8'))

    def items(self, item):
        if item == "CAP_HIT":
            return "chit"
        if item == "READ_LATENCY":
            return "rlat"
        if item == "WRITE_LATENCY":
            return "wlat"
        if item == "METADATA_LATENCY":
            return "mlat"
        if item == "DENTRY_LEASE":
            return "dlease"
        if item == "OPENED_FILES":
            return "ofiles"
        if item == "PINNED_ICAPS":
            return "oicaps"
        if item == "OPENED_INODES":
            return "oinodes"
        else:
            # return empty string for none type
            return ''

    def mtype(self, typ):
        if typ == MetricType.METRIC_TYPE_PERCENTAGE:
            return "(%)"
        elif typ == MetricType.METRIC_TYPE_LATENCY:
            return "(s)"
        else:
            # return empty string for none type
            return ''

    def refresh_top_line_and_build_coord(self):
        if self.topl is None:
            return

        xp = 0
        x_coord_map = {}

        heading = []
        for item in MAIN_WINDOW_TOP_LINE_ITEMS_START:
            heading.append(item)
            nlen = len(item) + len(ITEMS_PAD)
            x_coord_map[item] = (xp, nlen)
            xp += nlen

        for item, typ in MAIN_WINDOW_TOP_LINE_METRICS.items():
            it = f'{self.items(item)}{self.mtype(typ)}'
            heading.append(it)
            nlen = len(it) + len(ITEMS_PAD)
            x_coord_map[item] = (xp, nlen)
            xp += nlen

        for item in MAIN_WINDOW_TOP_LINE_ITEMS_END:
            heading.append(item)
            nlen = len(item) + len(ITEMS_PAD)
            x_coord_map[item] = (xp, nlen)
            xp += nlen
        title = ITEMS_PAD.join(heading)
        hlen = min(self.width - 2, len(title))
        self.topl.addnstr(0, 0, title, hlen, curses.A_STANDOUT | curses.A_BOLD)
        self.topl.refresh()
        return x_coord_map

    @staticmethod
    def has_metric(metadata, metrics_key):
        return metrics_key in metadata

    @staticmethod
    def has_metrics(metadata, metrics_keys):
        for key in metrics_keys:
            if not FSTop.has_metric(metadata, key):
                return False
        return True

    def refresh_client(self, client_id, metrics, counters, client_meta, x_coord_map, y_coord):
        remaining_hlen = self.width - 1
        for item in MAIN_WINDOW_TOP_LINE_ITEMS_START:
            coord = x_coord_map[item]
            hlen = coord[1] - len(ITEMS_PAD)
            hlen = min(hlen, remaining_hlen)
            if remaining_hlen < coord[1]:
                remaining_hlen = 0
            else:
                remaining_hlen -= coord[1]
            if item == FS_TOP_MAIN_WINDOW_COL_CLIENT_ID:
                self.mainw.addnstr(y_coord, coord[0],
                                   wrap(client_id.split('.')[1], hlen),
                                   hlen)
            elif item == FS_TOP_MAIN_WINDOW_COL_MNT_ROOT:
                if FSTop.has_metric(client_meta, CLIENT_METADATA_MOUNT_ROOT_KEY):
                    self.mainw.addnstr(y_coord, coord[0],
                                       wrap(client_meta[CLIENT_METADATA_MOUNT_ROOT_KEY], hlen),
                                       hlen)
                else:
                    self.mainw.addnstr(y_coord, coord[0], "N/A", hlen)

            if remaining_hlen == 0:
                return

        cidx = 0
        for item in counters:
            coord = x_coord_map[item]
            hlen = coord[1] - len(ITEMS_PAD)
            hlen = min(hlen, remaining_hlen)
            if remaining_hlen < coord[1]:
                remaining_hlen = 0
            else:
                remaining_hlen -= coord[1]
            m = metrics[cidx]
            typ = MAIN_WINDOW_TOP_LINE_METRICS[MGR_STATS_COUNTERS[cidx]]
            if item.lower() in client_meta.get(CLIENT_METADATA_VALID_METRICS_KEY, []):
                if typ == MetricType.METRIC_TYPE_PERCENTAGE:
                    self.mainw.addnstr(y_coord, coord[0], f'{calc_perc(m)}', hlen)
                elif typ == MetricType.METRIC_TYPE_LATENCY:
                    self.mainw.addnstr(y_coord, coord[0], f'{calc_lat(m)}', hlen)
                else:
                    # display 0th element from metric tuple
                    self.mainw.addnstr(y_coord, coord[0], f'{m[0]}', hlen)
            else:
                self.mainw.addnstr(y_coord, coord[0], "N/A", hlen)
            cidx += 1

            if remaining_hlen == 0:
                return

        for item in MAIN_WINDOW_TOP_LINE_ITEMS_END:
            coord = x_coord_map[item]
            hlen = coord[1] - len(ITEMS_PAD)
            # always place the FS_TOP_MAIN_WINDOW_COL_MNTPT_HOST_ADDR in the
            # last, it will be a very long string to display
            if item == FS_TOP_MAIN_WINDOW_COL_MNTPT_HOST_ADDR:
                if FSTop.has_metrics(client_meta, [CLIENT_METADATA_MOUNT_POINT_KEY,
                                                   CLIENT_METADATA_HOSTNAME_KEY,
                                                   CLIENT_METADATA_IP_KEY]):
                    self.mainw.addnstr(y_coord, coord[0],
                                       f'{client_meta[CLIENT_METADATA_MOUNT_POINT_KEY]}@'
                                       f'{client_meta[CLIENT_METADATA_HOSTNAME_KEY]}/'
                                       f'{client_meta[CLIENT_METADATA_IP_KEY]}',
                                       remaining_hlen)
                else:
                    self.mainw.addnstr(y_coord, coord[0], "N/A", remaining_hlen)
            hlen = min(hlen, remaining_hlen)
            if remaining_hlen < coord[1]:
                remaining_hlen = 0
            else:
                remaining_hlen -= coord[1]
            if remaining_hlen == 0:
                return

    def refresh_clients(self, x_coord_map, stats_json):
        counters = [m.upper() for m in stats_json[GLOBAL_COUNTERS_KEY]]
        y_coord = 0
        for client_id, metrics in stats_json[GLOBAL_METRICS_KEY].items():
            self.refresh_client(client_id,
                                metrics,
                                counters,
                                stats_json[CLIENT_METADATA_KEY][client_id],
                                x_coord_map,
                                y_coord)
            y_coord += 1

    def refresh_main_window(self, x_coord_map, stats_json):
        if self.mainw is None:
            return
        self.refresh_clients(x_coord_map, stats_json)
        self.mainw.refresh()

    def refresh_header(self, stats_json):
        hlen = self.width - 2
        if not stats_json['version'] == FS_TOP_SUPPORTED_VER:
            self.header.addnstr(0, 0, 'perf stats version mismatch!', hlen)
            return False
        client_metadata = stats_json[CLIENT_METADATA_KEY]
        num_clients = len(client_metadata)
        num_mounts = len([client for client, metadata in client_metadata.items() if
                          CLIENT_METADATA_MOUNT_POINT_KEY in metadata
                          and metadata[CLIENT_METADATA_MOUNT_POINT_KEY] != 'N/A'])
        num_kclients = len([client for client, metadata in client_metadata.items() if
                            "kernel_version" in metadata])
        num_libs = num_clients - (num_mounts + num_kclients)
        now = datetime.now().ctime()
        self.header.addnstr(0, 0,
                            FS_TOP_VERSION_HEADER_FMT.format(prog_name=FS_TOP_PROG_STR, now=now),
                            hlen, curses.A_STANDOUT | curses.A_BOLD)
        self.header.addnstr(1, 0, FS_TOP_CLIENT_HEADER_FMT.format(num_clients=num_clients,
                                                                  num_mounts=num_mounts,
                                                                  num_kclients=num_kclients,
                                                                  num_libs=num_libs), hlen)
        self.header.refresh()
        return True

    def run_display(self):
        while not self.exit_ev.is_set():
            # use stdscr.clear() instead of clearing each window
            # to avoid screen blinking.
            self.stdscr.clear()
            self.refresh_window_size()
            if self.width <= 2 or self.width <= 2:
                self.exit_ev.wait(timeout=self.refresh_interval_secs)
                continue

            # coordinate constants for windowing -- (height, width, y, x)
            # NOTE: requires initscr() call before accessing COLS, LINES.
            try:
                HEADER_WINDOW_COORD = (2, self.width - 1, 0, 0)
                self.header = curses.newwin(*HEADER_WINDOW_COORD)
                if self.height >= 3:
                    TOPLINE_WINDOW_COORD = (1, self.width - 1, 3, 0)
                    self.topl = curses.newwin(*TOPLINE_WINDOW_COORD)
                else:
                    self.topl = None
                if self.height >= 5:
                    MAIN_WINDOW_COORD = (self.height - 4, self.width - 1, 4, 0)
                    self.mainw = curses.newwin(*MAIN_WINDOW_COORD)
                else:
                    self.mainw = None
            except curses.error:
                # this may happen when creating the sub windows the
                # terminal window size changed, just retry it
                continue

            stats_json = self.perf_stats_query()
            try:
                if self.refresh_header(stats_json):
                    x_coord_map = self.refresh_top_line_and_build_coord()
                    self.refresh_main_window(x_coord_map, stats_json)
                self.exit_ev.wait(timeout=self.refresh_interval_secs)
            except curses.error:
                # this may happen when addstr the terminal window
                # size changed, just retry it
                pass


if __name__ == '__main__':
    def float_greater_than(x):
        value = float(x)
        if value < MIN_REFRESH_INTERVAL:
            raise argparse.ArgumentTypeError(f'{value} should be greater than '
                                             f'{MIN_REFRESH_INTERVAL}')
        return value

    parser = argparse.ArgumentParser(description='Ceph Filesystem top utility')
    parser.add_argument('--cluster', nargs='?', const='ceph', default='ceph',
                        help='Ceph cluster to connect (defualt: ceph)')
    parser.add_argument('--id', nargs='?', const='fstop', default='fstop',
                        help='Ceph user to use to connection (default: fstop)')
    parser.add_argument('--conffile', nargs='?', default=None,
                        help='Path to cluster configuration file')
    parser.add_argument('--selftest', dest='selftest', action='store_true',
                        help='run in selftest mode')
    parser.add_argument('-d', '--delay', nargs='?', default=DEFAULT_REFRESH_INTERVAL,
                        type=float_greater_than, help='Interval to refresh data '
                        f'(default: {DEFAULT_REFRESH_INTERVAL})')

    args = parser.parse_args()
    err = False
    ft = FSTop(args)
    try:
        ft.init()
        if args.selftest:
            ft.selftest()
            sys.stdout.write("selftest ok\n")
        else:
            curses.wrapper(ft.setup_curses)
    except FSTopException as fst:
        err = True
        sys.stderr.write(f'{fst.get_error_msg()}\n')
    except Exception as e:
        err = True
        sys.stderr.write(f'exception: {e}\n')
    finally:
        ft.fini()
    sys.exit(0 if not err else -1)
