#!/usr/bin/env python3

""" Create a JSON allowlist/policy from given files as input """

import argparse
import binascii
import json
import sys

from cryptography import x509
from cryptography.hazmat import backends

IMA_MEASUREMENT_LIST = "/sys/kernel/security/ima/ascii_runtime_measurements"
IGNORED_KEYRINGS = []


def is_x509_cert(bindata):
    """Determine whether the given bindata are a x509 cert"""
    try:
        x509.load_der_x509_certificate(bindata, backend=backends.default_backend())
        return True
    except ValueError:
        return False


def process_flat_allowlist(allowlist_file, hashes_map):
    """Process a flat allowlist file"""
    ret = 0
    try:
        with open(allowlist_file, "r") as fobj:
            while True:
                line = fobj.readline()
                if not line:
                    break
                line = line.strip()
                if len(line) == 0:
                    continue
                pieces = line.split(None, 1)
                if not len(pieces) == 2:
                    print(f"Skipping line that was split into {len(pieces)} parts, " "expected 2: {line}")
                (checksum_hash, path) = pieces
                if path in hashes_map:
                    hashes_map[path].append(checksum_hash)
                else:
                    hashes_map[path] = [checksum_hash]
    except (PermissionError, FileNotFoundError) as ex:
        print(f"An error occurred while accessing the allowlist: {ex}")
        ret = 1
    return hashes_map, ret


def get_hashes_from_measurement_list(ima_measurement_list_file, hashes_map):
    """Get the hashes from the IMA measurement list file"""
    ret = 0
    try:
        with open(ima_measurement_list_file, "r") as fobj:
            while True:
                line = fobj.readline()
                if not line:
                    break
                pieces = line.split(" ")
                if len(pieces) < 5:
                    print(f"Skipping line that was split into {len(pieces)} pieces, " "expected at least 5: {line}")
                    continue
                if pieces[2] not in ["ima-sig", "ima-ng"]:
                    continue
                # FIXME: filenames with spaces may be problematic
                checksum_hash = pieces[3].split(":")[1]
                path = pieces[4]
                if path in hashes_map:
                    hashes_map[path].append(checksum_hash)
                else:
                    hashes_map[path] = [checksum_hash]
    except (PermissionError, FileNotFoundError) as ex:
        print(f"An error occurred: {ex}")
        ret = 1
    return hashes_map, ret


def process_ima_buf_in_measurement_list(
    ima_measurement_list_file, ignored_keyrings, get_keyrings, keyrings_map, get_ima_buf, ima_buf_map
):
    """Process ima-buf entries and get the keyrings map from key-related entries
    and ima_buf map from the rest
    """
    ret = 0
    try:
        with open(ima_measurement_list_file, "r") as fobj:
            while True:
                line = fobj.readline()
                if not line:
                    break
                # FIXME: filenames with spaces may be problematic
                pieces = line.split(" ")
                if len(pieces) != 6:
                    print(f"Skipping line that was split into {len(pieces)} pieces, " "expected 6: {line}")
                    continue
                if pieces[2] not in ["ima-buf"]:
                    continue
                checksum_hash = pieces[3].split(":")[1]
                path = pieces[4]

                bindata = None
                try:
                    bindata = binascii.unhexlify(pieces[5].strip())
                except binascii.Error:
                    pass

                # check whether buf's bindata contains a key; if so, we will only
                # append it to 'keyrings', never to 'ima-buf'
                if bindata and is_x509_cert(bindata):
                    if path in ignored_keyrings or not get_keyrings:
                        continue

                    if path in keyrings_map:
                        keyrings_map[path].append(checksum_hash)
                    else:
                        keyrings_map[path] = [checksum_hash]
                    continue

                if get_ima_buf:
                    if path in ima_buf_map:
                        ima_buf_map[path].append(checksum_hash)
                    else:
                        ima_buf_map[path] = [checksum_hash]
    except (PermissionError, FileNotFoundError) as ex:
        print(f"An error occurred: {ex}")
        ret = 1
    return keyrings_map, ima_buf_map, ret


