#!/usr/bin/env python3

import argparse
import atexit
import binascii
import http.server
import os
import platform
import re
import select
import shutil
import signal
import socket
import socketserver
import subprocess
import sys
import tempfile
import time
import threading

killed_qemu = False
qemu_pidfile = None


def got_sigterm(a1, a2):
    global killed_qemu
    if not killed_qemu and qemu_pidfile:
        try:
            with open(qemu_pidfile, "r") as file:
                pid = int(file.read())
                os.kill(pid, signal.SIGTERM)
                killed_qemu = True
                # At this point, return and handle qemu exiting
                return
        except Exception:
            pass
    exit_error("Terminated")


class ShutdownHandler:
    def __init__(self, server: socketserver.TCPServer) -> None:
        self.server = server

    def handle_kill(self, signum, frame):
        self.server.shutdown()

    def __enter__(self):
        signal.signal(signal.SIGINT, self.handle_kill)
        signal.signal(signal.SIGTERM, self.handle_kill)

    def __exit__(self, type, value, traceback):
        pass


is_verbose = False


def print_verbose(s):
    if is_verbose:
        print(s)


def print_error(s):
    print(s, file=sys.stderr)


def exit_error(s):
    print_error("Error: " + s)
    sys.exit(1)


def bool_arg(val):
    return "on" if val else "off"


def find_qemu(arch):
    binary_names = [f"qemu-system-{arch}"]
    if arch == platform.machine():
        binary_names.append("qemu-kvm")

    for binary_name in binary_names:
        if "QEMU_BUILD_DIR" in os.environ:
            p = os.path.join(os.environ["QEMU_BUILD_DIR"], binary_name)
            if os.path.isfile(p):
                return p
            else:
                exit_error(f"Can't find {binary_name}")

        qemu_bin_dirs = ["/usr/bin", "/usr/libexec"]
        if "PATH" in os.environ:
            qemu_bin_dirs += os.environ["PATH"].split(":")

        for d in qemu_bin_dirs:
            p = os.path.join(d, binary_name)
            if os.path.isfile(p):
                return p

    exit_error(f"Can't find {binary_name}")


def download_aboot_u_boot(arch, dest_path):
    with tempfile.TemporaryDirectory(prefix="runvm-uboot") as tmpdir:
        autosigrepo = "https://mirror.stream.centos.org/SIGs/9-stream/automotive"
        subprocess.run(["dnf", "download", "-q",
                        "--downloaddir", tmpdir,
                        "--disablerepo", "*",
                        "--repofrompath", f"autosig,{autosigrepo}/{arch}/packages-main/",
                        "autosig-u-boot"], check=True)
        rpm_file = os.listdir(tmpdir)[0]
        p1 = subprocess.Popen(["rpm2cpio", os.path.join(tmpdir, rpm_file)], stdout=subprocess.PIPE)
        subprocess.run(["cpio", "-id", "--quiet", "-D", tmpdir], stdin=p1.stdout, check=True)
        subprocess.run(["cp", os.path.join(tmpdir, "boot/u-boot.bin"), dest_path], check=True)


def qemu_available_accels(qemu):
    cmd = qemu + ' -accel help'
    info = subprocess.check_output(cmd.split(" ")).decode('utf-8')
    accel_list = []
    for accel in ('kvm', 'xen', 'hvf', 'hax', 'tcg'):
        if info.find(accel) > 0:
            accel_list.append(accel)
    return accel_list


def random_id():
    return binascii.b2a_hex(os.urandom(8)).decode('utf8')


def machine_id():
    try:
        with open("/etc/machine-id", "r") as f:
            mid = f.read().strip()
    except FileNotFoundError:
        if sys.platform == "darwin":
            # for macOS
            import plistlib
            cmd = "ioreg -rd1 -c IOPlatformExpertDevice -a"
            plist_data = subprocess.check_output(cmd.split(" "))
            mid = plistlib.loads(plist_data)[0]["IOPlatformUUID"].replace("-", "")
        else:
            # fallback for the other distros
            hostname = socket.gethostname()
            mid = ''.join(hex(ord(x))[2:] for x in (hostname * 16)[:16])

    return mid


