From 9c3e0c5d3457322b9fa4aa618ffdeb88978e9109 Mon Sep 17 00:00:00 2001
From: Greg Stark <stark@mit.edu>
Date: Mon, 20 Mar 2023 14:09:30 -0400
Subject: [PATCH v5 4/6] alpn support

---
 src/backend/libpq/be-secure-openssl.c    | 66 ++++++++++++++++++++++++
 src/backend/libpq/be-secure.c            |  3 ++
 src/backend/postmaster/postmaster.c      | 25 +++++++++
 src/backend/utils/misc/guc_tables.c      |  9 ++++
 src/bin/psql/command.c                   |  7 ++-
 src/include/libpq/libpq-be.h             |  1 +
 src/include/libpq/libpq.h                |  1 +
 src/include/libpq/pqcomm.h               | 19 +++++++
 src/interfaces/libpq/fe-connect.c        |  5 ++
 src/interfaces/libpq/fe-secure-openssl.c | 31 +++++++++++
 src/interfaces/libpq/libpq-int.h         |  1 +
 11 files changed, 166 insertions(+), 2 deletions(-)

diff --git a/src/backend/libpq/be-secure-openssl.c b/src/backend/libpq/be-secure-openssl.c
index 685aa2ed69..620ffafb0b 100644
--- a/src/backend/libpq/be-secure-openssl.c
+++ b/src/backend/libpq/be-secure-openssl.c
@@ -67,6 +67,12 @@ static int	ssl_external_passwd_cb(char *buf, int size, int rwflag, void *userdat
 static int	dummy_ssl_passwd_cb(char *buf, int size, int rwflag, void *userdata);
 static int	verify_cb(int ok, X509_STORE_CTX *ctx);
 static void info_cb(const SSL *ssl, int type, int args);
+static int  alpn_cb(SSL *ssl,
+					const unsigned char **out,
+					unsigned char *outlen,
+					const unsigned char *in,
+					unsigned int inlen,
+					void *userdata);
 static bool initialize_dh(SSL_CTX *context, bool isServerStart);
 static bool initialize_ecdh(SSL_CTX *context, bool isServerStart);
 static const char *SSLerrmessage(unsigned long ecode);
@@ -432,6 +438,13 @@ be_tls_open_server(Port *port)
 	/* set up debugging/info callback */
 	SSL_CTX_set_info_callback(SSL_context, info_cb);
 
+	if (ssl_enable_alpn) {
+		elog(DEBUG2, "Enabling OpenSSL ALPN callback");
+		SSL_CTX_set_alpn_select_cb(SSL_context, alpn_cb, port);
+	} else {
+		elog(DEBUG2, "OpenSSL ALPN is disabled, not setting callback");
+	}
+
 	if (!(port->ssl = SSL_new(SSL_context)))
 	{
 		ereport(COMMERROR,
@@ -692,6 +705,12 @@ be_tls_close(Port *port)
 		pfree(port->peer_dn);
 		port->peer_dn = NULL;
 	}
+
+	if (port->ssl_alpn_protocol)
+	{
+		pfree(port->ssl_alpn_protocol);
+		port->ssl_alpn_protocol = NULL;
+	}
 }
 
 ssize_t
@@ -1250,6 +1269,53 @@ info_cb(const SSL *ssl, int type, int args)
 	}
 }
 
