From 760b516554101d5482265847db3d680accd23cf7 Mon Sep 17 00:00:00 2001
From: Thomas Munro <thomas.munro@gmail.com>
Date: Sat, 31 Dec 2022 23:36:47 +1300
Subject: [PATCH 3/3] Use latch API to wait for RADIUS authentication.

Handle interrupts while waiting for a response from a RADIUS server,
completing a TODO in comments.  Remote clients can't really interrupt
authentication (they don't have a cancel key yet), but it's important to
process other kinds of interrupts promptly.

Since CHECK_FOR_INTERRUPTS() might throw, leaving us with a leaked
socket, use PG_FINALLY() to make sure we clean up.

While here, also convert RADIUS_TIMEOUT to a GUC, another TODO in
comments, so that we can turn it up very high for automated testing on
slow/overloaded computers.

Now that the timeout is adjustable, we can also add a test of the
timeout code path with very short timeout.  Since the RADIUS protocol is
connectionless, we can provoke a timeout by sending a UDP packet to the
wrong port; nobody will write back to us and we will time out.
---
 src/backend/libpq/auth.c                      | 342 +++++++++---------
 src/backend/utils/misc/guc_tables.c           |  11 +
 src/backend/utils/misc/postgresql.conf.sample |   1 +
 src/include/libpq/auth.h                      |   1 +
 src/test/radius/t/001_auth.pl                 |  15 +
 5 files changed, 197 insertions(+), 173 deletions(-)

diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 25b3a781cd..4511b0f3e6 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -2802,8 +2802,9 @@ typedef struct
 /* RADIUS service types */
 #define RADIUS_AUTHENTICATE_ONLY	8
 
-/* Seconds to wait - XXX: should be in a config variable! */
-#define RADIUS_TIMEOUT 3
+/* RADIUS GUCs */
+
+int			radius_timeout = 3000;
 
 static void
 radius_add_attribute(radius_packet *packet, uint8 type, const unsigned char *data, int len)
@@ -2949,8 +2950,7 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
 	struct addrinfo *serveraddrs;
 	int			port;
 	socklen_t	addrsize;
-	fd_set		fdset;
-	struct timeval endtime;
+	TimestampTz endtime;
 	int			i,
 				j,
 				r;
@@ -3053,197 +3053,193 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
 		return STATUS_ERROR;
 	}
 
-	memset(&localaddr, 0, sizeof(localaddr));
-	localaddr.sin6_family = serveraddrs[0].ai_family;
-	localaddr.sin6_addr = in6addr_any;
-	if (localaddr.sin6_family == AF_INET6)
-		addrsize = sizeof(struct sockaddr_in6);
-	else
-		addrsize = sizeof(struct sockaddr_in);
-
-	if (bind(sock, (struct sockaddr *) &localaddr, addrsize))
-	{
-		ereport(LOG,
-				(errmsg("could not bind local RADIUS socket: %m")));
-		closesocket(sock);
-		pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
-		return STATUS_ERROR;
-	}
-
-	if (sendto(sock, radius_buffer, packetlength, 0,
-			   serveraddrs[0].ai_addr, serveraddrs[0].ai_addrlen) < 0)
-	{
-		ereport(LOG,
-				(errmsg("could not send RADIUS packet: %m")));
-		closesocket(sock);
-		pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
-		return STATUS_ERROR;
-	}
-
-	/* Don't need the server address anymore */
-	pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
-
-	/*
-	 * Figure out at what time we should time out. We can't just use a single
-	 * call to select() with a timeout, since somebody can be sending invalid
-	 * packets to our port thus causing us to retry in a loop and never time
-	 * out.
-	 *
-	 * XXX: Using WaitLatchOrSocket() and doing a CHECK_FOR_INTERRUPTS() if
-	 * the latch was set would improve the responsiveness to
-	 * timeouts/cancellations.
-	 */
-	gettimeofday(&endtime, NULL);
-	endtime.tv_sec += RADIUS_TIMEOUT;
-
-	while (true)
+	PG_TRY();
 	{
-		struct timeval timeout;
-		struct timeval now;
-		int64		timeoutval;
-		const char *errstr = NULL;
+		memset(&localaddr, 0, sizeof(localaddr));
+		localaddr.sin6_family = serveraddrs[0].ai_family;
+		localaddr.sin6_addr = in6addr_any;
+		if (localaddr.sin6_family == AF_INET6)
+			addrsize = sizeof(struct sockaddr_in6);
+		else
+			addrsize = sizeof(struct sockaddr_in);
 
-		gettimeofday(&now, NULL);
-		timeoutval = (endtime.tv_sec * 1000000 + endtime.tv_usec) - (now.tv_sec * 1000000 + now.tv_usec);
-		if (timeoutval <= 0)
+		if (bind(sock, (struct sockaddr *) &localaddr, addrsize))
 		{
 			ereport(LOG,
-					(errmsg("timeout waiting for RADIUS response from %s",
-							server)));
-			closesocket(sock);
+					(errmsg("could not bind local RADIUS socket: %m")));
+			pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
 			return STATUS_ERROR;
 		}
