From f8dc7c84c4db7625493d7061ac53acbbedd6da58 Mon Sep 17 00:00:00 2001
From: "okbob@github.com" <okbob@github.com>
Date: Wed, 12 Jun 2024 21:34:05 +0200
Subject: [PATCH 1/3] use strict rules for parsing PL/pgSQL expressions

Originally the rule PLpgSQL_Expr allows almost all SQL clauses. It was designed
to allow old undocumented syntax

    var := col FROM tab;

The reason for support of this "strange" syntax was technical. The PLpgSQL parser
cannot use SQL parser accurately (it was really primitive), and people found
this undocumented syntax. Lattery, when it was possible to do exact parsing, from
compatibility reasons, the parsing of PL/pgSQL expressions allows described syntax.

Unfortunately, with support almost all SQL clauses, the PLpgSQL can accept
really broken code like

    DO $$
    DECLARE
      l_cnt int;
    BEGIN
      l_cnt := 1
      DELETE FROM foo3 WHERE id=1;
    END; $$;

proposed patch introduce new extra error check strict_expr_check, that solve
this issue.
---
 doc/src/sgml/plpgsql.sgml             |  18 ++++
 src/pl/plpgsql/src/pl_comp.c          |   7 ++
 src/pl/plpgsql/src/pl_gram.y          | 138 ++++++++++++++++++++++----
 src/pl/plpgsql/src/pl_handler.c       |   2 +
 src/pl/plpgsql/src/plpgsql.h          |   1 +
 src/test/regress/expected/plpgsql.out |  14 +++
 src/test/regress/sql/plpgsql.sql      |  14 +++
 7 files changed, 176 insertions(+), 18 deletions(-)

diff --git a/doc/src/sgml/plpgsql.sgml b/doc/src/sgml/plpgsql.sgml
index 78e4983139b..a27700d6cb3 100644
--- a/doc/src/sgml/plpgsql.sgml
+++ b/doc/src/sgml/plpgsql.sgml
@@ -5386,6 +5386,24 @@ a_output := a_output || $$ if v_$$ || referrer_keys.kind || $$ like '$$
        </para>
       </listitem>
      </varlistentry>
+
+     <varlistentry id="plpgsql-extra-checks-strict-expr-check">
+      <term><varname>strict_expr_check</varname></term>
+      <listitem>
+       <para>
+        Enabling this check will cause <application>PL/pgSQL</application> to
+        check if a <application>PL/pgSQL</application> expression is just an
+        expression without any SQL clauses like <literal>FROM</literal>,
+        <literal>ORDER BY</literal>. This undocumented form of expressions
+        is allowed for compatibility reasons, but in some special cases
+        it doesn't to allow to detect broken code.
+       </para>
+
+       <para>
+        This check is allowed only <varname>plpgsql.extra_errors</varname>.
+       </para>
+      </listitem>
+     </varlistentry>
     </variablelist>
 
     The following example shows the effect of <varname>plpgsql.extra_warnings</varname>
diff --git a/src/pl/plpgsql/src/pl_comp.c b/src/pl/plpgsql/src/pl_comp.c
index a2de0880fb7..8174721a5a1 100644
--- a/src/pl/plpgsql/src/pl_comp.c
+++ b/src/pl/plpgsql/src/pl_comp.c
@@ -904,6 +904,13 @@ plpgsql_compile_inline(char *proc_source)
 	function->extra_warnings = 0;
 	function->extra_errors = 0;
 
+	/*
+	 * Although function->extra_errors is disabled, we want to
+	 * do strict_expr_check inside annoymous block too.
+	 */
+	if (plpgsql_extra_errors & PLPGSQL_XCHECK_STRICTEXPRCHECK)
+		function->extra_errors = PLPGSQL_XCHECK_STRICTEXPRCHECK;
+
 	function->nstatements = 0;
 	function->requires_procedure_resowner = false;
 
diff --git a/src/pl/plpgsql/src/pl_gram.y b/src/pl/plpgsql/src/pl_gram.y
index 64d2c362bf9..84a7fd4b762 100644
--- a/src/pl/plpgsql/src/pl_gram.y
+++ b/src/pl/plpgsql/src/pl_gram.y
@@ -18,6 +18,7 @@
 #include "catalog/namespace.h"
 #include "catalog/pg_proc.h"
 #include "catalog/pg_type.h"
