#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# boltd mocking
#
# Copyright © 2017 Red Hat, Inc
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library. If not, see <http://www.gnu.org/licenses/>.
# Authors:
#       Christian J. Kellner <christian@kellner.me>

from __future__ import print_function

import argparse
import atexit
import os
import readline
import shlex
import shutil
import subprocess
import sys
import tempfile
import time
import uuid

try:
    import gi
    from gi.repository import GLib
    from gi.repository import Gio

    gi.require_version('UMockdev', '1.0')
    from gi.repository import UMockdev

except ImportError as e:
    sys.stderr.write('Missing dependencies: %s\n' % str(e))
    sys.exit(1)

DBUS_NAME = 'org.freedesktop.bolt'
DBUS_PATH = '/org/freedesktop/bolt'
DBUS_IFACE_PREFIX = 'org.freedesktop.bolt1.'
DBUS_IFACE_MANAGER = DBUS_IFACE_PREFIX + 'Manager'
DBUS_IFACE_DEVICE = DBUS_IFACE_PREFIX + 'Device'
DBUS_IFACE_DOMAIN = DBUS_IFACE_PREFIX + 'Domain'
SERVICE_FILE = '/usr/share/dbus-1/system-services/org.freedesktop.bolt.service'


class SysfsDev:
    known_attrs = []

    def __init__(self, sysname, syspath):
        self.sysname = sysname
        self.syspath = syspath

    @property
    def is_connected(self):
        return self.syspath is not None

    def sysattr_path(self, name):
        return os.path.join(self.syspath, name)

    def have_sysattr(self, name):
        path = self.sysattr_path(name)
        return os.path.isfile(path)

    def read_sysattr(self, name):
        path = self.sysattr_path(name)
        try:
            with open(path, 'r') as fd:
                raw = fd.read()
                return raw.strip()
        except FileNotFoundError:
            return None

    def dump(self, prefix='', file=sys.stdout):
        if self.sysname is None:
            print('%s disconnected' % prefix, file=file)
        else:
            print('%s%s @ %s' % (prefix, self.sysname, self.syspath), file=file)

    @staticmethod
    def bridge(klass):
        lst = getattr(klass, 'known_attrs', [])

        def install(attr):
            def getter(self):
                return self.read_sysattr(attr)
            prop = property(getter)
            setattr(klass, attr, prop)
        [install(attr) for attr in lst]
        return klass


@SysfsDev.bridge
class Domain(SysfsDev):
    known_attrs = ['boot_acl',
                   'security']

    def __init__(self, index, sysname, syspath):
        super(Domain, self).__init__(sysname, syspath)
        self.index = index
        self.parent = None
        self.serial = 100
        self.host = None

    @property
    def uid(self):
        return self.host and self.host.uid

    def gen_serial(self):
        s = self.serial
        self.serial += 1
        return s

    def find_device(self, name_or_id):
        if self.host is None:
            return None
        return self.host.find_device(name_or_id)

    def show(self, prefix='', file=sys.stdout):
        name = self.sysname or 'domain'
        uid = self.host.uid if self.host else ''
        print('%s%s %s' % (prefix, name, uid), file=file)
        pf = prefix + '  '
        super(Domain, self).dump(prefix=pf, file=file)
        security = self.security or 'unknown'
        print('%ssecurity %s' % (pf, security), file=file)


@SysfsDev.bridge
class Device(SysfsDev):
    known_attrs = ['authorized',
                   'boot',
                   'device',
                   'device_name',
                   'generation',
                   'key',
                   'unique_id',
                   'vendor',
                   'vendor_name']

    search_fields = ('uid',
                     'sysname',
                     'device_name')

    def __init__(self, uid, sysname, syspath):
        super(Device, self).__init__(sysname, syspath)
        self.uid = uid
        self.domain = None
        self.is_host = False
        self.children = []

    def find_device(self, identifier, fields=search_fields):
        for i in [getattr(self, f) for f in fields]:
            if i == identifier:
                return self
        for c in self.children:
            d = c.find_device(identifier)
            if d is not None:
                return d
        return None

    def show(self, prefix='', file=sys.stdout):
        print('%sdevice %s' % (prefix, self.uid), file=file)
        pf = prefix + '  '
        super(Device, self).dump(prefix=pf, file=file)
        if not self.is_connected or self.is_host:
            return
        print('%sauthorized %s' % (pf, self.authorized), file=file)


