import time, psycopg2, os, sys, threading, Queue

# Set this to connect to the victim database
def ConnectDB():
	db = psycopg2.connect( host='localhost', user='peufeu', database='peufeu', port='5433' )
	return db

# Number of rows in benchmark
#
TOTAL_ROWS	= 20000000

# Number of rows in the data file (same data file is used over and over again)
# Data file will be COPYed TOTAL_ROWS // FILE_ROWS times.
# We are benchmarking COPY, so using a well cached data source is good.
# 
FILE_ROWS	= 50000

# number of tries (shortest timing is kept)
TRIES = 1

# nb of threads in parallel
THREADS = 4

# Where to put the COPY files (add a / at the end)
STORAGE = '/home/peufeu/temp/'

VERBOSE = True

# execs multi-statement query, silently and without caring for errors
def QueryNoError( *lines ):
	for line in lines:
		try:
			db = ConnectDB()
			cursor = db.cursor()
			cursor.execute( line )
			cursor.execute( "COMMIT" )
			db.close()
		except Exception, e:
			#~ print e
			pass

# Creates the tables 
def CreateTables():
	QueryNoError( 
		"DROP TABLE template_test_ints;", 
		"CREATE TABLE template_test_ints ( rowid INTEGER, a INTEGER, b INTEGER, c INTEGER, d INTEGER, e INTEGER, f INTEGER, g INTEGER, h INTEGER );",
		"DROP TABLE template_test;", 
		"CREATE TABLE template_test ( rowid INTEGER, a INTEGER, b INTEGER, c INTEGER, d FLOAT, e FLOAT, f NUMERIC(8,2), g NUMERIC(8,2), h TIMESTAMP WITHOUT TIME ZONE ); COMMIT;" 
		"DROP TABLE template_test_t500;", 
		"CREATE TABLE template_test_t500 ( rowid INTEGER, a TEXT, b TEXT, c TEXT, d TEXT );",
		"DROP TABLE template_test_t3k;", 
		"CREATE TABLE template_test_t3k ( rowid INTEGER, " + ", ".join( "%s TEXT"%chr(n) for n in xrange( ord('a'),ord('z')+1))+ " );",
		)

# creates the data files used in the test
def CreateData():
	db = ConnectDB()
	cursor = db.cursor()
	cursor.execute( """
COPY (SELECT rowid,
(random()*4294967294-2147483647)::INTEGER, (random()*4294967294-2147483647)::INTEGER,
(random()*4294967294-2147483647)::INTEGER, (random()*4294967294-2147483647)::INTEGER,
(random()*4294967294-2147483647)::INTEGER, (random()*4294967294-2147483647)::INTEGER,
(random()*4294967294-2147483647)::INTEGER, (random()*4294967294-2147483647)::INTEGER 
FROM generate_series(%s,%s) AS rowid
) 
TO %s""", (1, FILE_ROWS, STORAGE + 'test_ints.txt' ))

	cursor.execute( """
COPY (SELECT rowid,
(random()*4294967294-2147483647)::INTEGER,
(random()*4294967294-2147483647)::INTEGER,
(random()*4294967294-2147483647)::INTEGER,
random()::FLOAT,
random()::FLOAT,
(random()*1000000)::NUMERIC(8,2),
(random()*1000000)::NUMERIC(8,2),
now() + '1 mon'::INTERVAL * random()
FROM generate_series(%s,%s) AS rowid
) 
TO %s;""", (1, FILE_ROWS, STORAGE + 'test.txt' ))

	cursor.execute( """
COPY (SELECT rowid,
'd66f 5c2b a853 0137 57c7 c13c ec61 ac95 fad6 a47a 2913 dc1d d933 bdad 91ee ddd1 7cd3 13c2 fc7d 3a2a 6760 4b10 6c30 458d',
'd66f 5c2b a853 0137 57c7 c13c ec61 ac95 fad6 a47a 2913 dc1d d933 bdad 91ee ddd1 7cd3 13c2 fc7d 3a2a 6760 4b10 6c30 458d',
'd66f 5c2b a853 0137 57c7 c13c ec61 ac95 fad6 a47a 2913 dc1d d933 bdad 91ee ddd1 7cd3 13c2 fc7d 3a2a 6760 4b10 6c30 458d',
'd66f 5c2b a853 0137 57c7 c13c ec61 ac95 fad6 a47a 2913 dc1d d933 bdad 91ee ddd1 7cd3 13c2 fc7d 3a2a 6760 4b10 6c30 458d'
FROM generate_series(%s,%s) AS rowid
) 
TO %s;""", (1, FILE_ROWS, STORAGE + 'test_t500.txt' ))

	cursor.execute( """
COPY (SELECT rowid,
""" + ','.join( ["'d66f 5c2b a853 0137 57c7 c13c ec61 ac95 fad6 a47a 2913 dc1d d933 bdad 91ee ddd1 7cd3 13c2 fc7d 3a2a 6760 4b10 6c30 458d'"] * 26 ) + """
FROM generate_series(%s,%s) AS rowid
) 
TO %s;""", (1, FILE_ROWS, STORAGE + 'test_t3k.txt' ))


#
#	Helper functions
#
def MakeQueue():
	# Create a queue of stuff to COPY
	nf = TOTAL_ROWS // FILE_ROWS
	queue = Queue.Queue( nf+10 )
	for n in xrange( nf ):
		queue.put( n )
	return queue

def RunThreads( func ):
	threads = [ threading.Thread( target=func, args=(n,) ) for n in xrange( THREADS ) ]
	[ t.start() for t in threads ]
	[ t.join() for t in threads ]

