import psycopg2
import threading
from time import time, sleep
from datetime import datetime
import sys

TBL_SIZE = 1000
PADDING_SIZE = 100
DISPLAY_INTERVAL = 10
CONNSTRING = sys.argv[1] if len(sys.argv) > 1 else ""

def connect(name):
    conn = psycopg2.connect("%s application_name=%s" % (CONNSTRING, name))
    return conn

def init_high_tp(cur):
    cur.execute("DROP TABLE IF EXISTS high_throughput")    
    cur.execute("CREATE TABLE high_throughput (id int4 primary key, padding text)")
    cur.execute("INSERT INTO high_throughput SELECT x, repeat(' ', %d) FROM generate_series(1,%d) x" % (PADDING_SIZE, TBL_SIZE))

def show(msg):
    print datetime.now().strftime("[%H:%M:%S] "), msg

def high_tp_thread():
    conn = connect("write-workload")
    cur = conn.cursor()
    init_high_tp(cur)
    cur.execute("SHOW old_snapshot_threshold")
    row = cur.fetchone()
    show("old_snapshot_threshold = %s" % (row[0],))
    conn.commit()

    last_display = 0
    i = 1
    start = time()
    while True:
        cur_time = time()
        if cur_time - last_display > DISPLAY_INTERVAL:
            last_display = cur_time
            cur.execute("SELECT pg_table_size('high_throughput'), clock_timestamp() - last_autovacuum FROM pg_stat_user_tables WHERE relname = 'high_throughput'")
            row = cur.fetchone()
            show("High throughput table size @ %5ds. Size %6dkB Last vacuum %s ago" % (int(cur_time - start), row[0]/1024,row[1],))

        cur.execute("DELETE FROM high_throughput WHERE id = %s", (i,))
        cur.execute("INSERT INTO high_throughput VALUES (%s, REPEAT(' ',%s))", (i+TBL_SIZE, PADDING_SIZE))
        conn.commit()
        i += 1
    conn.close()

def long_ss_thread(interval=1800):
    conn = connect("long-unrelated-query")
    cur = conn.cursor()
    while True:
        show("Starting %ds long query" % interval)
        try:
            cur.execute("SELECT NOW(), pg_sleep(%s)", (interval,))
            conn.commit()
        except psycopg2.Error, e:
            show("Long query canceled due to %s" % (e,))
            break

def long_ss_error_thread():
    sleep(1)
    conn = connect("interfering-query")
    cur = conn.cursor()
    while True:
        try:
            cur.execute("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ")
            while True:
                cur.execute("SELECT COUNT(*), MAX(id) FROM high_throughput")
                row = cur.fetchone()
                show("Counted %d rows with max %d in high_throughput table" % (row[0],row[1],))
                sleep(DISPLAY_INTERVAL)
        except psycopg2.Error, e:
            show("Interfering query got error %s" % (e,))
            try:
                conn.rollback()
            except:
                return
            show("Waiting 3min to restart interfering query")
            sleep(180)

threads = []
for parallel_func in [high_tp_thread, long_ss_thread, long_ss_error_thread]:
    t = threading.Thread(target=parallel_func)
    threads.append(t)
    t.start()

for t in threads:
    t.join()

