From ecd692fbefa352ebe5b9698153f47102a93b7e16 Mon Sep 17 00:00:00 2001
From: Ashutosh Bapat <ashutosh.bapat@2ndquadrant.com>
Date: Thu, 22 Oct 2020 11:52:24 +0530
Subject: [PATCH 2/2] Functions to send and receive LogicalRepMsgType

Add wrappers on top of pq_sendbyte() and pq_getmsgbyte() to send and
receive a logical replication message type respectively.

This also removes default case from apply_dispatch() so that we can
detect any LogicalRepMsgType not handled by that function.

Ashutosh Bapat
---
 src/backend/replication/logical/proto.c  | 37 +++++++++++++--------
 src/backend/replication/logical/worker.c | 41 ++++++++++++++++++++----
 2 files changed, 59 insertions(+), 19 deletions(-)

diff --git a/src/backend/replication/logical/proto.c b/src/backend/replication/logical/proto.c
index fdb31182d7..8c57ff03ec 100644
--- a/src/backend/replication/logical/proto.c
+++ b/src/backend/replication/logical/proto.c
@@ -38,13 +38,24 @@ static void logicalrep_read_tuple(StringInfo in, LogicalRepTupleData *tuple);
 static void logicalrep_write_namespace(StringInfo out, Oid nspid);
 static const char *logicalrep_read_namespace(StringInfo in);
 
+/*
+ * Wrapper around pq_sendbyte to send logical replication message type.
+ */
+static void
+pq_send_logicalrep_msg_type(StringInfo out, LogicalRepMsgType msgtype)
+{
+	/* A logical message type should fit a single byte */
+	Assert((char) msgtype == msgtype);
+	pq_sendbyte(out, (char) msgtype);
+}
+
 /*
  * Write BEGIN to the output stream.
  */
 void
 logicalrep_write_begin(StringInfo out, ReorderBufferTXN *txn)
 {
-	pq_sendbyte(out, LOGICAL_REP_MSG_BEGIN);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_BEGIN);
 
 	/* fixed fields */
 	pq_sendint64(out, txn->final_lsn);
@@ -76,7 +87,7 @@ logicalrep_write_commit(StringInfo out, ReorderBufferTXN *txn,
 {
 	uint8		flags = 0;
 
-	pq_sendbyte(out, LOGICAL_REP_MSG_COMMIT);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_COMMIT);
 
 	/* send the flags field (unused for now) */
 	pq_sendbyte(out, flags);
@@ -112,7 +123,7 @@ void
 logicalrep_write_origin(StringInfo out, const char *origin,
 						XLogRecPtr origin_lsn)
 {
-	pq_sendbyte(out, LOGICAL_REP_MSG_ORIGIN);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_ORIGIN);
 
 	/* fixed fields */
 	pq_sendint64(out, origin_lsn);
@@ -141,7 +152,7 @@ void
 logicalrep_write_insert(StringInfo out, TransactionId xid, Relation rel,
 						HeapTuple newtuple, bool binary)
 {
-	pq_sendbyte(out, LOGICAL_REP_MSG_INSERT);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_INSERT);
 
 	/* transaction ID (if not valid, we're not streaming) */
 	if (TransactionIdIsValid(xid))
@@ -185,7 +196,7 @@ void
 logicalrep_write_update(StringInfo out, TransactionId xid, Relation rel,
 						HeapTuple oldtuple, HeapTuple newtuple, bool binary)
 {
-	pq_sendbyte(out, LOGICAL_REP_MSG_UPDATE);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_UPDATE);
 
 	Assert(rel->rd_rel->relreplident == REPLICA_IDENTITY_DEFAULT ||
 		   rel->rd_rel->relreplident == REPLICA_IDENTITY_FULL ||
@@ -263,7 +274,7 @@ logicalrep_write_delete(StringInfo out, TransactionId xid, Relation rel,
 		   rel->rd_rel->relreplident == REPLICA_IDENTITY_FULL ||
 		   rel->rd_rel->relreplident == REPLICA_IDENTITY_INDEX);
 
-	pq_sendbyte(out, LOGICAL_REP_MSG_DELETE);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_DELETE);
 
 	/* transaction ID (if not valid, we're not streaming) */
 	if (TransactionIdIsValid(xid))
@@ -317,7 +328,7 @@ logicalrep_write_truncate(StringInfo out,
 	int			i;
 	uint8		flags = 0;
 
-	pq_sendbyte(out, LOGICAL_REP_MSG_TRUNCATE);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_TRUNCATE);
 
 	/* transaction ID (if not valid, we're not streaming) */
 	if (TransactionIdIsValid(xid))
@@ -369,7 +380,7 @@ logicalrep_write_rel(StringInfo out, TransactionId xid, Relation rel)
 {
 	char	   *relname;
 
-	pq_sendbyte(out, LOGICAL_REP_MSG_RELATION);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_RELATION);
 
 	/* transaction ID (if not valid, we're not streaming) */
 	if (TransactionIdIsValid(xid))
@@ -425,7 +436,7 @@ logicalrep_write_typ(StringInfo out, TransactionId xid, Oid typoid)
 	HeapTuple	tup;
 	Form_pg_type typtup;
 
-	pq_sendbyte(out, LOGICAL_REP_MSG_TYPE);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_TYPE);
 
 	/* transaction ID (if not valid, we're not streaming) */
 	if (TransactionIdIsValid(xid))
