From 378c86619933d9c712730e3d6a105a79854660cf Mon Sep 17 00:00:00 2001
From: Michael Paquier <michael@paquier.xyz>
Date: Wed, 14 Dec 2022 11:35:33 +0900
Subject: [PATCH] Remove dependency to hash type and key length in internal
 SCRAM code

SCRAM_KEY_LEN had a hard dependency on SHA-256, making difficult the
addition of more hash methods in SCRAM with many statically-sized
buffers, as one problem.
---
 src/include/common/scram-common.h    |  25 ++--
 src/include/libpq/scram.h            |   8 +-
 src/backend/libpq/auth-scram.c       | 165 ++++++++++++++---------
 src/backend/libpq/crypt.c            |  10 +-
 src/common/scram-common.c            | 189 ++++++++++++++++++---------
 src/interfaces/libpq/fe-auth-scram.c | 175 +++++++++++++++++--------
 6 files changed, 380 insertions(+), 192 deletions(-)

diff --git a/src/include/common/scram-common.h b/src/include/common/scram-common.h
index 4acf2a78ad..5b647e4b81 100644
--- a/src/include/common/scram-common.h
+++ b/src/include/common/scram-common.h
@@ -21,7 +21,7 @@
 #define SCRAM_SHA_256_PLUS_NAME "SCRAM-SHA-256-PLUS"	/* with channel binding */
 
 /* Length of SCRAM keys (client and server) */
-#define SCRAM_KEY_LEN				PG_SHA256_DIGEST_LENGTH
+#define SCRAM_SHA_256_KEY_LEN				PG_SHA256_DIGEST_LENGTH
 
 /*
  * Size of random nonce generated in the authentication exchange.  This
@@ -43,17 +43,22 @@
  */
 #define SCRAM_DEFAULT_ITERATIONS	4096
 
-extern int	scram_SaltedPassword(const char *password, const char *salt,
-								 int saltlen, int iterations, uint8 *result,
-								 const char **errstr);
-extern int	scram_H(const uint8 *input, int len, uint8 *result,
+extern int	scram_SaltedPassword(const char *password,
+								 pg_cryptohash_type hash_type, int key_length,
+								 const char *salt, int saltlen, int iterations,
+								 uint8 *result, const char **errstr);
+extern int	scram_H(const uint8 *input, pg_cryptohash_type hash_type,
+					int key_length, uint8 *result,
 					const char **errstr);
-extern int	scram_ClientKey(const uint8 *salted_password, uint8 *result,
-							const char **errstr);
-extern int	scram_ServerKey(const uint8 *salted_password, uint8 *result,
-							const char **errstr);
+extern int	scram_ClientKey(const uint8 *salted_password,
+							pg_cryptohash_type hash_type, int key_length,
+							uint8 *result, const char **errstr);
+extern int	scram_ServerKey(const uint8 *salted_password,
+							pg_cryptohash_type hash_type, int key_length,
+							uint8 *result, const char **errstr);
 
-extern char *scram_build_secret(const char *salt, int saltlen, int iterations,
+extern char *scram_build_secret(pg_cryptohash_type hash_type, int key_length,
+								const char *salt, int saltlen, int iterations,
 								const char *password, const char **errstr);
 
 #endif							/* SCRAM_COMMON_H */
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index c51e848c24..2662c9d703 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -13,6 +13,7 @@
 #ifndef PG_SCRAM_H
 #define PG_SCRAM_H
 
+#include "common/cryptohash.h"
 #include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
 #include "libpq/sasl.h"
@@ -22,8 +23,11 @@ extern PGDLLIMPORT const pg_be_sasl_mech pg_be_scram_mech;
 
 /* Routines to handle and check SCRAM-SHA-256 secret */
 extern char *pg_be_scram_build_secret(const char *password);
-extern bool parse_scram_secret(const char *secret, int *iterations, char **salt,
-							   uint8 *stored_key, uint8 *server_key);
+extern bool parse_scram_secret(const char *secret,
+							   int *iterations,
+							   pg_cryptohash_type *hash_type,
+							   int *key_length, char **salt,
+							   uint8 **stored_key, uint8 **server_key);
 extern bool scram_verify_plain_password(const char *username,
 										const char *password, const char *secret);
 
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index ee7f52218a..d083439c13 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -141,10 +141,14 @@ typedef struct
 	Port	   *port;
 	bool		channel_binding_in_use;
 
+	/* State data depending on the hash type */
+	pg_cryptohash_type	hash_type;
+	int			key_length;
+
 	int			iterations;
 	char	   *salt;			/* base64-encoded */
-	uint8		StoredKey[SCRAM_KEY_LEN];
-	uint8		ServerKey[SCRAM_KEY_LEN];
+	uint8	   *stored_key;		/* size of key_length */
+	uint8	   *server_key;		/* size of key_length */
 
 	/* Fields of the first message from client */
 	char		cbind_flag;
@@ -155,7 +159,7 @@ typedef struct
 	/* Fields from the last message from client */
 	char	   *client_final_message_without_proof;
 	char	   *client_final_nonce;
-	char		ClientProof[SCRAM_KEY_LEN];
+	char	   *client_proof;	/* size of key_length */
 
 	/* Fields generated in the server */
 	char	   *server_first_message;
@@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state);
 static char *build_server_final_message(scram_state *state);
 static bool verify_client_proof(scram_state *state);
 static bool verify_final_nonce(scram_state *state);
-static void mock_scram_secret(const char *username, int *iterations,
-							  char **salt, uint8 *stored_key, uint8 *server_key);
+static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
+							  int *iterations, int *key_length, char **salt,
+							  uint8 **stored_key, uint8 **server_key);
 static bool is_scram_printable(char *p);
 static char *sanitize_char(char c);
 static char *sanitize_str(const char *s);
