From 22cd26de5266880d2cc5419ce80428ec5c25bf5f Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 13 Apr 2021 10:25:48 -0700
Subject: [PATCH v3 1/2] auth: generalize SASL mechanisms

Split the SASL logic out from the SCRAM implementation, so that it can
be reused by other mechanisms.  New implementations will implement both
a pg_fe_sasl_mech and a pg_be_sasl_mech.
---
 src/backend/libpq/auth-scram.c       |  48 ++++++----
 src/backend/libpq/auth.c             |  40 +++++---
 src/include/libpq/sasl.h             | 127 ++++++++++++++++++++++++++
 src/include/libpq/scram.h            |  13 +--
 src/interfaces/libpq/fe-auth-sasl.h  | 131 +++++++++++++++++++++++++++
 src/interfaces/libpq/fe-auth-scram.c |  40 +++++---
 src/interfaces/libpq/fe-auth.c       |  22 ++++-
 src/interfaces/libpq/fe-auth.h       |  11 +--
 src/interfaces/libpq/fe-connect.c    |   6 +-
 src/interfaces/libpq/libpq-int.h     |   2 +
 10 files changed, 367 insertions(+), 73 deletions(-)
 create mode 100644 src/include/libpq/sasl.h
 create mode 100644 src/interfaces/libpq/fe-auth-sasl.h

diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index f9e1026a12..2965ea2ddb 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -101,11 +101,25 @@
 #include "common/sha2.h"
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
+#include "libpq/sasl.h"
 #include "libpq/scram.h"
 #include "miscadmin.h"
 #include "utils/builtins.h"
 #include "utils/timestamp.h"
 
+static void  scram_get_mechanisms(Port *port, StringInfo buf);
+static void *scram_init(Port *port, const char *selected_mech,
+						const char *shadow_pass);
+static int   scram_exchange(void *opaq, const char *input, int inputlen,
+							char **output, int *outputlen, char **logdetail);
+
+/* Mechanism declaration */
+const pg_be_sasl_mech pg_be_scram_mech = {
+	scram_get_mechanisms,
+	scram_init,
+	scram_exchange,
+};
+
 /*
  * Status data for a SCRAM authentication exchange.  This should be kept
  * internal to this file.
@@ -170,16 +184,14 @@ static char *sanitize_str(const char *s);
 static char *scram_mock_salt(const char *username);
 
 /*
- * pg_be_scram_get_mechanisms
- *
  * Get a list of SASL mechanisms that this module supports.
  *
  * For the convenience of building the FE/BE packet that lists the
  * mechanisms, the names are appended to the given StringInfo buffer,
  * separated by '\0' bytes.
  */
