#!/usr/bin/python3

import argparse
import contextlib
import functools
import glob
import mimetypes
import json
import os
import subprocess
import sys
import tempfile
import xml.etree.ElementTree


def run_ostree(*args, _input=None, _check=True, **kwargs):
    args = list(args) + [f'--{k}={v}' for k, v in kwargs.items()]
    print("ostree " + " ".join(args), file=sys.stderr)
    res = subprocess.run(["ostree"] + args,
                         encoding="utf-8",
                         stdout=subprocess.PIPE,
                         input=_input,
                         check=_check)
    return res


@contextlib.contextmanager
def nbd_connect(image):
    for device in glob.glob("/dev/nbd*"):
        r = subprocess.run(["qemu-nbd", "--connect", device, "--read-only", image], check=False).returncode
        if r == 0:
            try:
                yield device
            finally:
                subprocess.run(["qemu-nbd", "--disconnect", device], check=True, stdout=subprocess.DEVNULL)
            break
    else:
        raise RuntimeError("no free network block device")


@contextlib.contextmanager
def mount_at(device, mountpoint, options=[], extra=[]):
    opts = ",".join(["ro"] + options)
    subprocess.run(["mount", "-o", opts] + extra + [device, mountpoint], check=True)
    try:
        yield mountpoint
    finally:
        subprocess.run(["umount", "--lazy", mountpoint], check=True)


@contextlib.contextmanager
def mount(device):
    with tempfile.TemporaryDirectory() as mountpoint:
        subprocess.run(["mount", "-o", "ro", device, mountpoint], check=True)
        try:
            yield mountpoint
        finally:
            subprocess.run(["umount", "--lazy", mountpoint], check=True)


def parse_environment_vars(s):
    r = {}
    for line in s.split("\n"):
        line = line.strip()
        if not line:
            continue
        if line[0] == '#':
            continue
        key, value = line.split("=", 1)
        r[key] = value.strip('"')
    return r


def parse_unit_files(s, expected_state):
    r = []
    for line in s.split("\n")[1:]:
        try:
            unit, state, *_ = line.split()
        except ValueError:
            pass
        if state != expected_state:
            continue
        r.append(unit)

    # deduplicate and sort
    r = list(set(r))
    r.sort()
    return r


def subprocess_check_output(argv, parse_fn=None):
    try:
        output = subprocess.check_output(argv, encoding="utf-8")
    except subprocess.CalledProcessError as e:
        sys.stderr.write(f"--- Output from {argv}:\n")
        sys.stderr.write(e.stdout)
        sys.stderr.write("\n--- End of the output\n")
        raise

    return parse_fn(output) if parse_fn else output


def read_image_format(device):
    qemu = subprocess_check_output(["qemu-img", "info", "--output=json", device], json.loads)
    return qemu["format"]


def read_partition(device, bootable, typ=None, start=0, size=0, type=None):
   blkid = subprocess_check_output(["blkid", "--output", "export", device], parse_environment_vars)
   return {
       "label": blkid.get("LABEL"), # doesn't exist for mbr
       "type": typ,
       "uuid": blkid.get("UUID"),
       "partuuid": blkid.get("PARTUUID"),
       "fstype": blkid.get("TYPE"),
       "bootable": bootable,
       "start": start,
       "size": size
   }


def read_partition_table(device):
    partitions = []
    try:
        sfdisk = subprocess_check_output(["sfdisk", "--json", device], json.loads)
    except subprocess.CalledProcessError:
        partitions.append(read_partition(device, False))
        return None, None, partitions
    else:
        ptable = sfdisk["partitiontable"]
        assert ptable["unit"] == "sectors"
        for p in ptable["partitions"]:
            partitions.append(read_partition(p["node"], p.get("bootable", False), p["type"], p["start"] * 512, p["size"] * 512))
        return ptable["label"], ptable["id"], partitions


def read_bootloader_type(device):
    with open(device, "rb") as f:
        if b"GRUB" in f.read(512):
            return "grub"
        else:
            return "unknown"


def read_boot_entries(boot_dir):
    entries = []
    for conf in glob.glob(f"{boot_dir}/loader/entries/*.conf"):
        with open(conf) as f:
           entries.append(dict(line.strip().split(" ", 1) for line in f))

    return sorted(entries, key=lambda e: e["title"])


