From 533001bf070053c62956a6ad9921f7c3b370fb96 Mon Sep 17 00:00:00 2001
From: "Andrey V. Lepikhov" <a.lepikhov@postgrespro.ru>
Date: Tue, 17 May 2022 17:48:49 +0500
Subject: [PATCH] Fix potential problem with negative value of groups number.
 Add the clamp_cardinality_to_long routine for safe cast Cardinality to a long
 value everywhere in the code.

---
 src/backend/executor/nodeSubplan.c       |  3 +-
 src/backend/optimizer/path/costsize.c    | 25 +++++++++++++++
 src/backend/optimizer/plan/createplan.c  |  6 ++--
 src/include/optimizer/optimizer.h        |  1 +
 src/test/regress/expected/aggregates.out | 41 ++++++++++++++++++++++++
 src/test/regress/sql/aggregates.sql      | 12 +++++++
 6 files changed, 84 insertions(+), 4 deletions(-)

diff --git a/src/backend/executor/nodeSubplan.c b/src/backend/executor/nodeSubplan.c
index 60d2290030..0189dc66ee 100644
--- a/src/backend/executor/nodeSubplan.c
+++ b/src/backend/executor/nodeSubplan.c
@@ -35,6 +35,7 @@
 #include "miscadmin.h"
 #include "nodes/makefuncs.h"
 #include "nodes/nodeFuncs.h"
+#include "optimizer/optimizer.h"
 #include "utils/array.h"
 #include "utils/lsyscache.h"
 #include "utils/memutils.h"
@@ -498,7 +499,7 @@ buildSubPlanHash(SubPlanState *node, ExprContext *econtext)
 	node->havehashrows = false;
 	node->havenullrows = false;
 
-	nbuckets = (long) Min(planstate->plan->plan_rows, (double) LONG_MAX);
+	nbuckets = clamp_cardinality_to_long(planstate->plan->plan_rows);
 	if (nbuckets < 1)
 		nbuckets = 1;
 
diff --git a/src/backend/optimizer/path/costsize.c b/src/backend/optimizer/path/costsize.c
index ed98ba7dbd..322b959993 100644
--- a/src/backend/optimizer/path/costsize.c
+++ b/src/backend/optimizer/path/costsize.c
@@ -215,6 +215,31 @@ clamp_row_est(double nrows)
 	return nrows;
 }
 
+/*
+ * clamp_cardinality_to_long
+ *		Cast a cardinality value to a sane long value.
+ */
+long
+clamp_cardinality_to_long(Cardinality x)
+{
+	long y;
+
+	/* Avoid infinite and NaN values by the same reason as described above. */
+	if (x < 0)
+		return 0;
+	if (isnan(x))
+		return LONG_MAX;
+
+	/*
+	 * LONG_MAX value, 2^n-1, hasn't direct representation in double. So, it
+	 * can be converted to greater (2^n) or a lesser number which has
+	 * representation in double format. So, result of this calculation may have
+	 * slight fluency across different compilers, but doesn't matter much.
+	 */
+	y = (x < (double) LONG_MAX) ? (long) x : LONG_MAX;
+
+	return y;
+}
 
 /*
  * cost_seqscan
diff --git a/src/backend/optimizer/plan/createplan.c b/src/backend/optimizer/plan/createplan.c
index f4cc56039c..a7de85a869 100644
--- a/src/backend/optimizer/plan/createplan.c
+++ b/src/backend/optimizer/plan/createplan.c
@@ -2724,7 +2724,7 @@ create_setop_plan(PlannerInfo *root, SetOpPath *best_path, int flags)
 								  flags | CP_LABEL_TLIST);
 
 	/* Convert numGroups to long int --- but 'ware overflow! */
