import psycopg2
import random
import sys
import time
from cStringIO import StringIO

def norm(weights):
    total_weight = float(sum(weight for weight,value in weights))
    result = []
    cumulative = 0
    for weight, value in sorted(weights, reverse=True):
        cumulative += weight/total_weight
        result.append((cumulative, value))
    return result

def weighed_choice(weights):
    pick = random.random()
    for limit, value in weights:
        if pick < limit:
            return value
    # account for floating point error
    return value

# Doubly nested pairs of weight - value
weights = norm([
    (3,('a', norm([(4,1),(2,2),(0,3)]))),
    (2,('b', norm([(1,1),(5,2),(100,3)]))),
    (1,('c', norm([(1,1),(100,4),(5,5),(6,3)]))),
])

conn = psycopg2.connect(host='/tmp')
cur = conn.cursor()

tables = ['expdist', 'normdist', 'weigheddist']

if len(sys.argv) > 1 and sys.argv[1] == 'recreate':
    print "Dropping tables"
    try:
        for tablename in tables:
            cur.execute("DROP TABLE %s" % tablename)
        conn.commit()
    except Exception:
        conn.rollback()

print "Creating tables"
# a is exponential distribution, b is normal dist around a
cur.execute("CREATE TABLE expdist (id serial primary key, a int, b int)")
# a is a large normal dist, b is small normal dist around a
cur.execute("CREATE TABLE normdist (id serial primary key, a int, b int)")
# a is a small distribution of strings, b is ints with weights dependent on a
cur.execute("CREATE TABLE weigheddist (id serial primary key, a text, b int)")

NUM_BLOCKS = 10
BLOCK_SIZE = 100000
print "Creating %d rows of data" % (NUM_BLOCKS*BLOCK_SIZE)

for i in xrange(NUM_BLOCKS):
    buf = StringIO()
    for j in xrange(BLOCK_SIZE):
        s, s_weights = weighed_choice(weights)
        v = weighed_choice(s_weights)
        buf.write("%s,%s\n" % (s, v))
    buf.seek(0)
    cur.copy_expert("COPY weigheddist (a, b) FROM STDIN WITH (DELIMITER ',')", buf)
        
    buf = StringIO()
    for j in xrange(BLOCK_SIZE):
        a = int(random.normalvariate(1000,100))
        b = int(random.normalvariate(a,10))
        buf.write("%d,%d\n" % (a, b))
    buf.seek(0)
    cur.copy_expert("COPY normdist (a, b) FROM STDIN WITH (DELIMITER ',')", buf)

    buf = StringIO()
    for j in xrange(BLOCK_SIZE):
        a = int(random.expovariate(1.0/16.0))
        b = int(random.normalvariate(a, 4.0))
        buf.write("%d,%d\n" % (a, b))
    buf.seek(0)
    cur.copy_expert("COPY expdist (a, b) FROM STDIN WITH (DELIMITER ',')", buf)
        
    sys.stdout.write("\r%d%% - %d rows" % (100*(i+1)/NUM_BLOCKS, (i+1)*BLOCK_SIZE))
    sys.stdout.flush()
print

print "Creating cross col stats"
for tablename in tables:
    cur.execute("CREATE CROSS COLUMN STATISTICS ON TABLE %s (a, b) WITH (1000)" % tablename)

print "Analyzing tables"
for tablename in tables:
    cur.execute("ANALYZE %s" % tablename)

conn.commit()
