From 4508f872720a0977cf00041a865d76a4d5f77028 Mon Sep 17 00:00:00 2001
From: Greg Stark <stark@mit.edu>
Date: Wed, 18 Jan 2023 15:34:34 -0500
Subject: [PATCH v1] Direct SSL Connections

---
 src/backend/libpq/be-secure.c       |  13 +++
 src/backend/libpq/pqcomm.c          |   9 +-
 src/backend/postmaster/postmaster.c | 140 ++++++++++++++++++++++------
 src/include/libpq/libpq-be.h        |   3 +
 src/include/libpq/libpq.h           |   2 +-
 5 files changed, 134 insertions(+), 33 deletions(-)

diff --git a/src/backend/libpq/be-secure.c b/src/backend/libpq/be-secure.c
index a0f7084018..39366d04dd 100644
--- a/src/backend/libpq/be-secure.c
+++ b/src/backend/libpq/be-secure.c
@@ -235,6 +235,19 @@ secure_raw_read(Port *port, void *ptr, size_t len)
 {
 	ssize_t		n;
 
+	/* XXX feed the raw_buf into SSL */
+	if (port->raw_buf_remaining > 0)
+	{
+		/* consume up to len bytes from the raw_buf */
+		if (len > port->raw_buf_remaining)
+			len = port->raw_buf_remaining;
+		Assert(port->raw_buf);
+		memcpy(ptr, port->raw_buf + port->raw_buf_consumed, len);
+		port->raw_buf_consumed += len;
+		port->raw_buf_remaining -= len;
+		return len;
+	}
+
 	/*
 	 * Try to read from the socket without blocking. If it succeeds we're
 	 * done, otherwise we'll wait for the socket using the latch mechanism.
diff --git a/src/backend/libpq/pqcomm.c b/src/backend/libpq/pqcomm.c
index 864c9debe8..60fab6a52b 100644
--- a/src/backend/libpq/pqcomm.c
+++ b/src/backend/libpq/pqcomm.c
@@ -1119,13 +1119,16 @@ pq_discardbytes(size_t len)
 /* --------------------------------
  *		pq_buffer_has_data		- is any buffered data available to read?
  *
- * This will *not* attempt to read more data.
+ * Actually returns the number of bytes in the buffer...
+ *
+ * This will *not* attempt to read more data. And reading up to that number of
+ * bytes should not cause reading any more data either.
  * --------------------------------
  */
-bool
+size_t
 pq_buffer_has_data(void)
 {
-	return (PqRecvPointer < PqRecvLength);
+	return (PqRecvLength - PqRecvPointer);
 }
 
 
diff --git a/src/backend/postmaster/postmaster.c b/src/backend/postmaster/postmaster.c
index 9cedc1b9f0..b1631e0830 100644
--- a/src/backend/postmaster/postmaster.c
+++ b/src/backend/postmaster/postmaster.c
@@ -412,6 +412,7 @@ static void BackendRun(Port *port) pg_attribute_noreturn();
 static void ExitPostmaster(int status) pg_attribute_noreturn();
 static int	ServerLoop(void);
 static int	BackendStartup(Port *port);
+static int	ProcessSSLStartup(Port *port);
 static int	ProcessStartupPacket(Port *port, bool ssl_done, bool gss_done);
 static void SendNegotiateProtocolVersion(List *unrecognized_protocol_options);
 static void processCancelRequest(Port *port, void *pkt);
@@ -1909,6 +1910,104 @@ ServerLoop(void)
 	}
 }
 