-	numGroups = (long) Min(best_path->numGroups, (double) LONG_MAX);
+	numGroups = clamp_cardinality_to_long(best_path->numGroups);
 
 	plan = make_setop(best_path->cmd,
 					  best_path->strategy,
@@ -2761,7 +2761,7 @@ create_recursiveunion_plan(PlannerInfo *root, RecursiveUnionPath *best_path)
 	tlist = build_path_tlist(root, &best_path->path);
 
 	/* Convert numGroups to long int --- but 'ware overflow! */
-	numGroups = (long) Min(best_path->numGroups, (double) LONG_MAX);
+	numGroups = clamp_cardinality_to_long(best_path->numGroups);
 
 	plan = make_recursive_union(tlist,
 								leftplan,
@@ -6554,7 +6554,7 @@ make_agg(List *tlist, List *qual,
 	long		numGroups;
 
 	/* Reduce to long, but 'ware overflow! */
-	numGroups = (long) Min(dNumGroups, (double) LONG_MAX);
+	numGroups = clamp_cardinality_to_long(dNumGroups);
 
 	node->aggstrategy = aggstrategy;
 	node->aggsplit = aggsplit;
diff --git a/src/include/optimizer/optimizer.h b/src/include/optimizer/optimizer.h
index d40ce2eae1..7be1e5906b 100644
--- a/src/include/optimizer/optimizer.h
+++ b/src/include/optimizer/optimizer.h
@@ -95,6 +95,7 @@ extern PGDLLIMPORT double recursive_worktable_factor;
 extern PGDLLIMPORT int effective_cache_size;
 
 extern double clamp_row_est(double nrows);
+extern long clamp_cardinality_to_long(Cardinality x);
 
 /* in path/indxpath.c: */
 
diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out
index 601047fa3d..0037981fc4 100644
--- a/src/test/regress/expected/aggregates.out
+++ b/src/test/regress/expected/aggregates.out
@@ -3049,10 +3049,51 @@ set work_mem to default;
 ----+----+----
 (0 rows)
 
+CREATE TABLE agg_group_5 AS (
+	SELECT 1 AS x, NULL AS x1 FROM generate_series(1,1e5) x);
+ANALYZE agg_group_5;
+-- Stress test for grouping: exceed number of 2^53 groups to check
+-- long <-> double conversion.
+EXPLAIN (COSTS OFF) SELECT t1.x1 FROM
+  agg_group_5 t1,agg_group_5 t2,agg_group_5 t3,agg_group_5 t4,agg_group_5 t5,
+  agg_group_5 t6,agg_group_5 t7,agg_group_5 t8,agg_group_5 t9
+GROUP BY (t1.x1,t2.x1,t3.x1,t4.x1,t5.x1,t6.x1,t7.x1,t8.x1,t9.x1);
+                                       QUERY PLAN                                       
+----------------------------------------------------------------------------------------
+ HashAggregate
+   Group Key: t1.x1, t2.x1, t3.x1, t4.x1, t5.x1, t6.x1, t7.x1, t8.x1, t9.x1
+   ->  Nested Loop
+         ->  Nested Loop
+               ->  Nested Loop
+                     ->  Nested Loop
+                           ->  Nested Loop
+                                 ->  Nested Loop
+                                       ->  Nested Loop
+                                             ->  Nested Loop
+                                                   ->  Seq Scan on agg_group_5 t1
+                                                   ->  Materialize
+                                                         ->  Seq Scan on agg_group_5 t2
+                                             ->  Materialize
+                                                   ->  Seq Scan on agg_group_5 t3
+                                       ->  Materialize
+                                             ->  Seq Scan on agg_group_5 t4
+                                 ->  Materialize
+                                       ->  Seq Scan on agg_group_5 t5
+                           ->  Materialize
+                                 ->  Seq Scan on agg_group_5 t6
+                     ->  Materialize
+                           ->  Seq Scan on agg_group_5 t7
+               ->  Materialize
+                     ->  Seq Scan on agg_group_5 t8
+         ->  Materialize
+               ->  Seq Scan on agg_group_5 t9
+(27 rows)
+
 drop table agg_group_1;
 drop table agg_group_2;
 drop table agg_group_3;
 drop table agg_group_4;
+drop table agg_group_5;
 drop table agg_hash_1;
 drop table agg_hash_2;
 drop table agg_hash_3;
diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql
index c6e0d7ba2b..e19d74fef5 100644
--- a/src/test/regress/sql/aggregates.sql
+++ b/src/test/regress/sql/aggregates.sql
@@ -1351,10 +1351,22 @@ set work_mem to default;
   union all
 (select * from agg_group_4 except select * from agg_hash_4);
 
+CREATE TABLE agg_group_5 AS (
+	SELECT 1 AS x, NULL AS x1 FROM generate_series(1,1e5) x);
+ANALYZE agg_group_5;
+
+-- Stress test for grouping: exceed number of 2^53 groups to check
+-- long <-> double conversion.
+EXPLAIN (COSTS OFF) SELECT t1.x1 FROM
+  agg_group_5 t1,agg_group_5 t2,agg_group_5 t3,agg_group_5 t4,agg_group_5 t5,
+  agg_group_5 t6,agg_group_5 t7,agg_group_5 t8,agg_group_5 t9
+GROUP BY (t1.x1,t2.x1,t3.x1,t4.x1,t5.x1,t6.x1,t7.x1,t8.x1,t9.x1);
+
 drop table agg_group_1;
 drop table agg_group_2;
 drop table agg_group_3;
 drop table agg_group_4;
+drop table agg_group_5;
 drop table agg_hash_1;
 drop table agg_hash_2;
 drop table agg_hash_3;
-- 
2.34.1