def main(argv):
    """main"""
    parser = argparse.ArgumentParser(argv[0])
    parser.add_argument(
        "-B",
        "--base-policy",
        action="store",
        dest="base_policy",
        help='Use given JSON policy as a "base" and merge new data into it',
    )
    parser.add_argument(
        "-k", "--keyrings", action="store_true", dest="get_keyrings", help="Create keyrings policy entries"
    )
    parser.add_argument(
        "-b",
        "--ima-buf",
        action="store_true",
        dest="get_ima_buf",
        help="Process ima-buf entries other than those related to keyrings",
    )
    parser.add_argument("-a", "--allowlist", action="store", dest="allowlist", help="Use given plain-text allowlist")
    parser.add_argument(
        "-m",
        "--ima-measurement-list",
        action="store",
        dest="ima_measurement_list",
        default=IMA_MEASUREMENT_LIST,
        help="Use given IMA measurement list for keyrings and critical "
        "data extraction rather than %s" % IMA_MEASUREMENT_LIST,
    )
    parser.add_argument(
        "-i",
        "--ignored-keyrings",
        action="append",
        dest="ignored_keyrings",
        default=IGNORED_KEYRINGS,
        help="Ignored the given keyring; this option may be passed multiple times",
    )
    parser.add_argument(
        "-o",
        "--output",
        action="store",
        dest="output",
        help="File to write JSON policy into; default is to print to stdout",
    )
    parser.add_argument(
        "--no-hashes", action="store_true", dest="no_hashes", help="Do not add any hashes to the policy"
    )
    args = parser.parse_args(argv[1:])

    policy = {
        "meta": {"version": 4},
        "release": 0,
        "hashes": {},
        "keyrings": {},
        "ima-buf": {},
        "ima": {
            "ignored_keyrings": args.ignored_keyrings,
        },
    }

    ret = 0

    if args.base_policy:
        try:
            with open(args.base_policy, "r") as fobj:
                basepol = fobj.read()
            base_policy = json.loads(basepol)

            # Cherry-pick from base policy what is supported and merge into policy
            policy["hashes"] = base_policy.get("hashes", {})
            policy["keyrings"] = base_policy.get("keyrings", {})
            policy["ima-buf"] = base_policy.get("ima-buf", {})
            ignored_keyrings = base_policy.get("ima", {}).get("ignored_keyrings", [])
            policy["ima"]["ignored_keyrings"] = ignored_keyrings
        except (PermissionError, FileNotFoundError) as ex:
            print(f"An error occurred while loading the policy: {ex}")
            ret = 1
        except json.decoder.JSONDecodeError as ex:
            print(f"An error occurred while converting the policy to a JSON object: {ex}")
            ret = 1
    if ret:
        sys.exit(ret)

    # Add the hashes map either from the allowlist, if given, or the IMA measurement list.
    if args.allowlist:
        policy["hashes"], ret = process_flat_allowlist(args.allowlist, policy["hashes"])
    elif not args.no_hashes:
        policy["hashes"], ret = get_hashes_from_measurement_list(args.ima_measurement_list, policy["hashes"])
    if ret:
        sys.exit(ret)

    if args.get_keyrings or args.get_ima_buf:
        policy["keyrings"], policy["ima-buf"], ret = process_ima_buf_in_measurement_list(
            args.ima_measurement_list,
            policy["ima"]["ignored_keyrings"],
            args.get_keyrings,
            policy["keyrings"],
            args.get_ima_buf,
            policy["ima-buf"],
        )

    if ret:
        sys.exit(ret)

    jsonpolicy = json.dumps(policy)
    if args.output:
        try:
            with open(args.output, "w") as fobj:
                fobj.write(jsonpolicy)
        except (PermissionError, FileNotFoundError) as ex:
            print(f"An error occurred while writing the policy: %{ex}")
            ret = 1
    else:
        print(jsonpolicy)

    sys.exit(ret)


if __name__ == "__main__":
    main(sys.argv)
