From fbd17c7b77251ed66eed00d80efc58abb5eeb84a Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Wed, 30 Jun 2021 09:27:40 -0700
Subject: [PATCH v3 2/2] auth: pull backend SASL exchange into its own file

This code motion is pulled into a separate commit to ease review.

Move SASL_exchange to its own file and rename it to CheckSASLAuth, which
is now called directly from ClientAuthentication(). This replaces the
CheckSCRAMAuth() and CheckOAuthBearer() wrappers.
---
 src/backend/libpq/Makefile    |   1 +
 src/backend/libpq/auth-sasl.c | 187 ++++++++++++++++++++++++++++++++++
 src/backend/libpq/auth.c      | 178 +-------------------------------
 src/include/libpq/auth.h      |   2 +
 src/include/libpq/sasl.h      |  13 +++
 5 files changed, 207 insertions(+), 174 deletions(-)
 create mode 100644 src/backend/libpq/auth-sasl.c

diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 8d1d16b0fc..6d385fd6a4 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,6 +15,7 @@ include $(top_builddir)/src/Makefile.global
 # be-fsstubs is here for historical reasons, probably belongs elsewhere
 
 OBJS = \
+	auth-sasl.o \
 	auth-scram.o \
 	auth.o \
 	be-fsstubs.o \
diff --git a/src/backend/libpq/auth-sasl.c b/src/backend/libpq/auth-sasl.c
new file mode 100644
index 0000000000..b7cdb2ecf6
--- /dev/null
+++ b/src/backend/libpq/auth-sasl.c
@@ -0,0 +1,187 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth-sasl.c
+ *	  Routines to handle network authentication via SASL
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ *	  src/backend/libpq/auth-sasl.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include "libpq/auth.h"
+#include "libpq/libpq.h"
+#include "libpq/pqformat.h"
+#include "libpq/sasl.h"
+
+/*
+ * Perform a SASL exchange with a libpq client, using a specific mechanism
+ * implementation.
+ *
+ * shadow_pass is an optional pointer to the shadow entry for the client's
+ * presented user name. For mechanisms that use shadowed passwords, a NULL
+ * pointer here means that an entry could not be found for the user (or the user
+ * does not exist), and the mechanism should fail the authentication exchange.
+ *
+ * Mechanisms must take care not to reveal to the client that a user entry does
+ * not exist; ideally, the external failure mode is identical to that of an
+ * incorrect password. Mechanisms may instead use the logdetail output parameter
+ * to internally differentiate between failure cases and assist debugging by the
+ * server admin.
+ *
+ * A mechanism is not required to utilize a shadow entry, or even a password
+ * system at all; for these cases, shadow_pass may be ignored and the caller
+ * should just pass NULL.
+ */
+int
+CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
+			  char **logdetail)
+{
+	StringInfoData sasl_mechs;
+	int			mtype;
+	StringInfoData buf;
+	void	   *opaq = NULL;
+	char	   *output = NULL;
+	int			outputlen = 0;
+	const char *input;
+	int			inputlen;
+	int			result;
+	bool		initial;
+
+	/*
+	 * Send the SASL authentication request to user.  It includes the list of
+	 * authentication mechanisms that are supported.
+	 */
+	initStringInfo(&sasl_mechs);
+
+	mech->get_mechanisms(port, &sasl_mechs);
+	/* Put another '\0' to mark that list is finished. */
+	appendStringInfoChar(&sasl_mechs, '\0');
+
+	sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len);
+	pfree(sasl_mechs.data);
+
+	/*
+	 * Loop through SASL message exchange.  This exchange can consist of
+	 * multiple messages sent in both directions.  First message is always
+	 * from the client.  All messages from client to server are password
+	 * packets (type 'p').
+	 */
+	initial = true;
+	do
+	{
+		pq_startmsgread();
+		mtype = pq_getbyte();
+		if (mtype != 'p')
+		{
+			/* Only log error if client didn't disconnect. */
+			if (mtype != EOF)
+			{
+				ereport(ERROR,
+						(errcode(ERRCODE_PROTOCOL_VIOLATION),
+						 errmsg("expected SASL response, got message type %d",
+								mtype)));
+			}
+			else
+				return STATUS_EOF;
+		}
+
+		/* Get the actual SASL message */
+		initStringInfo(&buf);
+		if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
+		{
+			/* EOF - pq_getmessage already logged error */
+			pfree(buf.data);
+			return STATUS_ERROR;
+		}
+
+		elog(DEBUG4, "processing received SASL response of length %d", buf.len);
+
+		/*
+		 * The first SASLInitialResponse message is different from the others.
+		 * It indicates which SASL mechanism the client selected, and contains
+		 * an optional Initial Client Response payload.  The subsequent
+		 * SASLResponse messages contain just the SASL payload.
+		 */
+		if (initial)
+		{
+			const char *selected_mech;
+
+			selected_mech = pq_getmsgrawstring(&buf);
+
+			/*
+			 * Initialize the status tracker for message exchanges.
+			 *
+			 * If the user doesn't exist, or doesn't have a valid password, or
+			 * it's expired, we still go through the motions of SASL
+			 * authentication, but tell the authentication method that the
+			 * authentication is "doomed". That is, it's going to fail, no
+			 * matter what.
+			 *
+			 * This is because we don't want to reveal to an attacker what
+			 * usernames are valid, nor which users have a valid password.
+			 */
+			opaq = mech->init(port, selected_mech, shadow_pass);
+
+			inputlen = pq_getmsgint(&buf, 4);
+			if (inputlen == -1)
+				input = NULL;
+			else
+				input = pq_getmsgbytes(&buf, inputlen);
+
+			initial = false;
+		}
+		else
+		{
+			inputlen = buf.len;
+			input = pq_getmsgbytes(&buf, buf.len);
+		}
+		pq_getmsgend(&buf);
+
+		/*
+		 * The StringInfo guarantees that there's a \0 byte after the
+		 * response.
+		 */
+		Assert(input == NULL || input[inputlen] == '\0');
+
+		/*
+		 * Hand the incoming message to the mechanism implementation.
+		 */
+		result = mech->exchange(opaq, input, inputlen,
+								&output, &outputlen,
+								logdetail);
+
+		/* input buffer no longer used */
+		pfree(buf.data);
+
+		if (output)
+		{
+			/*
+			 * Negotiation generated data to be sent to the client.
+			 */
+			elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
+
+			/* TODO: PG_SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
+			if (result == PG_SASL_EXCHANGE_SUCCESS)
+				sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
+			else
+				sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
+
+			pfree(output);
+		}
+	} while (result == PG_SASL_EXCHANGE_CONTINUE);
+
+	/* Oops, Something bad happened */
+	if (result != PG_SASL_EXCHANGE_SUCCESS)
+	{
+		return STATUS_ERROR;
+	}
+
+	return STATUS_OK;
+}
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 82f043a343..ac6fe4a747 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -45,19 +45,10 @@
  * Global authentication functions
  *----------------------------------------------------------------
  */
-static void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
-							int extralen);
 static void auth_failed(Port *port, int status, char *logdetail);
 static char *recv_password_packet(Port *port);
 static void set_authn_id(Port *port, const char *id);
 