def generate_mac_address():
    # create a new mac address based on our machine id
    data = machine_id()

    maclst = ["FE"] + [data[x:x + 2] for x in range(-12, -2, 2)]
    return ":".join(maclst)


def run_http_server(path):
    writer, reader = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
    child_pid = os.fork()
    if child_pid == 0:
        reader.close()

        # Child
        os.chdir(path)

        class HTTPHandler(http.server.SimpleHTTPRequestHandler):
            def log_message(self, format, *args):
                pass  # Silence logs

        httpd = socketserver.TCPServer(("127.0.0.1", 0), HTTPHandler)
        writer.send(str(httpd.server_address[1]).encode("utf8"))
        writer.close()

        with ShutdownHandler(httpd):
            server_thread = threading.Thread(target=httpd.serve_forever, daemon=True)
            server_thread.start()
            server_thread.join()

        sys.exit(0)

    # Parent
    writer.close()
    atexit.register(os.kill, child_pid, signal.SIGTERM)

    http_port = int(reader.recv(128).decode("utf8"))
    reader.close()

    return http_port


def run_virtiofs_server(socket, sharedir):
    vio_args = [
        "/usr/libexec/virtiofsd",
        "--socket-path=" + socket,
        "-o", "source=" + sharedir,
        "-o", "cache=always"
    ]
    if not is_verbose:
        vio_args += [
            "--log-level", "off"
        ]
    print_verbose(f"Running: {' '.join(vio_args)}")
    return subprocess.Popen(vio_args)


def find_ovmf(args):
    dirs = [
        "~/.local/share/ovmf",
        "/usr/share/OVMF",
        "/usr/share/edk2/ovmf",
    ]
    if args.ovmf_dir:
        dirs.insert(0, args.ovmf_dir)

    for d in dirs:
        path = os.path.expanduser(d)
        if args.secureboot:
            suffix = ".secboot"
        else:
            suffix = ""
        if (os.path.exists(f"{path}/OVMF_CODE{suffix}.fd")
                and os.path.exists(f"{path}/OVMF_VARS{suffix}.fd")):
            return path

    raise RuntimeError("Could not find OMVF")


qemu_dirs = [
    "/usr/local/share/qemu",
    "/opt/homebrew/share/qemu",
    "/usr/share/edk2/aarch64",
    "/usr/share/qemu"
]


# location can differ depending on how qemu is installed
def find_edk2():
    for path in qemu_dirs:
        if os.path.exists(path):
            return path

    raise RuntimeError("Could not find edk2 directory")


def find_edk2_code_fd():
    files = [
        "QEMU_EFI.fd",
        "edk2-aarch64-code.fd"
    ]

    for d in qemu_dirs:
        for f in files:
            dir_and_file = os.path.join(d, f)
            if os.path.exists(dir_and_file):
                return dir_and_file

    raise RuntimeError("Could not find edk2 code fd file")


def qemu_run_command(qmp_socket_path, command):
    sock2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    sock2.connect(qmp_socket_path)
    _ = sock2.recv(1024)
    sock2.send('{"execute":"qmp_capabilities"}\n'.encode("utf8"))
    _ = sock2.recv(1024)
    sock2.send(f'{command}\n'.encode("utf8"))
    _ = sock2.recv(1024)
    sock2.close()


def virtio_serial_connect(virtio_socket_path):
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    while True:
        time.sleep(0.1)
        try:
            sock.connect(virtio_socket_path)
            return sock
        except FileNotFoundError:
            pass


def available_tcp_port(port_range_from=1024):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    port = port_range_from
    port_range_to = port_range_from + 32  # limit for retry
    while port < port_range_to:
        try:
            s.bind(('', port))
        except OSError:
            port += 1
            continue
        break
    s.close()
    return port


