#!/usr/bin/env python
import os, sys
from optparse import OptionParser
from collections import defaultdict
import re
import pydot

from heapq import heappush, heappop

usage = "usage: %prog [options] output_file.png"
description = """Go through the include dir, find include relationships from 
all files and render them as a directed graph."""
parser = OptionParser(usage=usage, description=description)
parser.add_option("-i", "--include", dest="include_dir", metavar="DIR",
                  help="parse includes from DIR [default: %default]", default=".")
parser.add_option("-s", "--select", dest="select", metavar="SPEC",
                  help="Select nodes by SPEC. Spec is comma separated "
                       "list of header names, followed by optional traverse "
                       "operator. Traverse operator is +N to select N levels "
                       "up, -N to select N levels down. +* and -* respectively "
                       "to select all upstream or downstream dependencies. For "
                       "example \"storage/spin.h+*,access/xlog.h+*\" selects "
                       "all headers that directly or indirectly depend on "
                       "storage.h or xlog.h.")
parser.add_option("-l", "--list", dest="list", action="store_true",
                  help="show list of found includes", default=False)
parser.add_option("-p", "--platform", dest="platform", action="store_true",
                  help="show list of platform includes", default=False)

(options, args) = parser.parse_args()

# Find all files in include dir
include_dir = os.path.abspath(options.include_dir)
header_files = []
for root, dirs, files in os.walk(include_dir):
    header_files.extend(os.path.join(root, header)[len(include_dir)+1:] for header in files)

includes = defaultdict(set)
reverse_includes = defaultdict(set)
include_regexp = re.compile('#include (\S+)')

for header in header_files:
    with open(os.path.join(include_dir, header)) as f:
        for line in f:
            match = include_regexp.match(line)
            if match:
                # Remove double quotes to normalize header name
                include_name = match.group(1).replace('"', '')
                includes[header].add(include_name)
                reverse_includes[include_name].add(header)

ALL_LEVELS = 100000

def traverse(workheap, deps):
    visited = set()
    while len(workheap):
        n_left, src = heappop(workheap)
        visited.add(src)
        if n_left:
            for dest in deps[src]:
                if dest not in visited:
                    heappush(workheap, (n_left+1, dest))
    return visited

if options.select:
    select_specs = options.select.split(",")
    workheap_up = []
    workheap_down = []
    
    for spec in select_specs:
        match = re.match(r'(.*?)(\+([0-9]+|\*))?(-([0-9]+|\*))?$', spec)
        if not match:
            raise ValueError("Invalid select spec: %r" % spec)
        
        header_name = match.group(1)
        if match.group(2):
            n = match.group(3)
            n = ALL_LEVELS if n == "*" else int(n)
            heappush(workheap_up, (-n, header_name))
        if match.group(4):
            n = match.group(5)
            n = ALL_LEVELS if n == "*" else int(n)
            heappush(workheap_down, (-n, header_name))
    
    selected = traverse(workheap_up, reverse_includes)
    selected.update(traverse(workheap_down, includes))
else:
    selected = set(includes).union(reverse_includes)

if not options.platform:
    selected -= set(inc for inc in selected if inc.startswith('<'))

if options.list:
    print("\n".join(sorted(selected)))
    sys.exit(0)

def normalize_name(name):
    return name.replace('<','&lt;').replace('>','&gt;')

include_edges = [(normalize_name(src_header), normalize_name(dest_header))
                     for src_header, includes in includes.items()
                     for dest_header in includes
                 if src_header in selected and dest_header in selected]


if len(args) < 1:
    parser.print_usage()
    sys.exit(1)

graph = pydot.graph_from_edges(include_edges, directed=True)
graph.set_rankdir("LR") # top-to-bottom rank gets too wide
graph.set_dpi("72")

ext = args[0].rsplit('.',1)[-1]
if hasattr(graph, 'write_'+ext):
    getattr(graph, 'write_'+ext)(args[0])
else:
    graph.write(args[0])