-		timeout.tv_sec = timeoutval / 1000000;
-		timeout.tv_usec = timeoutval % 1000000;
-
-		FD_ZERO(&fdset);
-		FD_SET(sock, &fdset);
 
-		r = select(sock + 1, &fdset, NULL, NULL, &timeout);
-		if (r < 0)
-		{
-			if (errno == EINTR)
-				continue;
-
-			/* Anything else is an actual error */
-			ereport(LOG,
-					(errmsg("could not check status on RADIUS socket: %m")));
-			closesocket(sock);
-			return STATUS_ERROR;
-		}
-		if (r == 0)
+		if (sendto(sock, radius_buffer, packetlength, 0,
+				   serveraddrs[0].ai_addr, serveraddrs[0].ai_addrlen) < 0)
 		{
 			ereport(LOG,
-					(errmsg("timeout waiting for RADIUS response from %s",
-							server)));
-			closesocket(sock);
+					(errmsg("could not send RADIUS packet: %m")));
+			pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
 			return STATUS_ERROR;
 		}
 
+		/* Don't need the server address anymore */
+		pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
+
 		/*
-		 * Attempt to read the response packet, and verify the contents.
-		 *
-		 * Any packet that's not actually a RADIUS packet, or otherwise does
-		 * not validate as an explicit reject, is just ignored and we retry
-		 * for another packet (until we reach the timeout). This is to avoid
-		 * the possibility to denial-of-service the login by flooding the
-		 * server with invalid packets on the port that we're expecting the
-		 * RADIUS response on.
+		 * Figure out at what time we should time out. We can't just use a
+		 * single wait with a timeout, since somebody can be sending invalid
+		 * packets to our port thus causing us to retry in a loop and never
+		 * time out.
 		 */
+		endtime = GetCurrentTimestamp() + radius_timeout * 1000;
 
-		addrsize = sizeof(remoteaddr);
-		packetlength = recvfrom(sock, receive_buffer, RADIUS_BUFFER_SIZE, 0,
-								(struct sockaddr *) &remoteaddr, &addrsize);
-		if (packetlength < 0)
+		while (true)
 		{
-			ereport(LOG,
-					(errmsg("could not read RADIUS response: %m")));
-			closesocket(sock);
-			return STATUS_ERROR;
-		}
+			int			timeoutval;
+			const char *errstr = NULL;
 
-		if (remoteaddr.sin6_port != pg_hton16(port))
-		{
-			ereport(LOG,
-					(errmsg("RADIUS response from %s was sent from incorrect port: %d",
-							server, pg_ntoh16(remoteaddr.sin6_port))));
-			continue;
-		}
+			/* Remaining time, rounded up to the nearest millisecond. */
+			timeoutval = ((endtime - GetCurrentTimestamp()) + 999) / 1000;
+			if (timeoutval <= 0)
+			{
+				ereport(LOG,
+						(errmsg("timeout waiting for RADIUS response from %s",
+								server)));
+				return STATUS_ERROR;
+			}
 