-/*----------------------------------------------------------------
- * SASL common authentication
- *----------------------------------------------------------------
- */
-static int	SASL_exchange(const pg_be_sasl_mech *mech, Port *port,
-						  char *shadow_pass, char **logdetail);
-
 
 /*----------------------------------------------------------------
  * Password-based authentication methods (password, md5, and scram-sha-256)
@@ -67,7 +58,6 @@ static int	CheckPasswordAuth(Port *port, char **logdetail);
 static int	CheckPWChallengeAuth(Port *port, char **logdetail);
 
 static int	CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail);
-static int	CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail);
 
 
 /*----------------------------------------------------------------
@@ -231,14 +221,6 @@ static int	PerformRadiusTransaction(const char *server, const char *secret, cons
  */
 #define PG_MAX_AUTH_TOKEN_LENGTH	65535
 
-/*
- * Maximum accepted size of SASL messages.
- *
- * The messages that the server or libpq generate are much smaller than this,
- * but have some headroom.
- */
-#define PG_MAX_SASL_MESSAGE_LENGTH	1024
-
 /*----------------------------------------------------------------
  * Global authentication functions
  *----------------------------------------------------------------
@@ -675,7 +657,7 @@ ClientAuthentication(Port *port)
 /*
  * Send an authentication request packet to the frontend.
  */