class MockSysfs:

    def __init__(self):
        self.testbed = UMockdev.Testbed.new()
        self.domains = {}

    @property
    def root(self):
        return self.testbed.get_root_dir()

    def _gen_domain_id(self):
        for i in range(1024):
            if ('domain%d' % i) not in self.domains:
                return i
        raise ValueError('too many domains in use')

    def find_domain(self, name_or_id):
        for name, dom in self.sysfs.domains.items():
            if target in (name, dom.uid):
                return dom
        return None

    def find_device(self, name_or_id):
        for name, dom in self.domains.items():
            d = dom.find_device(name_or_id)
            if d is not None:
                return d
        return None

    def domain_add(self, security, bootacl, iommu):
        index = self._gen_domain_id()
        sysname = 'domain%d' % index

        props = ['security', security]

        if isinstance(bootacl, int):
            bootacl = ',' * (bootacl - 1)

        if bootacl is not None:
            props += ['boot_acl', bootacl]

        if iommu is not None:
            props += ['iommu_dma_protection', str(iommu) + '\n']

        syspath = self.testbed.add_device('thunderbolt', sysname, None,
                                          props,
                                          ['DEVTYPE', 'thunderbolt_domain'])
        domain = Domain(index, sysname, syspath)
        self.domains[sysname] = domain
        return domain

    def host_add(self, domain, uid, name, vendor, generation):
        sysname = "%d-0" % domain.index
        attributes = ['device_name', name,
                      'device', '0x23',
                      'vendor_name', vendor,
                      'vendor', '0x23',
                      'authorized', '1',
                      'unique_id', uid]
        if generation is not None:
            attributes += ['generation', str(generation)]

        syspath = self.testbed.add_device('thunderbolt', sysname, domain.syspath,
                                          attributes,
                                          ['DEVTYPE', 'thunderbolt_device'])
        host = Device(uid, sysname, syspath)
        host.is_host = True
        domain.host = host
        host.domain = domain
        return host

    def device_add(self, parent, uid, name, vendor, authorized, key, boot, gen):
        domain = parent.domain
        serial = domain.gen_serial()
        sysname = "%d-%d" % (domain.index, serial)

        props = ['device_name', name,
                 'device', '0x23',
                 'vendor_name', vendor,
                 'vendor', '0x23',
                 'authorized', '%d\n' % authorized,
                 'unique_id', uid]

        if key is not None:
            # The kernel always returns the key with trailing `\n`
            if not key.endswith('\n'):
                key += '\n'
                props += ['key', key]

        if boot is not None:
            props += ['boot', str(boot)]

        if gen is not None:
            props += ['generation', str(gen)]

        syspath = self.testbed.add_device('thunderbolt',
                                          sysname, parent.syspath,
                                          props,
                                          ['DEVTYPE', 'thunderbolt_device'])

        device = Device(uid, sysname, syspath)
        device.domain = domain
        parent.children.append(device)

        return device

    def remove(self, dev):
        self.testbed.uevent(dev.syspath, "remove")
        self.testbed.remove_device(dev.syspath)
        dev.syspath = None
        dev.sysname = None


class Store:
    def __init__(self, path=None):
        self.path = path
        self.mocked = False
        if self.path is None:
            self.make_mock_store()

    def __del__(self):
        if self.mocked:
            shutil.rmtree(self.path)

    def make_mock_store(self):
        path = tempfile.mkdtemp()
        os.makedirs(os.path.join(path, 'devices'))
        os.makedirs(os.path.join(path, 'domains'))
        os.makedirs(os.path.join(path, 'keys'))
        self.path = path
        self.mocked = True