-static char *scram_mock_salt(const char *username);
+static char *scram_mock_salt(const char *username,
+							 pg_cryptohash_type hash_type,
+							 int key_length);
 
 /*
  * Get a list of SASL mechanisms that this module supports.
@@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 
 		if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
 		{
-			if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt,
-								   state->StoredKey, state->ServerKey))
+			if (parse_scram_secret(shadow_pass, &state->iterations,
+								   &state->hash_type, &state->key_length,
+								   &state->salt,
+								   &state->stored_key,
+								   &state->server_key))
 				got_secret = true;
 			else
 			{
@@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 	 */
 	if (!got_secret)
 	{
-		mock_scram_secret(state->port->user_name, &state->iterations,
-						  &state->salt, state->StoredKey, state->ServerKey);
+		mock_scram_secret(state->port->user_name, &state->hash_type,
+						  &state->iterations, &state->key_length,
+						  &state->salt,
+						  &state->stored_key, &state->server_key);
 		state->doomed = true;
 	}
 
@@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password)
 				(errcode(ERRCODE_INTERNAL_ERROR),
 				 errmsg("could not generate random salt")));
 
-	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
+	result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
+								saltbuf, SCRAM_DEFAULT_SALT_LEN,
 								SCRAM_DEFAULT_ITERATIONS, password,
 								&errstr);
 
@@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password,
 	char	   *salt;
 	int			saltlen;
 	int			iterations;
-	uint8		salted_password[SCRAM_KEY_LEN];
-	uint8		stored_key[SCRAM_KEY_LEN];
-	uint8		server_key[SCRAM_KEY_LEN];
-	uint8		computed_key[SCRAM_KEY_LEN];
+	int			key_length = 0;
+	pg_cryptohash_type hash_type;
+	uint8	   *salted_password = NULL;	/* size of key_length */
+	uint8	   *stored_key = NULL;		/* size of key_length */
+	uint8	   *server_key = NULL;		/* size of key_length */
+	uint8	   *computed_key = NULL;	/* size of key_length */
 	char	   *prep_password;
 	pg_saslprep_rc rc;
 	const char *errstr = NULL;
 