+/* See pqcomm.h comments on OpenSSL implementation of ALPN (RFC 7301) */
+static unsigned char alpn_protos[] = PG_ALPN_PROTOCOL_VECTOR;
+
+/*
+ * Server callback for ALPN negotiation. We use use the standard "helper"
+ * function even though currently we only accept one value. We store the
+ * negotiated protocol in Port->ssl_alpn_protocol and rely on higher level
+ * logic (in postmaster.c) to decide what to do with that info.
+ */
+static int  alpn_cb(SSL *ssl,
+					const unsigned char **out,
+					unsigned char *outlen,
+					const unsigned char *in,
+					unsigned int inlen,
+					void *userdata)
+{
+	/* Why does OpenSSL provide a helper function that requires a nonconst
+	 * vector when the callback is declared to take a const vector? What are
+	 * we to do with that?
+	 */
+	int retval;
+	Assert(userdata != NULL);
+	Assert(out != NULL);
+	Assert(outlen != NULL);
+	Assert(in != NULL);
+
+	retval = SSL_select_next_proto((unsigned char **)out, outlen,
+								   alpn_protos, sizeof(alpn_protos),
+								   in, inlen);
+	if (*out == NULL || *outlen > sizeof(alpn_protos) || outlen <= 0)
+		return SSL_TLSEXT_ERR_NOACK; /* can't happen */
+
+	if (retval == OPENSSL_NPN_NEGOTIATED)
+	{
+		struct Port *port = (struct Port *)userdata;
+		char *alpn_protocol = MemoryContextAllocZero(TopMemoryContext, *outlen+1);
+		memcpy(alpn_protocol, *out, *outlen);
+		port->ssl_alpn_protocol = alpn_protocol;
+		return SSL_TLSEXT_ERR_OK;
+	} else if (retval == OPENSSL_NPN_NO_OVERLAP) {
+		return SSL_TLSEXT_ERR_NOACK;
+	} else {
+		return SSL_TLSEXT_ERR_NOACK;
+	}
+}
+
+
 /*
  * Set DH parameters for generating ephemeral DH keys.  The
  * DH parameters can take a long time to compute, so they must be
diff --git a/src/backend/libpq/be-secure.c b/src/backend/libpq/be-secure.c
index 1020b3adb0..79a61900ba 100644
--- a/src/backend/libpq/be-secure.c
+++ b/src/backend/libpq/be-secure.c
@@ -61,6 +61,9 @@ bool		SSLPreferServerCiphers;
 int			ssl_min_protocol_version = PG_TLS1_2_VERSION;
 int			ssl_max_protocol_version = PG_TLS_ANY;
 
+/* GUC variable: if false ignore ALPN negotiation */
+bool		ssl_enable_alpn = true;
+
 /* ------------------------------------------------------------ */
 /*			 Procedures common to all secure sessions			*/
 /* ------------------------------------------------------------ */
diff --git a/src/backend/postmaster/postmaster.c b/src/backend/postmaster/postmaster.c
index ec1d895a23..9b4b37b997 100644
--- a/src/backend/postmaster/postmaster.c
+++ b/src/backend/postmaster/postmaster.c
@@ -1934,6 +1934,10 @@ ServerLoop(void)
  * any bytes from the stream if it's not a direct SSL connection.
  */
 
+#ifdef USE_SSL
+static const char *expected_alpn_protocol = PG_ALPN_PROTOCOL;
+#endif
+
 static int
 ProcessSSLStartup(Port *port)
 {
@@ -1970,6 +1974,10 @@ ProcessSSLStartup(Port *port)
 		char *buf = NULL;
 		elog(LOG, "Detected direct SSL handshake");
 
+		if (!ssl_enable_alpn) {
+			elog(WARNING, "Received direct SSL connection without ssl_enable_alpn enabled");
+		}
+
 		/* push unencrypted buffered data back through SSL setup */
 		len = pq_buffer_has_data();
 		if (len > 0)
@@ -2000,6 +2008,23 @@ ProcessSSLStartup(Port *port)
 			return STATUS_ERROR;
 		}
 		pfree(port->raw_buf);
+
+		if (port->ssl_alpn_protocol == NULL)
+		{
+			ereport(COMMERROR,
+					(errcode(ERRCODE_PROTOCOL_VIOLATION),
+					 errmsg("Received direct SSL connection request without required ALPN protocol negotiation extension")));
+			return STATUS_ERROR;
+		}
+		if (strcmp(port->ssl_alpn_protocol, expected_alpn_protocol) != 0)
+		{
+			ereport(COMMERROR,
+					(errcode(ERRCODE_PROTOCOL_VIOLATION),
+					 errmsg("Received direct SSL connection request with unexpected ALPN protocol \"%s\" expected \"%s\"",
+							port->ssl_alpn_protocol,
+							expected_alpn_protocol)));
+			return STATUS_ERROR;
+		}
 #else
 		ereport(COMMERROR,
 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
diff --git a/src/backend/utils/misc/guc_tables.c b/src/backend/utils/misc/guc_tables.c
index 8062589efd..79a1d0a100 100644
--- a/src/backend/utils/misc/guc_tables.c
+++ b/src/backend/utils/misc/guc_tables.c
@@ -1088,6 +1088,15 @@ struct config_bool ConfigureNamesBool[] =
 		true,
 		NULL, NULL, NULL
 	},
