#include <stdlib.h>
#include <stdio.h>
#include <semaphore.h>

#include <unistd.h>
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <signal.h>

#define IPCProtection	(0600)	/* access/modify by user only */

int nthreads;
sem_t *sems;
unsigned char *wakers;

static void worker(int n);

int main(int argc, char *argv[])
{
  int i, shmid1, shmid2, runtime;
  pid_t *pids;

  if (argc <= 1)
    nthreads = 10;
  else 
    nthreads = atoi(argv[1]);

  if (argc <= 2)
    runtime = 10;
  else
    runtime = atoi(argv[2]);

  if (nthreads <= 0 || nthreads > 200)
    exit(1);
  if (runtime <= 0)
    exit(1);
  
  printf("running with %d processes for %ds\n", nthreads, runtime);

  shmid1 = shmget(IPC_PRIVATE, nthreads*sizeof(sem_t), IPC_CREAT | IPC_EXCL | IPCProtection);

  if (shmid1 == -1) {
    perror("shmget:");
    exit(1);
  }

  sems = shmat(shmid1, NULL, 0);
  for (i=0;i<nthreads;i++)
    if (sem_init(&sems[i], 1, 0) < 0) {
      perror("sem_init");
      exit(1);
    }

  shmid2 = shmget(IPC_PRIVATE, nthreads*sizeof(unsigned char), IPC_CREAT | IPC_EXCL | IPCProtection);

  if (shmid2 == -1) {
    perror("shmget:");
    exit(1);
  }

  wakers = shmat(shmid2, NULL, 0);
  for (i=0;i<nthreads;i++)
    wakers[i] = 255;

  pids = malloc(sizeof(pid_t)*nthreads);

  for (i=0;i<nthreads;i++) {
    /*printf("forking thread %d\n", i);*/
    switch(pids[i] = fork()) {
    case 0:
      worker(i);
      exit(0);
    case -1:
      perror("fork");
      exit(1);
    default:
      /*printf("successfully forked thread %d as pid %d\n", i, pids[i]);*/
      break;
    }
  }

  sleep(runtime);
  kill(pids[0], 3);
  sleep(1);

  for (i=1;i<nthreads;i++) {
    if (wakers[i] == 255)
      printf("thread %d lost a wakeup!!!\n", i);
    kill(pids[i], 3);
  }
  
  if (shmctl(shmid1, IPC_RMID, NULL) < 0)
    perror("smctl ipc_rmid");
  if (shmctl(shmid2, IPC_RMID, NULL) < 0)
    perror("smctl ipc_rmid");

  exit(0);
}

static void worker(int n)
{
  srandom(getpid());

  for(;;) {
    int waker;
    int i;

    /* wake anyone following us waiting for us to wake them */
    for (i=n+1;i<nthreads;i++) {
      if (wakers[i] == n) {
	/*printf("thread %d waking thread %d\n", n, i);*/
	wakers[i] = 255;
	if (sem_post(&sems[i]) < 0) {
	  perror("sem_post");
	  exit(1);
	}
      }
    }

    if (n == 0) {
      /* we're the first thread so we just sleep and then go around waking
	 people again */
      continue;
    }

    /* otherwise pick a random thread earlier than us to wake us and go to
       sleep until awoken by it */

    waker = random()%n;
    /*printf("thread %d sleeping waiting for %d to wake us\n", n, waker);*/
    wakers[n] = waker;
    if (sem_wait(&sems[n]) < 0) {
      perror("sem_wait");
      exit(1);
    }
    if ((waker = wakers[n]) != 255) {
      printf("thread %d awake but waker is still set to %d !!!!\n", n, waker);
      exit(1);
    }

    /*printf("thread %d awoken\n", n);*/
  }
}