def rpm_verify(tree):
    # cannot use `rpm --root` here, because rpm uses passwd from the host to
    # verify user and group ownership:
    #   https://github.com/rpm-software-management/rpm/issues/882
    rpm = subprocess.Popen(["chroot", tree, "rpm", "--verify", "--all"],
            stdout=subprocess.PIPE, encoding="utf-8")

    changed = {}
    missing = []
    for line in rpm.stdout:
        # format description in rpm(8), under `--verify`
        attrs = line[:9]
        if attrs == "missing  ":
            missing.append(line[12:].rstrip())
        else:
            changed[line[13:].rstrip()] = attrs

    # ignore return value, because it returns non-zero when it found changes
    rpm.wait()

    return {
        "missing": sorted(missing),
        "changed": changed
    }


def rpm_packages(tree, is_ostree):
    cmd = ["rpm", "--root", tree, "-qa"]
    if is_ostree:
        cmd += ["--dbpath", "/usr/share/rpm"]
    pkgs = subprocess_check_output(cmd, str.split)
    return list(sorted(pkgs))


def read_services(tree, state):
    return subprocess_check_output(["systemctl", f"--root={tree}", "list-unit-files"], (lambda s: parse_unit_files(s, state)))


def read_firewall_zone(tree):
    try:
        with open(f"{tree}/etc/firewalld/firewalld.conf") as f:
            conf = parse_environment_vars(f.read())
            default = conf["DefaultZone"]
    except FileNotFoundError:
        default = "public"

    r = []
    try:
        root = xml.etree.ElementTree.parse(f"{tree}/etc/firewalld/zones/{default}.xml").getroot()
    except FileNotFoundError:
        root = xml.etree.ElementTree.parse(f"{tree}/usr/lib/firewalld/zones/{default}.xml").getroot()

    for element in root.findall("service"):
        r.append(element.get("name"))

    return r


def append_filesystem(report, tree, *, is_ostree=False):
    if os.path.exists(f"{tree}/etc/os-release"):
        report["packages"] = rpm_packages(tree, is_ostree)
        if not is_ostree:
            report["rpm-verify"] = rpm_verify(tree)

        with open(f"{tree}/etc/os-release") as f:
            report["os-release"] = parse_environment_vars(f.read())

        report["services-enabled"] = read_services(tree, "enabled")
        report["services-disabled"] = read_services(tree, "disabled")

        try:
            with open(f"{tree}/etc/hostname") as f:
                report["hostname"] = f.read().strip()
        except FileNotFoundError:
            pass

        try:
            report["timezone"] = os.path.basename(os.readlink(f"{tree}/etc/localtime"))
        except FileNotFoundError:
            pass

        try:
            report["firewall-enabled"] = read_firewall_zone(tree)
        except FileNotFoundError:
            pass

        try:
            with open(f"{tree}/etc/fstab") as f:
                report["fstab"] = sorted([line.split() for line in f.read().split("\n") if line and not line.startswith("#")])
        except FileNotFoundError:
            pass

        with open(f"{tree}/etc/passwd") as f:
            report["passwd"] = sorted(f.read().strip().split("\n"))

        with open(f"{tree}/etc/group") as f:
            report["groups"] = sorted(f.read().strip().split("\n"))

        if is_ostree:
            with open(f"{tree}/usr/lib/passwd") as f:
                report["passwd-system"] = sorted(f.read().strip().split("\n"))

            with open(f"{tree}/usr/lib/group") as f:
                report["groups-system"] = sorted(f.read().strip().split("\n"))

        if os.path.exists(f"{tree}/boot") and len(os.listdir(f"{tree}/boot")) > 0:
            assert "bootmenu" not in report
            try:
                with open(f"{tree}/boot/grub2/grubenv") as f:
                    report["boot-environment"] = parse_environment_vars(f.read())
            except FileNotFoundError:
                pass
            report["bootmenu"] = read_boot_entries(f"{tree}/boot")

    elif len(glob.glob(f"{tree}/vmlinuz-*")) > 0:
        assert "bootmenu" not in report
        with open(f"{tree}/grub2/grubenv") as f:
            report["boot-environment"] = parse_environment_vars(f.read())
        report["bootmenu"] = read_boot_entries(tree)
    elif len(glob.glob(f"{tree}/EFI")):
        print("EFI partition", file=sys.stderr)


