From f587f0bc67995f4f9919d35180e2281efc7ddd4e Mon Sep 17 00:00:00 2001
From: Dean Rasheed <dean.a.rasheed@gmail.com>
Date: Sat, 28 Sep 2024 17:20:38 +0100
Subject: [PATCH v2 1/2] Fix incorrect non-strict join recheck in MERGE WHEN
 NOT MATCHED BY SOURCE.

If a MERGE command contains WHEN NOT MATCHED BY SOURCE actions, the
merge join condition is used by the executor to distinguish MATCHED
from NOT MATCHED BY SOURCE cases. However, this qual is executed using
the output of the join subplan node, which nulls the output from the
source relation in the not matched case, and so the result may be
incorrect if the join condition is "non-strict" -- for example,
something like "src.col IS NOT DISTINCT FROM tgt.col".

Fix this by enhancing the join condition with an additional "src IS
NOT NULL" check, so that it does the right thing when evaluated using
the output of the join subplan.
---
 src/backend/optimizer/prep/prepjointree.c | 66 +++++++++++++++++++++--
 src/test/regress/expected/merge.out       | 22 ++++++++
 src/test/regress/sql/merge.sql            | 18 +++++++
 3 files changed, 101 insertions(+), 5 deletions(-)

diff --git a/src/backend/optimizer/prep/prepjointree.c b/src/backend/optimizer/prep/prepjointree.c
index a70404558f..ca15994fdc 100644
--- a/src/backend/optimizer/prep/prepjointree.c
+++ b/src/backend/optimizer/prep/prepjointree.c
@@ -158,6 +158,9 @@ transform_MERGE_to_join(Query *parse)
 	int			joinrti;
 	List	   *vars;
 	RangeTblRef *rtr;
+	FromExpr   *target;
+	Node	   *source;
+	int			sourcerti;
 
 	if (parse->commandType != CMD_MERGE)
 		return;
@@ -226,13 +229,36 @@ transform_MERGE_to_join(Query *parse)
 	 * parse->jointree->quals are restrictions on the target relation (if the
 	 * target relation is an auto-updatable view).
 	 */
+	/* target rel, with any quals */
 	rtr = makeNode(RangeTblRef);
 	rtr->rtindex = parse->mergeTargetRelation;
+	target = makeFromExpr(list_make1(rtr), parse->jointree->quals);
+
+	/* source rel (expect exactly one -- see transformMergeStmt()) */
+	Assert(list_length(parse->jointree->fromlist) == 1);
+	source = linitial(parse->jointree->fromlist);
+
+	/*
+	 * index of source rel (expect either a RangeTblRef or a JoinExpr -- see
+	 * transformFromClauseItem()).
+	 */
+	if (IsA(source, RangeTblRef))
+		sourcerti = ((RangeTblRef *) source)->rtindex;
+	else if (IsA(source, JoinExpr))
+		sourcerti = ((JoinExpr *) source)->rtindex;
+	else
+	{
+		elog(ERROR, "unrecognized source node type: %d",
+			 (int) nodeTag(source));
+		sourcerti = 0;			/* keep compiler quiet */
+	}
+
+	/* Join the source and target */
 	joinexpr = makeNode(JoinExpr);
 	joinexpr->jointype = jointype;
 	joinexpr->isNatural = false;
-	joinexpr->larg = (Node *) makeFromExpr(list_make1(rtr), parse->jointree->quals);
-	joinexpr->rarg = linitial(parse->jointree->fromlist);	/* source rel */
+	joinexpr->larg = (Node *) target;
+	joinexpr->rarg = source;
 	joinexpr->usingClause = NIL;
 	joinexpr->join_using_alias = NULL;
 	joinexpr->quals = parse->mergeJoinCondition;
@@ -261,9 +287,39 @@ transform_MERGE_to_join(Query *parse)
 	 * use the join condition to distinguish between MATCHED and NOT MATCHED
 	 * BY SOURCE cases.  Otherwise, it's no longer needed, and we set it to
 	 * NULL, saving cycles during planning and execution.
