From 60a3427507929dd6d29359d3a89633d23d213181 Mon Sep 17 00:00:00 2001
From: Thomas Munro <thomas.munro@gmail.com>
Date: Sat, 31 Dec 2022 23:36:47 +1300
Subject: [PATCH v2 3/3] Use latch API to wait for RADIUS authentication.

Handle interrupts while waiting for a response from a RADIUS server,
completing a TODO in comments.  Remote clients can't really interrupt
authentication (they don't have a cancel key yet), but it's important to
process other kinds of interrupts promptly.

While here, also convert RADIUS_TIMEOUT to a configurable parameter,
another TODO that was left in comments, so that we can turn it up to
our standard very high 180s wait for automated testing on
slow/overloaded/valgrind build farm animals.

Now that the timeout is adjustable, we can also add a test of the
timeout code path with very short timeout.  Since the RADIUS protocol is
connectionless, we can provoke a timeout by sending a UDP packet to the
wrong port; nobody will write back to us and we will time out.

Reviewed-by: Andreas Karlsson <andreas@proxel.se>
Discussion: https://postgr.es/m/CA%2BhUKGKxNoVjkMCksnj6z3BwiS3y2v6LN6z7_CisLK%2Brv%2B0V4g%40mail.gmail.com
---
 doc/src/sgml/client-auth.sgml | 12 +++++
 src/backend/libpq/auth.c      | 94 +++++++++++++++++++++--------------
 src/backend/libpq/hba.c       | 15 ++++++
 src/include/libpq/hba.h       |  1 +
 src/test/radius/t/001_auth.pl | 16 +++++-
 5 files changed, 100 insertions(+), 38 deletions(-)

diff --git a/doc/src/sgml/client-auth.sgml b/doc/src/sgml/client-auth.sgml
index b9d73deced..66c29b0de0 100644
--- a/doc/src/sgml/client-auth.sgml
+++ b/doc/src/sgml/client-auth.sgml
@@ -2147,6 +2147,18 @@ host ... ldap ldapbasedn="dc=example,dc=net"
        </listitem>
       </varlistentry>
 
+      <varlistentry>
+       <term><literal>radiustimeout</literal></term>
+       <listitem>
+        <para>
+         The time, in milliseconds, to wait for a response from a RADIUS
+         server before trying the next one in the list.  The default time,
+         if not specified, is 3000 milliseconds (3 seconds), which may not
+         be enough for some systems.
+        </para>
+       </listitem>
+      </varlistentry>
+
      </variablelist>
    </para>
 
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 25b3a781cd..10c87414ce 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -195,7 +195,7 @@ static int	pg_SSPI_make_upn(char *accountname,
  *----------------------------------------------------------------
  */
 static int	CheckRADIUSAuth(Port *port);
-static int	PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd);
+static int	PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd, int timeout);
 
 
 /*
@@ -2802,9 +2802,6 @@ typedef struct
 /* RADIUS service types */
 #define RADIUS_AUTHENTICATE_ONLY	8
 
-/* Seconds to wait - XXX: should be in a config variable! */
-#define RADIUS_TIMEOUT 3
-
 static void
 radius_add_attribute(radius_packet *packet, uint8 type, const unsigned char *data, int len)
 {
@@ -2839,6 +2836,7 @@ CheckRADIUSAuth(Port *port)
 			   *secrets,
 			   *radiusports,
 			   *identifiers;
+	int			timeout;
 
 	/* Make sure struct alignment is correct */
 	Assert(offsetof(radius_packet, vector) == 4);
@@ -2858,6 +2856,10 @@ CheckRADIUSAuth(Port *port)
 		return STATUS_ERROR;
 	}
 
+	timeout = port->hba->radiustimeout;
+	if (timeout == 0)
+		timeout = 3000;			/* default to 3 seconds */
+
 	/* Send regular password request to client, and get the response */
 	sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
 
@@ -2886,7 +2888,8 @@ CheckRADIUSAuth(Port *port)
 												   radiusports ? lfirst(radiusports) : NULL,
 												   identifiers ? lfirst(identifiers) : NULL,
 												   port->user_name,
-												   passwd);
+												   passwd,
+												   timeout);
 
 		/*------
 		 * STATUS_OK = Login OK
@@ -2927,7 +2930,13 @@ CheckRADIUSAuth(Port *port)
 }
 
 static int
-PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd)
+PerformRadiusTransaction(const char *server,
+						 const char *secret,
+						 const char *portstr,
+						 const char *identifier,
+						 const char *user_name,
+						 const char *passwd,
+						 int timeout)
 {
 	radius_packet radius_send_pack;
 	radius_packet radius_recv_pack;
@@ -2949,8 +2958,7 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
 	struct addrinfo *serveraddrs;
 	int			port;
 	socklen_t	addrsize;
-	fd_set		fdset;
-	struct timeval endtime;
+	TimestampTz endtime;
 	int			i,
 				j,
 				r;
@@ -3085,26 +3093,18 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
 
 	/*
 	 * Figure out at what time we should time out. We can't just use a single
-	 * call to select() with a timeout, since somebody can be sending invalid
-	 * packets to our port thus causing us to retry in a loop and never time
-	 * out.
-	 *
-	 * XXX: Using WaitLatchOrSocket() and doing a CHECK_FOR_INTERRUPTS() if
-	 * the latch was set would improve the responsiveness to
-	 * timeouts/cancellations.
+	 * wait with a timeout, since somebody can be sending invalid packets to
+	 * our port thus causing us to retry in a loop and never time out.
 	 */
