From 5e19b26111708b654c7b49c3b9fd7dc18b94c0e7 Mon Sep 17 00:00:00 2001
From: Tristan Partin <tristan@neon.tech>
Date: Mon, 24 Jul 2023 11:12:59 -0500
Subject: [PATCH v7 1/2] Allow SIGINT to cancel psql database reconnections

After installing the SIGINT handler in psql, SIGINT can no longer cancel
database reconnections. For instance, if the user starts a reconnection
and then needs to do some form of interaction (ie psql is polling),
there is no way to cancel the reconnection process currently.

Use PQconnectStartParams() in order to insert a CancelRequested check
into the polling loop.
---
 meson.build                |   1 +
 src/bin/psql/command.c     | 156 ++++++++++++++++++++++++++++++++++++-
 src/include/pg_config.h.in |   3 +
 src/tools/msvc/Solution.pm |   1 +
 4 files changed, 160 insertions(+), 1 deletion(-)

diff --git a/meson.build b/meson.build
index ee58ee7a06..2d63485c53 100644
--- a/meson.build
+++ b/meson.build
@@ -2440,6 +2440,7 @@ func_checks = [
   ['posix_fadvise'],
   ['posix_fallocate'],
   ['ppoll'],
+  ['pselect'],
   ['pstat'],
   ['pthread_barrier_wait', {'dependencies': [thread_dep]}],
   ['pthread_is_threaded_np', {'dependencies': [thread_dep]}],
diff --git a/src/bin/psql/command.c b/src/bin/psql/command.c
index 82cc091568..3a76623b05 100644
--- a/src/bin/psql/command.c
+++ b/src/bin/psql/command.c
@@ -11,6 +11,7 @@
 #include <time.h>
 #include <pwd.h>
 #include <utime.h>
+#include <sys/select.h>
 #ifndef WIN32
 #include <sys/stat.h>			/* for stat() */
 #include <sys/time.h>			/* for setitimer() */
@@ -40,6 +41,7 @@
 #include "large_obj.h"
 #include "libpq-fe.h"
 #include "libpq/pqcomm.h"
+#include "libpq/pqsignal.h"
 #include "mainloop.h"
 #include "portability/instr_time.h"
 #include "pqexpbuffer.h"
@@ -3298,6 +3300,157 @@ param_is_newly_set(const char *old_val, const char *new_val)
 	return false;
 }
 