-		if (packetlength < RADIUS_HEADER_LENGTH)
-		{
-			ereport(LOG,
-					(errmsg("RADIUS response from %s too short: %d", server, packetlength)));
-			continue;
-		}
+			/*
+			 * No point in supplying a wait_event value as we don't have a
+			 * row in pg_stat_activity yet.
+			 */
+			r = WaitLatchOrSocket(MyLatch,
+								  WL_SOCKET_READABLE | WL_EXIT_ON_PM_DEATH |
+								  WL_LATCH_SET | WL_TIMEOUT,
+								  sock,
+								  timeoutval,
+								  0);
+			if (r & WL_LATCH_SET)
+			{
+				ResetLatch(MyLatch);
+				CHECK_FOR_INTERRUPTS();
+				continue;
+			}
+			else if (r & WL_TIMEOUT)
+			{
+				ereport(LOG,
+						(errmsg("timeout waiting for RADIUS response from %s",
+								server)));
+				return STATUS_ERROR;
+			}
+			else
+			{
+				Assert(r & WL_SOCKET_READABLE);
+			}
 
-		if (packetlength != pg_ntoh16(receivepacket->length))
-		{
-			ereport(LOG,
-					(errmsg("RADIUS response from %s has corrupt length: %d (actual length %d)",
-							server, pg_ntoh16(receivepacket->length), packetlength)));
-			continue;
-		}
+			/*
+			 * Attempt to read the response packet, and verify the contents.
+			 *
+			 * Any packet that's not actually a RADIUS packet, or otherwise
+			 * does not validate as an explicit reject, is just ignored and we
+			 * retry for another packet (until we reach the timeout). This is
+			 * to avoid the possibility to denial-of-service the login by
+			 * flooding the server with invalid packets on the port that we're
+			 * expecting the RADIUS response on.
+			 */
 
-		if (packet->id != receivepacket->id)
-		{
-			ereport(LOG,
-					(errmsg("RADIUS response from %s is to a different request: %d (should be %d)",
-							server, receivepacket->id, packet->id)));
-			continue;
-		}
+			addrsize = sizeof(remoteaddr);
+			packetlength = recvfrom(sock, receive_buffer, RADIUS_BUFFER_SIZE, 0,
+									(struct sockaddr *) &remoteaddr, &addrsize);
+			if (packetlength < 0)
+			{
+				ereport(LOG,
+						(errmsg("could not read RADIUS response: %m")));
+				return STATUS_ERROR;
+			}
 