@@ -755,7 +766,7 @@ void
 logicalrep_write_stream_start(StringInfo out,
 							  TransactionId xid, bool first_segment)
 {
-	pq_sendbyte(out, LOGICAL_REP_MSG_STREAM_START);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_STREAM_START);
 
 	Assert(TransactionIdIsValid(xid));
 
@@ -788,7 +799,7 @@ logicalrep_read_stream_start(StringInfo in, bool *first_segment)
 void
 logicalrep_write_stream_stop(StringInfo out)
 {
-	pq_sendbyte(out, LOGICAL_REP_MSG_STREAM_END);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_STREAM_END);
 }
 
 /*
@@ -800,7 +811,7 @@ logicalrep_write_stream_commit(StringInfo out, ReorderBufferTXN *txn,
 {
 	uint8		flags = 0;
 
-	pq_sendbyte(out, LOGICAL_REP_MSG_STREAM_COMMIT);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_STREAM_COMMIT);
 
 	Assert(TransactionIdIsValid(txn->xid));
 
@@ -849,7 +860,7 @@ void
 logicalrep_write_stream_abort(StringInfo out, TransactionId xid,
 							  TransactionId subxid)
 {
-	pq_sendbyte(out, LOGICAL_REP_MSG_STREAM_ABORT);
+	pq_send_logicalrep_msg_type(out, LOGICAL_REP_MSG_STREAM_ABORT);
 
 	Assert(TransactionIdIsValid(xid) && TransactionIdIsValid(subxid));
 
diff --git a/src/backend/replication/logical/worker.c b/src/backend/replication/logical/worker.c
index ec21cc55e5..f516c056a1 100644
--- a/src/backend/replication/logical/worker.c
+++ b/src/backend/replication/logical/worker.c
@@ -1890,6 +1890,40 @@ apply_handle_truncate(StringInfo s)
 	CommandCounterIncrement();
 }
 
+/*
+ * Wrapper around pq_getmsgbyte() to extract logical replication message type
+ * from stream.
+ */
+static LogicalRepMsgType
+pq_get_logicalrep_msg_type(StringInfo s)
+{
+	LogicalRepMsgType msgtype = pq_getmsgbyte(s);
+
+	switch (msgtype)
+	{
+		case LOGICAL_REP_MSG_BEGIN:
+		case LOGICAL_REP_MSG_COMMIT:
+		case LOGICAL_REP_MSG_INSERT:
+		case LOGICAL_REP_MSG_UPDATE:
+		case LOGICAL_REP_MSG_DELETE:
+		case LOGICAL_REP_MSG_TRUNCATE:
+		case LOGICAL_REP_MSG_RELATION:
+		case LOGICAL_REP_MSG_TYPE:
+		case LOGICAL_REP_MSG_ORIGIN:
+		case LOGICAL_REP_MSG_STREAM_START:
+		case LOGICAL_REP_MSG_STREAM_END:
+		case LOGICAL_REP_MSG_STREAM_ABORT:
+		case LOGICAL_REP_MSG_STREAM_COMMIT:
+			return msgtype;
+	}
+
+	ereport(ERROR,
+			(errcode(ERRCODE_PROTOCOL_VIOLATION),
+			 errmsg("invalid logical replication message type \"%c\"", msgtype)));
+
+	/* Unreachable, keep compiler happy */
+	return msgtype;
+}
 
 /*
  * Logical replication protocol message dispatcher.
@@ -1897,7 +1931,7 @@ apply_handle_truncate(StringInfo s)
 static void
 apply_dispatch(StringInfo s)
 {
-	LogicalRepMsgType action = pq_getmsgbyte(s);
+	LogicalRepMsgType action = pq_get_logicalrep_msg_type(s);
 
 	switch (action)
 	{
@@ -1952,11 +1986,6 @@ apply_dispatch(StringInfo s)
 		case LOGICAL_REP_MSG_STREAM_COMMIT:
 			apply_handle_stream_commit(s);
 			break;
-
-		default:
-			ereport(ERROR,
-					(errcode(ERRCODE_PROTOCOL_VIOLATION),
-					 errmsg("invalid logical replication message type \"%c\"", action)));
 	}
 }
 
-- 
2.17.1

