#!/usr/bin/env python
import os, sys
from optparse import OptionParser
from collections import defaultdict
import re
import pydot

from heapq import heappush, heappop

usage = "usage: %prog [options] output_file.png"
description = """Go through the include dir, find include relationships from 
all files and render them as a directed graph."""
parser = OptionParser(usage=usage, description=description)
parser.add_option("-i", "--include", dest="include_dir", metavar="DIR",
                  help="parse includes from DIR [default: %default]", default=".")
parser.add_option("-s", "--select", dest="select", metavar="SPEC",
                  help="Select nodes by SPEC. Spec is comma separated "
                       "list of header names, followed by optional traverse "
                       "operator. Traverse operator is +N to select N levels "
                       "up, -N to select N levels down. +* and -* respectively "
                       "to select all upstream or downstream dependencies. For "
                       "example \"storage/spin.h+*,access/xlog.h+*\" selects "
                       "all headers that directly or indirectly depend on "
                       "storage.h or xlog.h.")
parser.add_option("-l", "--list", dest="list", action="store_true",
                  help="show list of found includes", default=False)
parser.add_option("-p", "--platform", dest="platform", action="store_true",
                  help="show list of platform includes", default=False)

parser.add_option("-a", "--additional", dest="additional_dir", metavar="DIR",
                  help="parse additional C files from DIR", default=None)
parser.add_option("-x", "--exclude", dest="exclude", metavar="SPEC",
                  help="exclude nodes by SPEC, see select for details", default=None)
parser.add_option("-z", "--subgraphs", dest="subgraphs", action="store_true",
                  help="combine headers from one subdir as a subgraphs", default=False)

(options, args) = parser.parse_args()

# Find all files in include dir
include_dir = os.path.abspath(options.include_dir)
header_files = []
for root, dirs, files in os.walk(include_dir):
    header_files.extend(os.path.join(root, header)[len(include_dir)+1:] for header in files)

c_files = []

if options.additional_dir:
    ad_dir = os.path.abspath(options.additional_dir)
    for root, dirs, files in os.walk(ad_dir):
        for filename in files:
            if filename.endswith('.c'):
                c_files.append(os.path.join(root, filename)[len(ad_dir)+1:])


includes = defaultdict(set)
reverse_includes = defaultdict(set)
include_regexp = re.compile('#include (\S+)')

for header in header_files:
    with open(os.path.join(include_dir, header)) as f:
        for line in f:
            match = include_regexp.match(line)
            if match:
                # Remove double quotes to normalize header name
                include_name = match.group(1).replace('"', '')
                includes[header].add(include_name)
                reverse_includes[include_name].add(header)

for header in c_files:
    with open(os.path.join(ad_dir, header)) as f:
        for line in f:
            match = include_regexp.match(line)
            if match:
                # Remove double quotes to normalize header name
                include_name = match.group(1).replace('"', '')
                includes[header].add(include_name)
                reverse_includes[include_name].add(header)


all_names = set(includes) | set(reverse_includes)

ALL_LEVELS = 100000

def traverse(workheap, deps):
    visited = set()
    while len(workheap):
        n_left, src = heappop(workheap)
        if n_left:
            for dest in deps[src]:
                if dest not in visited:
                    visited.add(dest)
                    heappush(workheap, (n_left+1, dest))
    return visited

def evaluate_spec(specs):
    select_specs = specs.split(",")
    workheap_up = []
    workheap_down = []
    
    selected = set()
    
    for spec in select_specs:
        match = re.match(r'(.*?)(\+([0-9]+|\*))?(-([0-9]+|\*))?$', spec)
        if not match:
            raise ValueError("Invalid select spec: %r" % spec)
        
        header_name = match.group(1)
        if '*' in header_name:
            h_regexp = re.compile(header_name.replace('*','.*')+'$')
            header_names = [name for name in all_names if h_regexp.match(name)]
        else:
            header_names = [header_name]
        for header_name in header_names:
            selected.add(header_name)
            if match.group(2):
                n = match.group(3)
                n = ALL_LEVELS if n == "*" else int(n)
                heappush(workheap_up, (-n, header_name))
            if match.group(4):
                n = match.group(5)
                n = ALL_LEVELS if n == "*" else int(n)
                heappush(workheap_down, (-n, header_name))
    
    selected.update(traverse(workheap_up, reverse_includes))
    selected.update(traverse(workheap_down, includes))
    return selected

if options.select:
    selected = evaluate_spec(options.select)
else:
    selected = set(includes).union(reverse_includes)

if options.exclude:
    selected -= evaluate_spec(options.exclude)

if not options.platform:
    selected -= set(inc for inc in selected if inc.startswith('<'))

if options.list:
    print("\n".join(sorted(selected)))
    sys.exit(0)

def normalize_name(name):
    return name.replace('<','&lt;').replace('>','&gt;')

include_edges = [(normalize_name(src_header), normalize_name(dest_header))
                     for src_header, includes in includes.items()
                     for dest_header in includes
                 if src_header in selected and dest_header in selected]


if len(args) < 1:
    parser.print_usage()
    sys.exit(1)

#graph = pydot.graph_from_edges(include_edges, directed=True)

graph = pydot.Dot(graph_type='digraph')
graph.set_rankdir("LR") # top-to-bottom rank gets too wide
graph.set_dpi("72")
graph.set_ranksep("1.5")

nodes = {}
subgraphs = {}

node_args = dict(shape='rectangle', fontsize='12.0', margin="0.05,0.01")

def split_package(name):
    parts = name.rsplit('/',1)
    if name.startswith('<') or len(parts) == 1:
        return '', name
    else:
        return parts[0], parts[1]

colors = {}
color_list = []
color_list.extend('/set312/%d' % i for i in xrange(1,13))
color_list.extend('/greys4/%d' % i for i in xrange(1,5))
color_list.extend('/paired12/%d' % i for i in xrange(1,13))
color_list.extend('/dark28/%d' % i for i in xrange(1,9))

for token in selected:
    parts = token.rsplit('/', 1)
    if len(parts) == 1:
        parts = ['', parts[0]]
    package = parts[0]
    if package not in colors:
        colors[package] = color_list.pop(0)
    node = nodes[token] = pydot.Node(token, label=token, style="filled", fillcolor=colors[package], **node_args)

if options.subgraphs:
    packages = defaultdict(list)
    for node_name, node in nodes.items():
        package, label_name = split_package(node_name)
        node.set_label(label_name)
        packages[package].append(node)
    for cluster_idx, (package_name, pkg_nodes) in enumerate(packages.items()):
        if not package_name:
            subgraph = graph
        else:
            subgraph = pydot.Subgraph('cluster_%d' % cluster_idx, label=package_name)
            graph.add_subgraph(subgraph)
        for node in pkg_nodes:
            subgraph.add_node(node)
else:
    for node in nodes.values():
        graph.add_node(node)

for src, dest in include_edges:
    graph.add_edge(pydot.Edge(nodes[src], nodes[dest]))

ext = args[0].rsplit('.',1)[-1]
if hasattr(graph, 'write_'+ext):
    getattr(graph, 'write_'+ext)(args[0])
else:
    graph.write(args[0])