class Daemon:
    def __init__(self, store, sysfs):
        self._discover_binary()
        self.rundir = tempfile.mkdtemp()
        self.store = store
        self.sysfs = sysfs
        self.log = None
        self._boltd = None

    def __del__(self):
        if self.is_running:
            self.stop()
        shutil.rmtree(self.rundir)

    @staticmethod
    def path_from_service_file(sf):
        with open(sf) as f:
            for line in f:
                if not line.startswith('Exec='):
                    continue
                return line.split('=', 1)[1].strip()
        return None

    def _discover_binary(self):
        if 'BOLT_BUILD_DIR' in os.environ:
            print('Using boltd from local build')
            build_dir = os.environ['BOLT_BUILD_DIR']
            boltd = os.path.join(build_dir, 'boltd')
            boltctl = os.path.join(build_dir, 'boltctl')
        elif 'UNDER_JHBUILD' in os.environ:
            print('Using boltd from JHBuild')
            jhbuild_prefix = os.environ['JHBUILD_PREFIX']
            boltd = os.path.join(jhbuild_prefix, 'libexec', 'boltd')
            boltctl = os.path.join(jhbuild_prefix, 'bin', 'boltctl')
        elif os.path.exists(os.path.abspath('build/boltd')):
            build_dir = os.path.abspath('build')
            print('Using boltd from %s' % build_dir)
            boltd = os.path.join(build_dir, 'boltd')
            boltctl = os.path.join(build_dir, 'boltctl')
        else:
            print('Using boltd from system installation')
            boltd = Daemon.path_from_service_file(SERVICE_FILE)
            boltctl = shutil.which('boltctl')

        assert boltd is not None, 'failed to find daemon'
        assert os.access(boltctl, os.X_OK), "could not execute @ " + boltctl

        self.paths = {'daemon': boltd, 'boltctl': boltctl}
        self.dbus = Gio.bus_get_sync(Gio.BusType.SYSTEM, None)

    # dbus helper methods
    def _get_dbus_property(self, name, interface=DBUS_IFACE_MANAGER):
        proxy = Gio.DBusProxy.new_sync(self.dbus,
                                       Gio.DBusProxyFlags.DO_NOT_AUTO_START,
                                       None,
                                       DBUS_NAME,
                                       DBUS_PATH,
                                       'org.freedesktop.DBus.Properties',
                                       None)
        return proxy.Get('(ss)', interface, name)

    def start(self, args):
        timeout = 10
        env = os.environ.copy()
        env['G_DEBUG'] = 'fatal-criticals'
        env['UMOCKDEV_DIR'] = self.sysfs.root
        env['STATE_DIRECTORY'] = self.store.path
        env['RUNTIME_DIRECTORY'] = self.rundir
        argv = [self.paths['daemon'], '--replace']
        if args.verbose:
            argv += ['-v']
        self._boltd = subprocess.Popen(argv,
                                       env=env,
                                       stdout=self.log,
                                       stderr=subprocess.STDOUT)

        timeout_count = timeout * 10
        timeout_sleep = 0.1
        while timeout_count > 0:
            time.sleep(timeout_sleep)
            timeout_count -= 1
            try:
                self._get_dbus_property('Version')
                break
            except GLib.GError:
                pass
        else:
            timeout_time = timeout * 10 * timeout_sleep
            print('daemon did not start in %d seconds' % timeout_time)
            self._boltd = None
            return

        if self._boltd.poll() is not None:
            print('daemon crashed :(')
            self._boltd = None

    def stop(self):
        if self._boltd:
            try:
                self._boltd.terminate()
            except OSError:
                pass
            self._boltd.wait()
        self._boltd = None

    @property
    def is_running(self):
        return self._boltd is not None


