#!/usr/bin/env python3

import os, sys, argparse, re, difflib, json

def jsonKeys2int(x):
    try:
        if isinstance(x, dict):
            return {int(k):v for k,v in x.items()}
        return x
    except ValueError:
        return x

def symtypes_parse(path, data = None):
    if not data:
        data = {
                'children': { 0 : []},
                'parents': { 0 : []},
                'strtab': ["(root)"],
                'index' : {},
                'file' : { }
        }

    bpath = os.path.basename(path)
    data["file"] = {}
    data["file"][bpath] = { 0 : "" }

    with open(path, 'r') as fp_ref:
        for line in fp_ref.readlines():
            lsplit = line.split(' ')
            root = lsplit.pop(0)
            children = list(filter(lambda x: len(x) > 2 and (x[1] == '#' or x == "UNKNOWN"), lsplit))
            if root in data["index"]:
                idx = data["index"][root]
                if idx in data["children"] and len(data["children"][idx]) > 0:
                    continue
            index = data_add(data, root, 0, bpath, line)
            for child in children:
                child_index = data_add(data, child, index, "", "")
    return data

def data_add(data, ident, parent, bpath, line):
    index = data['strtab'].index(ident) if ident in data['strtab'] else -1
    if index == -1:
        index = len(data['strtab'])
        data['strtab'].append(ident)
        data['index'][ident] = index
    data['children'][parent].append(index)
    if index not in data['children']:
        data['children'][index] = []
    if index in data['parents']:
        if parent not in data['parents'][index]:
            data['parents'][index].append(parent)
    else:
        data['parents'][index] = [parent]
    if bpath and line:
        data['file'][bpath][index] = line
    return index

#def symtypes_dfs(data_a, source, sink, trace, inverse = False):
# print(' > '.join(list(map(lambda i: data_a['strtab'][i], path + [e]))))
def symtypes_dfs(data, start, inverse=False, full=False):
    start_i = data["index"][start]
    stack = [(start_i,[start_i])]
    visited = set()
    paths = []
    while stack:
        (node, path) = stack.pop()
        if full and node in path[:-1]:
            continue
        if not full:
            if node in visited:
                continue
            visited.add(node)
        paths.append(path)
        if not inverse:
            for child in reversed(data["children"][node]):
                stack.append((child, path + [child]))
        else:
            for child in reversed(data["parents"][node]):
                if child == 0:
                    continue
                stack.append((child, path + [child]))

    return visited, paths

def st_open(path):
    if not path:
        raise ValueError("Blank blank.")
    if not os.path.exists(path):
        raise OSError(f"Path {path} does not exist.")
    with open(path, "r") as f:
        try:
            return json.load(f, object_hook=jsonKeys2int)
        except ValueError:
            pass
    return symtypes_parse(path)

def st_write(path, data):
    with open(path, "w+") as f:
        if "file" in data:
            del data["file"]
        json.dump(data, f)

def index(symtype, output):
    data = st_open(symtype)
    if output:
        st_write(output, data)
    return data

def st_print(node):
    if node[1] != '#':
        return node
    if node[0] == 's':
        return "struct " + node[2:]
    if node[0] == 't':
        return "typedef " + node[2:]
    if node[0] == 'E':
        return "enum const " + node[2:]
    if node[0] == 'e':
        return "enum " + node[2:]
    if node[0] == 'u':
        return "union " + node[2:]
    return node

def im(file, dump_list, dump_path, dump_tree, start, inverse, silent):
    data = index(file, None)

    if start not in data["index"]:
        if not silent:
            print(f"Node {start} not found in file {file}. Exitting.")
        sys.exit(1)

    nodes, paths = symtypes_dfs(data, start, inverse)

    if dump_list:
        for node in map(lambda i: data['strtab'][i], nodes):
            print(f"{st_print(node)} (symtype node: {node})")

    if not dump_path and not dump_tree:
        return

    for path in paths:
        if dump_tree:
            print((len(path)-1)*"  " + " - " + f"{st_print(data['strtab'][path[-1]])} (symtype node: {data['strtab'][path[-1]]})");
            continue
        if dump_path:
            print(list(map(lambda i: data["strtab"][i], path)))