class WatchdogCommand:
    START = 1
    STOP = 2

    def __init__(self, op, arg=None):
        self.op = op
        self.arg = arg


def parse_watchdog_commands(sock):
    commands = []
    data = sock.recv(16).decode("utf8")
    for line in data.splitlines():
        if line.startswith("START"):
            try:
                arg = int(line[5:])
            except ValueError:
                arg = 30  # Default if not specified
            commands.append(WatchdogCommand(WatchdogCommand.START, arg))
        elif line.startswith("STOP"):
            commands.append(WatchdogCommand(WatchdogCommand.STOP))
        else:
            print_verbose(f"Unsupported watchdog command {line}")
    return commands


def run_watchdog(watch_socket_path, qmp_socket_path):
    sock = virtio_serial_connect(watch_socket_path)

    p = select.poll()
    p.register(sock, select.POLLIN)

    watchdog_timeout = None
    watchdog_delay = 30

    while True:
        timeout = None
        if watchdog_timeout is not None:
            timeout = max(watchdog_timeout - time.time(), 0) * 1000

        poll_res = p.poll(timeout)

        if len(poll_res) > 0:
            v = poll_res[0][1]
            if v & select.POLLHUP:
                sys.exit(0)
            commands = parse_watchdog_commands(sock)
            for cmd in commands:
                if cmd.op == WatchdogCommand.START:
                    print_verbose(f"Starting watchdog for {cmd.arg} sec")
                    watchdog_timeout = time.time() + cmd.arg
                if cmd.op == WatchdogCommand.STOP:
                    print_verbose("Stopped watchdog")
                    watchdog_timeout = None

        if watchdog_timeout is not None and time.time() >= watchdog_timeout:
            print_verbose("Triggering watchdog")
            qemu_run_command(qmp_socket_path, '{"execute": "system_reset"}')

            # Queue a new timeout in case the next boot fails, until disabled
            watchdog_timeout = time.time() + watchdog_delay