def CheckTable( table ):
	cursor = ConnectDB().cursor()
	cursor.execute( "SELECT count(*) FROM %s" % table )
	rows = cursor.fetchone()[0]
	print "Found %d rows" % rows
	assert rows == TOTAL_ROWS // FILE_ROWS * FILE_ROWS

#
#	Parallel COPY with xlog
#
def TestParallelCopy( table ):
	QueryNoError( "DROP TABLE %s;" % table, "CHECKPOINT" )
	cursor = ConnectDB().cursor()
	cursor.execute( "CREATE TABLE %s (LIKE template_%s) TABLESPACE ramdisk;" % (table,table) )
	cursor.execute( "COMMIT;" )
	
	timings = []
	queue	= MakeQueue()
	
	# thread function
	def thread( n ):
		try:
			c = ConnectDB().cursor()
			t = time.time()
			c.execute( "BEGIN;" )
						
			while True:
				try:
					elem = queue.get_nowait()
				except Queue.Empty:
					break
				
				if VERBOSE: 
					sys.stdout.write( "%s%d\r" % ("\t"*n,elem))
					sys.stdout.flush()
				c.execute( "COPY %s FROM %%s" % table, (STORAGE + table + '.txt', ))
			
			c.execute( "COMMIT" )
			timings.append( time.time() - t )
			if VERBOSE:
				sys.stdout.write( "%sDONE\r" % ("\t"*n))
				sys.stdout.flush()		
		finally:
			c.execute( "ROLLBACK" )
	
	RunThreads( thread )
	#~ CheckTable( table )
	QueryNoError( "DROP TABLE %s;" % table, "CHECKPOINT" )
		
	if VERBOSE: 
		print "All threads timings :", timings
	print "Total time :", max( timings )
	print
		
def TestParallelCopyNoLog( table ):
	for n in xrange( THREADS ):
		QueryNoError( "DROP TABLE t%d" % n )
	QueryNoError( "DROP TABLE %s;" % table, "CHECKPOINT" )
	
	timings = []
	queue	= MakeQueue()
	
	# thread function
	def thread( n ):
		try:
			c = ConnectDB().cursor()
			t = time.time()
			c.execute( "BEGIN;" )
			c.execute( "CREATE TABLE t%d (LIKE template_%s) TABLESPACE ramdisk;" % (n, table) )
						
			while True:
				try:
					elem = queue.get_nowait()
				except Queue.Empty:
					break
				
				if VERBOSE: 
					sys.stdout.write( "%s%d\r" % ("\t"*n,elem))
					sys.stdout.flush()
				c.execute( "COPY t%d FROM %%s" % n, (STORAGE + table + '.txt', ))
			
			
			c.execute( "COMMIT" )
			timings.append( time.time() - t )
			if VERBOSE:
				sys.stdout.write( "%sDONE\r" % ("\t"*n))
				sys.stdout.flush()		
		finally:
			c.execute( "ROLLBACK" )
	
	RunThreads( thread )
	
	for n in xrange( THREADS ):
		QueryNoError( "DROP TABLE t%d" % n )
	
	if VERBOSE: 
		print "All threads timings :", timings
	print "Total time :", max( timings )
	print

def TestParallelCopyHack( table ):
	for n in xrange( THREADS ):
		QueryNoError( "DROP TABLE t%d" % n )
	QueryNoError( "DROP TABLE %s;" % table, "CHECKPOINT" )
	
	timings = []
	queue	= MakeQueue()
	
	# thread function
	def thread( n ):
		try:
			c = ConnectDB().cursor()
			t = time.time()
			c.execute( "BEGIN;" )
			c.execute( "CREATE TABLE t%d (LIKE template_%s) TABLESPACE ramdisk;" % (n, table) )
						
			while True:
				try:
					elem = queue.get_nowait()
				except Queue.Empty:
					break
				
				if VERBOSE: 
					sys.stdout.write( "%s%d\r" % ("\t"*n,elem))
					sys.stdout.flush()
				c.execute( "COPY t%d FROM %%s" % n, (STORAGE + table + '.txt', ))
			
			
			c.execute( "COMMIT" )
			timings.append( time.time() - t )
			if VERBOSE:
				sys.stdout.write( "%sDONE\r" % ("\t"*n))
				sys.stdout.flush()		
		finally:
			c.execute( "ROLLBACK" )
	
	RunThreads( thread )
	
	sql = "CREATE TABLE %s TABLESPACE ramdisk AS %s;" % (table, " UNION ALL ".join( "SELECT * FROM t%d"%x for x in xrange(THREADS) ) )
	print sql
	cur = ConnectDB().cursor()
	t = time.time()
	cur.execute( sql )
	cur.execute( "COMMIT" )
	t = time.time() - t
	
	for n in xrange( THREADS ):
		QueryNoError( "DROP TABLE t%d" % n )
	
	if VERBOSE: 
		print "All threads timings :", timings
	print "Total time :", max( timings ) + t
	print
	
CreateTables()
CreateData()

#~ print "8 INTS, hack =>"
#~ TestParallelCopyHack( "test_ints" )

print "8 INTS, no xlog =>"
TestParallelCopyNoLog( "test_ints" )

print "8 INTS =>"
TestParallelCopy( "test_ints" )

TOTAL_ROWS	= 15000000

print "Doubles & Numerics, no xlog =>"
TestParallelCopyNoLog( "test" )

print "Doubles & Numerics =>"
TestParallelCopy( "test" )

TOTAL_ROWS	= 3200000

print "500B TEXT, no xlog    =>"
TestParallelCopyNoLog( "test_t500" )

print "500B TEXT    =>"
TestParallelCopy( "test_t500" )