+	{
+		{"ssl_enable_alpn", PGC_SIGHUP, CONN_AUTH_SSL,
+			gettext_noop("Respond to TLS ALPN Extension Requests."),
+			NULL,
+		},
+		&ssl_enable_alpn,
+		true,
+		NULL, NULL, NULL
+	},
 	{
 		{"fsync", PGC_SIGHUP, WAL_SETTINGS,
 			gettext_noop("Forces synchronization of updates to disk."),
diff --git a/src/bin/psql/command.c b/src/bin/psql/command.c
index d7731234b6..816b336973 100644
--- a/src/bin/psql/command.c
+++ b/src/bin/psql/command.c
@@ -3715,6 +3715,7 @@ printSSLInfo(void)
 	const char *protocol;
 	const char *cipher;
 	const char *compression;
+	const char *alpn;
 
 	if (!PQsslInUse(pset.db))
 		return;					/* no SSL */
@@ -3722,11 +3723,13 @@ printSSLInfo(void)
 	protocol = PQsslAttribute(pset.db, "protocol");
 	cipher = PQsslAttribute(pset.db, "cipher");
 	compression = PQsslAttribute(pset.db, "compression");
+	alpn = PQsslAttribute(pset.db, "alpn");
 
-	printf(_("SSL connection (protocol: %s, cipher: %s, compression: %s)\n"),
+	printf(_("SSL connection (protocol: %s, cipher: %s, compression: %s, ALPN: %s)\n"),
 		   protocol ? protocol : _("unknown"),
 		   cipher ? cipher : _("unknown"),
-		   (compression && strcmp(compression, "off") != 0) ? _("on") : _("off"));
+		   (compression && strcmp(compression, "off") != 0) ? _("on") : _("off"),
+		   alpn ? alpn : _("none"));
 }
 
 /*
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 824b28e824..2258241770 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -217,6 +217,7 @@ typedef struct Port
 	char	   *peer_cn;
 	char	   *peer_dn;
 	bool		peer_cert_valid;
+	char       *ssl_alpn_protocol;
 
 	/*
 	 * OpenSSL structures. (Keep these last so that the locations of other
diff --git a/src/include/libpq/libpq.h b/src/include/libpq/libpq.h
index 2b02f67257..e7adf4a989 100644
--- a/src/include/libpq/libpq.h
+++ b/src/include/libpq/libpq.h
@@ -123,6 +123,7 @@ extern PGDLLIMPORT char *SSLECDHCurve;
 extern PGDLLIMPORT bool SSLPreferServerCiphers;
 extern PGDLLIMPORT int ssl_min_protocol_version;
 extern PGDLLIMPORT int ssl_max_protocol_version;
+extern PGDLLIMPORT bool ssl_enable_alpn;
 
 enum ssl_protocol_versions
 {
diff --git a/src/include/libpq/pqcomm.h b/src/include/libpq/pqcomm.h
index c85090259d..d9fc02adfc 100644
--- a/src/include/libpq/pqcomm.h
+++ b/src/include/libpq/pqcomm.h
@@ -152,6 +152,25 @@ typedef struct CancelRequestPacket
 	uint32		cancelAuthCode; /* secret key to authorize cancel */
 } CancelRequestPacket;
 