class Command:
    class SubCommand:
        def __init__(self, name, desc, parser, handler):
            self.name = name
            self.desc = desc
            self.parser = parser
            self.handler = handler

        def __call__(self, obj, argv, parser):
            if argv.show_help:
                print(self.parser.format_help())
            else:
                self.handler(obj, argv, self.parser)

    def __init__(self, handler):
        self.handler = handler
        self.name = handler.__name__
        self.desc = handler.__doc__ or self.name
        self.parser = argparse.ArgumentParser(prog=self.name, description=self.desc, add_help=False)
        self.parser.add_argument('--help', dest='show_help', action='store_true')
        self.subparsers = None
        self.use_parse_known = False

    def __call__(self, obj, argv):
        if self.use_parse_known:
            args, rest = self.parser.parse_known_args(argv)
        else:
            args = self.parser.parse_args(argv)
            rest = None
        context = {'parser': self.parser, 'unknown': rest}
        if args.show_help:
            print(self.parser.format_help())
        elif hasattr(args, 'func'):
            args.func(obj, args, context)
        else:
            self.handler(obj, args, context)

    def subcommand(self, handler):
        if self.subparsers is None:
            self.subparsers = self.parser.add_subparsers()
        name = handler.__name__
        if not name.startswith(self.name + '_'):
            raise ValueError('Invalid naming scheme')
        name = name[len(self.name + '_'):]
        desc = handler.__doc__ or self.name
        p = self.subparsers.add_parser(name, help=desc, add_help=False)
        p.add_argument('--help', dest='show_help', action='store_true')
        cmd = Command.SubCommand(name, desc, p, handler)
        p.set_defaults(func=cmd)
        return cmd

    @staticmethod
    def subcommand_dispatch(obj, args):
        args.func(obj, args)


def command(func):
    cmd = Command(func)
    return cmd


def arg(name_or_flags, *args, **kwargs):
    def decorator(func):
        func.parser.add_argument(name_or_flags, *args, **kwargs)
        return func

    return decorator


def parse_known(func):
    func.use_parse_known = True
    return func


def collect_args(klass):
    methods = [getattr(klass, m) for m in dir(klass) if not m.startswith('__')]
    commands = [m for m in methods if isinstance(m, Command)]
    klass.commands = {m.name: m for m in commands}
    return klass