-	if (!parse_scram_secret(secret, &iterations, &encoded_salt,
-							stored_key, server_key))
+	if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length,
+							&encoded_salt, &stored_key, &server_key))
 	{
 		/*
 		 * The password looked like a SCRAM secret, but could not be parsed.
@@ -524,6 +539,11 @@ scram_verify_plain_password(const char *username, const char *password,
 		return false;
 	}
 
+	/* allocated by parse_scram_secret() */
+	Assert(stored_key && server_key);
+	salted_password = (uint8 *) palloc(key_length * sizeof(uint8));
+	computed_key = (uint8 *) palloc(key_length * sizeof(uint8));
+
 	saltlen = pg_b64_dec_len(strlen(encoded_salt));
 	salt = palloc(saltlen);
 	saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
@@ -541,9 +561,11 @@ scram_verify_plain_password(const char *username, const char *password,
 		password = prep_password;
 
 	/* Compute Server Key based on the user-supplied plaintext password */
-	if (scram_SaltedPassword(password, salt, saltlen, iterations,
+	if (scram_SaltedPassword(password, hash_type, key_length,
+							 salt, saltlen, iterations,
 							 salted_password, &errstr) < 0 ||
-		scram_ServerKey(salted_password, computed_key, &errstr) < 0)
+		scram_ServerKey(salted_password, hash_type, key_length,
+						computed_key, &errstr) < 0)
 	{
 		elog(ERROR, "could not compute server key: %s", errstr);
 	}
@@ -555,24 +577,25 @@ scram_verify_plain_password(const char *username, const char *password,
 	 * Compare the secret's Server Key with the one computed from the
 	 * user-supplied password.
 	 */
-	return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
+	return memcmp(computed_key, server_key, key_length) == 0;
 }
 
 
 /*
  * Parse and validate format of given SCRAM secret.
  *
- * On success, the iteration count, salt, stored key, and server key are
- * extracted from the secret, and returned to the caller.  For 'stored_key'
- * and 'server_key', the caller must pass pre-allocated buffers of size
- * SCRAM_KEY_LEN.  Salt is returned as a base64-encoded, null-terminated
- * string.  The buffer for the salt is palloc'd by this function.
+ * On success, the iteration count, salt, key length, stored key, and
+ * server key are extracted from the secret, and returned to the caller.
+ * 'stored_key' and 'server_key' are palloc'd with a size of 'key_length'.
+ * Salt is returned as a base64-encoded, null-terminated string.  The buffer
+ * for the salt is palloc'd by this function.
  *
  * Returns true if the SCRAM secret has been parsed, and false otherwise.
  */
 bool
-parse_scram_secret(const char *secret, int *iterations, char **salt,
-				   uint8 *stored_key, uint8 *server_key)
+parse_scram_secret(const char *secret, int *iterations,
+				   pg_cryptohash_type *hash_type, int *key_length,
+				   char **salt, uint8 **stored_key, uint8 **server_key)
 {
 	char	   *v;
 	char	   *p;
@@ -606,6 +629,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
 	/* Parse the fields */
 	if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
 		goto invalid_secret;
+	*hash_type = PG_SHA256;
+	*key_length = SCRAM_SHA_256_KEY_LEN;
 
 	errno = 0;
 	*iterations = strtol(iterations_str, &p, 10);
@@ -631,17 +656,19 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
 	decoded_stored_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
 								decoded_stored_buf, decoded_len);
-	if (decoded_len != SCRAM_KEY_LEN)
+	if (decoded_len != *key_length)
 		goto invalid_secret;
-	memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
+	*stored_key = (uint8 *) palloc(*key_length * sizeof(uint8));
+	memcpy(*stored_key, decoded_stored_buf, *key_length * sizeof(uint8));
 
 	decoded_len = pg_b64_dec_len(strlen(serverkey_str));
 	decoded_server_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
 								decoded_server_buf, decoded_len);
-	if (decoded_len != SCRAM_KEY_LEN)
+	if (decoded_len != *key_length)
 		goto invalid_secret;
-	memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
+	*server_key = (uint8 *) palloc(*key_length * sizeof(uint8));
+	memcpy(*server_key, decoded_server_buf, *key_length * sizeof(uint8));
 
 	return true;
 
@@ -655,20 +682,25 @@ invalid_secret:
  *
  * In a normal authentication, these are extracted from the secret
  * stored in the server.  This function generates values that look
- * realistic, for when there is no stored secret.
+ * realistic, for when there is no stored secret, using SCRAM-SHA-256.
  *
- * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
- * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and
- * the buffer for the salt is palloc'd by this function.
+ * 'stored_key' and 'server_key' are palloc'd by this function with
+ * an arbitrary key length guessed from the hash type, and the buffer
+ * for the salt is palloc'd by this function.
  */
 static void
-mock_scram_secret(const char *username, int *iterations, char **salt,
-				  uint8 *stored_key, uint8 *server_key)
+mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
+				  int *iterations, int *key_length, char **salt,
+				  uint8 **stored_key, uint8 **server_key)
 {
 	char	   *raw_salt;
 	char	   *encoded_salt;
 	int			encoded_len;
 
+	/* Enforce the use of SHA-256, which would be realistic enough */
+	*hash_type = PG_SHA256;
+	*key_length = SCRAM_SHA_256_KEY_LEN;
+
 	/*
 	 * Generate deterministic salt.
 	 *
@@ -677,7 +709,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
 	 * as the salt generated for mock authentication uses the cluster's nonce
 	 * value.
 	 */
-	raw_salt = scram_mock_salt(username);
+	raw_salt = scram_mock_salt(username, *hash_type, *key_length);
 	if (raw_salt == NULL)
 		elog(ERROR, "could not encode salt");
 
@@ -695,8 +727,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
 	*iterations = SCRAM_DEFAULT_ITERATIONS;
 
 	/* StoredKey and ServerKey are not used in a doomed authentication */
-	memset(stored_key, 0, SCRAM_KEY_LEN);
-	memset(server_key, 0, SCRAM_KEY_LEN);
+	*stored_key = (uint8 *) palloc0(*key_length * sizeof(uint8));
+	*server_key = (uint8 *) palloc0(*key_length * sizeof(uint8));
 }
 
 /*
@@ -1111,10 +1143,13 @@ verify_final_nonce(scram_state *state)
 static bool
 verify_client_proof(scram_state *state)
 {
-	uint8		ClientSignature[SCRAM_KEY_LEN];
-	uint8		ClientKey[SCRAM_KEY_LEN];
-	uint8		client_StoredKey[SCRAM_KEY_LEN];
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	uint8	   *ClientSignature = (uint8 *) palloc(state->key_length *
+												   sizeof(uint8));
+	uint8	   *ClientKey = (uint8 *) palloc(state->key_length *
+											 sizeof(uint8));
+	uint8	   *client_StoredKey = (uint8 *) palloc(state->key_length *
+													sizeof(uint8));
+	pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
 	int			i;
 	const char *errstr = NULL;
 
@@ -1123,7 +1158,7 @@ verify_client_proof(scram_state *state)
 	 * here even when processing the calculations as this could involve a mock
 	 * authentication.
 	 */
-	if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, state->stored_key, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -1135,7 +1170,7 @@ verify_client_proof(scram_state *state)
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_final_message_without_proof,
 					   strlen(state->client_final_message_without_proof)) < 0 ||
-		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
+		pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
 	{
 		elog(ERROR, "could not calculate client signature: %s",
 			 pg_hmac_error(ctx));
@@ -1144,14 +1179,15 @@ verify_client_proof(scram_state *state)
 	pg_hmac_free(ctx);
 
 	/* Extract the ClientKey that the client calculated from the proof */
-	for (i = 0; i < SCRAM_KEY_LEN; i++)
-		ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
+	for (i = 0; i < state->key_length; i++)
+		ClientKey[i] = state->client_proof[i] ^ ClientSignature[i];
 
 	/* Hash it one more time, and compare with StoredKey */
-	if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0)
+	if (scram_H(ClientKey, state->hash_type, state->key_length,
+				client_StoredKey, &errstr) < 0)
 		elog(ERROR, "could not hash stored key: %s", errstr);
 
-	if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
+	if (memcmp(client_StoredKey, state->stored_key, state->key_length) != 0)
 		return false;
 
 	return true;
