diff --git a/src/backend/replication/logical/worker.c b/src/backend/replication/logical/worker.c
index 6ba447ea97..9e9b47ce4f 100644
--- a/src/backend/replication/logical/worker.c
+++ b/src/backend/replication/logical/worker.c
@@ -282,30 +282,41 @@ should_apply_changes_for_rel(LogicalRepRelMapEntry *rel)
 }
 
 /*
- * Make sure that we started local transaction.
+ * Begin one step (one INSERT, UPDATE, etc) of a replication transaction.
  *
- * Also switches to ApplyMessageContext as necessary.
+ * Start a transaction, if this is the first step (else we keep using the
+ * existing transaction).
+ * Also provide a global snapshot and ensure we run in ApplyMessageContext.
  */
-static bool
-ensure_transaction(void)
+static void
+begin_transaction_step(void)
 {
-	if (IsTransactionState())
-	{
-		SetCurrentStatementStartTimestamp();
-
-		if (CurrentMemoryContext != ApplyMessageContext)
-			MemoryContextSwitchTo(ApplyMessageContext);
+	SetCurrentStatementStartTimestamp();
 
-		return false;
+	if (!IsTransactionState())
+	{
+		StartTransactionCommand();
+		maybe_reread_subscription();
 	}
 
-	SetCurrentStatementStartTimestamp();
-	StartTransactionCommand();
-
-	maybe_reread_subscription();
+	PushActiveSnapshot(GetTransactionSnapshot());
 
 	MemoryContextSwitchTo(ApplyMessageContext);
-	return true;
+}
+
+/*
+ * Finish up one step of a replication transaction.
+ * Callers of begin_transaction_step() must also call this.
+ *
+ * We don't close out the transaction here, but we should increment
+ * the command counter to make the effects of this step visible.
+ */
+static void
+end_transaction_step(void)
+{
+	PopActiveSnapshot();
+
+	CommandCounterIncrement();
 }
 
 /*
@@ -359,13 +370,6 @@ create_edata_for_relation(LogicalRepRelMapEntry *rel)
 	RangeTblEntry *rte;
 	ResultRelInfo *resultRelInfo;
 
-	/*
-	 * Input functions may need an active snapshot, as may AFTER triggers
-	 * invoked during finish_edata.  For safety, ensure an active snapshot
-	 * exists throughout all our usage of the executor.
-	 */
-	PushActiveSnapshot(GetTransactionSnapshot());
-
 	edata = (ApplyExecutionData *) palloc0(sizeof(ApplyExecutionData));
 	edata->targetRel = rel;
 
@@ -433,8 +437,6 @@ finish_edata(ApplyExecutionData *edata)
 	ExecResetTupleTable(estate->es_tupleTable, false);
 	FreeExecutorState(estate);
 	pfree(edata);
-
-	PopActiveSnapshot();
 }
 
 /*
@@ -831,7 +833,7 @@ apply_handle_stream_start(StringInfo s)
 	 * transaction for handling the buffile, used for serializing the
 	 * streaming data and subxact info.
 	 */
-	ensure_transaction();
+	begin_transaction_step();
 
 	/* notify handle methods we're processing a remote transaction */
 	in_streamed_transaction = true;
@@ -861,6 +863,8 @@ apply_handle_stream_start(StringInfo s)
 		subxact_info_read(MyLogicalRepWorker->subid, stream_xid);
 
 	pgstat_report_activity(STATE_RUNNING, NULL);
+
+	end_transaction_step();
 }
 
 /*
@@ -937,7 +941,7 @@ apply_handle_stream_abort(StringInfo s)
 		StreamXidHash *ent;
 
 		subidx = -1;
-		ensure_transaction();
+		begin_transaction_step();
 		subxact_info_read(MyLogicalRepWorker->subid, xid);
 
 		for (i = subxact_data.nsubxacts; i > 0; i--)
@@ -958,7 +962,7 @@ apply_handle_stream_abort(StringInfo s)
 		{
 			/* Cleanup the subxact info */
 			cleanup_subxact_info();
-
+			end_transaction_step();
 			CommitTransactionCommand();
 			return;
 		}