-		/*
-		 * Verify the response authenticator, which is calculated as
-		 * MD5(Code+ID+Length+RequestAuthenticator+Attributes+Secret)
-		 */
-		cryptvector = palloc(packetlength + strlen(secret));
-
-		memcpy(cryptvector, receivepacket, 4);	/* code+id+length */
-		memcpy(cryptvector + 4, packet->vector, RADIUS_VECTOR_LENGTH);	/* request
-																		 * authenticator, from
-																		 * original packet */
-		if (packetlength > RADIUS_HEADER_LENGTH)	/* there may be no
-													 * attributes at all */
-			memcpy(cryptvector + RADIUS_HEADER_LENGTH, receive_buffer + RADIUS_HEADER_LENGTH, packetlength - RADIUS_HEADER_LENGTH);
-		memcpy(cryptvector + packetlength, secret, strlen(secret));
-
-		if (!pg_md5_binary(cryptvector,
-						   packetlength + strlen(secret),
-						   encryptedpassword, &errstr))
-		{
-			ereport(LOG,
-					(errmsg("could not perform MD5 encryption of received packet: %s",
-							errstr)));
+			if (remoteaddr.sin6_port != pg_hton16(port))
+			{
+				ereport(LOG,
+						(errmsg("RADIUS response from %s was sent from incorrect port: %d",
+								server, pg_ntoh16(remoteaddr.sin6_port))));
+				continue;
+			}
+
+			if (packetlength < RADIUS_HEADER_LENGTH)
+			{
+				ereport(LOG,
+						(errmsg("RADIUS response from %s too short: %d", server, packetlength)));
+				continue;
+			}
+
+			if (packetlength != pg_ntoh16(receivepacket->length))
+			{
+				ereport(LOG,
+						(errmsg("RADIUS response from %s has corrupt length: %d (actual length %d)",
+								server, pg_ntoh16(receivepacket->length), packetlength)));
+				continue;
+			}
+
+			if (packet->id != receivepacket->id)
+			{
+				ereport(LOG,
+						(errmsg("RADIUS response from %s is to a different request: %d (should be %d)",
+								server, receivepacket->id, packet->id)));
+				continue;
+			}
+
+			/*
+			 * Verify the response authenticator, which is calculated as
+			 * MD5(Code+ID+Length+RequestAuthenticator+Attributes+Secret)
+			 */
+			cryptvector = palloc(packetlength + strlen(secret));
+
+			memcpy(cryptvector, receivepacket, 4);	/* code+id+length */
+
+			/* request * authenticator, from  original packet */
+			memcpy(cryptvector + 4, packet->vector, RADIUS_VECTOR_LENGTH);
+
+			if (packetlength > RADIUS_HEADER_LENGTH)	/* there may be no
+														 * attributes at all */
+				memcpy(cryptvector + RADIUS_HEADER_LENGTH, receive_buffer + RADIUS_HEADER_LENGTH, packetlength - RADIUS_HEADER_LENGTH);
+			memcpy(cryptvector + packetlength, secret, strlen(secret));
+
+			if (!pg_md5_binary(cryptvector,
+							   packetlength + strlen(secret),
+							   encryptedpassword, &errstr))
+			{
+				ereport(LOG,
+						(errmsg("could not perform MD5 encryption of received packet: %s",
+								errstr)));
+				pfree(cryptvector);
+				continue;
+			}
 			pfree(cryptvector);
-			continue;
-		}
-		pfree(cryptvector);
 
-		if (memcmp(receivepacket->vector, encryptedpassword, RADIUS_VECTOR_LENGTH) != 0)
-		{
-			ereport(LOG,
-					(errmsg("RADIUS response from %s has incorrect MD5 signature",
-							server)));
-			continue;
-		}
+			if (memcmp(receivepacket->vector, encryptedpassword, RADIUS_VECTOR_LENGTH) != 0)
+			{
+				ereport(LOG,
+						(errmsg("RADIUS response from %s has incorrect MD5 signature",
+								server)));
+				continue;
+			}
 
-		if (receivepacket->code == RADIUS_ACCESS_ACCEPT)
-		{
-			closesocket(sock);
-			return STATUS_OK;
-		}
-		else if (receivepacket->code == RADIUS_ACCESS_REJECT)
-		{
-			closesocket(sock);
-			return STATUS_EOF;
-		}
-		else
-		{
-			ereport(LOG,
-					(errmsg("RADIUS response from %s has invalid code (%d) for user \"%s\"",
-							server, receivepacket->code, user_name)));
-			continue;
-		}
-	}							/* while (true) */
+			if (receivepacket->code == RADIUS_ACCESS_ACCEPT)
+				return STATUS_OK;
+			else if (receivepacket->code == RADIUS_ACCESS_REJECT)
+				return STATUS_EOF;
+			else
+			{
+				ereport(LOG,
+						(errmsg("RADIUS response from %s has invalid code (%d) for user \"%s\"",
+								server, receivepacket->code, user_name)));
+				continue;
+			}
+		}						/* while (true) */
+	}
+	PG_FINALLY();
+	{
+		closesocket(sock);
+	}
+	PG_END_TRY();
+
+	pg_unreachable();
 }
