#!/usr/bin/python3
#
# bolt integration test suite
#
# Copyright © 2017 Red Hat, Inc
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library. If not, see <http://www.gnu.org/licenses/>.
# Authors:
#       Christian J. Kellner <christian@kellner.me>

import binascii
import os
import shutil
import sys
import subprocess
import unittest
import uuid
import tempfile
import time

from collections import namedtuple
from functools import reduce
from itertools import chain

try:
    import gi
    from gi.repository import GLib
    from gi.repository import Gio
    gi.require_version('UMockdev', '1.0')
    from gi.repository import UMockdev

    import dbus
    import dbusmock
except ImportError as e:
    sys.stderr.write('Skipping integration test due to missing depdendencies: %s\n' % str(e))
    sys.exit(1)


try:
    from subprocess import DEVNULL
except ImportError:
    DEVNULL = open(os.devnull, 'wb')

DBUS_NAME = 'org.freedesktop.bolt'
DBUS_PATH = '/org/freedesktop/bolt'
DBUS_IFACE_PREFIX = 'org.freedesktop.bolt1.'
DBUS_IFACE_MANAGER = DBUS_IFACE_PREFIX + 'Manager'
DBUS_IFACE_DEVICE = DBUS_IFACE_PREFIX + 'Device'
SERVICE_FILE = '/usr/share/dbus-1/system-services/org.freedesktop.bolt.service'


def get_timeout(topic='default'):
    vals = {
        'valgrind': {
            'default': 20,
            'daemon_start': 60
        },
        'default': {
            'default': 3,
            'daemon_start': 5
        }
    }

    valgrind = os.getenv('VALGRIND')
    lut = vals['valgrind' if valgrind is not None else 'default']
    if topic not in lut:
        raise ValueError('invalid topic')
    return lut[topic]