@@ -986,6 +990,7 @@ apply_handle_stream_abort(StringInfo s)
 		/* write the updated subxact list */
 		subxact_info_write(MyLogicalRepWorker->subid, xid);
 
+		end_transaction_step();
 		CommitTransactionCommand();
 	}
 }
@@ -1013,7 +1018,8 @@ apply_handle_stream_commit(StringInfo s)
 
 	elog(DEBUG1, "received commit for streamed transaction %u", xid);
 
-	ensure_transaction();
+	/* Make sure we have an open transaction */
+	begin_transaction_step();
 
 	/*
 	 * Allocate file handle and memory required to process all the messages in
@@ -1046,6 +1052,8 @@ apply_handle_stream_commit(StringInfo s)
 	in_remote_transaction = true;
 	pgstat_report_activity(STATE_RUNNING, NULL);
 
+	end_transaction_step();
+
 	/*
 	 * Read the entries one by one and pass them through the same logic as in
 	 * apply_dispatch.
@@ -1227,7 +1235,7 @@ apply_handle_insert(StringInfo s)
 	if (handle_streamed_transaction(LOGICAL_REP_MSG_INSERT, s))
 		return;
 
-	ensure_transaction();
+	begin_transaction_step();
 
 	relid = logicalrep_read_insert(s, &newtup);
 	rel = logicalrep_rel_open(relid, RowExclusiveLock);
@@ -1238,6 +1246,7 @@ apply_handle_insert(StringInfo s)
 		 * transaction so it's safe to unlock it.
 		 */
 		logicalrep_rel_close(rel, RowExclusiveLock);
+		end_transaction_step();
 		return;
 	}
 
@@ -1266,7 +1275,7 @@ apply_handle_insert(StringInfo s)
 
 	logicalrep_rel_close(rel, NoLock);
 
-	CommandCounterIncrement();
+	end_transaction_step();
 }
 
 /*
@@ -1346,7 +1355,7 @@ apply_handle_update(StringInfo s)
 	if (handle_streamed_transaction(LOGICAL_REP_MSG_UPDATE, s))
 		return;
 
-	ensure_transaction();
+	begin_transaction_step();
 
 	relid = logicalrep_read_update(s, &has_oldtup, &oldtup,
 								   &newtup);
@@ -1358,6 +1367,7 @@ apply_handle_update(StringInfo s)
 		 * transaction so it's safe to unlock it.
 		 */
 		logicalrep_rel_close(rel, RowExclusiveLock);
+		end_transaction_step();
 		return;
 	}
 
@@ -1416,7 +1426,7 @@ apply_handle_update(StringInfo s)
 
 	logicalrep_rel_close(rel, NoLock);
 
-	CommandCounterIncrement();
+	end_transaction_step();
 }
 
 /*
@@ -1501,7 +1511,7 @@ apply_handle_delete(StringInfo s)
 	if (handle_streamed_transaction(LOGICAL_REP_MSG_DELETE, s))
 		return;
 
-	ensure_transaction();
+	begin_transaction_step();
 
 	relid = logicalrep_read_delete(s, &oldtup);
 	rel = logicalrep_rel_open(relid, RowExclusiveLock);
@@ -1512,6 +1522,7 @@ apply_handle_delete(StringInfo s)
 		 * transaction so it's safe to unlock it.
 		 */
 		logicalrep_rel_close(rel, RowExclusiveLock);
+		end_transaction_step();
 		return;
 	}
 
@@ -1542,7 +1553,7 @@ apply_handle_delete(StringInfo s)
 
 	logicalrep_rel_close(rel, NoLock);
 
-	CommandCounterIncrement();
+	end_transaction_step();
 }
 
 /*
@@ -1867,7 +1878,7 @@ apply_handle_truncate(StringInfo s)
 	if (handle_streamed_transaction(LOGICAL_REP_MSG_TRUNCATE, s))
 		return;
 
-	ensure_transaction();
+	begin_transaction_step();
 
 	remote_relids = logicalrep_read_truncate(s, &cascade, &restart_seqs);
 
@@ -1958,7 +1969,7 @@ apply_handle_truncate(StringInfo s)
 		table_close(rel, NoLock);
 	}
 
-	CommandCounterIncrement();
+	end_transaction_step();
 }
 
 
