From 944779537c256179747cd1cf77a11c8a88cf57db Mon Sep 17 00:00:00 2001
From: Tom Lane <tgl@sss.pgh.pa.us>
Date: Wed, 15 Jan 2025 12:39:21 -0500
Subject: [PATCH v3 2/4] Detect whether plpgsql assignment targets are "local"
 variables.

Mark whether the target of a potentially optimizable assignment
is "local", in the sense of being declared inside any exception
block that could trap an error thrown from the assignment.
(This implies that we needn't preserve the variable's value
in case of an error.)

Normally, this requires a post-parsing scan of the function's
parse tree, since we don't know while parsing a BEGIN ...
construct whether we will find EXCEPTION at its end.  However,
if there are no BEGIN ... EXCEPTION blocks in the function at
all, then all assignments are local, even those to variables
representing function arguments.  We optimize that common case
by initializing the target_is_local flags to "true", and fixing
them up with a post-scan only if we found EXCEPTION.

The scan is implemented by code that's largely copied-and-pasted
from the nearby code to scan a plpgsql parse tree for deletion.
It's a bit annoying to have three copies of that now, but I'm
not seeing a way to refactor it that would save much code on net.

Note that variables' default-value expressions are never interesting
for expanded-variable optimization, since they couldn't contain a
reference to the target variable anyway.  But the code is set up
to compute their target_param and target_is_local correctly anyway,
for consistency and in case someone thinks of a use for that data.

I added a bit of plpgsql_dumptree support to help verify that
this code sets the flags as expected.  I'm not set on keeping
that, but I do want to keep the addition of a plpgsql_dumptree
call in plpgsql_compile_inline.  It's at best an oversight that
"#option dump" doesn't work in a DO block.

Discussion: https://postgr.es/m/CACxu=vJaKFNsYxooSnW1wEgsAO5u_v1XYBacfVJ14wgJV_PYeg@mail.gmail.com
---
 src/pl/plpgsql/src/pl_comp.c  |  12 +
 src/pl/plpgsql/src/pl_funcs.c | 398 ++++++++++++++++++++++++++++++++++
 src/pl/plpgsql/src/pl_gram.y  |  15 ++
 src/pl/plpgsql/src/plpgsql.h  |   7 +-
 4 files changed, 431 insertions(+), 1 deletion(-)

diff --git a/src/pl/plpgsql/src/pl_comp.c b/src/pl/plpgsql/src/pl_comp.c
index 9dc8218292..56b899693b 100644
--- a/src/pl/plpgsql/src/pl_comp.c
+++ b/src/pl/plpgsql/src/pl_comp.c
@@ -373,6 +373,7 @@ do_compile(FunctionCallInfo fcinfo,
 
 	function->nstatements = 0;
 	function->requires_procedure_resowner = false;
+	function->has_exception_block = false;
 
 	/*
 	 * Initialize the compiler, particularly the namespace stack.  The
@@ -814,6 +815,9 @@ do_compile(FunctionCallInfo fcinfo,
 
 	plpgsql_finish_datums(function);
 
+	if (function->has_exception_block)
+		plpgsql_mark_local_assignment_targets(function);
+
 	/* Debug dump for completed functions */
 	if (plpgsql_DumpExecTree)
 		plpgsql_dumptree(function);
@@ -909,6 +913,7 @@ plpgsql_compile_inline(char *proc_source)
 
 	function->nstatements = 0;
 	function->requires_procedure_resowner = false;
+	function->has_exception_block = false;
 
 	plpgsql_ns_init();
 	plpgsql_ns_push(func_name, PLPGSQL_LABEL_BLOCK);
@@ -966,6 +971,13 @@ plpgsql_compile_inline(char *proc_source)
 
 	plpgsql_finish_datums(function);
 
+	if (function->has_exception_block)
+		plpgsql_mark_local_assignment_targets(function);
+
+	/* Debug dump for completed functions */
+	if (plpgsql_DumpExecTree)
+		plpgsql_dumptree(function);
+
 	/*
 	 * Pop the error context stack
 	 */
diff --git a/src/pl/plpgsql/src/pl_funcs.c b/src/pl/plpgsql/src/pl_funcs.c
index 8c827fe5cc..549e5d9292 100644
--- a/src/pl/plpgsql/src/pl_funcs.c
+++ b/src/pl/plpgsql/src/pl_funcs.c
@@ -333,6 +333,401 @@ plpgsql_getdiag_kindname(PLpgSQL_getdiag_kind kind)
 }
 
 