class Signal(object):
    def __init__(self, name):
        self.name = name
        self.callbacks = set()
        self.notify = None
        self._bridge = None

    def connect(self, callback):
        self.callbacks.add(callback)
        if self.notify is not None:
            self.notify(self, 'connect', len(self.callbacks))
        if len(self.callbacks) == 1:
            self._bridge_build()

    def disconnect(self, callback):
        self.callbacks.remove(callback)
        if self.notify is not None:
            self.notify(self, 'disconnect', len(self.callbacks))
        if len(self.callbacks) == 0:
            self._bridge_destory()

    def disconnect_all(self):
        self.callbacks = set()
        if self.notify is not None:
            self.notify(self, 'disconnect', 0)
        self._bridge_destory()

    def emit(self, *args, **kwargs):
        res = [cb(*args, **kwargs) for cb in self.callbacks]
        return any(res)

    def bridge(self, obj, name, callback):
        if self._bridge is not None:
            raise ValueError('already bridged')
        self._bridge = {'object': obj,
                        'name': name}
        if callback is not None:
            self._bridge['filter'] = callback

    def birdge_destroy(self):
        self._bridge = None

    def _bridge_build(self):
        if self._bridge is None:
            return
        b = self._bridge
        signal_id = b['object'].connect(b['name'], self._bridge_signal)
        b['signal_id'] = signal_id

    def _bridge_destory(self):
        if self._bridge is None:
            return
        b = self._bridge
        b['object'].disconnect(b['signal_id'])
        del b['signal_id']

    def _bridge_signal(self, *args, **kwargs):
        if 'filter' in self._bridge:
            res, args, kwargs = self._bridge['filter'](args, kwargs)
            if not res:
                return
        return self.emit(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        return self.emit(*args, **kwargs)

    def __iadd__(self, callback):
        self.connect(callback)
        return self

    def __isub__(self, callback):
        self.disconnect(callback)
        return self

    @staticmethod
    def enable(klass):
        lst = getattr(klass, 'signals', [])
        methods = [m for m in dir(klass) if not m.startswith('__')]

        def install(l):
            if l is None:
                return
            if l in methods:
                print('WARNING: signal "%s" will overwrite method' % l, file=sys.stderr)

            def get_signals(self):
                signals = getattr(self, '__signals', None)
                if signals is None:
                    signals = {}
                    setattr(self, '__signals', signals)
                return signals

            def get_signal(self):
                signals = get_signals(self)
                if l not in signals:
                    signals[l] = Signal(l)
                return signals[l]

            def getter(self):
                return get_signal(self)

            def setter(self, value):
                return get_signal(self)

            p = property(getter, setter)
            setattr(klass, l, p)
            return l

        bases = klass.__bases__
        ps = {s for b in bases for s in getattr(b, 'signals', [])}
        klass.signals = list(ps.union({install(l) for l in lst}))
        return klass


@Signal.enable
class Recorder(object):
    Event = namedtuple('Event', ['what', 'name', 'details', 'time'])

    signals = ['event']

    def __init__(self, target):
        self.recording = True
        self.events = []
        self.target = target
        self.target.g_properties_changed += self._on_props_changed
        self.target.g_signal += self._on_signal

    def close(self):
        if not self.recording:
            return
        self.target.g_properties_changed -= self._on_props_changed
        self.target.g_signal -= self._on_signal
        self.recording = False
        return self.events

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def _on_props_changed(self, props):
        now = time.time()
        for k, v in props.items():
            e = self.Event('property', k, v, now)
            self._add_event(e)

    def _on_signal(self, proxy, sender, signal, params):
        now = time.time()
        e = self.Event('signal', signal, None, now)
        self._add_event(e)

    def _add_event(self, event):
        self.events.append(event)
        self.event.emit(event)

    @staticmethod
    def event_match(e, target):
        if e.what != target.what or e.name != target.name:
            return False
        if target.details is None:
            return True
        return e.details == target.details

    @staticmethod
    def events_list_has_event(events, target):
        return list(filter(lambda x: Recorder.event_match(x, target), events))

    @staticmethod
    def events_list_contains(events, what, name, details=None):
        target = Recorder.Event(what, name, details, None)
        return list(filter(lambda x: Recorder.event_match(x, target), events))

    def events_filter(self, e):
        return self.events_list_has_event(self.events, e)

    def have_event(self, e):
        return len(self.events_filter(e)) > 0

    def wait_for_events(self, lst, timeout=None):
        loop = GLib.MainLoop()

        def got_event(event):
            for idx, e in enumerate(lst):
                if self.event_match(e, event):
                    del lst[idx]
                    break
            if len(lst) == 0:
                loop.quit()

        def got_timeout():
            print('WARNING: timeout reached! Waiting list: ',
                  str(lst), file=sys.stderr)
            loop.quit()

        self.event += got_event
        timeout = timeout or get_timeout()
        GLib.timeout_add(timeout*1000, got_timeout)
        loop.run()
        self.event -= got_event
        return len(lst) == 0

    def wait_for_event(self, what, name, details=None):
        t = self.Event(what, name, details, None)
        if self.have_event(t):
            return True
        return self.wait_for_events([t])


@Signal.enable
class ProxyWrapper(object):
    signals = ['g_properties_changed', 'g_signal']

    def __init__(self, bus, iname, path):
        self._proxy = Gio.DBusProxy.new_sync(bus,
                                             Gio.DBusProxyFlags.DO_NOT_AUTO_START,
                                             None,
                                             DBUS_NAME,
                                             path,
                                             iname,
                                             None)

        def props_changed(args, kwargs):
            return True, [args[1].unpack()], {}

        self.g_properties_changed.bridge(self._proxy,
                                         'g-properties-changed',
                                         props_changed)

        self.g_signal.bridge(self._proxy, 'g-signal', None)

    def __getattr__(self, name):
        if name.startswith('_'):
            raise AttributeError
        if '_' in name:
            c = name.split('_')
            name = "".join(x.title() for x in c)
        else:
            name = name[0].upper() + name[1:]
        if name in self._proxy.get_cached_property_names():
            value = self._proxy.get_cached_property(name)
            if value is not None:
                return value.unpack()
            else:
                return value

        return getattr(self._proxy, name)

    def record(self):
        return Recorder(self)

    @property
    def object_path(self):
        return self._proxy.get_object_path()


class BoltDevice(ProxyWrapper):
    UNKNOWN = -1
    DISCONNECTED = 0
    CONNECTED = 1
    CONNECTING = 2
    AUTHORIZING = 3
    AUTH_ERROR = 4
    AUTHORIZED = 5

    KEY_MISSING = 0
    KEY_HAVE = 1
    KEY_NEW = 2

    NOPCIE = 1 << 0
    SECURE = 1 << 1
    NOKEY  = 1 << 2
    BOOT   = 1 << 3

    def __init__(self, bus, path):
        super(BoltDevice, self).__init__(bus,
                                         DBUS_IFACE_DEVICE,
                                         path)

    @property
    def is_connected(self):
        return self.status > self.DISCONNECTED

    @property
    def is_authorized(self):
        return self.status >= self.AUTHORIZED

    def authorize(self, flags=""):
        self.Authorize('(s)', flags)
        return True

    @property
    def status(self):
        res = getattr(self, 'Status')
        mapping = {'unknown': self.UNKNOWN,
                   'disconnected': self.DISCONNECTED,
                   'connecting': self.CONNECTING,
                   'connected': self.CONNECTED,
                   'authorizing': self.AUTHORIZING,
                   'auth-error': self.AUTH_ERROR,
                   'authorized': self.AUTHORIZED}

        return mapping.get(res, self.UNKNOWN)

    @property
    def authflags(self):
        res = getattr(self, 'AuthFlags')
        mapping = {'none': 0,
                   'nopcie': self.NOPCIE,
                   'secure': self.SECURE,
                   'nokey': self.NOKEY,
                   'boot': self.BOOT}

        print(res, file=sys.stderr)
        keys = [x.strip() for x in res.split('|')]
        return reduce(lambda r, x: r | mapping.get(x, 0), keys, 0)

    @property
    def key(self):
        res = getattr(self, 'Key')
        mapping = {'missing': self.KEY_MISSING,
                   'have': self.KEY_HAVE,
                   'new': self.KEY_NEW}

        return mapping.get(res, self.MISSING)

    @property
    def label(self):
        res = getattr(self, 'Label')
        if res is not None and len(res) < 1:
            return None
        return res

    @label.setter
    def label(self, value):
        if isinstance(value, str):
            value = GLib.Variant("s", value)

        res = self._proxy.call_sync('org.freedesktop.DBus.Properties.Set',
                                    GLib.Variant("(ssv)",
                                                 (DBUS_IFACE_DEVICE,
                                                  "Label",
                                                  value)),
                                    0,
                                    -1,
                                    None)
        return res is not None


@Signal.enable
class BoltClient(ProxyWrapper):
    signals = ['device_added', 'device_removed']

    POLICY_DEFAULT = 'default'
    POLICY_MANUAL = 'manual'
    POLICY_AUTO = 'auto'

    def __init__(self, bus):
        super(BoltClient, self).__init__(bus,
                                         DBUS_IFACE_MANAGER,
                                         DBUS_PATH)
        self._proxy.connect('g-signal', self._on_dbus_signal)

    def _on_dbus_signal(self, proxy, sender, signal, params):
        bus = self._proxy.get_connection()
        if signal == 'DeviceAdded':
            self.device_added.emit(BoltDevice(bus, params[0]))
            return True
        elif signal == 'DeviceRemoved':
            self.device_removed.emit(params[0])
            return True
        return False

    def list_devices(self):
        devices = self.ListDevices()
        if devices is None:
            return None
        bus = self._proxy.get_connection()
        return [BoltDevice(bus, d) for d in devices]

    def device_by_uid(self, uid):
        object_path = self.DeviceByUid("(s)", uid)
        if object_path is None:
            return None
        bus = self._proxy.get_connection()
        return BoltDevice(bus, object_path)

    def enroll(self, uid, policy=POLICY_DEFAULT, flags=""):
        object_path = self.EnrollDevice("(sss)", uid, policy, flags)
        if object_path is None:
            return None
        bus = self._proxy.get_connection()
        return BoltDevice(bus, object_path)

    def forget(self, uid):
        self.ForgetDevice("(s)", uid)
        return True


# Mock Device Tree
@Signal.enable
class Device(object):
    subsystem = "unknown"
    udev_attrs = []
    udev_props = []

    signals = ['device_connected',
               'device_disconnected']

    def __init__(self, name, children):
        self._parent = None
        self.children = [self._adopt(c) for c in children]
        self.udev = None
        self.name = name
        self.syspath = None

    def _adopt(self, device):
        device.parent = self
        return device

    def _get_own(self, items):
        i = chain.from_iterable([a, str(getattr(self, a.lower()))] for a in items)
        return list(i)

    def collect(self, predicate):
        children = self.children
        head = [self] if predicate(self) else []
        tail = chain.from_iterable(c.collect(predicate) for c in children)
        return head + list(filter(predicate, tail))

    def first(self, predicate):
        if predicate(self):
            return self
        for c in self.children:
            found = c.first(predicate)
            if found:
                return found

    @property
    def parent(self):
        return self._parent

    @parent.setter
    def parent(self, value):
        self._parent = value

    @property
    def root(self):
        return self if self.parent is None else self.parent.root

    def connect_tree(self, bed):
        self.connect(bed)
        for c in self.children:
            c.connect_tree(bed)

    def connect(self, bed):
        print('connecting ' + self.name, file=sys.stderr)
        assert self.syspath is None
        attributes = self._get_own(self.udev_attrs)
        properties = self._get_own(self.udev_props)
        sysparent = self.parent and self.parent.syspath
        self.syspath = bed.add_device(self.subsystem,
                                      self.name,
                                      sysparent,
                                      attributes,
                                      properties)
        self.root.device_connected(self)
        self.testbed = bed

    def disconnect(self, bed):
        print('disconnecting ' + self.name, file=sys.stderr)
        for c in self.children:
            c.disconnect(bed)
        bed.uevent(self.syspath, "remove")
        bed.remove_device(self.syspath)
        self.authorized = 0
        self.key = ""
        self.root.device_disconnected(self)
        self.syspath = None
        self.testbed = None


class TbDevice(Device):
    subsystem = "thunderbolt"
    devtype = "thunderbolt_device"

    udev_attrs = ['authorized',
                  'device',
                  'device_name',
                  'key',
                  'unique_id',
                  'vendor',
                  'vendor_name']

    udev_props = ['DEVTYPE']

    def __init__(self, name, authorized=0, vendor=None, uid=None, children=None):
        super(TbDevice, self).__init__(name, children or [])
        self.unique_id = uid or str(uuid.uuid4())
        self.device_name = 'Thunderbolt ' + name
        self.device = self._make_id(self.device_name)
        self.vendor_name = vendor or 'GNOME.org'
        self.vendor = self._make_id(self.vendor_name)
        self.authorized = authorized
        self.key = ""

    def _make_id(self, name):
        return '0x%X' % binascii.crc32(name.encode('utf-8'))

    @property
    def authorized_file(self):
        if self.syspath is None:
            return None
        return os.path.join(self.syspath, 'authorized')

    @property
    def bolt_status(self):
        if self.syspath is None:
            return BoltDevice.DISCONNECTED
        elif self.authorized == 0:
            return BoltDevice.CONNECTED
        elif self.authorized in [1, 2]:
            return BoltDevice.AUTHORIZED

    @property
    def bolt_authflags(self):
        flags = 0
        if self.syspath is None:
            return 0

        if self.authorized == 2:
            flags |= BoltDevice.SECURE

        if self.domain.security == TbDomain.SECURITY_SECURE:
            if self.key is None:
                flags |= BoltDevice.NOKEY
        elif self.domain.security in [TbDomain.SECURITY_DPONLY,
                                      TbDomain.SECURITY_USBONLY]:
            flags |= BoltDevice.NOPCIE

        return flags

    @property
    def domain(self):
        return self.parent.domain

    @staticmethod
    def is_unauthorized(d):
        return isinstance(d, TbDevice) and d.authorized == 0

    def reload_auth(self):
        authorized = self.authorized
        key = self.key
        f = self.authorized_file
        with open(self.authorized_file, 'r') as f:
            data = f.read()
            self.authorized = int(data)
        with open(os.path.join(self.syspath, 'key'), 'r') as f:
            self.key = f.read().strip()
        if self.authorized != authorized or self.key != key:
            if self.syspath:
                self.testbed.uevent(self.syspath, 'change')

    def authorize(self, level):
        with open(self.authorized_file, 'w') as f:
            f.write(level)
        self.reload_auth()


class TbHost(TbDevice):
    def __init__(self, children):
        super(TbHost, self).__init__('Laptop',
                                     authorized=1,
                                     uid='3b7d4bad-4fdf-44ff-8730-ffffdeadbabe',
                                     children=children)

    def connect(self, bed):
        self.authorized = 1
        super(TbHost, self).connect(bed)


class TbDomain(Device):
    subsystem = "thunderbolt"
    devtype = "thunderbolt_domain"

    udev_attrs = ['security']
    udev_props = ['DEVTYPE']

    SECURITY_NONE = 'none'
    SECURITY_USER = 'user'
    SECURITY_SECURE = 'secure'
    SECURITY_DPONLY = 'dponly'
    SECURITY_USBONLY = 'usbonly'

    def __init__(self, security=SECURITY_SECURE, index=0, host=None):
        assert host
        assert isinstance(host, TbHost)
        name = 'domain%d' % index
        super(TbDomain, self).__init__(name, children=[host])
        self.security = security

    @property
    def devices(self):
        return self.collect(lambda c: isinstance(c, TbDevice))

    @property
    def domain(self):
        return self

    @staticmethod
    def checkattr(d, k, v):
        return hasattr(d, k) and getattr(d, k) == v

    def find(self, **kwargs):
        def finder(d):
            return all([self.checkattr(d, k, v) for k, v in kwargs.items()])

        return self.first(finder)


class TreeChecker(object):

    def __init__(self, client, tree):
        self.client = client
        self.tree = tree
        self.remote_devices = {}
        devices = client.list_devices()
        [self._register_device(d) for d in devices]
        client.device_added += self._on_device_added
        client.device_removed += self._on_device_removed
        tree.device_connected += self._on_udev_connected
        tree.device_disconnected += self._on_udev_disconnected
        self.actions = {}
        self.loop = None

    # signal handler for local (udev) devices
    def _on_udev_connected(self, dev):
        if not isinstance(dev, TbDevice):
            return

        uid = dev.unique_id
        self.actions[uid] = 'connected'

    def _on_udev_disconnected(self, dev):
        if not isinstance(dev, TbDevice):
            return
        uid = dev.unique_id
        if uid in self.actions:
            del self.actions[uid]
        else:
            self.actions[uid] = 'disconnected'

    # signal handler for remote devices
    def _on_device_added(self, dev):
        self._register_device(dev)
        self._check_action(dev.uid, 'connected')

    def _on_device_removed(self, object_path):
        dev = self.remote_devices.get(object_path, None)
        if dev is None:
            return
        self._check_action(dev.uid, 'disconnected')
        self._unregister_device(dev)

    # book keeping of remote devices
    def _register_device(self, dev):
        self.remote_devices[dev.object_path] = dev
        return dev

    def _unregister_device(self, dev):
        del self.remote_devices[dev.object_path]

    def _check_action(self, uid, action):
        if self.actions.get(uid, None) != action:
            return False
        del self.actions[uid]
        self._check_sync()
        return True

    def _check_sync(self):
        keep_going = len(self.actions) > 0
        if not keep_going:
            self._stop_loop()
        return keep_going

    def _stop_loop(self):
        if self.loop is None:
            return
        self.loop.quit()
        self.loop = None

    def _on_timeout(self):
        print('WARNING, timeout reached', file=sys.stderr)
        self._stop_loop()

    def sync(self, timeout=None):
        timeout = timeout or get_timeout()
        GLib.timeout_add(timeout*1000, self._on_timeout)
        self.loop = GLib.MainLoop()
        self.loop.run()

    def close(self):
        self.client.device_added.disconnect_all()
        self.client.device_removed.disconnect_all()
        self.tree.device_connected.disconnect_all()
        self.tree.device_disconnected.disconnect_all()


# Test Suite
class BoltTest(dbusmock.DBusTestCase):
    @staticmethod
    def path_from_service_file(sf):
        with open(SERVICE_FILE) as f:
                for line in f:
                    if not line.startswith('Exec='):
                        continue
                    return line.split('=', 1)[1].strip()
        return None

    @classmethod
    def setUpClass(cls):
        path = None
        if 'BOLT_BUILD_DIR' in os.environ:
            print('Testing local build')
            build_dir = os.environ['BOLT_BUILD_DIR']
            path = os.path.join(build_dir, 'boltd')
        elif 'UNDER_JHBUILD' in os.environ:
            print('Testing JHBuild version')
            jhbuild_prefix = os.environ['JHBUILD_PREFIX']
            path = os.path.join(jhbuild_prefix, 'libexec', 'boltd')
        else:
            print('Testing installed system binaries')
            path = BoltTest.path_from_service_file(SERVICE_FILE)

        assert path is not None, 'failed to find daemon'
        cls.paths = {'daemon': path}

        cls.test_bus = Gio.TestDBus.new(Gio.TestDBusFlags.NONE)
        cls.test_bus.up()
        try:
            del os.environ['DBUS_SESSION_BUS_ADDRESS']
        except KeyError:
            pass
        os.environ['DBUS_SYSTEM_BUS_ADDRESS'] = cls.test_bus.get_bus_address()
        cls.dbus = Gio.bus_get_sync(Gio.BusType.SYSTEM, None)

    @classmethod
    def tearDownClass(cls):
        cls.test_bus.down()
        dbusmock.DBusTestCase.tearDownClass()

    def setUp(self):
        self.testbed = UMockdev.Testbed.new()
        self.assertTrue(UMockdev.in_mock_environment())
        self.dbpath = tempfile.mkdtemp()
        os.makedirs(os.path.join(self.dbpath, 'devices'))
        os.makedirs(os.path.join(self.dbpath, 'keys'))

        self.client = None
        self.log = None
        self.daemon = None
        self.polkitd = None
        self.valgrind = False

    def tearDown(self):
        shutil.rmtree(self.dbpath)
        del self.testbed
        self.daemon_stop()
        self.polkitd_stop()

    # dbus helper methods
    def get_dbus_property(self, name, interface=DBUS_IFACE_MANAGER):
        proxy = Gio.DBusProxy.new_sync(self.dbus,
                                       Gio.DBusProxyFlags.DO_NOT_AUTO_START,
                                       None,
                                       DBUS_NAME,
                                       DBUS_PATH,
                                       'org.freedesktop.DBus.Properties',
                                       None)
        return proxy.Get('(ss)', interface, name)

    # daemon helper
    def daemon_start(self):
        timeout = get_timeout('daemon_start')  # seconds
        env = os.environ.copy()
        env['G_DEBUG'] = 'fatal-criticals'
        env['UMOCKDEV_DIR'] = self.testbed.get_root_dir()
        env['BOLT_DBPATH'] = self.dbpath
        argv = [self.paths['daemon'], '-v']
        valgrind = os.getenv('VALGRIND')
        if valgrind is not None:
            argv.insert(0, 'valgrind')
            argv.insert(1, '--leak-check=full')
            if os.path.exists(valgrind):
                argv.insert(2, '--suppressions=%s' % valgrind)
            self.valgrind = True
        self.daemon = subprocess.Popen(argv,
                                       env=env,
                                       stdout=self.log,
                                       stderr=subprocess.STDOUT)

        timeout_count = timeout * 10
        timeout_sleep = 0.1
        while timeout_count > 0:
            time.sleep(timeout_sleep)
            timeout_count -= 1
            try:
                self.get_dbus_property('Version')
                break
            except GLib.GError:
                pass
        else:
            timeout_time = timeout_count * timeout_sleep
            self.fail('daemon did not start in %d seconds' % timeout_time)

        self.client = BoltClient(self.dbus)
        self.assertEqual(self.daemon.poll(), None, 'daemon crashed')

    def daemon_stop(self):

        if self.daemon:
            try:
                self.daemon.terminate()
            except OSError:
                pass
            self.daemon.wait()

        self.daemon = None
        self.client = None

    def polkitd_start(self):
        self._polkitd, self._polkitd_obj = self.spawn_server_template(
            'polkitd', {}, stdout=DEVNULL)
        self.polkitd = dbus.Interface(self._polkitd_obj, dbusmock.MOCK_IFACE)

    def polkitd_stop(self):
        if self.polkitd is None:
            return
        self._polkitd.terminate()
        self._polkitd.wait()
        self.polkitd = None

    def user_config(self, **kwargs):
        import configparser
        cfg = configparser.ConfigParser()
        cfg.optionxform = lambda option: option

        cfg['config'] = {}
        for k, v in kwargs.items():
            cfg['config'][k] = v

        path = os.path.join(self.dbpath, 'boltd.conf')
        with open(path, 'w') as f:
            cfg.write(f)

        with open(path, 'r') as f:
            print(f.read())

    # mock tree stuff
    def default_mock_tree(self):
        # default mock tree
        mt = TbDomain(host=TbHost([
            TbDevice('Cable1'),
            TbDevice('Cable2'),
            TbDevice('SSD1')
        ]))
        return mt

    def simple_mock_tree(self):
        mt = TbDomain(host=TbHost([
            TbDevice('Dock')
        ]))
        return mt

    def assertDeviceEqual(self, local, remote):
        self.assertTrue(local and remote)
        self.assertEqual(local.unique_id, remote.uid)
        self.assertEqual(local.device_name, remote.Name)
        self.assertEqual(local.vendor_name, remote.Vendor)

        # if we are "connected"
        if local.syspath is not None:
            self.assertEqual(local.syspath, remote.sysfs_path)
            self.assertTrue(remote.is_connected)

            # remote.parent is also only valid if we are connected
            if local.parent is not None and isinstance(local.parent, TbDevice):
                self.assertEqual(local.parent.unique_id, remote.parent)

        self.assertEqual(local.bolt_status, remote.status)
        self.assertEqual(local.bolt_authflags, remote.authflags)
        return True

    def add_domain_host(self, domain=0, security='secure'):

        dc = self.testbed.add_device('thunderbolt', 'domain%d' % domain, None,
                                     ['security', security],
                                     ['DEVTYPE', 'thunderbolt_domain'])

        host = self.testbed.add_device('thunderbolt', "%d-0" % domain, dc,
                                       ['device_name', 'Host',
                                        'device', '0x23',
                                        'vendor_name', 'GNOME.org',
                                        'vendor', '0x23',
                                        'authorized', '1',
                                        'unique_id', str(uuid.uuid4())],
                                       ['DEVTYPE', 'thunderbolt_device'])
        return dc, host

    def add_device(self, parent, devid, name, vendor, domain=0, authorized=1, key='', boot=None):
        uid = str(uuid.uuid4())
        props = ['device_name', name,
                 'device', '0x23',
                 'vendor_name', vendor,
                 'vendor', '0x23',
                 'authorized', '%d' % authorized,
                 'unique_id', uid]

        if key is not None:
            props += ['key', key]

        if boot is not None:
            props += ['boot', boot]

        d = self.testbed.add_device('thunderbolt',
                                    "%d-%d" % (domain, devid), parent,
                                    props,
                                    ['DEVTYPE', 'thunderbolt_device'])
        return d, uid

    def find_device_by_uid(self, lst, uid):
        x = [x for x in lst if x.uid == uid]
        self.assertEqual(len(x), 1)
        return x[0]

    def store_device(self, dev, policy='auto', key=None):
        import configparser
        df = configparser.ConfigParser()
        df.optionxform = lambda option: option

        uid = dev.unique_id
        df['device'] = {
            'name': dev.device_name,
            'vendor': dev.vendor_name,
            'type': 'peripheral'
        }

        df['user'] = {
            'storetime': int(time.time()),
            'policy': policy
        }

        path = os.path.join(self.dbpath, 'devices', uid)
        with open(path, 'w') as f:
            df.write(f)

        if key == 'known':
            key = 'a26d5ad55b011df39ae06cae1fd329babfecac3465fe0a8828d6178f88e59083'
        elif key == 'device':
            key = dev.key

        if key is None:
            return

        path = os.path.join(self.dbpath, 'keys', uid)
        with open(path, 'w') as f:
            f.write(key)


    # the actual tests
    def test_basic(self):
        self.daemon_start()
        version = self.client.version
        assert version is not None
        d = self.client.list_devices()
        self.assertEqual(len(d), 0)
        policy = self.client.default_policy
        self.assertIn(policy, [self.client.POLICY_AUTO,
                               self.client.POLICY_MANUAL])
        # connect all device
        tree = self.default_mock_tree()
        chk = TreeChecker(self.client, tree)
        tree.connect_tree(self.testbed)
        chk.sync()  # check for the signals

        devices = self.client.list_devices()
        self.assertEqual(len(devices), len(tree.devices))
        for remote in devices:
            local = tree.find(unique_id=remote.uid)
            self.assertDeviceEqual(local, remote)

        # disconnect all devices again
        tree.disconnect(self.testbed)
        chk.sync()  # check for the signals
        chk.close()

        devices = self.client.list_devices()
        self.assertEqual(len(devices), 0)
        self.daemon_stop()

    def test_signals_on_start(self):
        # Check that we get DeviceAdded signals for un-authorized
        # devices that are not in the database

        client = BoltClient(self.dbus)
        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)

        with client.record() as tape:
            self.daemon_start()
            res = tape.wait_for_event('signal',
                                      'DeviceAdded',
                                      None)
            self.assertTrue(res)
        self.daemon_stop()

    def test_basic_device_name(self):
        security = 'secure'

        # prepare the basic setup
        dc, host = self.add_domain_host()

        devs = [
            # name              vendor           label                         notes
            ['GNOME.org Cable', 'GNOME.org',     'GNOME.org Cable'       ],  # duplicated vendor name
            ['GNOME.org Cable', 'GNOME.org',     'GNOME.org Cable #2'    ],  # duplicated device
            ['GNOME.org Cable', 'GNOME.org',     'GNOME.org Cable #3'    ],  # duplicated device, again
            ['⍾ Laptop',        'Evil Corp. ☢', 'Evil Corp. ☢ ⍾ Laptop'],  # utf-8 chars
        ]

        devs = [{'name': d[0], 'vendor': d[1], 'label': d[2], 'id': i+1} for i, d in enumerate(devs)]

        for i, d in enumerate(devs):
            did, name, vendor = d['id'], d['name'], d['vendor']
            path, uid = self.add_device(host, did, name, vendor);
            d['path'] = path
            d['uid'] = uid

        self.daemon_start()
        devices = self.client.list_devices()
        self.assertEqual(len(devices), len(devs) + 1)

        for d in devs:
            remote = self.find_device_by_uid(devices, d['uid'])
            self.assertEqual(remote.name, d['name'])
            self.assertEqual(remote.vendor, d['vendor'])
            self.assertEqual(remote.label, d['label'])

        self.daemon_stop()

    def test_basic_user_config(self):
        self.user_config(DefaultPolicy='manual')

        self.daemon_start()
        policy = self.client.default_policy
        self.assertEqual(policy, self.client.POLICY_MANUAL)
        self.daemon_stop()

    def test_device_by_uid(self):
        self.daemon_start()

        with self.assertRaises(GLib.GError):
            self.client.device_by_uid("")

        with self.assertRaises(GLib.GError):
            self.client.device_by_uid("nonexistant")

        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)

        for d in tree.devices:
            remote = self.client.device_by_uid(d.unique_id)
            self.assertIsNotNone(remote)
            self.assertDeviceEqual(d, remote)

        self.daemon_stop()

    def test_device_authflags(self):
        key = 'b68bce095a13ac39e9254a88b189a38f240487aa6f78f803390a0cdeceb774d8'

        dc, host = self.add_domain_host()
        d1, d1_uid = self.add_device(host, 1, "Dock", "GNOME.org", authorized=2, key=key, boot='0')
        d2, d2_uid = self.add_device(host, 2, "Dock2", "GNOME.org", authorized=1, key=None, boot='1')

        self.daemon_start()
        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        device = self.find_device_by_uid(devices, d1_uid)
        flags = BoltDevice.SECURE
        self.assertEqual(device.authflags, flags)

        device = self.find_device_by_uid(devices, d2_uid)
        flags = BoltDevice.NOKEY | BoltDevice.BOOT
        self.assertEqual(device.authflags, flags)

        self.daemon_stop()
        self.testbed.set_attribute(dc, 'security', 'dponly')
        self.testbed.remove_device(d1)
        self.testbed.remove_device(d2)

        d1, d1_uid = self.add_device(host, 1, "Dock", "GNOME.org", authorized=1, key=None)
        d2, d2_uid = self.add_device(host, 2, "Dock2", "GNOME.org", authorized=1, key=None)

        self.daemon_start()
        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        device = self.find_device_by_uid(devices, d1_uid)
        self.assertEqual(device.authflags, BoltDevice.NOPCIE)
        device = self.find_device_by_uid(devices, d2_uid)
        self.assertEqual(device.authflags, BoltDevice.NOPCIE)
        self.daemon_stop()

        self.testbed.set_attribute(dc, 'security', 'usbonly')
        self.daemon_start()
        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        device = self.find_device_by_uid(devices, d1_uid)
        self.assertEqual(device.authflags, BoltDevice.NOPCIE)
        device = self.find_device_by_uid(devices, d2_uid)
        self.assertEqual(device.authflags, BoltDevice.NOPCIE)
        self.daemon_stop()

    def test_device_authorize(self):
        self.daemon_start()
        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)

        self.polkitd_start()

        to_authorize = tree.collect(TbDevice.is_unauthorized)

        # check that we are not allowed to authorize devices
        for d in to_authorize:
            remote = self.client.device_by_uid(d.unique_id)
            with self.assertRaises(GLib.GError) as cm:
                remote.authorize()
            err = cm.exception
            self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
            self.assertEqual(err.code, int(Gio.DBusError.ACCESS_DENIED))

        self.polkitd.SetAllowed(['org.freedesktop.bolt.authorize'])
        before = int(time.time())
        for d in to_authorize:
            remote = self.client.device_by_uid(d.unique_id)
            tape = remote.record()
            remote.authorize()
            d.reload_auth()  # will emit the uevent, so the daemon can update
            res = tape.wait_for_event('property', 'Status', 'authorized')
            self.assertTrue(res)
            self.assertDeviceEqual(d, remote)
            # make sure AuthorizeTime is correct
            now = int(time.time())
            self.assertTrue(remote.AuthorizeTime > 1)
            self.assertTrue(remote.AuthorizeTime >= before)
            self.assertTrue(remote.AuthorizeTime <= now)
            tape.close()

        for d in to_authorize:
            remote = self.client.device_by_uid(d.unique_id)
            with self.assertRaises(GLib.GError) as cm:
                remote.authorize()

        self.daemon_stop()

    def test_device_auto_auth(self):
        ssd1 = TbDevice('SSD1',)
        cable1 = TbDevice('Cable1', children=[ssd1])
        ssd2 = TbDevice('SSD2')
        cable2 = TbDevice('Cable2', children=[ssd2])
        tree = TbDomain(security=TbDomain.SECURITY_SECURE,
                        host=TbHost([
                            cable1,
                            cable2,
                        ]))
        tree.connect_tree(self.testbed)

        self.store_device(cable1, key=None)
        self.store_device(ssd1, key='known')
        self.store_device(cable2, key='known')
        self.store_device(ssd2, key='known')

        self.daemon_start()

        devices = self.client.list_devices()
        self.assertEqual(len(devices), len(tree.devices))

        remote_ssd2 = self.find_device_by_uid(devices, ssd2.unique_id)
        tries = 0
        while remote_ssd2.status != BoltDevice.AUTHORIZED and tries < 10:
            time.sleep(.1)
            tries += 1

        self.assertEqual(remote_ssd2.status, BoltDevice.AUTHORIZED)

        # cable1 does *NOT* have a key but we are in SECURE mode
        # so it should not be authorized
        remote_c1 = self.find_device_by_uid(devices, cable1.unique_id)
        self.assertEqual(remote_c1.status, BoltDevice.CONNECTED)

        remote_ssd1 = self.find_device_by_uid(devices, ssd1.unique_id)
        self.assertEqual(remote_ssd1.status, BoltDevice.CONNECTED)

        # now we pretend the user has authorized the device manually,
        # to check if boltd picks up udev changes properly and then
        # auto-authorizes also SSD1

        # we start to tape recorder for ssd1 too, so we don't miss its events
        tape_ssd1 = remote_ssd1.record()
        with remote_c1.record() as tape:
            cable1.authorize('1')
            res = tape.wait_for_event('property', 'Status', 'authorized')
            self.assertTrue(res)
        self.assertEqual(remote_c1.status, BoltDevice.AUTHORIZED)

        res = tape_ssd1.wait_for_event('property', 'Status', 'authorized')
        events = tape_ssd1.close()
        self.assertTrue(Recorder.events_list_contains(events, 'property', 'Status', 'authorizing'))
        self.assertTrue(res)
        self.assertEqual(remote_ssd1.status, BoltDevice.AUTHORIZED)
        self.daemon_stop()

    def test_device_auto_import(self):
        key = 'b68bce095a13ac39e9254a88b189a38f240487aa6f78f803390a0cdeceb774d8'

        devs = [
            {'authorized': 1, 'key': None, 'boot': 0, 'stored': False},  # no boot flag
            {'authorized': 2, 'key': key,  'boot': 0, 'stored': False},  # no boot flag
            {'authorized': 1, 'key': None, 'boot': 1, 'stored': True},   # boot, user mode -> import
            # TODO: check we are not authorizing a device without a key in secure mode
        ]

        dc, host = self.add_domain_host(security='user')

        for i, d in enumerate(devs):
            did = i + 1
            path, uid = self.add_device(host,
                                        did,
                                        "Dock%d" % did,
                                        "GNOME.org",
                                        authorized=d['authorized'],
                                        key=d['key'],
                                        boot='%d' % d['boot'])
            d['path'] = path
            d['uid'] = uid

        self.daemon_start()
        self.polkitd_start()
        client = self.client

        devices = client.list_devices()
        self.assertEqual(len(devices), len(devs) + 1)

        for d in devs:
            remote = self.find_device_by_uid(devices, d['uid'])

            self.assertEqual(remote.status, BoltDevice.AUTHORIZED)
            self.assertEqual(remote.stored, d['stored'])

        self.daemon_stop()

    def test_device_enroll(self):
        self.daemon_start()
        tree = self.default_mock_tree()
        tree.connect_tree(self.testbed)
        self.polkitd_start()

        client = self.client

        to_enroll = tree.collect(TbDevice.is_unauthorized)

        # check that we are not allowed to enroll devices, i.e. the correct
        # policykit action is called.

        for d in to_enroll:
            with self.assertRaises(GLib.GError) as cm:
                client.enroll(d.unique_id)
            err = cm.exception
            self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
            self.assertEqual(err.code, int(Gio.DBusError.ACCESS_DENIED))

        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        # check we get a proper error for a unknown device
        with self.assertRaises(GLib.GError) as cm:
            # non-existent uuid
            client.forget("884c6edd-7118-4b21-b186-b02d396ecca0")

        before = int(time.time())
        policy = BoltClient.POLICY_AUTO
        for d in to_enroll:
            remote = client.enroll(d.unique_id, policy)
            d.reload_auth()  # will emit the uevent, so the daemon can update
            # the security level for the domain is SECURE, which means we should
            # have authorized via a new key:
            #  status should be AUTHORIZED
            #  stored should be True
            #  key(state) should be KEY_NEW
            #  authflags should be 'secure'
            self.assertDeviceEqual(d, remote)
            self.assertTrue(remote.stored, True)
            self.assertEqual(remote.key, BoltDevice.KEY_NEW)
            self.assertEqual(remote.policy, policy)
            # check the StoreTime is correct
            now = int(time.time())
            self.assertEqual(remote.stored, True)
            self.assertTrue(remote.StoreTime > 1)
            self.assertTrue(remote.StoreTime >= before)
            self.assertTrue(remote.StoreTime <= now)
            self.assertTrue(remote.AuthorizeTime > 1)
            self.assertTrue(remote.AuthorizeTime >= before)
            self.assertTrue(remote.AuthorizeTime <= now)

        # we disconnect the tree, but since the devices are connected
        # the daemon should have them in its database now
        tree.disconnect(self.testbed)

        devices = self.client.list_devices()

        # the host itself is not stored in the daemon
        expected_number = len(tree.devices) - 1
        devices = self.client.list_devices()
        tries = 0
        while expected_number != len(devices) and tries < 3:
            time.sleep(.2)
            tries += 1
            devices = self.client.list_devices()
            self.assertEqual(len(devices), expected_number)

        tree.connect(self.testbed)              # we connect the domain again
        tree.children[0].connect(self.testbed)  # and the host too

        for remote in devices:
            local = tree.find(unique_id=remote.uid)
            self.assertDeviceEqual(local, remote)
            self.assertTrue(remote.stored, True)
            # key status should have changed to HAVE from NEW
            self.assertEqual(remote.key, BoltDevice.KEY_HAVE)
            self.assertEqual(remote.policy, policy)

            # now we connect that specific device and wait for
            # the property changes
            with remote.record() as tape:
                local.connect(self.testbed)
                res = tape.wait_for_event('property',
                                          'Status',
                                          'authorized')
                local.reload_auth()  # will emit the uevent, so the daemon can update
                self.assertTrue(res)
                self.assertDeviceEqual(local, remote)

    def test_enroll_authorized(self):
        key = 'b68bce095a13ac39e9254a88b189a38f240487aa6f78f803390a0cdeceb774d8'

        dc, host = self.add_domain_host()
        d1, d1_uid = self.add_device(host, 1, "Dock", "GNOME.org", authorized=2, key=key, boot='0')
        d2, d2_uid = self.add_device(host, 2, "Dock2", "GNOME.org", authorized=1, key=None, boot='1')

        self.daemon_start()
        self.polkitd_start()
        client = self.client

        devices = self.client.list_devices()
        self.assertEqual(len(devices), 3)

        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        d1_remote = self.find_device_by_uid(devices, d1_uid)
        d2_remote = self.find_device_by_uid(devices, d2_uid)

        before = int(time.time())

        for remote in [d1_remote, d2_remote]:
            self.assertEqual(remote.status, BoltDevice.AUTHORIZED)
            self.assertEqual(remote.stored, False)

        policy = BoltClient.POLICY_AUTO
        for remote, uid in [(d1_remote, d1_uid), (d2_remote, d2_uid)]:

            with remote.record() as tape:
                client.enroll(uid, policy)
                res = tape.wait_for_event('property',
                                          'Stored',
                                          True)
                self.assertTrue(res)
            now = int(time.time())
            self.assertEqual(remote.stored, True)
            self.assertTrue(remote.StoreTime > 1)
            self.assertTrue(remote.StoreTime >= before)
            self.assertTrue(remote.StoreTime <= now)

        self.daemon_stop()

    def test_device_forget(self):
        self.daemon_start()
        tree = self.default_mock_tree()
        self.polkitd_start()
        tree.connect_tree(self.testbed)

        client = self.client
        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        to_enroll = tree.collect(TbDevice.is_unauthorized)
        policy = BoltClient.POLICY_MANUAL
        for d in to_enroll:
            remote = client.enroll(d.unique_id, policy)
            d.reload_auth()
            self.assertDeviceEqual(d, remote)
            self.assertTrue(remote.stored, True)
            self.assertEqual(remote.key, BoltDevice.KEY_NEW)
            self.assertEqual(remote.policy, policy)

        tree.disconnect(self.testbed)
        expected_number = len(tree.devices) - 1  # host is not stored
        devices = self.client.list_devices()
        tries = 0
        while expected_number != len(devices) and tries < 3:
            time.sleep(.2)
            tries += 1
            devices = self.client.list_devices()
        self.assertEqual(len(devices), expected_number)

        for remote in devices:
            with self.assertRaises(GLib.GError) as cm:
                client.forget(d.unique_id)
            err = cm.exception
            self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
            self.assertEqual(err.code, int(Gio.DBusError.ACCESS_DENIED))

        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage'])

        # check we get a proper error for a unknown device
        with self.assertRaises(GLib.GError) as cm:
            # non-existent uuid
            client.forget("884c6edd-7118-4b21-b186-b02d396ecca0")

        for remote in devices:
            client.forget(remote.uid)

        devices = self.client.list_devices()
        self.assertEqual(len(devices), 0)

    def test_device_label(self):
        self.daemon_start()
        tree = self.simple_mock_tree()
        self.polkitd_start()
        tree.connect_tree(self.testbed)

        client = self.client
        self.polkitd.SetAllowed(['org.freedesktop.bolt.enroll'])

        local = tree.collect(TbDevice.is_unauthorized)[0]
        policy = BoltClient.POLICY_AUTO

        remote = client.enroll(local.unique_id, policy)
        local.reload_auth()

        label = remote.label
        self.assertEqual(remote.label, "%s %s" % (local.vendor_name, local.device_name))

        self.assertDeviceEqual(local, remote)
        self.assertTrue(remote.stored, True)
        self.assertEqual(remote.key, BoltDevice.KEY_NEW)
        self.assertEqual(remote.policy, policy)

        with self.assertRaises(GLib.GError) as cm:
            remote.label = 'not authorized'
        err = cm.exception
        self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
        self.assertEqual(err.code, int(Gio.DBusError.ACCESS_DENIED))

        self.polkitd.SetAllowed(['org.freedesktop.bolt.manage'])
        for val in ['', ' ', '     ']:
            with self.assertRaises(GLib.GError) as cm:
                remote.label = val
            err = cm.exception
            self.assertEqual(err.domain, GLib.quark_to_string(Gio.DBusError.quark()))
            self.assertEqual(err.code, int(Gio.DBusError.INVALID_ARGS))

        self.assertEqual(remote.label, "%s %s" % (local.vendor_name, local.device_name))

        val = 'A valid label'

        with remote.record() as tape:
            remote.label = val
            res = tape.wait_for_event('property',
                                      'Label',
                                      val)
            self.assertTrue(res)
        self.assertEqual(remote.label, val)

        self.daemon_stop()


if __name__ == '__main__':
    if len(sys.argv) == 2 and sys.argv[1] == "list-tests":
        suit = unittest.defaultTestLoader.loadTestsFromTestCase(BoltTest)
        for t in suit:
            name = t.id()
            print(name[9:], end=" ")
        sys.exit(0)
    unittest.main(verbosity=2)
