/* Behavioural controls */
#define HAVE_TCP_KEEPALIVE
static const unsigned short listen_port = 9999;
static const unsigned int conn_check_delay_seconds = 10;
/* End behavioural controls */

#define _POSIX_SOURCE
#define _BSD_SOURCE
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <errno.h>
#include <unistd.h>
#include <fcntl.h>
#include <signal.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <stdint.h>
#include <assert.h>
#if defined(HAVE_LINUX_IP_RECVERR)
#include <linux/errqueue.h>
#endif

/**
 * Enable TCP keepalives on the socket fd passed, if the platform
 * supports them.
 *
 * \param clsockfd file descriptor for target socket
 * \returns -1 on error, 0 on success
 */
int set_keepalive(int clsockfd) {
#if defined(HAVE_TCP_KEEPALIVE)
	int val = 1;
	socklen_t vallen = sizeof(val);
	int ret = setsockopt(clsockfd, SOL_SOCKET, SO_KEEPALIVE, &val, vallen);
	if (ret != 0)
		return ret;
	/* For testing purposes, make the keepalives insanely, stupidly aggressive */
	val = 5;
	if (setsockopt(clsockfd, SOL_TCP, TCP_KEEPCNT, &val, vallen)) /* Only probe val time(s) before giving up */
		perror("setsockopt(sock, SOL_TCP, TCP_KEEPCNT, 1)");
	val = 5;
	if (setsockopt(clsockfd, SOL_TCP, TCP_KEEPIDLE, &val, vallen)) /* Assume the connection is idle after val second(s) of inactivity */
		perror("setsockopt(sock, SOL_TCP, TCP_KEEPIDLE, 1)");
	val = 5;
	if (setsockopt(clsockfd, SOL_TCP, TCP_KEEPINTVL, &val, vallen)) /* Poke the remote end every val second(s) when the connection is idle */
		perror("setsockopt(sock, SOL_TCP, TCP_KEEPINTVL, 1)");
	return 0;
#else
	return -1;
#endif
}

static void printpeer(struct sockaddr_in * addr) {
	char buf[20];
	inet_ntop(AF_INET, &addr->sin_addr.s_addr, &buf[0], sizeof(buf));
	printf("Accepted connection from peer %s:%hu\n", buf, ntohs(addr->sin_port));
}

static int createsrvsock() {
	int srvsockfd = socket(AF_INET, SOCK_STREAM, 0);
        struct sockaddr_in addr;
	addr.sin_family = AF_INET;
	addr.sin_port = htons(listen_port);
	addr.sin_addr.s_addr = INADDR_ANY;
	if (bind(srvsockfd, &addr, sizeof(addr))) {
		printf("Couldn't bind socket: %i %s\n", errno, strerror(errno));
		exit(errno);
	}
	if (listen(srvsockfd, 1)) {
		printf("Couldn't listen: %i %s\n", errno, strerror(errno));
		exit(errno);
	}
	return srvsockfd;
}

static void setupclsock(int clsockfd) {
	/* Enable socket keepalives if available */
	if (set_keepalive(clsockfd) == -1) {
		printf("Couldn't enable socket keepalives\n");
	}
}

static int pokeclient(int clsockfd) {
	/* Anything eventful on the socket? */
	char readbuf[2048];
	ssize_t read_size = recv(clsockfd, &readbuf, sizeof(readbuf), MSG_DONTWAIT|MSG_PEEK );
	if (read_size == 0) {
		/* If the remote end disconnected cleanly we get a zero-length read */
		printf("Connection dropped by remote\n");
		return 0;
	} else if (read_size > 0) {
		read_size = read(clsockfd, &readbuf, sizeof(readbuf));
		printf("Read and discarded %u bytes\n", read_size);
		return 1;
	} else if (read_size == -1) {
		if (errno == EAGAIN) {
			/* No data ready so MSG_DONTWAIT flag caused EAGAIN return. Nothing interesting happening on the socket. */
			return 1;
		} else if (errno == ETIMEDOUT) {
			/* If a TCP keepalive kills the remote end, we exit here */
			printf("recv(..., MSG_DONTWAIT|MSG_PEEK) returned ETIMEDOUT: Remote has dropped or lost the connection.\n");
			return 0;
		} else {
			/* Something else went wrong */
			printf("recv(..., MSG_DONTWAIT|MSG_PEEK) returned %i: %s", errno, strerror(errno));
			printf("Aborting\n");
			exit(1);
		}
	}
	/* Unreachable */
	assert(0);
}

void alarm_callback(int signum) {
	assert(signum == SIGALRM);
	printf("(alarm)");
	/* Re-schedule alarm */
	alarm(conn_check_delay_seconds);
}

int main() {
	int srvsockfd = createsrvsock();

	socklen_t len;
        struct sockaddr_in addr;
	int clsockfd = accept(srvsockfd, &addr, &len);
	if (clsockfd == -1) {
		printf("accept() failure: %i %s\n", errno, strerror(errno));
		return errno;
	}

	setupclsock(clsockfd);
	printpeer(&addr);

	/*
	 * Set up an alarm to periodically check the client socket, and set our
	 * signal handler for SIGALRM to do nothing interesting.
	 */
	{
		struct sigaction act;
		act.sa_handler = &alarm_callback;
		act.sa_flags = 0;
		sigemptyset(&act.sa_mask);
		if (sigaction( SIGALRM, &act, NULL ))
		{
			perror("sigaction()");
			return errno;
		}

		alarm(conn_check_delay_seconds);
	}

	/* 
	 * Do some long-running non-network-related "work"
	 */
	int sleeptime = 3600;
	int should_continue = 1;
	do {
		printf("Sleeping...");
		sleeptime = sleep(sleeptime);
		printf(" intr at %i\n", sleeptime);

		if (sleeptime)
		{
			/* Sleep interrupted, probably by SIGALRM but we don't really care why.
			 * Poke the client socket to see if it's alive then go back to, er,
			 * "working" unless should_continue gets cleared.
			 */
			should_continue = pokeclient(clsockfd);
		}

	} while ( sleeptime && should_continue );

	printf("Exiting\n");

	shutdown(clsockfd, SHUT_RDWR);
	close(clsockfd);

	close(srvsockfd);
}