+/* Application-Layer Protocol Negotiation is required for direct connections
+ * to avoid protocol confusion attacks (e.g https://alpaca-attack.com/).
+ *
+ * ALPN is specified in RFC 7301
+ *
+ * This string should be registered at:
+ * https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids
+ *   
+ * OpenSSL uses this wire-format for the list of alpn protocols even in the
+ * API. Both server and client take the same format parameter but the client
+ * actually sends it to the server as-is and the server it specifies the
+ * preference order to use to choose the one selected to send back.
+ *
+ * c.f. https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_alpn_select_cb.html
+ *
+ * The #define can be used to initialize a char[] vector to use directly in the API
+ */
+#define PG_ALPN_PROTOCOL "TBD-pgsql"
+#define PG_ALPN_PROTOCOL_VECTOR { 9, 'T','B','D','-','p','g','s','q','l' }
 
 /*
  * A client can also start by sending a SSL or GSSAPI negotiation request to
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 7cd0eb261f..3118c19edd 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -314,6 +314,10 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
 		"SSL-SNI", "", 1,
 	offsetof(struct pg_conn, sslsni)},
 
+	{"sslalpn", "PGSSLALPN", "1", NULL,
+		"SSL-ALPN", "", 1,
+	offsetof(struct pg_conn, sslalpn)},
+
 	{"requirepeer", "PGREQUIREPEER", NULL, NULL,
 		"Require-Peer", "", 10,
 	offsetof(struct pg_conn, requirepeer)},
@@ -4487,6 +4491,7 @@ freePGconn(PGconn *conn)
 	free(conn->sslcrldir);
 	free(conn->sslcompression);
 	free(conn->sslsni);
+	free(conn->sslalpn);
 	free(conn->requirepeer);
 	free(conn->require_auth);
 	free(conn->ssl_min_protocol_version);
diff --git a/src/interfaces/libpq/fe-secure-openssl.c b/src/interfaces/libpq/fe-secure-openssl.c
index 4d1e4009ef..02af7bf571 100644
--- a/src/interfaces/libpq/fe-secure-openssl.c
+++ b/src/interfaces/libpq/fe-secure-openssl.c
@@ -912,6 +912,9 @@ destroy_ssl_system(void)
 #endif
 }
 
+/* See pqcomm.h comments on OpenSSL implementation of ALPN (RFC 7301) */
+static unsigned char alpn_protos[] = PG_ALPN_PROTOCOL_VECTOR;
+
 /*
  *	Create per-connection SSL object, and load the client certificate,
  *	private key, and trusted CA certs.
@@ -1240,6 +1243,20 @@ initialize_SSL(PGconn *conn)
 		}
 	}
 
+	if (conn->sslalpn && conn->sslalpn[0] == '1')
+	{
+		int retval;
+		retval = SSL_set_alpn_protos(conn->ssl, alpn_protos, sizeof(alpn_protos));
+
+		if (retval != 0)
+		{
+			char	   *err = SSLerrmessage(ERR_get_error());
+			libpq_append_conn_error(conn, "could not set ssl alpn extension: %s", err);
+			SSLerrfree(err);
+			return -1;
+		}
+	}
+
 	/*
 	 * Read the SSL key. If a key is specified, treat it as an engine:key
 	 * combination if there is colon present - we don't support files with
@@ -1723,6 +1740,7 @@ PQsslAttributeNames(PGconn *conn)
 		"cipher",
 		"compression",
 		"protocol",
+		"alpn",
 		NULL
 	};
 	static const char *const empty_attrs[] = {NULL};
@@ -1777,6 +1795,19 @@ PQsslAttribute(PGconn *conn, const char *attribute_name)
 	if (strcmp(attribute_name, "protocol") == 0)
 		return SSL_get_version(conn->ssl);
 
+	if (strcmp(attribute_name, "alpn") == 0)
+	{
+		const unsigned char *data;
+		unsigned int len;
+		static char alpn_str[256]; /* alpn doesn't support longer than 255 bytes */
+		SSL_get0_alpn_selected(conn->ssl, &data, &len);
+		if (data == NULL || len==0 || len > sizeof(alpn_str)-1)
+			return NULL;
+		memcpy(alpn_str, data, len);
+		alpn_str[len] = 0;
+		return alpn_str;
+	}
+
 	return NULL;				/* unknown attribute */
 }
 
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index db63afd786..17e5b2528f 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -400,6 +400,7 @@ struct pg_conn
 	char	   *sslcrl;			/* certificate revocation list filename */
 	char	   *sslcrldir;		/* certificate revocation list directory name */
 	char	   *sslsni;			/* use SSL SNI extension (0 or 1) */
+	char       *sslalpn;        /* use SSL ALPN extension (0 or 1) */
 	char	   *requirepeer;	/* required peer credentials for local sockets */
 	char	   *gssencmode;		/* GSS mode (require,prefer,disable) */
 	char	   *krbsrvname;		/* Kerberos service name */
-- 
2.40.0