def partition_is_esp(partition):
    return partition["type"] == "C12A7328-F81F-11D2-BA4B-00A0C93EC93B"


def find_esp(partitions):
    for i, p in enumerate(partitions):
        if partition_is_esp(p):
            return p, i
    return None, 0


def analyse_image(image):
    subprocess.run(["modprobe", "nbd"], check=True)

    report = {}
    with nbd_connect(image) as device:
        report["image-format"] = read_image_format(image)
        report["bootloader"] = read_bootloader_type(device)
        report["partition-table"], report["partition-table-id"], report["partitions"] = read_partition_table(device)

        if report["partition-table"]:
            esp, esp_id = find_esp(report["partitions"])
            n_partitions = len(report["partitions"])
            for n in range(n_partitions):
                if report["partitions"][n]["fstype"]:
                    with mount(device + f"p{n + 1}") as tree:
                        if esp and os.path.exists(f"{tree}/boot/efi"):
                            with mount_at(device + f"p{esp_id + 1}", f"{tree}/boot/efi", options=['umask=077']):
                                append_filesystem(report, tree)
                        else:
                            append_filesystem(report, tree)
        else:
            with mount(device) as tree:
                append_filesystem(report, tree)

    return report


def append_directory(report, tree):
    if os.path.lexists(f"{tree}/ostree"):
        os.makedirs(f"{tree}/etc", exist_ok=True)
        with mount_at(f"{tree}/usr/etc", f"{tree}/etc", extra=["--bind"]):
            append_filesystem(report, tree, is_ostree=True)
    else:
        append_filesystem(report, tree)


def append_ostree_repo(report, repo):
    ostree = functools.partial(run_ostree, repo=repo)

    r = ostree("config", "get", "core.mode")
    report["ostree"] = {
        "repo": {
            "core.mode": r.stdout.strip()
        }
    }

    r = ostree("refs")
    refs = r.stdout.strip().split("\n")
    report["ostree"]["refs"] = refs

    resolved = {r: ostree("rev-parse", r).stdout.strip() for r in refs}
    commit = resolved[refs[0]]

    refs = {r: {"inputhash": ostree("show", "--print-metadata-key=rpmostree.inputhash", resolved[r]).stdout.strip("'\n")} for r in refs}
    report["ostree"]["refs"] = refs

    with tempfile.TemporaryDirectory(dir="/var/tmp") as tmpdir:
        tree = os.path.join(tmpdir, "tree")
        ostree("checkout", "--force-copy", commit, tree)
        append_directory(report, tree)


def analyse_directory(path):
    report = {}

    if os.path.exists(os.path.join(path, "compose.json")):
        report["type"] = "ostree/commit"
        repo = os.path.join(path, "repo")
        append_ostree_repo(report, repo)
    elif os.path.isdir(os.path.join(path, "refs")):
        report["type"] = "ostree/repo"
        append_ostree_repo(report, repo)
    else:
        append_directory(report, path)

    return report


def is_tarball(path):
    mtype, encoding = mimetypes.guess_type(path)
    return mtype == "application/x-tar"


def analyse_tarball(path):
    with tempfile.TemporaryDirectory(dir="/var/tmp") as tmpdir:
        tree = os.path.join(tmpdir, "root")
        os.makedirs(tree)
        command = [
            "tar",
            "-x",
            "--auto-compress",
            "-f", path,
            "-C", tree
        ]
        subprocess.run(command,
                       stdout=sys.stderr,
                       check=True)
        return analyse_directory(tree)


def main():
    parser = argparse.ArgumentParser(description="Inspect an image")
    parser.add_argument("target",  help="The file or directory to analyse")

    args = parser.parse_args()
    target = args.target

    if os.path.isdir(target):
        report = analyse_directory(target)
    elif is_tarball(target):
        report = analyse_tarball(target)
    else:
        report = analyse_image(target)

    json.dump(report, sys.stdout, sort_keys=True, indent=2)


if __name__ == "__main__":
    main()