+#include "nodes/nodeFuncs.h"
 #include "parser/parser.h"
 #include "parser/parse_type.h"
 #include "parser/scanner.h"
@@ -67,6 +68,7 @@ static	PLpgSQL_expr	*read_sql_construct(int until,
 											const char *expected,
 											RawParseMode parsemode,
 											bool isexpression,
+											bool allowlist,
 											bool valid_sql,
 											int *startloc,
 											int *endtoken,
@@ -102,7 +104,7 @@ static	PLpgSQL_row		*make_scalar_list1(char *initial_name,
 										   PLpgSQL_datum *initial_datum,
 										   int lineno, int location, yyscan_t yyscanner);
 static	void			 check_sql_expr(const char *stmt,
-										RawParseMode parseMode, int location, yyscan_t yyscanner);
+										RawParseMode parseMode, bool allowlist, int location, yyscan_t yyscanner);
 static	void			 plpgsql_sql_error_callback(void *arg);
 static	PLpgSQL_type	*parse_datatype(const char *string, int location, yyscan_t yyscanner);
 static	void			 check_labels(const char *start_label,
@@ -113,6 +115,7 @@ static	PLpgSQL_expr	*read_cursor_args(PLpgSQL_var *cursor, int until,
 										  YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner);
 static	List			*read_raise_options(YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner);
 static	void			check_raise_parameters(PLpgSQL_stmt_raise *stmt);
+static	bool			is_strict_expr(List *parsetree, int *errpos, bool allowlist);
 
 %}
 
@@ -189,6 +192,7 @@ static	void			check_raise_parameters(PLpgSQL_stmt_raise *stmt);
 %type <expr>	expr_until_semi
 %type <expr>	expr_until_then expr_until_loop opt_expr_until_when
 %type <expr>	opt_exitcond
+%type <expr>	expressions_until_then
 
 %type <var>		cursor_variable
 %type <datum>	decl_cursor_arg
@@ -906,7 +910,7 @@ stmt_perform	: K_PERFORM
 						 */
 						new->expr = read_sql_construct(';', 0, 0, ";",
 													   RAW_PARSE_DEFAULT,
-													   false, false,
+													   false, false, false,
 													   &startloc, NULL,
 													   &yylval, &yylloc, yyscanner);
 						/* overwrite "perform" ... */
@@ -916,7 +920,7 @@ stmt_perform	: K_PERFORM
 								strlen(new->expr->query));
 						/* offset syntax error position to account for that */
 						check_sql_expr(new->expr->query, new->expr->parseMode,
-									   startloc + 1, yyscanner);
+									   false, startloc + 1, yyscanner);
 
 						$$ = (PLpgSQL_stmt *) new;
 					}
@@ -993,7 +997,7 @@ stmt_assign		: T_DATUM
 						plpgsql_push_back_token(T_DATUM, &yylval, &yylloc, yyscanner);
 						new->expr = read_sql_construct(';', 0, 0, ";",
 													   pmode,
-													   false, true,
+													   false, false, true,
 													   NULL, NULL,
 													   &yylval, &yylloc, yyscanner);
 
@@ -1253,7 +1257,7 @@ case_when_list	: case_when_list case_when
 					}
 				;
 
-case_when		: K_WHEN expr_until_then proc_sect
+case_when		: K_WHEN expressions_until_then proc_sect
 					{
 						PLpgSQL_case_when *new = palloc(sizeof(PLpgSQL_case_when));
 
@@ -1283,6 +1287,15 @@ opt_case_else	:
 					}
 				;
 
+expressions_until_then :
+					{
+						$$ = read_sql_construct(K_THEN, 0, 0, "THEN",
+												RAW_PARSE_PLPGSQL_EXPR, /* expr_list */
+												true, true, true, NULL, NULL,
+												&yylval, &yylloc, yyscanner);
+					}
+				;
+
 stmt_loop		: opt_loop_label K_LOOP loop_body
 					{
 						PLpgSQL_stmt_loop *new;
@@ -1486,6 +1499,7 @@ for_control		: for_variable K_IN
 													   RAW_PARSE_DEFAULT,
 													   true,
 													   false,
+													   false,
 													   &expr1loc,
 													   &tok,
 													   &yylval, &yylloc, yyscanner);