def diff(ref, new, start):
    data_ref = index(ref, None)
    data_new = index(new, None)

    nodes_ref, _ = symtypes_dfs(data_ref, start)
    nodes_new, _ = symtypes_dfs(data_new, start)

    nodes_ref_lbl = set(map(lambda i: data_ref['strtab'][i], nodes_ref))
    nodes_new_lbl = set(map(lambda i: data_new['strtab'][i], nodes_new))
    nodes_all = nodes_ref_lbl | nodes_new_lbl
    nodes_13 = nodes_all - nodes_ref_lbl
    nodes_23 = nodes_all - nodes_new_lbl

    if nodes_23:
        print("The following nodes were encountered only in reference symtypes:")
        print("\t" + "\n\t".join(nodes_23))

    if nodes_13:
        print("The following nodes were encountered only in new symtypes:")
        print("\t" + "\n\t".join(nodes_13))

    bpath_a = os.path.basename(ref)
    bpath_b = os.path.basename(new)
    for node in nodes_all - (nodes_13 | nodes_23):
        idx_a = data_ref['index'][node]
        idx_b = data_new['index'][node]
        r = set(map(lambda i: data_ref['strtab'][i], data_ref['children'][idx_a]))
        n = set(map(lambda i: data_new['strtab'][i], data_new['children'][idx_b]))

        if r == n:
            continue

        i = ["\t"+data_ref['file'][bpath_a][idx_a]], \
            ["\t"+data_new['file'][bpath_b][idx_b]]

        if i[0] != i[1]:
            print(f"Possible breakage detected for {st_print(node)} (symtype node: {node}) ...")
            if len(n) == 1 and "UNKNOWN" in n:
                print("\treplaced by UNKNOWN. Please inspect changes to #include directives")
            if len(r) == 1 and "UNKNOWN" in r:
                print("\tUNKNOWN got replaced. Please inspect changes to #include directives")
            diff = difflib.ndiff(i[0], i[1])
            print(''.join(diff), end="")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    subparsers = parser.add_subparsers(help='Modes of operation.',
            dest="mode")
    parser_index = subparsers.add_parser('index',
            help='Calculate symtypes index.')
    parser_image = subparsers.add_parser('image',
            help='Show type/symbol dependencies.')
    parser_preimage = subparsers.add_parser('preimage',
            help='Show type/symbol preimage.')
    parser_df = subparsers.add_parser('diff',
            help='Calculate simple symtype diff.')

    parser_index.add_argument('-o', '--output', type=str, required=True,
            help='Output index file.')
    parser_index.add_argument('symtype', type=str)

    for p in [ parser_image, parser_preimage ]:
        p.add_argument('-i', '--index', action='store_true',
                help="Input is an index file.")
        p.add_argument('-S', '--silent', action='store_true')
        p.add_argument('-l', '--ls', action='store_true',
                help="List dependent nodes.")
        p.add_argument('-p', '--path', action='store_true',
                help="List paths to dependent nodes.")
        p.add_argument('-t', '--tree', action='store_true',
                help="Dump tree.")
        p.add_argument('-s', '--start', type=str, nargs='?',
                help="Start symtype entry/entries.")
        p.add_argument('symtype', type=str)

    parser_df.add_argument('reference', type=str)
    parser_df.add_argument('new', type=str)
    parser_df.add_argument('-s', '--start', type=str, nargs='?',
                help="Start symtype entry/entries.")

    args = parser.parse_args()

    if args.mode == "index":
        index(args.symtype, args.output)
    elif args.mode == "image" or args.mode == "preimage":
        im(args.symtype, args.ls, args.path, args.tree, args.start, args.mode == "preimage", args.silent)
    elif args.mode == "diff":
        diff(args.reference, args.new, args.start)
