#include <time.h>
#include <stdio.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <inttypes.h>
#include <sys/time.h>
#include <pthread.h>
#include <libpq-fe.h>

#define MAX_THREADS 1024
#define MAX_ATTEMPTS 10
#define INT8OID 20

char const* connection = "dbname=postgres host=localhost port=5432 sslmode=disable connect_timeout=10";
int update_percent = 1;
int n_records = 100;
int n_clients = 100;
long selects[MAX_THREADS];
long updates[MAX_THREADS];
volatile int termination;

void* worker(void* arg)
{
        PGconn* con;
	PGresult *res;
	ConnStatusType status;
	Oid paramTypes[1] = { INT8OID };
	size_t id = (size_t)arg;
	int i;
	for (i = 0; i < MAX_ATTEMPTS; i++)
	{
	    con = PQconnectdb(connection);
	    status = PQstatus(con);
	    if (status == CONNECTION_OK)
	        break;
	    PQfinish(con);
        }
        if (status != CONNECTION_OK)
	{
		fprintf(stderr, "Could not establish connection to server %s, error = %s",
				connection, PQerrorMessage(con));
		exit(1);
	}
	PQprepare(con,
			  "update",
			  //"update t set v=v+1 from (select k from t where k=$1 for no key update) q where t.k=q.k",
			  "update t set v=v+1 where k=$1",
			  1,
			  paramTypes);
	PQprepare(con,
			  "select",
			  "select v from t where k=$1",
			  1,
			  paramTypes);

	while (!termination) {
		char key[64];
		char const* paramValues[] = {key};
		sprintf(key, "%d", rand() % n_records + 1);
		if (rand() % 100 < update_percent) {
			res = PQexecPrepared(con, "update", 1, paramValues, NULL, NULL, 0);
			if (PQresultStatus(res) != PGRES_COMMAND_OK) {
				fprintf(stderr, "Update failed: %s\n", PQresultErrorMessage(res));
				exit(1);
			}
			if (strcmp(PQcmdTuples(res), "1") != 0) {
				fprintf(stderr, "Update affect wrong number of tuples: %s\n", PQcmdTuples(res));
				exit(1);
			}
			updates[id] += 1;
		} else {
			res = PQexecPrepared(con, "select", 1, paramValues, NULL, NULL, 0);
			if (PQresultStatus(res) != PGRES_TUPLES_OK) {
				fprintf(stderr, "Select failed: %s\n", PQresultErrorMessage(res));
				exit(1);
			}
			if (PQntuples(res) != 1) {
				fprintf(stderr, "Select returns wrong number of tuples: %d\n", PQntuples(res));
				exit(1);
			}
			selects[id] += 1;
		}
		PQclear(res);
	}
	PQfinish(con);
	return 0;
}



int main (int argc, char* argv[])
{
	int i;
	pthread_t threads[MAX_THREADS];
	int test_duration = 10;
	long thread_updates[MAX_THREADS];
	long thread_selects[MAX_THREADS];
	time_t finish;
	int iteration = 0;
	int initialize = 0;
	long total_selects = 0;
	long total_updates = 0;
	int verbose = 0;
	if (argc == 1) {
        fprintf(stderr, "Use -h to show usage options\n");
        return 1;
    }

    for (i = 1; i < argc; i++) {
        if (argv[i][0] == '-') {
            switch (argv[i][1]) {
			  case 'n':
				n_records = atol(argv[++i]);
				continue;
			  case 'c':
				n_clients = atoi(argv[++i]);
				continue;
			  case 'u':
				update_percent = atoi(argv[++i]);
				continue;
			  case 't':
				test_duration = atoi(argv[++i]);
				continue;
	    case 'v':
	      verbose = 1;
	      continue;
	                  case 'd':
                connection = argv[++i];
                continue;
			  case 'i':
				initialize = 1;
				continue;
			}
        }
		printf("Options:\n"
			   "\t-i\tinitialize database\n"
			   "\t-n\tnumber of records (default 100)\n"
			   "\t-c\tnumber of client (default 100)\n"
			   "\t-u\tupdate percent (default 1%%)\n"
			   "\t-t\ttest duration (default 10 sec)\n"
			   "\t-d\tconnection string ('host=localhost port=5432')\n");
        return 1;
    }
	if (initialize) {
		PGconn* con = PQconnectdb(connection);
		char sql[256];
		sprintf(sql, "insert into t values (generate_series(1, %d), 0)", n_records);
		PQexec(con, "drop table t");
		PQexec(con, "create table t(k integer primary key, v integer)");
		PQexec(con, sql);
		PQfinish(con);
	}
	for (i = 0; i < n_clients; i++) {
		thread_updates[i] = 0;
		thread_selects[i] = 0;
		pthread_create(&threads[i], NULL, worker, (void*)(size_t)i);
	}
	finish = time(NULL) + test_duration;
	do {
		total_selects = 0;
		total_updates = 0;
		sleep(1);
		for (i = 0; i < n_clients; i++) {
			total_selects += selects[i] - thread_selects[i];
			thread_selects[i] = selects[i];
			total_updates += updates[i] - thread_updates[i];
			thread_updates[i] = updates[i];
		}
		if (verbose) 
		  printf("%d: %ld (%ld updates, %ld selects)\n", ++iteration, total_updates + total_selects, total_updates, total_selects);
	} while (time(NULL) < finish);

	total_selects = 0;
	total_updates = 0;
	for (i = 0; i < n_clients; i++) {
		total_selects += selects[i];
		total_updates += updates[i];
	}
	if (verbose)
	    printf("Summury: %ld (%ld updates, %ld selects)\n", (total_updates + total_selects)/test_duration, total_updates, total_selects);
	else 
	    printf("%ld\n", (total_updates + total_selects)/test_duration);

	termination = 1;

	for (i = 0; i < n_clients; i++) {
		pthread_join(threads[i], NULL);
	}
	return 0;
}