-	gettimeofday(&endtime, NULL);
-	endtime.tv_sec += RADIUS_TIMEOUT;
+	endtime = TimestampTzPlusMilliseconds(GetCurrentTimestamp(), timeout);
 
 	while (true)
 	{
-		struct timeval timeout;
-		struct timeval now;
-		int64		timeoutval;
+		int			timeoutval;
 		const char *errstr = NULL;
 
-		gettimeofday(&now, NULL);
-		timeoutval = (endtime.tv_sec * 1000000 + endtime.tv_usec) - (now.tv_sec * 1000000 + now.tv_usec);
+		/* Remaining time, rounded up to the nearest millisecond. */
+		timeoutval = TimestampDifferenceMilliseconds(GetCurrentTimestamp(), endtime);
 		if (timeoutval <= 0)
 		{
 			ereport(LOG,
@@ -3113,25 +3113,39 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
 			closesocket(sock);
 			return STATUS_ERROR;
 		}
-		timeout.tv_sec = timeoutval / 1000000;
-		timeout.tv_usec = timeoutval % 1000000;
 
-		FD_ZERO(&fdset);
-		FD_SET(sock, &fdset);
-
-		r = select(sock + 1, &fdset, NULL, NULL, &timeout);
-		if (r < 0)
+		/*
+		 * No point in supplying a wait_event value as we don't have a row in
+		 * pg_stat_activity yet.
+		 */
+		r = WaitLatchOrSocket(MyLatch,
+							  WL_SOCKET_READABLE | WL_EXIT_ON_PM_DEATH |
+							  WL_LATCH_SET | WL_TIMEOUT,
+							  sock,
+							  timeoutval,
+							  0);
+		if (r & WL_LATCH_SET)
 		{
-			if (errno == EINTR)
-				continue;
+			ResetLatch(MyLatch);
 
-			/* Anything else is an actual error */
-			ereport(LOG,
-					(errmsg("could not check status on RADIUS socket: %m")));
-			closesocket(sock);
-			return STATUS_ERROR;
+			/* Process interrupts, making sure to close the socket if we throw. */
+			if (INTERRUPTS_PENDING_CONDITION())
+			{
+				volatile pgsocket vsock = sock;
+				PG_TRY();
+				{
+					CHECK_FOR_INTERRUPTS();
+				}
+				PG_CATCH();
+				{
+					closesocket(vsock);
+					PG_RE_THROW();
+				}
+				PG_END_TRY();
+			}
+			continue;
 		}
-		if (r == 0)
+		else if (r & WL_TIMEOUT)
 		{
 			ereport(LOG,
 					(errmsg("timeout waiting for RADIUS response from %s",
@@ -3139,6 +3153,10 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
 			closesocket(sock);
 			return STATUS_ERROR;
 		}
+		else
+		{
+			Assert(r & WL_SOCKET_READABLE);
+		}
 
 		/*
 		 * Attempt to read the response packet, and verify the contents.
@@ -3246,4 +3264,6 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
 			continue;
 		}
 	}							/* while (true) */
+
+	pg_unreachable();
 }
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index adedbd3128..3a33f52af1 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -2514,6 +2514,21 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 		hbaline->radiusidentifiers = parsed_identifiers;
 		hbaline->radiusidentifiers_s = pstrdup(val);
 	}
+	else if (strcmp(name, "radiustimeout") == 0)
+	{
+		REQUIRE_AUTH_OPTION(uaRADIUS, "radiustimeout", "radius");
+
+		if ((hbaline->radiustimeout = atoi(val)) <= 0)
+		{
+			ereport(elevel,
+					(errcode(ERRCODE_CONFIG_FILE_ERROR),
+					 errmsg("could not parse RADIUS timeout \"%s\"",
+							val),
+					 errcontext("line %d of configuration file \"%s\"",
+								line_num, file_name)));
+			return false;
+		}
+	}
 	else
 	{
 		ereport(elevel,
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 189f6d0df2..3f051d0ab5 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -135,6 +135,7 @@ typedef struct HbaLine
 	char	   *radiusidentifiers_s;
 	List	   *radiusports;
 	char	   *radiusports_s;
+	int			radiustimeout;
 } HbaLine;
 
 typedef struct IdentLine
diff --git a/src/test/radius/t/001_auth.pl b/src/test/radius/t/001_auth.pl
index 44db62a3d7..6d2e6e7345 100644
--- a/src/test/radius/t/001_auth.pl
+++ b/src/test/radius/t/001_auth.pl
@@ -60,6 +60,7 @@ else
 }
 
 my $radius_port     = PostgreSQL::Test::Cluster::get_free_port();
+my $not_radius_port = PostgreSQL::Test::Cluster::get_free_port();
 
 note "setting up radiusd";
 
@@ -159,8 +160,9 @@ sub test_access
 note "enable RADIUS auth";
 
 unlink($node->data_dir . '/pg_hba.conf');
+my $timeout = $PostgreSQL::Test::Utils::timeout_default * 1000;
 $node->append_conf('pg_hba.conf',
-	qq{local all all radius radiusservers="127.0.0.1" radiussecrets="secret" radiusports="$radius_port"}
+	qq{local all all radius radiusservers="127.0.0.1" radiussecrets="secret" radiusports="$radius_port" radiustimeout="$timeout"}
 );
 $node->restart;
 
@@ -184,4 +186,16 @@ test_access(
 		qr/connection authenticated: identity="test2" method=radius/
 	],);
 
+# Set the timeout very short and point to a non-existent radius server
+unlink($node->data_dir . '/pg_hba.conf');
+$node->append_conf('pg_hba.conf',
+	qq{local all all radius radiusservers="127.0.0.1" radiussecrets="secret" radiusports="$not_radius_port" radiustimeout="2"}
+);
+$node->restart;
+
+test_access(
+	$node, 'test2', 2,
+	'authentication fails with timeout',
+	log_like => [qr/timeout waiting for RADIUS response/]);
+
 done_testing();
-- 
2.39.2

