import psycopg2
import sys
import time

conn = psycopg2.connect('host=localhost dbname=test')

conn.set_session(autocommit=True)

nrows = 10000000

for t in ['int', 'text']:

	for o in ['minmax', 'bloom']:

		for d in [int(nrows*5/100), int(nrows/100)]:

			if t == 'int':
				cur = conn.cursor()
				cur.execute('drop table if exists t')
				cur.execute('create table t (a int)')
				cur.execute('insert into t select * from (select i from generate_series(1,%d) s(i)) foo order by (i + %d * random() / 100.0)' % (nrows, d))

				if o == 'minmax':
					cur.execute('create index on t using brin (a int4_minmax_ops) with (pages_per_range=1)')
				else:
					cur.execute('create index on t using brin (a int4_bloom_ops(n_distinct_per_range=200)) with (pages_per_range=1)')

				cur.execute('vacuum analyze t')
				cur.close()
			else:
				cur = conn.cursor()
				cur.execute('drop table if exists t')
				cur.execute('create table t (a text)')
				cur.execute('insert into t select a from (select row_number() over (order by a) as i, a from (select md5(i::text) a from generate_series(1,%d) s(i) order by 1) foo) bar order by (i + %d * random() / 100.0)' % (nrows, d))

				if o == 'minmax':
					cur.execute('create index on t using brin (a text_minmax_ops) with (pages_per_range=1)')
				else:
					cur.execute('create index on t using brin (a text_bloom_ops(n_distinct_per_range=250)) with (pages_per_range=1)')

				cur.execute('vacuum analyze t')
				cur.close()

			cnt = sum([1, 10, 20, 50, 100, 500]) * 10

			cur = conn.cursor()
			cur.execute('select a from t order by random() limit %d' % (cnt,))
			values = [str(v[0]) for v in cur.fetchall()]
			cur.close()

			# 10 runs
			for r in range(0,10):

				# number of values in the lists
				for v in [1, 10, 20, 50, 100, 500]:

					vals = values[:v]
					values = values[v:]

					if len(vals) != v:
						print ('incorrect length', len(vals))
						sys.exit(1)

					if t == 'text':
						vals = "'" + "','".join(vals) + "'"
					else:
						vals = ','.join(vals)

					cur = conn.cursor()

					cur.execute('set enable_seqscan = off')
					cur.execute('set max_parallel_workers_per_gather = 0')

					query = 'explain (analyze, timing off) select * from t where a in (%s)' % (vals,)

					s = time.time()
					cur.execute(query)
					e = time.time()

					cur.close()

					print(t, o, r, d, v, round(1000 * (e-s),3))

					sys.stdout.flush()