def main():
    parser = argparse.ArgumentParser(description="Boot virtual machine images")
    parser.add_argument("--verbose", default=False, action="store_true")
    parser.add_argument("--arch", default=platform.machine(), action="store",
                        help=f"Arch to run for (default {platform.machine()})")
    parser.add_argument("--publish-dir", action="store",
                        help="Publish the specified directory over http in the vm")
    parser.add_argument("--memory", default="2G",
                        help="Memory size (default 2G)")
    parser.add_argument("--vga", default=False, action="store_true",
                        help="Use vga graphics instead of the default virtio-gpu")
    parser.add_argument("--nographics", default=False, action="store_true",
                        help="Run without graphics")
    parser.add_argument("--nosmp", default=False, action="store_true",
                        help="Use a single core")
    parser.add_argument("--aboot", default=False, action="store_true",
                        help="Boot with aboot")
    parser.add_argument("--watchdog", default=False, action="store_true",
                        help="Enable watchdog")
    parser.add_argument("--tpm2", default=False, action="store_true",
                        help="Enable TPM2")
    parser.add_argument("--nvme", default=False, action="store_true",
                        help="Use nvme instead of virtio")
    parser.add_argument("--snapshot", default=False, action="store_true",
                        help="Work on a snapshot  of the image")
    parser.add_argument("--ovmf-dir", action="store",
                        help="Specify directory for OVMF files (Open Virtual Machine Firmware)")
    parser.add_argument("--secureboot-vars", dest="secureboot", action="store",
                        help="Use the specified vars file for secureboot vars, will be initiated if non-existing")
    parser.add_argument("--secureboot-writeable", dest="secureboot_writeable", action="store_true", default=False,
                        help="Make the --secureboot-vars file persistently writeable (for enrollment).")
    parser.add_argument("--ssh-port", type=int, default=2222,
                        help="SSH port forwarding to SSH_PORT (default 2222)")
    parser.add_argument("--port-forward", type=str, metavar="host:guest,...",
                        help="Add port forwarding rules by host:guest format with comma separation.\ne.g - \"8443:443,8143:143\" will forward the port 8443 of host machine to the post 443 of guest OS and the port 8143 of host to the port 143 of guest too")
    parser.add_argument("--cdrom", action="store",
                        help="Specify .iso to load")
    parser.add_argument("--noaccel", default=False, action="store_true",
                        help="Disable acceleration (kvm or hvf)")
    parser.add_argument("--sharedir", action="store",
                        help="Share directory using virtiofs")
    parser.add_argument("--ip", default=None,
                        help="IP address of that the VM is configured for, defaults to: 10.0.2.15")
    parser.add_argument("--network", default="10.0.2.0/24",
                        help="Network IP range to use, defaults to: 10.0.2.0/24")
    parser.add_argument("--serial-socket",
                        help="Unix socket to redirect serial output to")
    parser.add_argument("--virtio-console",
                        help="Unix socket to create virtual console at")
    parser.add_argument("image", type=str, help="The image to boot")
    parser.add_argument('extra_args', nargs=argparse.REMAINDER, metavar="...", help="extra qemu arguments")

    args = parser.parse_args(sys.argv[1:])

    global is_verbose
    is_verbose = args.verbose

    # arm64 is an alias for aarch64 on macOS
    if args.arch == "arm64":
        args.arch = "aarch64"

    if args.aboot:
        if args.arch != "aarch64":
            exit_error("--aboot only supported with --arch=aarch64")
        aboot_bios = f"qemu-u-boot-{args.arch}.bin"
        if not os.path.exists(aboot_bios):
            print(f"Missing file '{aboot_bios}'. Downloading")
            download_aboot_u_boot(args.arch, aboot_bios)

    qemu = find_qemu(args.arch)
    accel_list = qemu_available_accels(qemu)
    qemu_args = [qemu]

    num_cpus = os.cpu_count()

    if args.arch == "x86_64":
        if args.secureboot:
            machine = "q35,smm=on"
        else:
            machine = "q35"
        default_cpu = "qemu64,+ssse3,+sse4.1,+sse4.2,+popcnt"

        ovmf = find_ovmf(args)
        if args.secureboot:
            vars_file = args.secureboot
            if not os.path.exists(vars_file):
                subprocess.run(["cp", os.path.join(ovmf, "OVMF_VARS.fd"), vars_file], check=True)
            perms_args = "" if args.secureboot_writeable else ",snapshot=on,readonly=off"
            qemu_args += [
                "-drive", f"file={ovmf}/OVMF_CODE.secboot.fd,if=pflash,format=raw,unit=0,readonly=on",
                "-drive", f"file={vars_file},if=pflash,format=raw,unit=1{perms_args}",
            ]
        else:
            qemu_args += [
                "-drive", f"file={ovmf}/OVMF_CODE.fd,if=pflash,format=raw,unit=0,readonly=on",
                "-drive", f"file={ovmf}/OVMF_VARS.fd,if=pflash,format=raw,unit=1,snapshot=on,readonly=off",
            ]
    elif args.arch == "aarch64":
        machine = "virt"
        default_cpu = "cortex-a57"
        num_cpus = min(os.cpu_count(), 8)  # for up to 8 cores (limitation of qemu-system-aarch64)
        if sys.platform == "darwin":
            qemu_args += [
                "-device", "virtio-gpu-pci",  # for display
                "-display", "default,show-cursor=on",  # for display
                "-device", "qemu-xhci",  # for keyboard
                "-device", "usb-kbd",  # for keyboard
                "-device", "usb-tablet",  # for mouse
            ]
        if args.aboot:
            args.memory = "4G"  # this is hardcoded in our dtb files
            qemu_args += [
                "-bios", aboot_bios
            ]
        elif sys.platform == "darwin":
            edk2 = find_edk2()
            qemu_args += [
                "-drive", f"file={edk2}/edk2-aarch64-code.fd,if=pflash,format=raw,unit=0,readonly=on",
                "-drive", f"file={edk2}/edk2-arm-vars.fd,if=pflash,format=raw,unit=1,snapshot=on,readonly=off"
            ]
        else:
            edk2_file = find_edk2_code_fd()
            qemu_args += [
                "-bios", f"{edk2_file}",
                "-boot", "efi"
            ]
    else:
        exit_error(f"unsupported architecture {args.arch}")

    if not args.nosmp and num_cpus > 1:
        qemu_args += [
            "-smp", str(num_cpus)
        ]

    accel_enabled = True

    # There are some cases that acceleration may not work,
    # kvm accelerated aboot is one, kernel crash
    if args.noaccel:
        accel_enabled = False
    elif 'kvm' in accel_list and os.path.exists("/dev/kvm"):
        qemu_args += ['-enable-kvm']
    elif 'hvf' in accel_list:
        qemu_args += ['-accel', 'hvf']
    else:
        accel_enabled = False

    if not accel_enabled:
        print_verbose("Acceleration: off")

    qemu_args += [
        "-m", str(args.memory),
        "-machine", machine,
        "-cpu", "host" if accel_enabled else default_cpu
    ]

    guestfwds = ""

    if args.publish_dir:
        if shutil.which("netcat") is None:
            exit_error("Command `netcat` not found in path and --publish-dir was specified")
        else:
            httpd_port = run_http_server(args.publish_dir)
            guestfwds = f"guestfwd=tcp:10.0.2.100:80-cmd:netcat 127.0.0.1 {httpd_port},"
            print_verbose(f"publishing {args.publish_dir} on http://10.0.2.100/")

    portfwd = {
        available_tcp_port(args.ssh_port): 22
    }

    if args.port_forward:
        for rule in args.port_forward.split(','):
            match = re.search('([0-9]+):([0-9]+)', rule)
            if match:
                host, guest = match.groups()
                portfwd[available_tcp_port(int(host))] = int(guest)
            else:
                exit_error(f'Invalid port-forward rule "{rule}"')

    for local, remote in portfwd.items():
        print_verbose(f"port: {local} → {remote}")

    ip = args.ip or ""
    fwds = [f"hostfwd=tcp::{h}-{ip}:{g}" for h, g in portfwd.items()]

    macstr = generate_mac_address()
    print_verbose(f"MAC: {macstr}")

    net_dev = "virtio-net-pci"
    if args.aboot:
        net_dev = "virtio-net-device"
    qemu_args += [
        "-device", f"{net_dev},netdev=n0,mac={macstr}",
        "-netdev", f"user,id=n0,net={args.network}," + guestfwds + ",".join(fwds),
    ]

    if args.nographics:
        qemu_args += ["-nographic"]
    elif args.vga:
        if args.arch != "x86_64":
            exit_error(f"-vga is not supported on {args.arch}")
        # On x86, vga is already the default
    else:  # Default is virtio-gpu
        # vga is on by default on x86, disable
        if args.arch == "x86_64":
            qemu_args += ["-vga", "none"]
        qemu_args += [
            "-device", "virtio-gpu-pci",
            "-device", "virtio-keyboard-pci",
            "-device", "virtio-mouse-pci",
        ]

    runvm_id = random_id()

    tmpdir = tempfile.TemporaryDirectory(prefix=f"runvm-{runvm_id}")

    watchdog_pid = 0
    if args.watchdog:
        qmp_socket_path = os.path.join(tmpdir.name, "qmp-socket")
        watch_socket_path = os.path.join(tmpdir.name, "watch-socket")

        qemu_args += [
            "-qmp", f"unix:{qmp_socket_path},server=on,wait=off",
            "-device", "virtio-serial", "-chardev", f"socket,path={watch_socket_path},server=on,wait=off,id=watchdog",
            "-device", "virtserialport,chardev=watchdog,name=watchdog.0"
        ]

        watchdog_pid = os.fork()
        if watchdog_pid == 0:
            run_watchdog(watch_socket_path, qmp_socket_path)
            sys.exit(0)

    if args.tpm2:
        if shutil.which("swtpm") is None:
            exit_error("Command `swtpm` not found in path, this is needed for tpm2 support")

        tpm2_socket = os.path.join(tmpdir.name, "tpm-socket")

        if args.snapshot:
            tpm2_path = os.path.join(tmpdir.name, "tpm2_state")
        else:
            tpm2_path = ".tpm2_state"
        os.makedirs(tpm2_path, exist_ok=True)

        swtpm_args = ["swtpm", "socket", "--tpm2", "--tpmstate", f"dir={tpm2_path}", "--ctrl", f"type=unixio,path={tpm2_socket}"]
        res = subprocess.Popen(swtpm_args)

        qemu_args += [
            "-chardev", f"socket,id=chrtpm,path={tpm2_socket}",
            "-tpmdev", "emulator,id=tpm0,chardev=chrtpm",
            "-device", "tpm-tis,tpmdev=tpm0"
        ]

    print_verbose(f"Image: {args.image}")

    # Chose disk device
    disk_if = "none"
    if args.nvme:
        qemu_args += ["-device", "nvme,serial=deadbeef,drive=rootdisk"]
    elif args.aboot:
        # virtio-blk-pci doesn't work with aboot, so use virtio-blk-device
        qemu_args += ["-device", "virtio-blk-device,drive=rootdisk"]
    else:
        # Default "if=virtio", typically a pci based virtio device
        disk_if = "virtio"

    disk_format = "qcow2"
    if args.image.endswith(".raw") or args.image.endswith(".img"):
        disk_format = "raw"

    qemu_args += [
        "-drive", f"file={args.image},index=0,media=disk,format={disk_format},if={disk_if},id=rootdisk,snapshot={bool_arg(args.snapshot)}",
    ]

    if args.cdrom:
        qemu_args += [
            "-cdrom", args.cdrom,
            "-boot", "d"
        ]

    virtiod = None
    if args.sharedir:
        if not os.path.isdir(args.sharedir):
            exit_error(f"Shared dir {args.sharedir} is not a valid directory")

        vhostsocket = os.path.join(tmpdir.name, "vhost")

        virtiod = run_virtiofs_server(vhostsocket, args.sharedir)

        qemu_args += [
            "-chardev", "socket,id=char0,path=" + vhostsocket,
            "-device", "vhost-user-fs-pci,queue-size=1024,chardev=char0,tag=host",
            "-object", "memory-backend-file,id=mem,size=" + str(args.memory) + ",mem-path=/dev/shm,share=on",
            "-numa", "node,memdev=mem"
        ]
        print(f"Sharing directory {args.sharedir}, mount using 'mount -t virtiofs host /mnt'")

    if args.serial_socket:
        qemu_args += ["-serial", f"unix:{args.serial_socket},server=on"]

    if args.virtio_console:
        qemu_args += [
            "-device", "virtio-serial-pci",
            "-chardev", f"socket,id=con0,path={args.virtio_console},server=on,wait=off",
            "-device", "virtconsole,chardev=con0,id=vc0",
        ]

    qemu_args += args.extra_args

    # Handle SIGTERM and try to kill qemu
    global qemu_pidfile
    qemu_pidfile = os.path.join(tmpdir.name, "qemu.pid")
    qemu_args += [
        "-pidfile", qemu_pidfile
    ]
    signal.signal(signal.SIGTERM, got_sigterm)

    print_verbose(f"Running: {' '.join(qemu_args)}")
    try:
        res = subprocess.run(qemu_args, check=False)
    except KeyboardInterrupt:
        exit_error("Aborted")

    if watchdog_pid:
        os.kill(watchdog_pid, signal.SIGTERM)

    if virtiod:
        virtiod.terminate()

    tmpdir.cleanup()

    return res.returncode


if __name__ == "__main__":
    sys.exit(main())