@collect_args
class World:
    def __init__(self, store, sysfs, daemon):
        self.store = store
        self.sysfs = sysfs
        self.daemon = daemon

    def handle_line(self, line):
        if not line:
            return True
        if line == 'exit':
            return False
        if line.startswith('#'):
            return True

        argv = shlex.split(line)
        idx = [i for i, arg in enumerate(argv) if arg.startswith('#')]
        if len(idx):
            argv = argv[:idx[0]]

        cmd = self.commands.get(argv[0], None)
        if cmd is None:
            print('unknown command')
            return True
        try:
            cmd(self, argv[1:])
        except SystemExit:
            return True
        return True

    def loop(self):
        do_loop = True
        while do_loop:
            try:
                line = input('> ')
            except EOFError:
                print('Bye.')
                break

            do_loop = self.handle_line(line)

    @command
    def help(self, args, context):
        for name, cmd in self.commands.items():
            print('%s - %s ' % (name, cmd.desc))

    @arg('--verbose', action='store_true')
    @command
    def start(self, args, context):
        '''start the bolt daemon'''
        if self.daemon.is_running:
            print('boltd already running')
        else:
            print('starting boltd')
            self.daemon.start(args)

    @command
    def stop(self, args, context):
        '''stops the bolt daemon'''
        if not self.daemon.is_running:
            print('boltd not running')
        else:
            print('stopping boltd')
            self.daemon.stop()

    @command
    def status(self, args, context):
        '''Show overall status'''
        print('sysfs: %s' % self.sysfs.root)
        print('store: %s' % self.store.path)
        print('boltd: %s' % self.daemon.paths['daemon'])
        print('rundir: %s' % self.daemon.rundir)
        print('boltd %s' % ('running' if self.daemon.is_running else 'stopped'))

    @arg('file', type=argparse.FileType('r'), default=None)
    @command
    def load(self, args, context):
        '''Load and execute a commands from a file'''
        for line in args.file:
            line = line.strip()
            print('@ %s' % line)
            self.handle_line(line)

    @command
    def controller(self, args, context):
        '''control controllers. hah!'''
        parser = context['parser']
        print(parser.format_help())

    @arg('--security', type=str, default='secure')
    @arg('--uuid', type=str, default=None)
    @arg('--iommu', type=str, choices=[None, '0', '1'], default=None)
    @arg('--bootacl', type=str, default=None)
    @arg('--generation', type=int, default=3)
    @arg('--vendor', type=str, default='GNOME.org')
    @arg('--name', type=str, default='Laptop')
    @controller.subcommand
    def controller_new(self, args, context):
        '''create a new domain+host combination'''
        security = args.security
        bootacl = args.bootacl
        iommu = args.iommu
        gen = args.generation
        uid = args.uuid or str(uuid.uuid4())
        print(uid)

        domain = self.sysfs.domain_add(security, bootacl, iommu)
        domain.show()

        name = args.name
        vendor = args.vendor

        host = self.sysfs.host_add(domain, uid, name, vendor, gen)
        host.show()

    @arg('id', type=str, help='name or uuid')
    @controller.subcommand
    def controller_rm(self, args, parser):
        '''removes a new domain+host combination'''
        target = args.id
        domain = self.sysfs.find_domain(target)
        if domain is None:
            print('Could not find controller')
            return
        if domain.host is not None:
            self.sysfs.remove(domain.host)
        self.sysfs.remove(domain)

    @command
    def device(self, args, parser):
        '''control devices'''
        print(parser.format_help())

    @arg('--key', type=str, default=None)
    @arg('--boot', type=int, choices=[None, 0, 1], default=None)
    @arg('--authorized', type=int, choices=[0, 1], default=0)
    @arg('--uuid', type=str, default=None)
    @arg('--vendor', type=str, default='GNOME.org')
    @arg('--generation', type=int, default=3)
    @arg('name', type=str)
    @arg('parent', type=str)
    @device.subcommand
    def device_new(self, args, context):
        '''create a new device combination'''
        parent = self.sysfs.find_device(args.parent)
        if parent is None:
            print('unknown parent')
            return

        uid = args.uuid or str(uuid.uuid4())

        device = self.sysfs.device_add(parent,
                                       uid,
                                       args.name,
                                       args.vendor,
                                       args.authorized,
                                       args.key,
                                       args.boot,
                                       args.generation)

        device.show()

    @arg('id', type=str, help='name or uuid')
    @device.subcommand
    def device_rm(self, args, context):
        '''removes a device'''
        device = self.sysfs.find_device(args.id)
        if device is None:
            print('unknown device')
            return
        print('disconnecting %s' % device.uid)
        self.sysfs.remove(device)

    @parse_known
    @command
    def boltctl(self, args, context):
        '''Invoke boltctl'''
        boltctl = self.daemon.paths['boltctl']
        bc_args = context['unknown'] or []
        argv = [boltctl] + bc_args
        subprocess.call(argv)


def main():
    readline.set_history_length(1000)
    histfile = os.path.join(os.path.expanduser("~"), ".cache", "bolt_mock")

    try:
        readline.read_history_file(histfile)
    except FileNotFoundError:
        pass

    atexit.register(readline.write_history_file, histfile)

    store = Store()
    sysfs = MockSysfs()
    daemon = Daemon(store, sysfs)

    ctrl = World(store, sysfs, daemon)
    ctrl.loop()


if __name__ == '__main__':
    if 'umockdev' not in os.environ.get('LD_PRELOAD', ''):
        wrapped = ['umockdev-wrapper'] + sys.argv
        os.execvp(wrapped[0], wrapped)

    main()