@@ -1349,12 +1385,13 @@ read_client_final_message(scram_state *state, const char *input)
 	client_proof_len = pg_b64_dec_len(strlen(value));
 	client_proof = palloc(client_proof_len);
 	if (pg_b64_decode(value, strlen(value), client_proof,
-					  client_proof_len) != SCRAM_KEY_LEN)
+					  client_proof_len) != state->key_length)
 		ereport(ERROR,
 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
 				 errmsg("malformed SCRAM message"),
 				 errdetail("Malformed proof in client-final-message.")));
-	memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
+	state->client_proof = palloc(state->key_length);
+	memcpy(state->client_proof, client_proof, state->key_length);
 	pfree(client_proof);
 
 	if (*p != '\0')
@@ -1374,13 +1411,14 @@ read_client_final_message(scram_state *state, const char *input)
 static char *
 build_server_final_message(scram_state *state)
 {
-	uint8		ServerSignature[SCRAM_KEY_LEN];
+	uint8	   *ServerSignature = (uint8 *) palloc(state->key_length *
+												   sizeof(uint8));
 	char	   *server_signature_base64;
 	int			siglen;
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
 
 	/* calculate ServerSignature */
-	if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, state->server_key, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -1392,7 +1430,7 @@ build_server_final_message(scram_state *state)
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_final_message_without_proof,
 					   strlen(state->client_final_message_without_proof)) < 0 ||
-		pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0)
+		pg_hmac_final(ctx, ServerSignature, state->key_length) < 0)
 	{
 		elog(ERROR, "could not calculate server signature: %s",
 			 pg_hmac_error(ctx));
@@ -1400,11 +1438,11 @@ build_server_final_message(scram_state *state)
 
 	pg_hmac_free(ctx);
 
-	siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
+	siglen = pg_b64_enc_len(state->key_length);
 	/* don't forget the zero-terminator */
 	server_signature_base64 = palloc(siglen + 1);
 	siglen = pg_b64_encode((const char *) ServerSignature,
-						   SCRAM_KEY_LEN, server_signature_base64,
+						   state->key_length, server_signature_base64,
 						   siglen);
 	if (siglen < 0)
 		elog(ERROR, "could not encode server signature");
@@ -1431,12 +1469,15 @@ build_server_final_message(scram_state *state)
  * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
  */
 static char *
-scram_mock_salt(const char *username)
+scram_mock_salt(const char *username, pg_cryptohash_type hash_type,
+				int key_length)
 {
 	pg_cryptohash_ctx *ctx;
-	static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH];
+	uint8	   *sha_digest = (uint8 *) palloc(key_length * sizeof(uint8));
 	char	   *mock_auth_nonce = GetMockAuthenticationNonce();
 
+	Assert(hash_type == PG_SHA256);
+
 	/*
 	 * Generate salt using a SHA256 hash of the username and the cluster's
 	 * mock authentication nonce.  (This works as long as the salt length is
@@ -1446,11 +1487,11 @@ scram_mock_salt(const char *username)
 	StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
 					 "salt length greater than SHA256 digest length");
 
-	ctx = pg_cryptohash_create(PG_SHA256);
+	ctx = pg_cryptohash_create(hash_type);
 	if (pg_cryptohash_init(ctx) < 0 ||
 		pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 ||
 		pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
-		pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0)
+		pg_cryptohash_final(ctx, sha_digest, key_length) < 0)
 	{
 		pg_cryptohash_free(ctx);
 		return NULL;
diff --git a/src/backend/libpq/crypt.c b/src/backend/libpq/crypt.c
index 1ff8b0507d..4e2b7c99fe 100644
--- a/src/backend/libpq/crypt.c
+++ b/src/backend/libpq/crypt.c
@@ -90,15 +90,17 @@ get_password_type(const char *shadow_pass)
 {
 	char	   *encoded_salt;
 	int			iterations;
-	uint8		stored_key[SCRAM_KEY_LEN];
-	uint8		server_key[SCRAM_KEY_LEN];
+	int			key_length = 0;
+	pg_cryptohash_type hash_type;
+	uint8	   *stored_key;		/* size of key_length */
+	uint8	   *server_key;		/* size of key_length */
 
 	if (strncmp(shadow_pass, "md5", 3) == 0 &&
 		strlen(shadow_pass) == MD5_PASSWD_LEN &&
 		strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3)
 		return PASSWORD_TYPE_MD5;
-	if (parse_scram_secret(shadow_pass, &iterations, &encoded_salt,
-						   stored_key, server_key))
+	if (parse_scram_secret(shadow_pass, &iterations, &hash_type, &key_length,
+						   &encoded_salt, &stored_key, &server_key))
 		return PASSWORD_TYPE_SCRAM_SHA_256;
 	return PASSWORD_TYPE_PLAINTEXT;
 }
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
index 1268625929..d41be27ca4 100644
--- a/src/common/scram-common.c
+++ b/src/common/scram-common.c
@@ -33,6 +33,7 @@
  */
 int
 scram_SaltedPassword(const char *password,
+					 pg_cryptohash_type hash_type, int key_length,
 					 const char *salt, int saltlen, int iterations,
 					 uint8 *result, const char **errstr)
 {
@@ -40,9 +41,9 @@ scram_SaltedPassword(const char *password,
 	uint32		one = pg_hton32(1);
 	int			i,
 				j;
-	uint8		Ui[SCRAM_KEY_LEN];
-	uint8		Ui_prev[SCRAM_KEY_LEN];
-	pg_hmac_ctx *hmac_ctx = pg_hmac_create(PG_SHA256);
+	uint8	   *Ui;			/* size of key_length */
+	uint8	   *Ui_prev;	/* size of key_length */
+	pg_hmac_ctx *hmac_ctx = pg_hmac_create(hash_type);
 
 	if (hmac_ctx == NULL)
 	{
@@ -50,6 +51,19 @@ scram_SaltedPassword(const char *password,
 		return -1;
 	}
 
+#ifdef FRONTEND
+	Ui = (uint8 *) malloc(key_length * sizeof(uint8));
+	Ui_prev = (uint8 *) malloc(key_length * sizeof(uint8));
+	if (Ui == NULL || Ui_prev == NULL)
+	{
+		*errstr = _("out of memory");
+		goto error;
+	}
+#else
+	Ui = (uint8 *) palloc(key_length * sizeof(uint8));
+	Ui_prev = (uint8 *) palloc(key_length * sizeof(uint8));
+#endif
+
 	/*
 	 * Iterate hash calculation of HMAC entry using given salt.  This is
 	 * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
@@ -60,48 +74,70 @@ scram_SaltedPassword(const char *password,
 	if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
 		pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 ||
 		pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
-		pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0)
+		pg_hmac_final(hmac_ctx, Ui_prev, key_length) < 0)
 	{
 		*errstr = pg_hmac_error(hmac_ctx);
-		pg_hmac_free(hmac_ctx);
-		return -1;
+		goto error;
 	}
 
-	memcpy(result, Ui_prev, SCRAM_KEY_LEN);
+	memcpy(result, Ui_prev, key_length);
 
 	/* Subsequent iterations */
 	for (i = 2; i <= iterations; i++)
 	{
 		if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
-			pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 ||
-			pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0)
+			pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, key_length) < 0 ||
+			pg_hmac_final(hmac_ctx, Ui, key_length) < 0)
 		{
 			*errstr = pg_hmac_error(hmac_ctx);
-			pg_hmac_free(hmac_ctx);
-			return -1;
+			goto error;
 		}
 
-		for (j = 0; j < SCRAM_KEY_LEN; j++)
+		for (j = 0; j < key_length; j++)
 			result[j] ^= Ui[j];
-		memcpy(Ui_prev, Ui, SCRAM_KEY_LEN);
+		memcpy(Ui_prev, Ui, key_length);
 	}
 