@@ -1504,7 +1518,7 @@ for_control		: for_variable K_IN
 								 */
 								expr1->parseMode = RAW_PARSE_PLPGSQL_EXPR;
 								check_sql_expr(expr1->query, expr1->parseMode,
-											   expr1loc, yyscanner);
+											   false, expr1loc, yyscanner);
 
 								/* Read and check the second one */
 								expr2 = read_sql_expression2(K_LOOP, K_BY,
@@ -1561,7 +1575,7 @@ for_control		: for_variable K_IN
 
 								/* Check syntax as a regular query */
 								check_sql_expr(expr1->query, expr1->parseMode,
-											   expr1loc, yyscanner);
+											   false, expr1loc, yyscanner);
 
 								new = palloc0(sizeof(PLpgSQL_stmt_fors));
 								new->cmd_type = PLPGSQL_STMT_FORS;
@@ -1893,7 +1907,7 @@ stmt_raise		: K_RAISE
 									expr = read_sql_construct(',', ';', K_USING,
 															  ", or ; or USING",
 															  RAW_PARSE_PLPGSQL_EXPR,
-															  true, true,
+															  true, false, true,
 															  NULL, &tok,
 															  &yylval, &yylloc, yyscanner);
 									new->params = lappend(new->params, expr);
@@ -2031,7 +2045,7 @@ stmt_dynexecute : K_EXECUTE
 						expr = read_sql_construct(K_INTO, K_USING, ';',
 												  "INTO or USING or ;",
 												  RAW_PARSE_PLPGSQL_EXPR,
-												  true, true,
+												  true, false, true,
 												  NULL, &endtoken,
 												  &yylval, &yylloc, yyscanner);
 
@@ -2071,7 +2085,7 @@ stmt_dynexecute : K_EXECUTE
 									expr = read_sql_construct(',', ';', K_INTO,
 															  ", or ; or INTO",
 															  RAW_PARSE_PLPGSQL_EXPR,
-															  true, true,
+															  true, false, true,
 															  NULL, &endtoken,
 															  &yylval, &yylloc, yyscanner);
 									new->params = lappend(new->params, expr);