-static void
+void
 sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extralen)
 {
 	StringInfoData buf;
@@ -855,12 +837,13 @@ CheckPWChallengeAuth(Port *port, char **logdetail)
 	 * SCRAM secret, we must do SCRAM authentication.
 	 *
 	 * If MD5 authentication is not allowed, always use SCRAM.  If the user
-	 * had an MD5 password, CheckSCRAMAuth() will fail.
+	 * had an MD5 password, the SCRAM mechanism will fail.
 	 */
 	if (port->hba->auth_method == uaMD5 && pwtype == PASSWORD_TYPE_MD5)
 		auth_result = CheckMD5Auth(port, shadow_pass, logdetail);
 	else
-		auth_result = CheckSCRAMAuth(port, shadow_pass, logdetail);
+		auth_result = CheckSASLAuth(&pg_be_scram_mech, port, shadow_pass,
+									logdetail);
 
 	if (shadow_pass)
 		pfree(shadow_pass);
@@ -918,159 +901,6 @@ CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail)
 	return result;
 }
 
-static int
-SASL_exchange(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
-			  char **logdetail)
-{
-	StringInfoData sasl_mechs;
-	int			mtype;
-	StringInfoData buf;
-	void	   *opaq = NULL;
-	char	   *output = NULL;
-	int			outputlen = 0;
-	const char *input;
-	int			inputlen;
-	int			result;
-	bool		initial;
-
-	/*
-	 * Send the SASL authentication request to user.  It includes the list of
-	 * authentication mechanisms that are supported.
-	 */
-	initStringInfo(&sasl_mechs);
-
-	mech->get_mechanisms(port, &sasl_mechs);
-	/* Put another '\0' to mark that list is finished. */
-	appendStringInfoChar(&sasl_mechs, '\0');
-
-	sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len);
-	pfree(sasl_mechs.data);
-
-	/*
-	 * Loop through SASL message exchange.  This exchange can consist of
-	 * multiple messages sent in both directions.  First message is always
-	 * from the client.  All messages from client to server are password
-	 * packets (type 'p').
-	 */
-	initial = true;
-	do
-	{
-		pq_startmsgread();
-		mtype = pq_getbyte();
-		if (mtype != 'p')
-		{
-			/* Only log error if client didn't disconnect. */
-			if (mtype != EOF)
-			{
-				ereport(ERROR,
-						(errcode(ERRCODE_PROTOCOL_VIOLATION),
-						 errmsg("expected SASL response, got message type %d",
-								mtype)));
-			}
-			else
-				return STATUS_EOF;
-		}
-
-		/* Get the actual SASL message */
-		initStringInfo(&buf);
-		if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
-		{
-			/* EOF - pq_getmessage already logged error */
-			pfree(buf.data);
-			return STATUS_ERROR;
-		}
-
-		elog(DEBUG4, "processing received SASL response of length %d", buf.len);
-
-		/*
-		 * The first SASLInitialResponse message is different from the others.
-		 * It indicates which SASL mechanism the client selected, and contains
-		 * an optional Initial Client Response payload.  The subsequent
-		 * SASLResponse messages contain just the SASL payload.
-		 */
-		if (initial)
-		{
-			const char *selected_mech;
-
-			selected_mech = pq_getmsgrawstring(&buf);
-
-			/*
-			 * Initialize the status tracker for message exchanges.
-			 *
-			 * If the user doesn't exist, or doesn't have a valid password, or
-			 * it's expired, we still go through the motions of SASL
-			 * authentication, but tell the authentication method that the
-			 * authentication is "doomed". That is, it's going to fail, no
-			 * matter what.
-			 *
-			 * This is because we don't want to reveal to an attacker what
-			 * usernames are valid, nor which users have a valid password.
-			 */
-			opaq = mech->init(port, selected_mech, shadow_pass);
-
-			inputlen = pq_getmsgint(&buf, 4);
-			if (inputlen == -1)
-				input = NULL;
-			else
-				input = pq_getmsgbytes(&buf, inputlen);
-
-			initial = false;
-		}
-		else
-		{
-			inputlen = buf.len;
-			input = pq_getmsgbytes(&buf, buf.len);
-		}
-		pq_getmsgend(&buf);
-
-		/*
-		 * The StringInfo guarantees that there's a \0 byte after the
-		 * response.
-		 */
-		Assert(input == NULL || input[inputlen] == '\0');
-
-		/*
-		 * Hand the incoming message to the mechanism implementation.
-		 */
-		result = mech->exchange(opaq, input, inputlen,
-								&output, &outputlen,
-								logdetail);
-
-		/* input buffer no longer used */
-		pfree(buf.data);
-
-		if (output)
-		{
-			/*
-			 * Negotiation generated data to be sent to the client.
-			 */
-			elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
-
-			/* TODO: PG_SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
-			if (result == PG_SASL_EXCHANGE_SUCCESS)
-				sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
-			else
-				sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
-
-			pfree(output);
-		}
-	} while (result == PG_SASL_EXCHANGE_CONTINUE);
-
-	/* Oops, Something bad happened */
-	if (result != PG_SASL_EXCHANGE_SUCCESS)
-	{
-		return STATUS_ERROR;
-	}
-
-	return STATUS_OK;
-}
-
-static int
-CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
-{
-	return SASL_exchange(&pg_be_scram_mech, port, shadow_pass, logdetail);
-}
-
 
 /*----------------------------------------------------------------
  * GSSAPI authentication system
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 3610fae3ff..3d6734f253 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -21,6 +21,8 @@ extern bool pg_krb_caseins_users;
 extern char *pg_krb_realm;
 
 extern void ClientAuthentication(Port *port);
+extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
+							int extralen);
 
 /* Hook for plugins to get control in ClientAuthentication() */
 typedef void (*ClientAuthentication_hook_type) (Port *, int);
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
index c732f35564..dad04d8ecd 100644
--- a/src/include/libpq/sasl.h
+++ b/src/include/libpq/sasl.h
@@ -16,6 +16,7 @@
 #ifndef PG_SASL_H
 #define PG_SASL_H
 
+#include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
 
 /* Status codes for message exchange */
@@ -23,6 +24,14 @@
 #define PG_SASL_EXCHANGE_SUCCESS		1
 #define PG_SASL_EXCHANGE_FAILURE		2
 
+/*
+ * Maximum accepted size of SASL messages.
+ *
+ * The messages that the server or libpq generate are much smaller than this,
+ * but have some headroom.
+ */
+#define PG_MAX_SASL_MESSAGE_LENGTH	1024
+
 /*
  * Backend mechanism API
  *
@@ -124,4 +133,8 @@ typedef struct
 	pg_be_sasl_exchange_func	exchange;
 } pg_be_sasl_mech;
 
+/* Common implementation for auth.c */
+extern int CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port,
+						 char *shadow_pass, char **logdetail);
+
 #endif /* PG_SASL_H */
-- 
2.25.1