+#ifdef FRONTEND
+	free(Ui);
+	free(Ui_prev);
+#else
+	pfree(Ui);
+	pfree(Ui_prev);
+#endif
 	pg_hmac_free(hmac_ctx);
 	return 0;
+
+error:
+#ifdef FRONTEND
+	if (Ui)
+		free(Ui);
+	if (Ui_prev)
+		free(Ui_prev);
+#else
+	if (Ui)
+		pfree(Ui);
+	if (Ui_prev)
+		pfree(Ui_prev);
+#endif
+
+	pg_hmac_free(hmac_ctx);
+	return -1;
 }
 
 
 /*
- * Calculate SHA-256 hash for a NULL-terminated string. (The NULL terminator is
+ * Calculate hash for a NULL-terminated string. (The NULL terminator is
  * not included in the hash).  Returns 0 on success, -1 on failure with *errstr
  * pointing to a message about the error details.
  */
 int
-scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
+scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length,
+		uint8 *result, const char **errstr)
 {
 	pg_cryptohash_ctx *ctx;
 
-	ctx = pg_cryptohash_create(PG_SHA256);
+	ctx = pg_cryptohash_create(hash_type);
 	if (ctx == NULL)
 	{
 		*errstr = pg_cryptohash_error(NULL);	/* returns OOM */
@@ -109,8 +145,8 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
 	}
 
 	if (pg_cryptohash_init(ctx) < 0 ||
-		pg_cryptohash_update(ctx, input, len) < 0 ||
-		pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0)
+		pg_cryptohash_update(ctx, input, key_length) < 0 ||
+		pg_cryptohash_final(ctx, result, key_length) < 0)
 	{
 		*errstr = pg_cryptohash_error(ctx);
 		pg_cryptohash_free(ctx);
@@ -126,10 +162,11 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
  * pointing to a message about the error details.
  */
 int
-scram_ClientKey(const uint8 *salted_password, uint8 *result,
-				const char **errstr)
+scram_ClientKey(const uint8 *salted_password,
+				pg_cryptohash_type hash_type, int key_length,
+				uint8 *result, const char **errstr)
 {
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
 
 	if (ctx == NULL)
 	{
@@ -137,9 +174,9 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
 		return -1;
 	}
 
-	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
 		pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
-		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
+		pg_hmac_final(ctx, result, key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
@@ -155,10 +192,11 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
  * pointing to a message about the error details.
  */
 int
-scram_ServerKey(const uint8 *salted_password, uint8 *result,
-				const char **errstr)
+scram_ServerKey(const uint8 *salted_password,
+				pg_cryptohash_type hash_type, int key_length,
+				uint8 *result, const char **errstr)
 {
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
 
 	if (ctx == NULL)
 	{
@@ -166,9 +204,9 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
 		return -1;
 	}
 
-	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
 		pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
-		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
+		pg_hmac_final(ctx, result, key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
@@ -192,13 +230,14 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
  * error details.
  */
 char *
-scram_build_secret(const char *salt, int saltlen, int iterations,
+scram_build_secret(pg_cryptohash_type hash_type, int key_length,
+				   const char *salt, int saltlen, int iterations,
 				   const char *password, const char **errstr)
 {
-	uint8		salted_password[SCRAM_KEY_LEN];
-	uint8		stored_key[SCRAM_KEY_LEN];
-	uint8		server_key[SCRAM_KEY_LEN];
-	char	   *result;
+	uint8	   *salted_password = NULL;
+	uint8	   *stored_key = NULL;
+	uint8	   *server_key = NULL;
+	char	   *result = NULL;
 	char	   *p;
 	int			maxlen;
 	int			encoded_salt_len;
@@ -206,19 +245,42 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	int			encoded_server_len;
 	int			encoded_result;
 
+	Assert(hash_type == PG_SHA256);
+
 	if (iterations <= 0)
 		iterations = SCRAM_DEFAULT_ITERATIONS;
 
+#ifdef FRONTEND
+	salted_password = (uint8 *) malloc(key_length * sizeof(uint8));
+	stored_key = (uint8 *) malloc(key_length * sizeof(uint8));
+	server_key = (uint8 *) malloc(key_length * sizeof(uint8));
+	if (salted_password == NULL ||
+		stored_key == NULL ||
+		server_key == NULL)
+	{
+		*errstr = _("out of memory");
+		goto error;
+	}
+#else
+	salted_password = (uint8 *) palloc(key_length * sizeof(uint8));
+	stored_key = (uint8 *) palloc(key_length * sizeof(uint8));
+	server_key = (uint8 *) palloc(key_length * sizeof(uint8));
+#endif
+
 	/* Calculate StoredKey and ServerKey */
-	if (scram_SaltedPassword(password, salt, saltlen, iterations,
+	if (scram_SaltedPassword(password, hash_type, key_length,
+							 salt, saltlen, iterations,
 							 salted_password, errstr) < 0 ||
-		scram_ClientKey(salted_password, stored_key, errstr) < 0 ||
-		scram_H(stored_key, SCRAM_KEY_LEN, stored_key, errstr) < 0 ||
-		scram_ServerKey(salted_password, server_key, errstr) < 0)
+		scram_ClientKey(salted_password, hash_type, key_length,
+						stored_key, errstr) < 0 ||
+		scram_H(stored_key, hash_type, key_length,
+				stored_key, errstr) < 0 ||
+		scram_ServerKey(salted_password, hash_type, key_length,
+						server_key, errstr) < 0)
 	{
 		/* errstr is filled already here */
 #ifdef FRONTEND
-		return NULL;
+		goto error;
 #else
 		elog(ERROR, "could not calculate stored key and server key: %s",
 			 *errstr);
@@ -231,8 +293,8 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	 *----------
 	 */
 	encoded_salt_len = pg_b64_enc_len(saltlen);
-	encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
-	encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	encoded_stored_len = pg_b64_enc_len(key_length);
+	encoded_server_len = pg_b64_enc_len(key_length);
 
 	maxlen = strlen("SCRAM-SHA-256") + 1
 		+ 10 + 1				/* iteration count */
@@ -245,7 +307,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	if (!result)
 	{
 		*errstr = _("out of memory");
-		return NULL;
+		goto error;
 	}
 #else
 	result = palloc(maxlen);
@@ -258,45 +320,30 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	if (encoded_result < 0)
 	{
 		*errstr = _("could not encode salt");
-#ifdef FRONTEND
-		free(result);
-		return NULL;
-#else
-		elog(ERROR, "%s", *errstr);
-#endif
+		goto error;
 	}
 	p += encoded_result;
 	*(p++) = '$';
 
 	/* stored key */
-	encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
+	encoded_result = pg_b64_encode((char *) stored_key, key_length, p,
 								   encoded_stored_len);
 	if (encoded_result < 0)
 	{
 		*errstr = _("could not encode stored key");
-#ifdef FRONTEND
-		free(result);
-		return NULL;
-#else
-		elog(ERROR, "%s", *errstr);
-#endif
+		goto error;
 	}
 
 	p += encoded_result;
 	*(p++) = ':';
 
 	/* server key */
-	encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
+	encoded_result = pg_b64_encode((char *) server_key, key_length, p,
 								   encoded_server_len);
 	if (encoded_result < 0)
 	{
 		*errstr = _("could not encode server key");
-#ifdef FRONTEND
-		free(result);
-		return NULL;
-#else
-		elog(ERROR, "%s", *errstr);
-#endif
+		goto error;
 	}
 
 	p += encoded_result;
@@ -304,5 +351,25 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 
 	Assert(p - result <= maxlen);
 
+#ifdef FRONTEND
+	free(salted_password);
+	free(stored_key);
+	free(server_key);
+#endif
 	return result;
+
+error:
+#ifdef FRONTEND
+	if (salted_password)
+		free(salted_password);
+	if (stored_key)
+		free(stored_key);
+	if (server_key)
+		free(server_key);
+	if (result)
+		free(result);
+#else
+	elog(ERROR, "%s", *errstr);
+#endif
+	return NULL;
 }
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index c500bea9e7..3b20062484 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -58,8 +58,12 @@ typedef struct
 	char	   *password;
 	char	   *sasl_mechanism;
 
+	/* State data depending on the hash type */
+	pg_cryptohash_type	hash_type;
+	int			key_length;
+
 	/* We construct these */
-	uint8		SaltedPassword[SCRAM_KEY_LEN];
+	uint8	   *salted_password;	/* size of key_length */
 	char	   *client_nonce;
 	char	   *client_first_message_bare;
 	char	   *client_final_message_without_proof;
@@ -73,7 +77,7 @@ typedef struct
 
 	/* These come from the server-final message */
 	char	   *server_final_message;
-	char		ServerSignature[SCRAM_KEY_LEN];
+	char	   *server_signature;	/* size of key_length */
 } fe_scram_state;
 
 static bool read_server_first_message(fe_scram_state *state, char *input);
@@ -106,35 +110,47 @@ scram_init(PGconn *conn,
 	memset(state, 0, sizeof(fe_scram_state));
 	state->conn = conn;
 	state->state = FE_SCRAM_INIT;
-	state->sasl_mechanism = strdup(sasl_mechanism);
+	state->key_length = SCRAM_SHA_256_KEY_LEN;
+	state->hash_type = PG_SHA256;
 
+	state->sasl_mechanism = strdup(sasl_mechanism);
 	if (!state->sasl_mechanism)
-	{
-		free(state);
-		return NULL;
-	}
+		goto oom_error;
+
+	state->salted_password = (uint8 *) malloc(state->key_length * sizeof(uint8));
+	if (state->salted_password == NULL)
+		goto oom_error;
+	state->server_signature = (char *) malloc(state->key_length * sizeof(char));
+	if (state->server_signature == NULL)
+		goto oom_error;
 
 	/* Normalize the password with SASLprep, if possible */
 	rc = pg_saslprep(password, &prep_password);
 	if (rc == SASLPREP_OOM)
-	{
-		free(state->sasl_mechanism);
-		free(state);
-		return NULL;
-	}
+		goto oom_error;
+
 	if (rc != SASLPREP_SUCCESS)
 	{
 		prep_password = strdup(password);
 		if (!prep_password)
-		{
-			free(state->sasl_mechanism);
-			free(state);
-			return NULL;
-		}
+			goto oom_error;
 	}
 	state->password = prep_password;
 
 	return state;
+
+oom_error:
+	if (state->salted_password)
+		free(state->salted_password);
+	if (state->server_signature)
+		free(state->server_signature);
+	if (state->password)
+		free(state->password);
+	if (state->sasl_mechanism)
+		free(state->sasl_mechanism);
+	if (state)
+		free(state);
+	return NULL;
 }
 
 /*
@@ -178,6 +194,7 @@ scram_free(void *opaq)
 	free(state->sasl_mechanism);
 
 	/* client messages */
+	free(state->salted_password);
 	free(state->client_nonce);
 	free(state->client_first_message_bare);
 	free(state->client_final_message_without_proof);
@@ -189,6 +206,7 @@ scram_free(void *opaq)
 
 	/* final message from server */
 	free(state->server_final_message);
+	free(state->server_signature);
 
 	free(state);
 }
@@ -450,13 +468,17 @@ build_client_final_message(fe_scram_state *state)
 {
 	PQExpBufferData buf;
 	PGconn	   *conn = state->conn;
-	uint8		client_proof[SCRAM_KEY_LEN];
+	uint8	   *client_proof;	/* size of key_length */
 	char	   *result;
 	int			encoded_len;
 	const char *errstr = NULL;
 
 	initPQExpBuffer(&buf);
 
+	client_proof = (uint8 *) malloc(state->key_length * sizeof(uint8));
+	if (client_proof == NULL)
+		goto oom_error;
+
 	/*
 	 * Construct client-final-message-without-proof.  We need to remember it
 	 * for verifying the server proof in the final step of authentication.
@@ -565,11 +587,11 @@ build_client_final_message(fe_scram_state *state)
 	}
 
 	appendPQExpBufferStr(&buf, ",p=");
-	encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	encoded_len = pg_b64_enc_len(state->key_length);
 	if (!enlargePQExpBuffer(&buf, encoded_len))
 		goto oom_error;
 	encoded_len = pg_b64_encode((char *) client_proof,
-								SCRAM_KEY_LEN,
+								state->key_length,
 								buf.data + buf.len,
 								encoded_len);
 	if (encoded_len < 0)
@@ -590,6 +612,7 @@ build_client_final_message(fe_scram_state *state)
 
 oom_error:
 	termPQExpBuffer(&buf);
+	free(client_proof);
 	libpq_append_conn_error(conn, "out of memory");
 	return NULL;
 }
@@ -738,13 +761,14 @@ read_server_final_message(fe_scram_state *state, char *input)
 										 strlen(encoded_server_signature),
 										 decoded_server_signature,
 										 server_signature_len);
-	if (server_signature_len != SCRAM_KEY_LEN)
+	if (server_signature_len != state->key_length)
 	{
 		free(decoded_server_signature);
 		libpq_append_conn_error(conn, "malformed SCRAM message (invalid server signature)");
 		return false;
 	}
-	memcpy(state->ServerSignature, decoded_server_signature, SCRAM_KEY_LEN);
+	memcpy(state->server_signature, decoded_server_signature,
+		   state->key_length);
 	free(decoded_server_signature);
 
 	return true;
@@ -760,35 +784,48 @@ calculate_client_proof(fe_scram_state *state,
 					   const char *client_final_message_without_proof,
 					   uint8 *result, const char **errstr)
 {
-	uint8		StoredKey[SCRAM_KEY_LEN];
-	uint8		ClientKey[SCRAM_KEY_LEN];
-	uint8		ClientSignature[SCRAM_KEY_LEN];
+	uint8	   *StoredKey = NULL;
+	uint8	   *ClientKey = NULL;
+	uint8	   *ClientSignature = NULL;
 	int			i;
-	pg_hmac_ctx *ctx;
+	pg_hmac_ctx *ctx = NULL;
 
-	ctx = pg_hmac_create(PG_SHA256);
+	StoredKey = malloc(state->key_length * sizeof(uint8));
+	ClientKey = malloc(state->key_length * sizeof(uint8));
+	ClientSignature = malloc(state->key_length * sizeof(uint8));
+	if (StoredKey == NULL ||
+		ClientKey == NULL ||
+		ClientSignature == NULL)
+	{
+		*errstr = libpq_gettext("out of memory");
+		goto error;
+	}
+
+	ctx = pg_hmac_create(state->hash_type);
 	if (ctx == NULL)
 	{
 		*errstr = pg_hmac_error(NULL);	/* returns OOM */
-		return false;
+		goto error;
 	}
 
 	/*
 	 * Calculate SaltedPassword, and store it in 'state' so that we can reuse
 	 * it later in verify_server_signature.
 	 */
-	if (scram_SaltedPassword(state->password, state->salt, state->saltlen,
-							 state->iterations, state->SaltedPassword,
+	if (scram_SaltedPassword(state->password, state->hash_type,
+							 state->key_length, state->salt, state->saltlen,
+							 state->iterations, state->salted_password,
 							 errstr) < 0 ||
-		scram_ClientKey(state->SaltedPassword, ClientKey, errstr) < 0 ||
-		scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey, errstr) < 0)
+		scram_ClientKey(state->salted_password, state->hash_type,
+						state->key_length, ClientKey, errstr) < 0 ||
+		scram_H(ClientKey, state->hash_type, state->key_length,
+				StoredKey, errstr) < 0)
 	{
 		/* errstr is already filled here */
-		pg_hmac_free(ctx);
-		return false;
+		goto error;
 	}
 
-	if (pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, StoredKey, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -800,18 +837,30 @@ calculate_client_proof(fe_scram_state *state,
 		pg_hmac_update(ctx,
 					   (uint8 *) client_final_message_without_proof,
 					   strlen(client_final_message_without_proof)) < 0 ||
-		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
+		pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
-		pg_hmac_free(ctx);
-		return false;
+		goto error;
 	}
 
-	for (i = 0; i < SCRAM_KEY_LEN; i++)
+	for (i = 0; i < state->key_length; i++)
 		result[i] = ClientKey[i] ^ ClientSignature[i];
 
+	free(StoredKey);
+	free(ClientKey);
+	free(ClientSignature);
 	pg_hmac_free(ctx);
 	return true;
+
+error:
+	if (StoredKey)
+		free(StoredKey);
+	if (ClientKey)
+		free(ClientKey);
+	if (ClientSignature)
+		free(ClientSignature);
+	pg_hmac_free(ctx);
+	return false;
 }
 
 /*
@@ -825,26 +874,35 @@ static bool
 verify_server_signature(fe_scram_state *state, bool *match,
 						const char **errstr)
 {
-	uint8		expected_ServerSignature[SCRAM_KEY_LEN];
-	uint8		ServerKey[SCRAM_KEY_LEN];
-	pg_hmac_ctx *ctx;
+	uint8	   *expected_ServerSignature = NULL;
+	uint8	   *ServerKey = NULL;
+	pg_hmac_ctx *ctx = NULL;
 
-	ctx = pg_hmac_create(PG_SHA256);
+	ServerKey = (uint8 *) malloc(state->key_length * sizeof(uint8));
+	expected_ServerSignature = (uint8 *) malloc(state->key_length * sizeof(uint8));
+
+	if (ServerKey == NULL || expected_ServerSignature == NULL)
+	{
+		*errstr = libpq_gettext("out of memory");
+		goto error;
+	}
+
+	ctx = pg_hmac_create(state->hash_type);
 	if (ctx == NULL)
 	{
 		*errstr = pg_hmac_error(NULL);	/* returns OOM */
-		return false;
+		goto error;
 	}
 
-	if (scram_ServerKey(state->SaltedPassword, ServerKey, errstr) < 0)
+	if (scram_ServerKey(state->salted_password, state->hash_type,
+						state->key_length, ServerKey, errstr) < 0)
 	{
 		/* errstr is filled already */
-		pg_hmac_free(ctx);
-		return false;
+		goto error;
 	}
 
 	/* calculate ServerSignature */
-	if (pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, ServerKey, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -857,22 +915,32 @@ verify_server_signature(fe_scram_state *state, bool *match,
 					   (uint8 *) state->client_final_message_without_proof,
 					   strlen(state->client_final_message_without_proof)) < 0 ||
 		pg_hmac_final(ctx, expected_ServerSignature,
-					  sizeof(expected_ServerSignature)) < 0)
+					  state->key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
-		pg_hmac_free(ctx);
-		return false;
+		goto error;
 	}
 
 	pg_hmac_free(ctx);
 
 	/* signature processed, so now check after it */
-	if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0)
+	if (memcmp(expected_ServerSignature, state->server_signature,
+			   state->key_length) != 0)
 		*match = false;
 	else
 		*match = true;
 
+	free(ServerKey);
+	free(expected_ServerSignature);
 	return true;
+
+error:
+	if (ServerKey)
+		free(ServerKey);
+	if (expected_ServerSignature)
+		free(expected_ServerSignature);
+	pg_hmac_free(ctx);
+	return false;
 }
 
 /*
@@ -912,7 +980,8 @@ pg_fe_scram_build_secret(const char *password, const char **errstr)
 		return NULL;
 	}
 
-	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
+	result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, saltbuf,
+								SCRAM_DEFAULT_SALT_LEN,
 								SCRAM_DEFAULT_ITERATIONS, password,
 								errstr);
 
-- 
2.38.1