+/**********************************************************************
+ * Mark assignment source expressions that have local target variables,
+ * that is, variables declared within the exception block most closely
+ * containing the assignment itself.  (Such target variables need not be
+ * preserved if the assignment's source expression raises an error,
+ * allowing better optimization.)
+ *
+ * This code need not be called if the plpgsql function contains no exception
+ * blocks, because expr_is_assignment_source() will have set all the flags
+ * to true already.  Also, we need not examine default-value expressions for
+ * variables, because variable declarations are necessarily within the nearest
+ * exception block.  (In DECLARE ... BEGIN ... EXCEPTION ... END, the variable
+ * initializations are done before entering the exception scope.)  So it's
+ * sufficient to find assignment statements.
+ *
+ * Within the recursion, local_dnos is a Bitmapset of dnos of variables
+ * known to be declared within the current exception level.
+ **********************************************************************/
+static void mark_stmt(PLpgSQL_stmt *stmt, Bitmapset *local_dnos);
+static void mark_block(PLpgSQL_stmt_block *block, Bitmapset *local_dnos);
+static void mark_assign(PLpgSQL_stmt_assign *stmt, Bitmapset *local_dnos);
+static void mark_if(PLpgSQL_stmt_if *stmt, Bitmapset *local_dnos);
+static void mark_case(PLpgSQL_stmt_case *stmt, Bitmapset *local_dnos);
+static void mark_loop(PLpgSQL_stmt_loop *stmt, Bitmapset *local_dnos);
+static void mark_while(PLpgSQL_stmt_while *stmt, Bitmapset *local_dnos);
+static void mark_fori(PLpgSQL_stmt_fori *stmt, Bitmapset *local_dnos);
+static void mark_fors(PLpgSQL_stmt_fors *stmt, Bitmapset *local_dnos);
+static void mark_forc(PLpgSQL_stmt_forc *stmt, Bitmapset *local_dnos);
+static void mark_foreach_a(PLpgSQL_stmt_foreach_a *stmt, Bitmapset *local_dnos);
+static void mark_exit(PLpgSQL_stmt_exit *stmt, Bitmapset *local_dnos);
+static void mark_return(PLpgSQL_stmt_return *stmt, Bitmapset *local_dnos);
+static void mark_return_next(PLpgSQL_stmt_return_next *stmt, Bitmapset *local_dnos);
+static void mark_return_query(PLpgSQL_stmt_return_query *stmt, Bitmapset *local_dnos);
+static void mark_raise(PLpgSQL_stmt_raise *stmt, Bitmapset *local_dnos);
+static void mark_assert(PLpgSQL_stmt_assert *stmt, Bitmapset *local_dnos);
+static void mark_execsql(PLpgSQL_stmt_execsql *stmt, Bitmapset *local_dnos);
+static void mark_dynexecute(PLpgSQL_stmt_dynexecute *stmt, Bitmapset *local_dnos);
+static void mark_dynfors(PLpgSQL_stmt_dynfors *stmt, Bitmapset *local_dnos);
+static void mark_getdiag(PLpgSQL_stmt_getdiag *stmt, Bitmapset *local_dnos);
+static void mark_open(PLpgSQL_stmt_open *stmt, Bitmapset *local_dnos);
+static void mark_fetch(PLpgSQL_stmt_fetch *stmt, Bitmapset *local_dnos);
+static void mark_close(PLpgSQL_stmt_close *stmt, Bitmapset *local_dnos);
+static void mark_perform(PLpgSQL_stmt_perform *stmt, Bitmapset *local_dnos);
+static void mark_call(PLpgSQL_stmt_call *stmt, Bitmapset *local_dnos);
+static void mark_commit(PLpgSQL_stmt_commit *stmt, Bitmapset *local_dnos);
+static void mark_rollback(PLpgSQL_stmt_rollback *stmt, Bitmapset *local_dnos);
+
+
+static void
+mark_stmt(PLpgSQL_stmt *stmt, Bitmapset *local_dnos)
+{
+	switch (stmt->cmd_type)
+	{
+		case PLPGSQL_STMT_BLOCK:
+			mark_block((PLpgSQL_stmt_block *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_ASSIGN:
+			mark_assign((PLpgSQL_stmt_assign *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_IF:
+			mark_if((PLpgSQL_stmt_if *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_CASE:
+			mark_case((PLpgSQL_stmt_case *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_LOOP:
+			mark_loop((PLpgSQL_stmt_loop *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_WHILE:
+			mark_while((PLpgSQL_stmt_while *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_FORI:
+			mark_fori((PLpgSQL_stmt_fori *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_FORS:
+			mark_fors((PLpgSQL_stmt_fors *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_FORC:
+			mark_forc((PLpgSQL_stmt_forc *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_FOREACH_A:
+			mark_foreach_a((PLpgSQL_stmt_foreach_a *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_EXIT:
+			mark_exit((PLpgSQL_stmt_exit *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_RETURN:
+			mark_return((PLpgSQL_stmt_return *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_RETURN_NEXT:
+			mark_return_next((PLpgSQL_stmt_return_next *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_RETURN_QUERY:
+			mark_return_query((PLpgSQL_stmt_return_query *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_RAISE:
+			mark_raise((PLpgSQL_stmt_raise *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_ASSERT:
+			mark_assert((PLpgSQL_stmt_assert *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_EXECSQL:
+			mark_execsql((PLpgSQL_stmt_execsql *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_DYNEXECUTE:
+			mark_dynexecute((PLpgSQL_stmt_dynexecute *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_DYNFORS:
+			mark_dynfors((PLpgSQL_stmt_dynfors *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_GETDIAG:
+			mark_getdiag((PLpgSQL_stmt_getdiag *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_OPEN:
+			mark_open((PLpgSQL_stmt_open *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_FETCH:
+			mark_fetch((PLpgSQL_stmt_fetch *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_CLOSE:
+			mark_close((PLpgSQL_stmt_close *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_PERFORM:
+			mark_perform((PLpgSQL_stmt_perform *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_CALL:
+			mark_call((PLpgSQL_stmt_call *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_COMMIT:
+			mark_commit((PLpgSQL_stmt_commit *) stmt, local_dnos);
+			break;
+		case PLPGSQL_STMT_ROLLBACK:
+			mark_rollback((PLpgSQL_stmt_rollback *) stmt, local_dnos);
+			break;
+		default:
+			elog(ERROR, "unrecognized cmd_type: %d", stmt->cmd_type);
+			break;
+	}
+}
+
+static void
+mark_stmts(List *stmts, Bitmapset *local_dnos)
+{
+	ListCell   *s;
+
+	foreach(s, stmts)
+	{
+		mark_stmt((PLpgSQL_stmt *) lfirst(s), local_dnos);
+	}
+}
+
+static void
+mark_block(PLpgSQL_stmt_block *block, Bitmapset *local_dnos)
+{
+	if (block->exceptions)
+	{
+		ListCell   *e;
+
+		/*
+		 * The block creates a new exception scope, so variables declared at
+		 * outer levels are nonlocal.  For that matter, so are any variables
+		 * declared in the block's DECLARE section.  Hence, we must pass down
+		 * empty local_dnos.
+		 */
+		mark_stmts(block->body, NULL);
+
+		foreach(e, block->exceptions->exc_list)
+		{
+			PLpgSQL_exception *exc = (PLpgSQL_exception *) lfirst(e);
+
+			mark_stmts(exc->action, NULL);
+		}
+	}
+	else
+	{
+		/*
+		 * Otherwise, the block does not create a new exception scope, and any
+		 * variables it declares can also be considered local within it.  Note
+		 * that only initializable datum types (VAR, REC) are included in
+		 * initvarnos; but that's sufficient for our purposes.
+		 */
+		local_dnos = bms_copy(local_dnos);
+		for (int i = 0; i < block->n_initvars; i++)
+			local_dnos = bms_add_member(local_dnos, block->initvarnos[i]);
+		mark_stmts(block->body, local_dnos);
+		bms_free(local_dnos);
+	}
+}
+
+static void
+mark_assign(PLpgSQL_stmt_assign *stmt, Bitmapset *local_dnos)
+{
+	PLpgSQL_expr *expr = stmt->expr;
+
+	/*
+	 * If the assignment target is a plain DTYPE_VAR datum, mark it as local
+	 * or not.  (If it's not a VAR, we don't care.)
+	 */
+	if (expr->target_param >= 0)
+		expr->target_is_local = bms_is_member(expr->target_param, local_dnos);
+}
+
+static void
+mark_if(PLpgSQL_stmt_if *stmt, Bitmapset *local_dnos)
+{
+	ListCell   *l;
+
+	/* stmt->cond cannot be an assignment source */
+	mark_stmts(stmt->then_body, local_dnos);
+	foreach(l, stmt->elsif_list)
+	{
+		PLpgSQL_if_elsif *elif = (PLpgSQL_if_elsif *) lfirst(l);
+
+		/* elif->cond cannot be an assignment source */
+		mark_stmts(elif->stmts, local_dnos);
+	}
+	mark_stmts(stmt->else_body, local_dnos);
+}
+
+static void
+mark_case(PLpgSQL_stmt_case *stmt, Bitmapset *local_dnos)
+{
+	ListCell   *l;
+
+	/* stmt->t_expr cannot be an assignment source */
+	foreach(l, stmt->case_when_list)
+	{
+		PLpgSQL_case_when *cwt = (PLpgSQL_case_when *) lfirst(l);
+
+		/* cwt->expr cannot be an assignment source */
+		mark_stmts(cwt->stmts, local_dnos);
+	}
+	mark_stmts(stmt->else_stmts, local_dnos);
+}
+
+static void
+mark_loop(PLpgSQL_stmt_loop *stmt, Bitmapset *local_dnos)
+{
+	mark_stmts(stmt->body, local_dnos);
+}
+
+static void
+mark_while(PLpgSQL_stmt_while *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->cond cannot be an assignment source */
+	mark_stmts(stmt->body, local_dnos);
+}
+
+static void
+mark_fori(PLpgSQL_stmt_fori *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->lower, upper, step cannot be an assignment source */
+	mark_stmts(stmt->body, local_dnos);
+}
+
+static void
+mark_fors(PLpgSQL_stmt_fors *stmt, Bitmapset *local_dnos)
+{
+	mark_stmts(stmt->body, local_dnos);
+	/* stmt->query cannot be an assignment source */
+}
+
+static void
+mark_forc(PLpgSQL_stmt_forc *stmt, Bitmapset *local_dnos)
+{
+	mark_stmts(stmt->body, local_dnos);
+	/* stmt->argquery cannot be an assignment source */
+}
+
+static void
+mark_foreach_a(PLpgSQL_stmt_foreach_a *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->expr cannot be an assignment source */
+	mark_stmts(stmt->body, local_dnos);
+}
+
+static void
+mark_open(PLpgSQL_stmt_open *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->argquery, query, dynquery cannot be an assignment source */
+	/* stmt->params cannot contain an assignment source */
+}
+
+static void
+mark_fetch(PLpgSQL_stmt_fetch *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->expr cannot be an assignment source */
+}
+
+static void
+mark_close(PLpgSQL_stmt_close *stmt, Bitmapset *local_dnos)
+{
+}
+
+static void
+mark_perform(PLpgSQL_stmt_perform *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->expr cannot be an assignment source */
+}
+
+static void
+mark_call(PLpgSQL_stmt_call *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->expr cannot be an assignment source */
+}
+
+static void
+mark_commit(PLpgSQL_stmt_commit *stmt, Bitmapset *local_dnos)
+{
+}
+
+static void
+mark_rollback(PLpgSQL_stmt_rollback *stmt, Bitmapset *local_dnos)
+{
+}
+
+static void
+mark_exit(PLpgSQL_stmt_exit *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->cond cannot be an assignment source */
+}
+
+static void
+mark_return(PLpgSQL_stmt_return *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->expr cannot be an assignment source */
+}
+
+static void
+mark_return_next(PLpgSQL_stmt_return_next *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->expr cannot be an assignment source */
+}
+
+static void
+mark_return_query(PLpgSQL_stmt_return_query *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->query, dynquery cannot be an assignment source */
+	/* stmt->params cannot contain an assignment source */
+}
+
+static void
+mark_raise(PLpgSQL_stmt_raise *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->params cannot contain an assignment source */
+	/* stmt->options cannot contain an assignment source */
+}
+
+static void
+mark_assert(PLpgSQL_stmt_assert *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->cond, message cannot be an assignment source */
+}
+
+static void
+mark_execsql(PLpgSQL_stmt_execsql *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->sqlstmt cannot be an assignment source */
+}
+
+static void
+mark_dynexecute(PLpgSQL_stmt_dynexecute *stmt, Bitmapset *local_dnos)
+{
+	/* stmt->query cannot be an assignment source */
+	/* stmt->params cannot contain an assignment source */
+}
+
+static void
+mark_dynfors(PLpgSQL_stmt_dynfors *stmt, Bitmapset *local_dnos)
+{
+	mark_stmts(stmt->body, local_dnos);
+	/* stmt->query cannot be an assignment source */
+	/* stmt->params cannot contain an assignment source */
+}
+
+static void
+mark_getdiag(PLpgSQL_stmt_getdiag *stmt, Bitmapset *local_dnos)
+{
+}
+
+void
+plpgsql_mark_local_assignment_targets(PLpgSQL_function *func)
+{
+	Bitmapset  *local_dnos;
+
+	/* Function parameters can be treated as local targets at outer level */
+	local_dnos = NULL;
+	for (int i = 0; i < func->fn_nargs; i++)
+		local_dnos = bms_add_member(local_dnos, func->fn_argvarnos[i]);
+	if (func->action)
+		mark_block(func->action, local_dnos);
+	bms_free(local_dnos);
+}
+
+
 /**********************************************************************
  * Release memory when a PL/pgSQL function is no longer needed
  *
@@ -1594,6 +1989,9 @@ static void
 dump_expr(PLpgSQL_expr *expr)
 {
 	printf("'%s'", expr->query);
+	if (expr->target_param >= 0)
+		printf(" target %d%s", expr->target_param,
+			   expr->target_is_local ? " (local)" : "");
 }
 
 void
diff --git a/src/pl/plpgsql/src/pl_gram.y b/src/pl/plpgsql/src/pl_gram.y
index 7ff6b663e3..2426ca4a04 100644
--- a/src/pl/plpgsql/src/pl_gram.y
+++ b/src/pl/plpgsql/src/pl_gram.y
@@ -2327,6 +2327,8 @@ exception_sect	:
 						PLpgSQL_exception_block *new = palloc(sizeof(PLpgSQL_exception_block));
 						PLpgSQL_variable *var;
 
+						plpgsql_curr_compile->has_exception_block = true;
+
 						var = plpgsql_build_variable("sqlstate", lineno,
 													 plpgsql_build_datatype(TEXTOID,
 																			-1,
@@ -2672,6 +2674,7 @@ make_plpgsql_expr(const char *query,
 	expr->ns = plpgsql_ns_top();
 	/* might get changed later during parsing: */
 	expr->target_param = -1;
+	expr->target_is_local = false;
 	/* other fields are left as zeroes until first execution */
 	return expr;
 }
@@ -2686,9 +2689,21 @@ expr_is_assignment_source(PLpgSQL_expr *expr, PLpgSQL_datum *target)
 	 * other DTYPEs.
 	 */
 	if (target->dtype == PLPGSQL_DTYPE_VAR)
+	{
 		expr->target_param = target->dno;
+
+		/*
+		 * For now, assume the target is local to the nearest enclosing
+		 * exception block.  That's correct if the function contains no
+		 * exception blocks; otherwise we'll update this later.
+		 */
+		expr->target_is_local = true;
+	}
 	else
+	{
 		expr->target_param = -1;	/* should be that already */
+		expr->target_is_local = false; /* ditto */
+	}
 }
 
 /* Convenience routine to read an expression with one possible terminator */
diff --git a/src/pl/plpgsql/src/plpgsql.h b/src/pl/plpgsql/src/plpgsql.h
index 67fdfb3141..762af78a5e 100644
--- a/src/pl/plpgsql/src/plpgsql.h
+++ b/src/pl/plpgsql/src/plpgsql.h
@@ -225,9 +225,12 @@ typedef struct PLpgSQL_expr
 	/*
 	 * These fields are used to help optimize assignments to expanded-datum
 	 * variables.  If this expression is the source of an assignment to a
-	 * simple variable, target_param holds that variable's dno (else it's -1).
+	 * simple variable, target_param holds that variable's dno (else it's -1),
+	 * and target_is_local indicates whether the target is declared inside the
+	 * closest exception block containing the assignment.
 	 */
 	int			target_param;	/* dno of assign target, or -1 if none */
+	bool		target_is_local;	/* is it within nearest exception block? */
 
 	/*
 	 * Fields above are set during plpgsql parsing.  Remaining fields are left
@@ -1014,6 +1017,7 @@ typedef struct PLpgSQL_function
 	/* data derived while parsing body */
 	unsigned int nstatements;	/* counter for assigning stmtids */
 	bool		requires_procedure_resowner;	/* contains CALL or DO? */
+	bool		has_exception_block;	/* contains BEGIN...EXCEPTION? */
 
 	/* these fields change when the function is used */
 	struct PLpgSQL_execstate *cur_estate;
@@ -1314,6 +1318,7 @@ extern PLpgSQL_nsitem *plpgsql_ns_find_nearest_loop(PLpgSQL_nsitem *ns_cur);
  */
 extern PGDLLEXPORT const char *plpgsql_stmt_typename(PLpgSQL_stmt *stmt);
 extern const char *plpgsql_getdiag_kindname(PLpgSQL_getdiag_kind kind);
+extern void plpgsql_mark_local_assignment_targets(PLpgSQL_function *func);
 extern void plpgsql_free_function_memory(PLpgSQL_function *func);
 extern void plpgsql_dumptree(PLpgSQL_function *func);
 
-- 
2.43.5