+/*
+ * Check a file descriptor for read and/or write data, possibly waiting.
+ * If neither forRead nor forWrite are set, immediately return a timeout
+ * condition (without waiting).  Return >0 if condition is met, 0
+ * if a timeout occurred, -1 if an error or interrupt occurred.
+ *
+ * Timeout is infinite if end_time is -1.  Timeout is immediate (no blocking)
+ * if end_time is 0 (or indeed, any time before now).
+ *
+ * Uses the select(2) family of functions because it is available on every
+ * platform. It is unlikely that psql would be holding enough file descriptors
+ * that would necessitate using poll(2) or ppoll(2) for example.
+ */
+static int
+pqSocketPoll(int sock, int forRead, int forWrite, time_t end_time)
+{
+	/*
+	 * We use functions in the following order if available: pselect(2) or,
+	 * select(2).
+	 */
+	fd_set		input_mask;
+	fd_set		output_mask;
+	fd_set		except_mask;
+#ifdef HAVE_PSELECT
+	int			rc;
+	sigset_t	emptyset;
+	sigset_t	blockset;
+	sigset_t	origset;
+	struct timespec timeout;
+	struct timespec *ptr_timeout;
+#else
+	struct timeval timeout;
+	struct timeval *ptr_timeout;
+#endif
+
+	if (!forRead && !forWrite)
+		return 0;
+
+	FD_ZERO(&input_mask);
+	FD_ZERO(&output_mask);
+	FD_ZERO(&except_mask);
+	if (forRead)
+		FD_SET(sock, &input_mask);
+
+	if (forWrite)
+		FD_SET(sock, &output_mask);
+	FD_SET(sock, &except_mask);
+
+	/* Compute appropriate timeout interval */
+#ifdef HAVE_PSELECT
+	sigemptyset(&blockset);
+	sigaddset(&blockset, SIGINT);
+	sigprocmask(SIG_BLOCK, &blockset, &origset);
+
+	if (end_time == ((time_t) -1))
+		ptr_timeout = NULL;
+	else
+	{
+		timeout.tv_sec = end_time;
+		timeout.tv_nsec = 0;
+		ptr_timeout = &timeout;
+	}
+#else
+	if (end_time == ((time_t) -1))
+		ptr_timeout = NULL;
+	else
+	{
+		time_t		now = time(NULL);
+
+		if (end_time > now)
+			timeout.tv_sec = end_time - now;
+		else
+			timeout.tv_sec = 0;
+		timeout.tv_usec = 0;
+		ptr_timeout = &timeout;
+	}
+#endif
+
+#ifdef HAVE_PSELECT
+	sigemptyset(&emptyset);
+	if (cancel_pressed)
+	{
+		sigprocmask(SIG_SETMASK, &origset, NULL);
+		return 1;
+	}
+
+	rc = pselect(sock + 1, &input_mask, &output_mask,
+				 &except_mask, ptr_timeout, &emptyset);
+	sigprocmask(SIG_SETMASK, &origset, NULL);
+	return rc;
+#else
+	return select(sock + 1, &input_mask, &output_mask,
+				  &except_mask, ptr_timeout);
+#endif
+}
+
+static void
+process_connection_state_machine(PGconn *conn)
+{
+	bool		for_read = false;
+
+	while (true)
+	{
+		int			rc;
+		int			sock;
+		time_t		timeout;
+
+		if (cancel_pressed)
+			break;
+
+		sock = PQsocket(conn);
+		if (sock == -1)
+			break;
+
+		/*
+		 * We use pselect(2) to account for the race condition in which SIGINT
+		 * is sent after checking cancel_pressed. But on platforms that don't
+		 * have either function, we can just spin the CPU a bit polling, so
+		 * set the timeout to 1 second if we don't have the aforementioned
+		 * function. Otherwise, set timeout to a negative value indicating we
+		 * will sit and wait forever.
+		 */
+#if defined(HAVE_PSELECT)
+		timeout = -1;
+#else
+		timeout = 1;
+#endif
+
+		rc = pqSocketPoll(sock, for_read, !for_read, timeout);
+		if (rc == -1)
+			return;
+
+		switch (PQconnectPoll(conn))
+		{
+			case PGRES_POLLING_OK:
+			case PGRES_POLLING_FAILED:
+				return;
+			case PGRES_POLLING_READING:
+				for_read = true;
+				continue;
+			case PGRES_POLLING_WRITING:
+				for_read = false;
+				continue;
+			case PGRES_POLLING_ACTIVE:
+				pg_unreachable();
+		}
+	}
+
+	pg_unreachable();
+}
+
 /*
  * do_connect -- handler for \connect
  *
@@ -3614,11 +3767,12 @@ do_connect(enum trivalue reuse_previous_specification,
 		values[paramnum] = NULL;
 
 		/* Note we do not want libpq to re-expand the dbname parameter */
-		n_conn = PQconnectdbParams(keywords, values, false);
+		n_conn = PQconnectStartParams(keywords, values, false);
 
 		pg_free(keywords);
 		pg_free(values);
 
+		process_connection_state_machine(n_conn);
 		if (PQstatus(n_conn) == CONNECTION_OK)
 			break;
 
diff --git a/src/include/pg_config.h.in b/src/include/pg_config.h.in
index d8a2985567..f9fd7d0de7 100644
--- a/src/include/pg_config.h.in
+++ b/src/include/pg_config.h.in
@@ -333,6 +333,9 @@
 /* Define to 1 if you have the `ppoll' function. */
 #undef HAVE_PPOLL
 
+/* Define to 1 if you have the `pselect' function. */
+#undef HAVE_PSELECT
+
 /* Define if you have POSIX threads libraries and header files. */
 #undef HAVE_PTHREAD
 
diff --git a/src/tools/msvc/Solution.pm b/src/tools/msvc/Solution.pm
index 98a5b5d872..d035f44f73 100644
--- a/src/tools/msvc/Solution.pm
+++ b/src/tools/msvc/Solution.pm
@@ -308,6 +308,7 @@ sub GenerateFiles
 		HAVE_POSIX_FADVISE => undef,
 		HAVE_POSIX_FALLOCATE => undef,
 		HAVE_PPOLL => undef,
+		HAVE_PSELECT => undef,
 		HAVE_PTHREAD => undef,
 		HAVE_PTHREAD_BARRIER_WAIT => undef,
 		HAVE_PTHREAD_IS_THREADED_NP => undef,
-- 
Tristan Partin
Neon (https://neon.tech)