+/*
+ * Check for a native direct SSL connection.
+ *
+ * This happens before startup packets so we are careful not to actual read
+ * any bytes from the stream if it's not a direct SSL connection.
+ */
+
+static int
+ProcessSSLStartup(Port *port)
+{
+	int		firstbyte;
+
+	pq_startmsgread();
+
+	firstbyte = pq_peekbyte();
+	if (firstbyte == EOF)
+	{
+		/*
+		 * If we get no data at all, don't clutter the log with a complaint;
+		 * such cases often occur for legitimate reasons.  An example is that
+		 * we might be here after responding to NEGOTIATE_SSL_CODE, and if the
+		 * client didn't like our response, it'll probably just drop the
+		 * connection.  Service-monitoring software also often just opens and
+		 * closes a connection without sending anything.  (So do port
+		 * scanners, which may be less benign, but it's not really our job to
+		 * notice those.)
+		 */
+		return STATUS_ERROR;
+	}
+
+	/*
+	 * First byte indicates standard SSL handshake message
+	 *
+	 * (It can't be a Postgres startup length because in network byte order
+	 * that would be a startup packet hundreds of megabytes long)
+	 */
+	if (firstbyte == 0x16)
+	{
+#ifdef USE_SSL
+		ssize_t len;
+		char *buf = NULL;
+		elog(LOG, "Detected direct SSL handshake");
+
+		/* push unencrypted buffered data back through SSL setup */
+		len = pq_buffer_has_data();
+		if (len > 0)
+		{
+			buf = palloc(len);
+			if (pq_getbytes(buf, len) == EOF)
+				return STATUS_ERROR; /* shouldn't be possible */
+			port->raw_buf = buf;
+			port->raw_buf_remaining = len;
+			port->raw_buf_consumed = 0;
+		}
+
+		Assert(pq_buffer_has_data() == 0);
+		if (secure_open_server(port) == -1)
+		{
+			ereport(COMMERROR,
+					(errcode(ERRCODE_PROTOCOL_VIOLATION),
+					 errmsg("SSL Protocol Error during direct SSL connection initiation")));
+			return STATUS_ERROR;
+		}
+
+		if (port->raw_buf_remaining > 0)
+		{
+			/* This shouldn't be possible -- it would mean the client sent
+			 * encrypted data before we established a session key...
+			 */
+			elog(LOG, "Buffered unencrypted data remains after negotiating native SSL connection");
+			return STATUS_ERROR;
+		}
+		pfree(port->raw_buf);
+		ereport(DEBUG2,
+				(errmsg_internal("Direct native SSL connection set up")));
+
+#else
+		ereport(COMMERROR,
+				(errcode(ERRCODE_PROTOCOL_VIOLATION),
+				 errmsg("Received direct SSL connection request with no SSL support")));
+		return STATUS_ERROR;
+#endif
+	}
+
+	pq_endmsgread();
+
+	if (port->ssl_in_use)
+		ereport(DEBUG2,
+				(errmsg_internal("Direct native SSL connection set up")));
+	else
+		ereport(DEBUG2,
+				(errmsg_internal("Direct native SSL connection NOT set up")));
+		
+
+	return STATUS_OK;
+}
+
+
 /*
  * Read a client's startup packet and do something according to it.
  *
@@ -1937,28 +2036,7 @@ ProcessStartupPacket(Port *port, bool ssl_done, bool gss_done)
 
 	pq_startmsgread();
 
-	/*
-	 * Grab the first byte of the length word separately, so that we can tell
-	 * whether we have no data at all or an incomplete packet.  (This might
-	 * sound inefficient, but it's not really, because of buffering in
-	 * pqcomm.c.)
-	 */
-	if (pq_getbytes((char *) &len, 1) == EOF)
-	{
-		/*
-		 * If we get no data at all, don't clutter the log with a complaint;
-		 * such cases often occur for legitimate reasons.  An example is that
-		 * we might be here after responding to NEGOTIATE_SSL_CODE, and if the
-		 * client didn't like our response, it'll probably just drop the
-		 * connection.  Service-monitoring software also often just opens and
-		 * closes a connection without sending anything.  (So do port
-		 * scanners, which may be less benign, but it's not really our job to
-		 * notice those.)
-		 */
-		return STATUS_ERROR;
-	}
-
-	if (pq_getbytes(((char *) &len) + 1, 3) == EOF)
+	if (pq_getbytes(((char *) &len), 4) == EOF)
 	{
 		/* Got a partial length word, so bleat about that */
 		if (!ssl_done && !gss_done)
@@ -2015,8 +2093,11 @@ ProcessStartupPacket(Port *port, bool ssl_done, bool gss_done)
 		char		SSLok;
 
 #ifdef USE_SSL
-		/* No SSL when disabled or on Unix sockets */
-		if (!LoadedSSL || port->laddr.addr.ss_family == AF_UNIX)
+		/* No SSL when disabled or on Unix sockets.
+		 * 
+		 * Also no SSL negotiation if we already have a direct SSL
+		 */
+		if (!LoadedSSL || port->laddr.addr.ss_family == AF_UNIX || port->ssl_in_use)
 			SSLok = 'N';
 		else
 			SSLok = 'S';		/* Support for SSL */
@@ -2024,11 +2105,10 @@ ProcessStartupPacket(Port *port, bool ssl_done, bool gss_done)
 		SSLok = 'N';			/* No support for SSL */
 #endif
 
-retry1:
-		if (send(port->sock, &SSLok, 1, 0) != 1)
+		while (secure_write(port, &SSLok, 1) != 1)
 		{
 			if (errno == EINTR)
-				goto retry1;	/* if interrupted, just retry */
+				continue;	/* if interrupted, just retry */
 			ereport(COMMERROR,
 					(errcode_for_socket_access(),
 					 errmsg("failed to send SSL negotiation response: %m")));
@@ -2069,7 +2149,7 @@ retry1:
 			GSSok = 'G';
 #endif
 
-		while (send(port->sock, &GSSok, 1, 0) != 1)
+		while (secure_write(port, &GSSok, 1) != 1)
 		{
 			if (errno == EINTR)
 				continue;
@@ -4372,7 +4452,9 @@ BackendInitialize(Port *port)
 	 * Receive the startup packet (which might turn out to be a cancel request
 	 * packet).
 	 */
-	status = ProcessStartupPacket(port, false, false);
+	status = ProcessSSLStartup(port);
+	if (status == STATUS_OK)
+		status = ProcessStartupPacket(port, false, false);
 
 	/*
 	 * Disable the timeout, and prevent SIGTERM again.
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 8c70b2fd5b..c2402ea8b8 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -226,6 +226,9 @@ typedef struct Port
 	SSL		   *ssl;
 	X509	   *peer;
 #endif
+	/* XXX */
+	char       *raw_buf;
+	ssize_t     raw_buf_consumed, raw_buf_remaining;
 } Port;
 
 #ifdef USE_SSL
diff --git a/src/include/libpq/libpq.h b/src/include/libpq/libpq.h
index 50fc781f47..2b02f67257 100644
--- a/src/include/libpq/libpq.h
+++ b/src/include/libpq/libpq.h
@@ -80,7 +80,7 @@ extern int	pq_getmessage(StringInfo s, int maxlen);
 extern int	pq_getbyte(void);
 extern int	pq_peekbyte(void);
 extern int	pq_getbyte_if_available(unsigned char *c);
-extern bool pq_buffer_has_data(void);
+extern size_t pq_buffer_has_data(void);
 extern int	pq_putmessage_v2(char msgtype, const char *s, size_t len);
 extern bool pq_check_connection(void);
 
-- 
2.39.0