diff --git a/src/backend/utils/misc/guc_tables.c b/src/backend/utils/misc/guc_tables.c
index 68328b1402..17bbe9d940 100644
--- a/src/backend/utils/misc/guc_tables.c
+++ b/src/backend/utils/misc/guc_tables.c
@@ -3387,6 +3387,17 @@ struct config_int ConfigureNamesInt[] =
 		NULL, assign_tcp_user_timeout, show_tcp_user_timeout
 	},
 
+	{
+		{"radius_timeout", PGC_SIGHUP, CONN_AUTH_AUTH,
+			gettext_noop("RADIUS authentication timeout."),
+			NULL,
+			GUC_UNIT_MS
+		},
+		&radius_timeout,
+		3000, 1, INT_MAX,
+		NULL, NULL, NULL
+	},
+
 	{
 		{"huge_page_size", PGC_POSTMASTER, RESOURCES_MEM,
 			gettext_noop("The size of huge page that should be requested."),
diff --git a/src/backend/utils/misc/postgresql.conf.sample b/src/backend/utils/misc/postgresql.conf.sample
index 5afdeb04de..3f95c52f93 100644
--- a/src/backend/utils/misc/postgresql.conf.sample
+++ b/src/backend/utils/misc/postgresql.conf.sample
@@ -95,6 +95,7 @@
 #authentication_timeout = 1min		# 1s-600s
 #password_encryption = scram-sha-256	# scram-sha-256 or md5
 #db_user_namespace = off
+#radius_timeout = 3s
 
 # GSSAPI using Kerberos
 #krb_server_keyfile = 'FILE:${sysconfdir}/krb5.keytab'
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 137bee7c45..86d246b1f8 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -19,6 +19,7 @@
 extern PGDLLIMPORT char *pg_krb_server_keyfile;
 extern PGDLLIMPORT bool pg_krb_caseins_users;
 extern PGDLLIMPORT char *pg_krb_realm;
+extern PGDLLIMPORT int radius_timeout;
 
 extern void ClientAuthentication(Port *port);
 extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
diff --git a/src/test/radius/t/001_auth.pl b/src/test/radius/t/001_auth.pl
index 44db62a3d7..ebfd0d6301 100644
--- a/src/test/radius/t/001_auth.pl
+++ b/src/test/radius/t/001_auth.pl
@@ -60,6 +60,7 @@ else
 }
 
 my $radius_port     = PostgreSQL::Test::Cluster::get_free_port();
+my $not_radius_port = PostgreSQL::Test::Cluster::get_free_port();
 
 note "setting up radiusd";
 
@@ -131,6 +132,7 @@ note "setting up PostgreSQL instance";
 my $node = PostgreSQL::Test::Cluster->new('node');
 $node->init;
 $node->append_conf('postgresql.conf', "log_connections = on\n");
+$node->append_conf('postgresql.conf', "radius_timeout = '${PostgreSQL::Test::Utils::timeout_default}s'\n");
 $node->start;
 
 $node->safe_psql('postgres', 'CREATE USER test1;');
@@ -184,4 +186,17 @@ test_access(
 		qr/connection authenticated: identity="test2" method=radius/
 	],);
 
+# Set the timeout very short and point to a non-existent radius server
+$node->append_conf('postgresql.conf', "radius_timeout = '2ms'\n");
+unlink($node->data_dir . '/pg_hba.conf');
+$node->append_conf('pg_hba.conf',
+	qq{local all all radius radiusservers="127.0.0.1" radiussecrets="secret" radiusports="$not_radius_port"}
+);
+$node->restart;
+
+test_access(
+	$node, 'test2', 2,
+	'authentication fails with timeout',
+	log_like => [qr/timeout waiting for RADIUS response/]);
+
 done_testing();
-- 
2.38.1

