From c1a619500a41dee654868efe480b0b8a7c67dc4f Mon Sep 17 00:00:00 2001
From: Jeff Davis <jeff@j-davis.com>
Date: Mon, 3 Feb 2025 15:13:03 -0800
Subject: [PATCH v2 2/2] Minor refactor of hash_agg_set_limits().

Avoid implicit assumption that input_groups is greater than or equal
to one. The assumption appears to be true, but it's easier to read
without it.

Also, a branch that avoided underflow while calculating a max
was hard to read; replace with ssize_t types and a Max().
---
 src/backend/executor/nodeAgg.c | 18 ++++++------------
 1 file changed, 6 insertions(+), 12 deletions(-)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 35cf18e5282..fb467ec9877 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -1809,8 +1809,8 @@ hash_agg_set_limits(double hashentrysize, double input_groups, int used_bits,
 					int *num_partitions)
 {
 	int			npartitions;
-	Size		partition_mem;
-	Size		hash_mem_limit = get_hash_memory_limit();
+	ssize_t		partition_mem;
+	ssize_t		hash_mem_limit = get_hash_memory_limit();
 
 	/* if not expected to spill, use all of hash_mem */
 	if (input_groups * hashentrysize <= hash_mem_limit)
@@ -1818,7 +1818,7 @@ hash_agg_set_limits(double hashentrysize, double input_groups, int used_bits,
 		if (num_partitions != NULL)
 			*num_partitions = 0;
 		*mem_limit = hash_mem_limit;
-		*ngroups_limit = hash_mem_limit / hashentrysize;
+		*ngroups_limit = Max(*mem_limit / hashentrysize, 1);
 		return;
 	}
 
@@ -1841,17 +1841,11 @@ hash_agg_set_limits(double hashentrysize, double input_groups, int used_bits,
 	/*
 	 * Don't set the limit below 3/4 of hash_mem. In that case, we are at the
 	 * minimum number of partitions, so we aren't going to dramatically exceed
-	 * work mem anyway.
+	 * work mem anyway. Use ssize_t to avoid underflow during subtraction.
 	 */
-	if (hash_mem_limit > 4 * partition_mem)
-		*mem_limit = hash_mem_limit - partition_mem;
-	else
-		*mem_limit = hash_mem_limit * 0.75;
+	*mem_limit = Max(hash_mem_limit - partition_mem, hash_mem_limit * 0.75);
 
-	if (*mem_limit > hashentrysize)
-		*ngroups_limit = *mem_limit / hashentrysize;
-	else
-		*ngroups_limit = 1;
+	*ngroups_limit = Max(*mem_limit / hashentrysize, 1);
 }
 
 /*
-- 
2.34.1