-void
-pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
+static void
+scram_get_mechanisms(Port *port, StringInfo buf)
 {
 	/*
 	 * Advertise the mechanisms in decreasing order of importance.  So the
@@ -199,8 +211,6 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
 }
 
 /*
- * pg_be_scram_init
- *
  * Initialize a new SCRAM authentication exchange status tracker.  This
  * needs to be called before doing any exchange.  It will be filled later
  * after the beginning of the exchange with authentication information.
@@ -215,10 +225,8 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
  * an authentication exchange, but it will fail, as if an incorrect password
  * was given.
  */
-void *
-pg_be_scram_init(Port *port,
-				 const char *selected_mech,
-				 const char *shadow_pass)
+static void *
+scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 {
 	scram_state *state;
 	bool		got_secret;
@@ -325,9 +333,9 @@ pg_be_scram_init(Port *port,
  * string at *logdetail that will be sent to the postmaster log (but not
  * the client).
  */
-int
-pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
-					 char **output, int *outputlen, char **logdetail)
+static int
+scram_exchange(void *opaq, const char *input, int inputlen,
+			   char **output, int *outputlen, char **logdetail)
 {
 	scram_state *state = (scram_state *) opaq;
 	int			result;
@@ -346,7 +354,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 
 		*output = pstrdup("");
 		*outputlen = 0;
-		return SASL_EXCHANGE_CONTINUE;
+		return PG_SASL_EXCHANGE_CONTINUE;
 	}
 
 	/*
@@ -379,7 +387,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			*output = build_server_first_message(state);
 
 			state->state = SCRAM_AUTH_SALT_SENT;
-			result = SASL_EXCHANGE_CONTINUE;
+			result = PG_SASL_EXCHANGE_CONTINUE;
 			break;
 
 		case SCRAM_AUTH_SALT_SENT:
@@ -408,7 +416,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			 * erroring out in an application-specific way.  We choose to do
 			 * the latter, so that the error message for invalid password is
 			 * the same for all authentication methods.  The caller will call
-			 * ereport(), when we return SASL_EXCHANGE_FAILURE with no output.
+			 * ereport(), when we return PG_SASL_EXCHANGE_FAILURE with no output.
 			 *
 			 * NB: the order of these checks is intentional.  We calculate the
 			 * client proof even in a mock authentication, even though it's
@@ -417,7 +425,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			 */
 			if (!verify_client_proof(state) || state->doomed)
 			{
-				result = SASL_EXCHANGE_FAILURE;
+				result = PG_SASL_EXCHANGE_FAILURE;
 				break;
 			}
 
@@ -425,16 +433,16 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			*output = build_server_final_message(state);
 
 			/* Success! */
-			result = SASL_EXCHANGE_SUCCESS;
+			result = PG_SASL_EXCHANGE_SUCCESS;
 			state->state = SCRAM_AUTH_FINISHED;
 			break;
 
 		default:
 			elog(ERROR, "invalid SCRAM exchange state");
-			result = SASL_EXCHANGE_FAILURE;
+			result = PG_SASL_EXCHANGE_FAILURE;
 	}
 
-	if (result == SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
+	if (result == PG_SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
 		*logdetail = state->logdetail;
 
 	if (*output)
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 967b5ef73c..82f043a343 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -26,11 +26,11 @@
 #include "commands/user.h"
 #include "common/ip.h"
 #include "common/md5.h"
-#include "common/scram-common.h"
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
 #include "libpq/libpq.h"
 #include "libpq/pqformat.h"
+#include "libpq/sasl.h"
 #include "libpq/scram.h"
 #include "miscadmin.h"
 #include "port/pg_bswap.h"
@@ -51,6 +51,13 @@ static void auth_failed(Port *port, int status, char *logdetail);
 static char *recv_password_packet(Port *port);
 static void set_authn_id(Port *port, const char *id);
 
+/*----------------------------------------------------------------
+ * SASL common authentication
+ *----------------------------------------------------------------
+ */
+static int	SASL_exchange(const pg_be_sasl_mech *mech, Port *port,
+						  char *shadow_pass, char **logdetail);
+
 
 /*----------------------------------------------------------------
  * Password-based authentication methods (password, md5, and scram-sha-256)
@@ -912,12 +919,13 @@ CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail)
 }
 
 static int
-CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
+SASL_exchange(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
+			  char **logdetail)
 {
 	StringInfoData sasl_mechs;
 	int			mtype;
 	StringInfoData buf;
-	void	   *scram_opaq = NULL;
+	void	   *opaq = NULL;
 	char	   *output = NULL;
 	int			outputlen = 0;
 	const char *input;
@@ -931,7 +939,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 	 */
 	initStringInfo(&sasl_mechs);
 
-	pg_be_scram_get_mechanisms(port, &sasl_mechs);
+	mech->get_mechanisms(port, &sasl_mechs);
 	/* Put another '\0' to mark that list is finished. */
 	appendStringInfoChar(&sasl_mechs, '\0');
 
@@ -998,7 +1006,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 			 * This is because we don't want to reveal to an attacker what
 			 * usernames are valid, nor which users have a valid password.
 			 */
-			scram_opaq = pg_be_scram_init(port, selected_mech, shadow_pass);
+			opaq = mech->init(port, selected_mech, shadow_pass);
 
 			inputlen = pq_getmsgint(&buf, 4);
 			if (inputlen == -1)
@@ -1022,12 +1030,11 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 		Assert(input == NULL || input[inputlen] == '\0');
 
 		/*
-		 * we pass 'logdetail' as NULL when doing a mock authentication,
-		 * because we should already have a better error message in that case
+		 * Hand the incoming message to the mechanism implementation.
 		 */
-		result = pg_be_scram_exchange(scram_opaq, input, inputlen,
-									  &output, &outputlen,
-									  logdetail);
+		result = mech->exchange(opaq, input, inputlen,
+								&output, &outputlen,
+								logdetail);
 
 		/* input buffer no longer used */
 		pfree(buf.data);
@@ -1039,17 +1046,18 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 			 */
 			elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
 
-			if (result == SASL_EXCHANGE_SUCCESS)
+			/* TODO: PG_SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
+			if (result == PG_SASL_EXCHANGE_SUCCESS)
 				sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
 			else
 				sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
 
 			pfree(output);
 		}
-	} while (result == SASL_EXCHANGE_CONTINUE);
+	} while (result == PG_SASL_EXCHANGE_CONTINUE);
 
 	/* Oops, Something bad happened */
-	if (result != SASL_EXCHANGE_SUCCESS)
+	if (result != PG_SASL_EXCHANGE_SUCCESS)
 	{
 		return STATUS_ERROR;
 	}
@@ -1057,6 +1065,12 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
 	return STATUS_OK;
 }
 
+static int
+CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
+{
+	return SASL_exchange(&pg_be_scram_mech, port, shadow_pass, logdetail);
+}
+
 
 /*----------------------------------------------------------------
  * GSSAPI authentication system
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
new file mode 100644
index 0000000000..c732f35564
--- /dev/null
+++ b/src/include/libpq/sasl.h
@@ -0,0 +1,127 @@
+/*-------------------------------------------------------------------------
+ *
+ * sasl.h
+ *     Defines the SASL mechanism interface for the libpq backend. Each SASL
+ *     mechanism defines a frontend and a backend callback structure.
+ *
+ *     See src/interfaces/libpq/fe-auth-sasl.h for the frontend counterpart.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/libpq/sasl.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef PG_SASL_H
+#define PG_SASL_H
+
+#include "libpq/libpq-be.h"
+
+/* Status codes for message exchange */
+#define PG_SASL_EXCHANGE_CONTINUE		0
+#define PG_SASL_EXCHANGE_SUCCESS		1
+#define PG_SASL_EXCHANGE_FAILURE		2
+
+/*
+ * Backend mechanism API
+ *
+ * To implement a backend mechanism, declare a pg_be_sasl_mech struct with
+ * appropriate callback implementations. Then pass the mechanism to
+ * CheckSASLAuth() during ClientAuthentication(), once the server has decided
+ * which authentication method to use.
+ */
+
+/*
+ * mech.get_mechanisms()
+ *
+ * Retrieves the list of SASL mechanism names supported by this implementation.
+ * The names are appended into the provided buffer.
+ *
+ * Input parameters:
+ *
+ *   port: the client Port
+ *
+ * Output parameters:
+ *
+ *   buf: a StringInfo buffer that the callback should populate with supported
+ *        mechanism names. Null-terminated names should be printed to the buffer
+ *        using appendStringInfo*().
+ */
+typedef void  (*pg_be_sasl_mechanism_func)(Port *port, StringInfo buf);
+
+/*
+ * mech.init()
+ *
+ * Initializes mechanism-specific state for a connection. This callback must
+ * return a pointer to its allocated state, which will be passed as-is as the
+ * first argument to the other callbacks.
+ *
+ * Input paramters:
+ *
+ *   port:        the client Port
+ *
+ *	 mech:        the actual mechanism name in use by the client
+ *
+ *	 shadow_pass: the shadow entry for the user being authenticated, or NULL if
+ *	              one does not exist. Mechanisms that do not use shadow entries
+ *	              may ignore this parameter. If a mechanism uses shadow entries
+ *	              but shadow_pass is NULL, the implementation must continue the
+ *	              exchange as if the user existed and the password did not
+ *	              match, to avoid disclosing valid user names.
+ */
+typedef void *(*pg_be_sasl_init_func)(Port *port, const char *mech,
+									  const char *shadow_pass);
+
+/*
+ * mech.exchange()
+ *
+ * Produces a server challenge to be sent to the client. The callback must
+ * return one of the PG_SASL_EXCHANGE_* values, depending on whether the
+ * exchange must continue, has finished successfully, or has failed.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ *
+ *   input:    the response data sent by the client, or NULL if the mechanism is
+ *             client-first but the client did not send an initial response.
+ *             (This can only happen during the first message from the client.)
+ *             This is guaranteed to be null-terminated for safety, but SASL
+ *             allows embedded nulls in responses, so mechanisms must be careful
+ *             to check inputlen.
+ *
+ *   inputlen: the length of the challenge data sent by the server, or -1 if the
+ *             client did not send an initial response
+ *
+ * Output parameters, to be set by the callback function:
+ *
+ *   output:    a palloc'd buffer containing either the server's next challenge
+ *              (if PG_SASL_EXCHANGE_CONTINUE is returned) or the server's
+ *              outcome data (if PG_SASL_EXCHANGE_SUCCESS is returned and the
+ *              mechanism requires data to be sent during a successful outcome).
+ *              The callback should set this to NULL if the exchange is over and
+ *              no output should be sent, which should correspond to either
+ *              PG_SASL_EXCHANGE_FAILURE or a PG_SASL_EXCHANGE_SUCCESS with no
+ *              outcome data.
+ *
+ *   outputlen: the length of the challenge data. Ignored if *output is NULL.
+ *
+ *   logdetail: set to an optional DETAIL message to be printed to the server
+ *              log, to disambiguate failure modes. (The client will only ever
+ *              see the same generic authentication failure message.) Ignored if
+ *              the exchange is completed with PG_SASL_EXCHANGE_SUCCESS.
+ */
+typedef int   (*pg_be_sasl_exchange_func)(void *state,
+										  const char *input, int inputlen,
+										  char **output, int *outputlen,
+										  char **logdetail);
+
+typedef struct
+{
+	pg_be_sasl_mechanism_func	get_mechanisms;
+	pg_be_sasl_init_func		init;
+	pg_be_sasl_exchange_func	exchange;
+} pg_be_sasl_mech;
+
+#endif /* PG_SASL_H */
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index 2c879150da..9e4540bde3 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -15,17 +15,10 @@
 
 #include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
+#include "libpq/sasl.h"
 
-/* Status codes for message exchange */
-#define SASL_EXCHANGE_CONTINUE		0
-#define SASL_EXCHANGE_SUCCESS		1
-#define SASL_EXCHANGE_FAILURE		2
-
-/* Routines dedicated to authentication */
-extern void pg_be_scram_get_mechanisms(Port *port, StringInfo buf);
-extern void *pg_be_scram_init(Port *port, const char *selected_mech, const char *shadow_pass);
-extern int	pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
-								 char **output, int *outputlen, char **logdetail);
+/* Implementation */
+extern const pg_be_sasl_mech pg_be_scram_mech;
 
 /* Routines to handle and check SCRAM-SHA-256 secret */
 extern char *pg_be_scram_build_secret(const char *password);
diff --git a/src/interfaces/libpq/fe-auth-sasl.h b/src/interfaces/libpq/fe-auth-sasl.h
new file mode 100644
index 0000000000..1409e51287
--- /dev/null
+++ b/src/interfaces/libpq/fe-auth-sasl.h
@@ -0,0 +1,131 @@
+/*-------------------------------------------------------------------------
+ *
+ * fe-auth-sasl.h
+ *    Defines the SASL mechanism interface for the libpq frontend. Each SASL
+ *    mechanism defines a frontend and a backend callback structure. This is not
+ *    part of the public API for applications.
+ *
+ *    See src/include/libpq/sasl.h for the backend counterpart.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/interfaces/libpq/fe-auth-sasl.h
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#ifndef FE_AUTH_SASL_H
+#define FE_AUTH_SASL_H
+
+#include "libpq-fe.h"
+
+/*
+ * Frontend mechanism API
+ *
+ * To implement a frontend mechanism, declare a pg_be_sasl_mech struct with
+ * appropriate callback implementations, then hook it into conn->sasl during
+ * pg_SASL_init()'s mechanism negotiation.
+ */
+
+/*
+ * mech.init()
+ *
+ * Initializes mechanism-specific state for a connection. This callback must
+ * return a pointer to its allocated state, which will be passed as-is as the
+ * first argument to the other callbacks. mech.free() will be called to release
+ * any state resources.
+ *
+ * If state allocation fails, the implementation should return NULL to fail the
+ * authentication exchange.
+ *
+ * Input parameters:
+ *
+ *   conn:     the connection to the server
+ *
+ *   password: the user's supplied password for the current connection
+ *
+ *   mech:     the mechanism name in use, for implementations that may advertise
+ *             more than one name (such as *-PLUS variants)
+ */
+typedef void *(*pg_fe_sasl_init_func)(PGconn *conn, const char *password,
+									  const char *mech);
+
+/*
+ * mech.exchange()
+ *
+ * Produces a client response to a server challenge. As a special case for
+ * client-first SASL mechanisms, exchange() is called with a NULL server
+ * response once at the start of the authentication exchange to generate an
+ * initial response.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ *
+ *   input:    the challenge data sent by the server, or NULL when generating a
+ *             client-first initial response (that is, when the server expects
+ *             the client to send a message to start the exchange). This is
+ *             guaranteed to be null-terminated for safety, but SASL allows
+ *             embedded nulls in challenges, so mechanisms must be careful to
+ *             check inputlen.
+ *
+ *   inputlen: the length of the challenge data sent by the server, or -1
+ *             during client-first initial response generation.
+ *
+ * Output parameters, to be set by the callback function:
+ *
+ *   output:    a malloc'd buffer containing the client's response to the
+ *              server, or NULL if the exchange should be aborted. (*success
+ *              should be set to false in the latter case.)
+ *
+ *   outputlen: the length of the client response buffer, or zero if no data
+ *              should be sent due to an exchange failure
+ *
+ *   done:      set to true if the SASL exchange should not continue, because
+ *              the exchange is either complete or failed
+ *
+ *   success:   set to true if the SASL exchange completed successfully. Ignored
+ *              if *done is false.
+ */
+typedef void  (*pg_fe_sasl_exchange_func)(void *state,
+										  char *input, int inputlen,
+										  char **output, int *outputlen,
+										  bool *done, bool *success);
+
+/*
+ * mech.channel_bound()
+ *
+ * Returns true if the connection has an established channel binding. A
+ * mechanism implementation must ensure that a SASL exchange has actually been
+ * completed, in addition to checking that channel binding is in use.
+ *
+ * Mechanisms that do not implement channel binding may simply return false.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ */
+typedef bool  (*pg_fe_sasl_channel_bound_func)(void *);
+
+/*
+ * mech.free()
+ *
+ * Frees the state allocated by mech.init(). This is called when the connection
+ * is dropped, not when the exchange is completed.
+ *
+ * Input parameters:
+ *
+ *   state:    the opaque mechanism state returned by mech.init()
+ */
+typedef void  (*pg_fe_sasl_free_func)(void *);
+
+typedef struct
+{
+	pg_fe_sasl_init_func			init;
+	pg_fe_sasl_exchange_func		exchange;
+	pg_fe_sasl_channel_bound_func	channel_bound;
+	pg_fe_sasl_free_func			free;
+} pg_fe_sasl_mech;
+
+#endif /* FE_AUTH_SASL_H */
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index 5881386e37..515ef66f37 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -21,6 +21,22 @@
 #include "fe-auth.h"
 
 
+/* The exported SCRAM callback mechanism. */
+static void *scram_init(PGconn *conn, const char *password,
+						const char *sasl_mechanism);
+static void scram_exchange(void *opaq, char *input, int inputlen,
+						   char **output, int *outputlen,
+						   bool *done, bool *success);
+static bool scram_channel_bound(void *opaq);
+static void scram_free(void *opaq);
+
+const pg_fe_sasl_mech pg_scram_mech = {
+	scram_init,
+	scram_exchange,
+	scram_channel_bound,
+	scram_free,
+};
+
 /*
  * Status of exchange messages used for SCRAM authentication via the
  * SASL protocol.
@@ -72,10 +88,10 @@ static bool calculate_client_proof(fe_scram_state *state,
 /*
  * Initialize SCRAM exchange status.
  */
-void *
-pg_fe_scram_init(PGconn *conn,
-				 const char *password,
-				 const char *sasl_mechanism)
+static void *
+scram_init(PGconn *conn,
+		   const char *password,
+		   const char *sasl_mechanism)
 {
 	fe_scram_state *state;
 	char	   *prep_password;
@@ -128,8 +144,8 @@ pg_fe_scram_init(PGconn *conn,
  * Note that the caller must also ensure that the exchange was actually
  * successful.
  */
-bool
-pg_fe_scram_channel_bound(void *opaq)
+static bool
+scram_channel_bound(void *opaq)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 
@@ -152,8 +168,8 @@ pg_fe_scram_channel_bound(void *opaq)
 /*
  * Free SCRAM exchange status
  */
-void
-pg_fe_scram_free(void *opaq)
+static void
+scram_free(void *opaq)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 
@@ -188,10 +204,10 @@ pg_fe_scram_free(void *opaq)
 /*
  * Exchange a SCRAM message with backend.
  */
-void
-pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
-					 char **output, int *outputlen,
-					 bool *done, bool *success)
+static void
+scram_exchange(void *opaq, char *input, int inputlen,
+			   char **output, int *outputlen,
+			   bool *done, bool *success)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 	PGconn	   *conn = state->conn;
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index e8062647e6..f299e72e7e 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -41,6 +41,7 @@
 #include "common/md5.h"
 #include "common/scram-common.h"
 #include "fe-auth.h"
+#include "fe-auth-sasl.h"
 #include "libpq-fe.h"
 
 #ifdef ENABLE_GSS
@@ -482,7 +483,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 				 * channel_binding is not disabled.
 				 */
 				if (conn->channel_binding[0] != 'd')	/* disable */
+				{
 					selected_mechanism = SCRAM_SHA_256_PLUS_NAME;
+					conn->sasl = &pg_scram_mech;
+				}
 #else
 				/*
 				 * The client does not support channel binding.  If it is
@@ -516,7 +520,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 		}
 		else if (strcmp(mechanism_buf.data, SCRAM_SHA_256_NAME) == 0 &&
 				 !selected_mechanism)
+		{
 			selected_mechanism = SCRAM_SHA_256_NAME;
+			conn->sasl = &pg_scram_mech;
+		}
 	}
 
 	if (!selected_mechanism)
@@ -555,20 +562,22 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 		goto error;
 	}
 
+	Assert(conn->sasl);
+
 	/*
 	 * Initialize the SASL state information with all the information gathered
 	 * during the initial exchange.
 	 *
 	 * Note: Only tls-unique is supported for the moment.
 	 */
-	conn->sasl_state = pg_fe_scram_init(conn,
+	conn->sasl_state = conn->sasl->init(conn,
 										password,
 										selected_mechanism);
 	if (!conn->sasl_state)
 		goto oom_error;
 
 	/* Get the mechanism-specific Initial Client Response, if any */
-	pg_fe_scram_exchange(conn->sasl_state,
+	conn->sasl->exchange(conn->sasl_state,
 						 NULL, -1,
 						 &initialresponse, &initialresponselen,
 						 &done, &success);
@@ -649,7 +658,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
 	/* For safety and convenience, ensure the buffer is NULL-terminated. */
 	challenge[payloadlen] = '\0';
 
-	pg_fe_scram_exchange(conn->sasl_state,
+	conn->sasl->exchange(conn->sasl_state,
 						 challenge, payloadlen,
 						 &output, &outputlen,
 						 &done, &success);
@@ -664,6 +673,11 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
 							 libpq_gettext("AuthenticationSASLFinal received from server, but SASL authentication was not completed\n"));
 		return STATUS_ERROR;
 	}
+	/*
+	 * TODO SASL requires us to accomodate zero-length responses.
+	 * TODO is it legal for a client not to send a response to a server
+	 * challenge, if the exchange isn't being aborted?
+	 */
 	if (outputlen != 0)
 	{
 		/*
@@ -830,7 +844,7 @@ check_expected_areq(AuthRequest areq, PGconn *conn)
 			case AUTH_REQ_SASL_FIN:
 				break;
 			case AUTH_REQ_OK:
-				if (!pg_fe_scram_channel_bound(conn->sasl_state))
+				if (!conn->sasl || !conn->sasl->channel_bound(conn->sasl_state))
 				{
 					appendPQExpBufferStr(&conn->errorMessage,
 										 libpq_gettext("channel binding required, but server authenticated client without channel binding\n"));
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 7877dcbd09..63927480ee 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -22,15 +22,8 @@
 extern int	pg_fe_sendauth(AuthRequest areq, int payloadlen, PGconn *conn);
 extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
-/* Prototypes for functions in fe-auth-scram.c */
-extern void *pg_fe_scram_init(PGconn *conn,
-							  const char *password,
-							  const char *sasl_mechanism);
-extern bool pg_fe_scram_channel_bound(void *opaq);
-extern void pg_fe_scram_free(void *opaq);
-extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
-								 char **output, int *outputlen,
-								 bool *done, bool *success);
+/* Mechanisms in fe-auth-scram.c */
+extern const pg_fe_sasl_mech pg_scram_mech;
 extern char *pg_fe_scram_build_secret(const char *password);
 
 #endif							/* FE_AUTH_H */
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index fc65e490ef..e950b41374 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -516,11 +516,7 @@ pqDropConnection(PGconn *conn, bool flushInput)
 #endif
 	if (conn->sasl_state)
 	{
-		/*
-		 * XXX: if support for more authentication mechanisms is added, this
-		 * needs to call the right 'free' function.
-		 */
-		pg_fe_scram_free(conn->sasl_state);
+		conn->sasl->free(conn->sasl_state);
 		conn->sasl_state = NULL;
 	}
 }
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 6b7fd2c267..e9f214b61b 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -41,6 +41,7 @@
 #include "getaddrinfo.h"
 #include "libpq/pqcomm.h"
 /* include stuff found in fe only */
+#include "fe-auth-sasl.h"
 #include "pqexpbuffer.h"
 
 #ifdef ENABLE_GSS
@@ -500,6 +501,7 @@ struct pg_conn
 	PGresult   *next_result;	/* next result (used in single-row mode) */
 
 	/* Assorted state for SASL, SSL, GSS, etc */
+	const pg_fe_sasl_mech *sasl;
 	void	   *sasl_state;
 
 	/* SSL structures */
-- 
2.25.1