-	 */
-	if (!have_action[MERGE_WHEN_NOT_MATCHED_BY_SOURCE])
-		parse->mergeJoinCondition = NULL;
+	 *
+	 * We need to be careful though: the executor evaluates this condition
+	 * using the output of the join subplan node, which nulls the output from
+	 * the source relation when the join condition doesn't match.  That risks
+	 * producing incorrect results when rechecking using a "non-strict" join
+	 * condition, such as "src.col IS NOT DISTINCT FROM tgt.col".  To guard
+	 * against that, we add an additional "src IS NOT NULL" check to the join
+	 * condition, so that it does the right thing when performing a recheck
+	 * based on the output of the join subplan.
+	 */
+	if (have_action[MERGE_WHEN_NOT_MATCHED_BY_SOURCE])
+	{
+		Var		   *var;
+		NullTest   *ntest;
+
+		/* source wholerow Var (nullable by the new join) */
+		var = makeWholeRowVar(rt_fetch(sourcerti, parse->rtable),
+							  sourcerti, 0, false);
+		var->varnullingrels = bms_make_singleton(joinrti);
+
+		/* "src IS NOT NULL" check */
+		ntest = makeNode(NullTest);
+		ntest->arg = (Expr *) var;
+		ntest->nulltesttype = IS_NOT_NULL;
+		ntest->argisrow = false;
+		ntest->location = -1;
+
+		/* combine it with the original join condition */
+		parse->mergeJoinCondition =
+			(Node *) make_and_qual((Node *) ntest, parse->mergeJoinCondition);
+	}
+	else
+		parse->mergeJoinCondition = NULL;	/* join condition not needed */
 }
 
 /*
diff --git a/src/test/regress/expected/merge.out b/src/test/regress/expected/merge.out
index 3d33259e8f..0e59bae1a7 100644
--- a/src/test/regress/expected/merge.out
+++ b/src/test/regress/expected/merge.out
@@ -2689,6 +2689,28 @@ DETAIL:  drop cascades to table measurement_y2006m02
 drop cascades to table measurement_y2006m03
 drop cascades to table measurement_y2007m01
 DROP FUNCTION measurement_insert_trigger();
+--
+-- test non-strict join clause
+--
+CREATE TABLE src (a int, b text);
+INSERT INTO src VALUES (1, 'src row');
+CREATE TABLE tgt (a int, b text);
+INSERT INTO tgt VALUES (NULL, 'tgt row');
+MERGE INTO tgt USING src ON tgt.a IS NOT DISTINCT FROM src.a
+  WHEN MATCHED THEN UPDATE SET a = src.a, b = src.b
+  WHEN NOT MATCHED BY SOURCE THEN DELETE
+  RETURNING merge_action(), src.*, tgt.*;
+ merge_action | a | b | a |    b    
+--------------+---+---+---+---------
+ DELETE       |   |   |   | tgt row
+(1 row)
+
+SELECT * FROM tgt;
+ a | b 
+---+---
+(0 rows)
+
+DROP TABLE src, tgt;
 -- prepare
 RESET SESSION AUTHORIZATION;
 -- try a system catalog
diff --git a/src/test/regress/sql/merge.sql b/src/test/regress/sql/merge.sql
index 92163ec9fe..2a7753c65b 100644
--- a/src/test/regress/sql/merge.sql
+++ b/src/test/regress/sql/merge.sql
@@ -1710,6 +1710,24 @@ SELECT * FROM new_measurement ORDER BY city_id, logdate;
 DROP TABLE measurement, new_measurement CASCADE;
 DROP FUNCTION measurement_insert_trigger();
 
+--
+-- test non-strict join clause
+--
+CREATE TABLE src (a int, b text);
+INSERT INTO src VALUES (1, 'src row');
+
+CREATE TABLE tgt (a int, b text);
+INSERT INTO tgt VALUES (NULL, 'tgt row');
+
+MERGE INTO tgt USING src ON tgt.a IS NOT DISTINCT FROM src.a
+  WHEN MATCHED THEN UPDATE SET a = src.a, b = src.b
+  WHEN NOT MATCHED BY SOURCE THEN DELETE
+  RETURNING merge_action(), src.*, tgt.*;
+
+SELECT * FROM tgt;
+
+DROP TABLE src, tgt;
+
 -- prepare
 
 RESET SESSION AUTHORIZATION;
-- 
2.43.0