@@ -2657,7 +2671,7 @@ read_sql_expression(int until, const char *expected, YYSTYPE *yylvalp, YYLTYPE *
 {
 	return read_sql_construct(until, 0, 0, expected,
 							  RAW_PARSE_PLPGSQL_EXPR,
-							  true, true, NULL, NULL,
+							  true, false, true, NULL, NULL,
 							  yylvalp, yyllocp, yyscanner);
 }
 
@@ -2668,7 +2682,7 @@ read_sql_expression2(int until, int until2, const char *expected,
 {
 	return read_sql_construct(until, until2, 0, expected,
 							  RAW_PARSE_PLPGSQL_EXPR,
-							  true, true, NULL, endtoken,
+							  true, false, true, NULL, endtoken,
 							  yylvalp, yyllocp, yyscanner);
 }
 
@@ -2678,7 +2692,7 @@ read_sql_stmt(YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner)
 {
 	return read_sql_construct(';', 0, 0, ";",
 							  RAW_PARSE_DEFAULT,
-							  false, true, NULL, NULL,
+							  false, false, true, NULL, NULL,
 							  yylvalp, yyllocp, yyscanner);
 }
 
@@ -2691,6 +2705,7 @@ read_sql_stmt(YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner)
  * expected:	text to use in complaining that terminator was not found
  * parsemode:	raw_parser() mode to use
  * isexpression: whether to say we're reading an "expression" or a "statement"
+ * allowlist:   the result can be list of expressions
  * valid_sql:   whether to check the syntax of the expr
  * startloc:	if not NULL, location of first token is stored at *startloc
  * endtoken:	if not NULL, ending token is stored at *endtoken
@@ -2703,6 +2718,7 @@ read_sql_construct(int until,
 				   const char *expected,
 				   RawParseMode parsemode,
 				   bool isexpression,
+				   bool allowlist,
 				   bool valid_sql,
 				   int *startloc,
 				   int *endtoken,
@@ -2804,7 +2820,7 @@ read_sql_construct(int until,
 	pfree(ds.data);
 
 	if (valid_sql)
-		check_sql_expr(expr->query, expr->parseMode, startlocation, yyscanner);
+		check_sql_expr(expr->query, expr->parseMode, allowlist, startlocation, yyscanner);
 
 	return expr;
 }
@@ -3131,7 +3147,7 @@ make_execsql_stmt(int firsttoken, int location, PLword *word, YYSTYPE *yylvalp,
 	expr->ns = plpgsql_ns_top();
 	pfree(ds.data);
 
-	check_sql_expr(expr->query, expr->parseMode, location, yyscanner);
+	check_sql_expr(expr->query, expr->parseMode, false, location, yyscanner);
 
 	execsql = palloc0(sizeof(PLpgSQL_stmt_execsql));
 	execsql->cmd_type = PLPGSQL_STMT_EXECSQL;
@@ -3731,11 +3747,15 @@ make_scalar_list1(char *initial_name,
  * If no error cursor is provided, we'll just point at "location".
  */
 static void
-check_sql_expr(const char *stmt, RawParseMode parseMode, int location, yyscan_t yyscanner)
+check_sql_expr(const char *stmt,
+			   RawParseMode parseMode, bool allowlist,
+			   int location, yyscan_t yyscanner)
 {
 	sql_error_callback_arg cbarg;
 	ErrorContextCallback syntax_errcontext;
 	MemoryContext oldCxt;
+	List   *parsetree;
+	int		errpos;
 
 	if (!plpgsql_check_syntax)
 		return;
@@ -3749,11 +3769,25 @@ check_sql_expr(const char *stmt, RawParseMode parseMode, int location, yyscan_t
 	error_context_stack = &syntax_errcontext;
 
 	oldCxt = MemoryContextSwitchTo(plpgsql_compile_tmp_cxt);
-	(void) raw_parser(stmt, parseMode);
+	parsetree = raw_parser(stmt, parseMode);
 	MemoryContextSwitchTo(oldCxt);
 
 	/* Restore former ereport callback */
 	error_context_stack = syntax_errcontext.previous;
+
+	if (plpgsql_curr_compile->extra_warnings & PLPGSQL_XCHECK_STRICTEXPRCHECK ||
+		plpgsql_curr_compile->extra_errors & PLPGSQL_XCHECK_STRICTEXPRCHECK)
+	{
+		/* do this check only for expressions */
+		if (parseMode == RAW_PARSE_DEFAULT)
+			return;
+
+		if (!is_strict_expr(parsetree, &errpos, allowlist))
+			ereport(plpgsql_curr_compile->extra_errors & PLPGSQL_XCHECK_STRICTEXPRCHECK ? ERROR : WARNING,
+					(errcode(ERRCODE_SYNTAX_ERROR),
+					 errmsg("syntax of expression is not strict"),
+					 parser_errposition(errpos != -1 ? location + errpos : location)));
+	}
 }
 
 static void
@@ -3787,6 +3821,74 @@ plpgsql_sql_error_callback(void *arg)
 	errposition(0);
 }
 
+/*
+ * Returns true, when the only targetList is in parsetree. Cursors
+ * can require list of expressions or list of named expressions.
+ */
+static bool
+is_strict_expr(List *parsetree, int *errpos, bool allowlist)
+{
+	RawStmt *rawstmt;
+	SelectStmt *select;
+	int		targets = 0;
+	ListCell *lc;
+
+	/* Top should be RawStmt */
+	rawstmt = castNode(RawStmt, linitial(parsetree));
+
+	if (IsA(rawstmt->stmt, SelectStmt))
+	{
+		select = (SelectStmt *) rawstmt->stmt;
+	}
+	else if (IsA(rawstmt->stmt, PLAssignStmt))
+	{
+		select = castNode(SelectStmt, ((PLAssignStmt *) rawstmt->stmt)->val);
+	}
+	else
+		elog(ERROR, "unexpected node type");
+
+	if (!select->targetList)
+	{
+		*errpos = -1;
+		return false;
+	}
+	else
+		*errpos = exprLocation((Node *) select->targetList);
+
+	if (select->distinctClause ||
+		select->fromClause ||
+		select->whereClause ||
+		select->groupClause ||
+		select->groupDistinct ||
+		select->havingClause ||
+		select->windowClause ||
+		select->sortClause ||
+		select->limitOffset ||
+		select->limitCount ||
+		select->limitOption ||
+		select->lockingClause)
+		return false;
+
+	foreach(lc, select->targetList)
+	{
+		ResTarget *rt = castNode(ResTarget, lfirst(lc));
+
+		if (targets++ >= 1 && !allowlist)
+		{
+			*errpos = exprLocation((Node *) rt);
+			return false;
+		}
+
+		if (rt->name)
+		{
+			*errpos = exprLocation((Node *) rt);
+			return false;
+		}
+	}
+
+	return true;
+}
+
 /*
  * Parse a SQL datatype name and produce a PLpgSQL_type structure.
  *
@@ -3967,7 +4069,7 @@ read_cursor_args(PLpgSQL_var *cursor, int until, YYSTYPE *yylvalp, YYLTYPE *yyll
 		item = read_sql_construct(',', ')', 0,
 								  ",\" or \")",
 								  RAW_PARSE_PLPGSQL_EXPR,
-								  true, true,
+								  true, false, true,
 								  NULL, &endtoken,
 								  yylvalp, yyllocp, yyscanner);
 
diff --git a/src/pl/plpgsql/src/pl_handler.c b/src/pl/plpgsql/src/pl_handler.c
index 5af38d5773b..3ce196de58f 100644
--- a/src/pl/plpgsql/src/pl_handler.c
+++ b/src/pl/plpgsql/src/pl_handler.c
@@ -94,6 +94,8 @@ plpgsql_extra_checks_check_hook(char **newvalue, void **extra, GucSource source)
 				extrachecks |= PLPGSQL_XCHECK_TOOMANYROWS;
 			else if (pg_strcasecmp(tok, "strict_multi_assignment") == 0)
 				extrachecks |= PLPGSQL_XCHECK_STRICTMULTIASSIGNMENT;
+			else if (pg_strcasecmp(tok, "strict_expr_check") == 0)
+				extrachecks |= PLPGSQL_XCHECK_STRICTEXPRCHECK;
 			else if (pg_strcasecmp(tok, "all") == 0 || pg_strcasecmp(tok, "none") == 0)
 			{
 				GUC_check_errdetail("Key word \"%s\" cannot be combined with other key words.", tok);
diff --git a/src/pl/plpgsql/src/plpgsql.h b/src/pl/plpgsql/src/plpgsql.h
index 441df5354e2..1cb2e80210e 100644
--- a/src/pl/plpgsql/src/plpgsql.h
+++ b/src/pl/plpgsql/src/plpgsql.h
@@ -1204,6 +1204,7 @@ extern bool plpgsql_check_asserts;
 #define PLPGSQL_XCHECK_SHADOWVAR				(1 << 1)
 #define PLPGSQL_XCHECK_TOOMANYROWS				(1 << 2)
 #define PLPGSQL_XCHECK_STRICTMULTIASSIGNMENT	(1 << 3)
+#define PLPGSQL_XCHECK_STRICTEXPRCHECK			(1 << 4)
 #define PLPGSQL_XCHECK_ALL						((int) ~0)
 
 extern int	plpgsql_extra_warnings;
diff --git a/src/test/regress/expected/plpgsql.out b/src/test/regress/expected/plpgsql.out
index 0a6945581bd..4e21bd714c7 100644
--- a/src/test/regress/expected/plpgsql.out
+++ b/src/test/regress/expected/plpgsql.out
@@ -3083,6 +3083,20 @@ select shadowtest(1);
  t
 (1 row)
 
+-- test of strict expression check
+set plpgsql.extra_errors to 'strict_expr_check';
+create or replace function strict_expr_check_func()
+returns void as $$
+declare var int;
+begin
+  var = 1
+  delete from pg_class where false;
+end;
+$$ language plpgsql;
+ERROR:  syntax of expression is not strict
+LINE 5:   var = 1
+                ^
+reset plpgsql.extra_errors;
 -- runtime extra checks
 set plpgsql.extra_warnings to 'too_many_rows';
 do $$
diff --git a/src/test/regress/sql/plpgsql.sql b/src/test/regress/sql/plpgsql.sql
index 18c91572ae1..a5d40bd95c7 100644
--- a/src/test/regress/sql/plpgsql.sql
+++ b/src/test/regress/sql/plpgsql.sql
@@ -2617,6 +2617,20 @@ declare f1 int; begin return 1; end $$ language plpgsql;
 
 select shadowtest(1);
 
+-- test of strict expression check
+set plpgsql.extra_errors to 'strict_expr_check';
+
+create or replace function strict_expr_check_func()
+returns void as $$
+declare var int;
+begin
+  var = 1
+  delete from pg_class where false;
+end;
+$$ language plpgsql;
+
+reset plpgsql.extra_errors;
+
 -- runtime extra checks
 set plpgsql.extra_warnings to 'too_many_rows';
 
-- 
2.48.1

