#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>

#include <unistd.h>
#include <signal.h>
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sys/sem.h>
#include <sys/time.h>

union semun {
  int              val;    /* Value for SETVAL */
  struct semid_ds *buf;    /* Buffer for IPC_STAT, IPC_SET */
  unsigned short  *array;  /* Array for GETALL, SETALL */
  struct seminfo  *__buf;  /* Buffer for IPC_INFO (Linux specific) */
};

#define SEMAS_PER_SET	16
#define IPCProtection	(0600)	/* access/modify by user only */
#define PGSemaMagic		537		/* must be less than SEMVMX */
#define SEMAS_KEY_START	(5431*1000-1)

int nthreads, timeout, shmid;
volatile unsigned char *wakers;
typedef struct PGSemaphoreData
{
	int			semId;			/* semaphore set identifier */
	int			semNum;			/* semaphore number within set */
} PGSemaphoreData;
PGSemaphoreData *sems;

static void atexit_handler();
static void worker(int n);

static void down(int n);
static void up(int n);

#define MAX_THREADS 250
#define WAKER_NOOP 253
#define WAKER_EXIT 254
#define WAKER_RUNNABLE 255

/* this just forces the atexit handler to be called */
static void handle_sig(int arg) {exit(127+arg);}

int main(int argc, char *argv[])
{
  int i, semKey, runtime;
  pid_t *pids;
  struct sigaction act, oact;
  int semId=-1;

  if (argc <= 1)
    nthreads = 10;
  else 
    nthreads = atoi(argv[1]);

  if (nthreads <= 0 || nthreads > MAX_THREADS) {
    fprintf(stderr, "usage: nthreads not between 1 and %d\n", MAX_THREADS);
    exit(1);
  }

  if (argc <= 2)
    runtime = 10;
  else
    runtime = atoi(argv[2]);

  if (runtime < 1) {
    fprintf(stderr, "usage: runtime shorter than 1s\n");
    exit(1);
  }

  if (argc <= 3)
    timeout = 1000*60;
  else
    timeout = 1000.0*atof(argv[3]);

  if (timeout < 1) {
    fprintf(stderr, "usage: timeout shorter than 1s\n");
    exit(1);
  }
  
  printf("running with %d processes for %ds with timeout of %dms\n", nthreads, runtime, timeout);
  sems = malloc(sizeof(*sems)*nthreads);

  semKey = SEMAS_KEY_START;
  for (i=0;i<nthreads;i++) {
    union semun semun;
    int semNum = i % SEMAS_PER_SET;

    if (semNum == 0) {
      semKey += 1;
      semId = semget(semKey, SEMAS_PER_SET, IPC_CREAT | IPC_EXCL | IPCProtection);
      if (semId < 0) {
	perror("semget");
	exit(1);
      }
    }

    semun.val = 0;
    if (semctl(semId, semNum, SETVAL, semun) < 0) {
      fprintf(stderr, "semctl(%d, %d, SETVAL, 0): %s\n", semId, semNum, strerror(errno));
      exit(1);
    }

    if (semId<0 || semNum > SEMAS_PER_SET)
      exit(1);

    sems[i].semId = semId;
    sems[i].semNum= semNum;
  }
  
  shmid = shmget(IPC_PRIVATE, nthreads*sizeof(unsigned char), IPC_CREAT | IPC_EXCL | IPCProtection);

  if (shmid == -1) {
    perror("shmget");
    exit(1);
  }


  wakers = shmat(shmid, NULL, 0);
  wakers[0] = WAKER_NOOP;
  for (i=1;i<nthreads;i++)
    wakers[i] = WAKER_RUNNABLE;

  pids = malloc(sizeof(pid_t)*nthreads);

  for (i=0;i<nthreads;i++) {
    /*printf("forking thread %d\n", i);*/
    switch(pids[i] = fork()) {
    case 0:
      worker(i);
      exit(0);
    case -1:
      perror("fork");
      exit(1);
    default:
      /*printf("successfully forked thread %d as pid %d\n", i, pids[i]);*/
      break;
    }
  }

  act.sa_handler = handle_sig;
  sigemptyset(&act.sa_mask);
  act.sa_flags = 0;
  if (sigaction(SIGINT, &act, &oact) < 0)
    perror("sigaction");
  atexit(atexit_handler);

  sleep(runtime);
  printf("telling threads to exit\n");

  for (i=0;i<nthreads;i++) {
    while (wakers[i] == WAKER_RUNNABLE) {
      printf("still waiting for thread %d to block\n", i);
      sleep(1);
    }

    /*printf("telling thread %d to exit\n", i);*/
    wakers[i] = WAKER_EXIT;
    up(i);
    usleep(20000);
    
    while(wakers[i] == WAKER_EXIT) {
      printf("still waiting for thread %d to exit\n", i);
      sleep(1);
    }
  }

  printf("run done\n");
  
  exit(0);
}

void static atexit_handler()
{
  int i;

  printf("cleaning up semaphores and shared memory\n");

  for (i=1;i<nthreads;i++)
    wakers[i] = WAKER_EXIT;

  if (shmctl(shmid, IPC_RMID, NULL) < 0)
    perror("shmctl ipc_rmid");

  for(i=0; i<nthreads; i += SEMAS_PER_SET) {
    union semun semun;
    semun.val = 0;
    if (semctl(sems[i].semId, 0, IPC_RMID, semun) < 0)
      fprintf(stderr, "semctl(sems[%d].semId==%d, 0, IPCS_RMID, {0}: %s\n", i, sems[i].semId, strerror(errno));
  }
}



int MyThread;
volatile int sigalarm_fired, sigalarm_found_myself_runnable;

static void
handle_sig_alarm(int arg)
{
  int waker = wakers[MyThread];

  sigalarm_fired = 1;

  if (waker == WAKER_RUNNABLE)
    sigalarm_found_myself_runnable = 1;
}

static void worker(int n)
{
  long niterations=0;
  struct itimerval timeval;
  struct sigaction act, oact;

  srandom(getpid());
  MyThread = n;

  act.sa_handler = handle_sig_alarm;
  sigemptyset(&act.sa_mask);
  act.sa_flags = 0;
  if (sigaction(SIGALRM, &act, &oact) < 0)
    perror("sigaction");

  for(;;) {
    int waker;
    int i;

    /* wake anyone following us waiting for us to wake them */
    for (i=n+1;i<nthreads;i++) {
      if (wakers[i] == n) {
	/*printf("thread %d waking thread %d\n", n, i);*/
	wakers[i] = WAKER_RUNNABLE;
	up(i);
      }
    }
    
    niterations++;
    if (wakers[n] == WAKER_EXIT) {
      /*printf("thread %d exiting after %ld iterations\n", n, niterations);*/
      wakers[n] = WAKER_NOOP;
      exit(0);
    }
  
    if (n == 0) {
      /* we're the first thread so we just sleep and then go around waking
	 people again */
      usleep(10000);
      continue;
    }

    /* otherwise pick a random thread earlier than us to wake us and go to
       sleep until awoken by it */

    sigalarm_fired = 0;
    sigalarm_found_myself_runnable = 0;
    memset(&timeval, 0, sizeof(struct itimerval));
    timeval.it_value.tv_sec  = timeout / 1000;
    timeval.it_value.tv_usec = (timeout % 1000) * 1000;
    if (setitimer(ITIMER_REAL, &timeval, NULL)) {
      perror("setitimer");
      exit(1);
    }
    waker = random()%n;
    /*printf("thread %d sleeping waiting for %d to wake us\n", n, waker);*/
    wakers[n] = waker;
    down(n);

    if ((waker = wakers[n]) <= MAX_THREADS) {
      printf("thread %d awake but waker is still set to %d !!!!\n", n, waker);
      exit(1);
    }

    memset(&timeval, 0, sizeof(struct itimerval));
    if (setitimer(ITIMER_REAL, &timeval, NULL)) {
      perror("setitimer");
      exit(1);
    }

    /*
    if (sigalarm_fired)
      printf("timer fired\n");
    */
    if (sigalarm_found_myself_runnable)
      printf("thread %d lost a wakeup!!!\n", n);

    /*printf("thread %d awoken\n", n);*/
  }
}

static void mysemop(int n, int op)
{
  struct sembuf sops;
  int errstatus;

  sops.sem_op = op;
  sops.sem_flg = 0;
  sops.sem_num = sems[n].semNum;
  do {
    errstatus = semop(sems[n].semId, &sops, 1);
  } while (errstatus < 0 && errno == EINTR);

  if (errstatus < 0) {
    fprintf(stderr, "semop(%d, {%d,0,%d}, 1): %s\n", sems[n].semId, op, sems[n].semNum, strerror(errno));
    exit(1);
  }
}

static void up(int n)
{
  return mysemop(n,1);
}

static void down(int n)
{
  return mysemop(n,-1);
}


