From edcaac0705193acea33c4423bf0d59128219a46f Mon Sep 17 00:00:00 2001
From: Tomas Vondra <tomas@vondra.me>
Date: Tue, 31 Dec 2024 17:09:40 +0100
Subject: [PATCH v20241231-adjust-limit] Account for batch files in hash join
 spilling

Hash joins try to limit the amount of memory used by the Hash node by
only keeping a single batch in memory and spilling future batches to
disk. Unfortunately, the implementation does not account for the files
used for spilling, which can lead to issues with many batches. Each file
keeps a BLCKSZ buffer in memory, and we need 2*nbatches of those files,
so with many batches this may use substatial amounts of memory.

The hash join code however assumes adding batches is virtually free (in
terms of memory needed), ignoring this issue. It increases the number of
batches, possibly keeping the current batch within the limit, but ends
up using much more memory for the files.

This can be particularly painful with adversary data sets, with a batch
that can't be split. This may happen due to hash collisions (overlaps in
the part used to calculate "batch"), or a value with many duplicities
that however didn't make it to the MCV list (and thus the skew table
can't help). In these cases the hash can get into a cycle of increasing
the number of batches, often reaching 256k or 512k batches before
exhausting available hash space (32-bits).

If this happens, there's not much point in enforcing the original memory
limit. That's simply not feasible - especially in the case of a single
batch exceeding the allowed space.

Instead, the best we can do is relaxing the limit, and focusing on using
as little total memory as possible. By allowing the batches to be
larger, we reduce the number of batch files. The adjustment formula is
based on the observation that doubling the number of batches doubles the
amount of memory needed for the files, while cutting the batch size in
half. This defines the "break even" point for the next batch increase.
---
 src/backend/executor/nodeHash.c         | 153 +++++++++++++++++++++---
 src/test/regress/expected/join_hash.out |   4 +-
 src/test/regress/sql/join_hash.sql      |   4 +-
 3 files changed, 142 insertions(+), 19 deletions(-)

diff --git a/src/backend/executor/nodeHash.c b/src/backend/executor/nodeHash.c
index 3e22d50e3a4..680d2897738 100644
--- a/src/backend/executor/nodeHash.c
+++ b/src/backend/executor/nodeHash.c
@@ -54,6 +54,9 @@ static void ExecHashSkewTableInsert(HashJoinTable hashtable,
 									uint32 hashvalue,
 									int bucketNumber);
 static void ExecHashRemoveNextSkewBucket(HashJoinTable hashtable);
+static void ExecHashUpdateSpacePeak(HashJoinTable hashtable);
+static bool ExecHashExceededMemoryLimit(HashJoinTable hashtable);
+static void ExecHashAdjustMemoryLimit(HashJoinTable hashtable);
 
 static void *dense_alloc(HashJoinTable hashtable, Size size);
 static HashJoinTuple ExecParallelHashTupleAlloc(HashJoinTable hashtable,
@@ -199,10 +202,8 @@ MultiExecPrivateHash(HashState *node)
 	if (hashtable->nbuckets != hashtable->nbuckets_optimal)
 		ExecHashIncreaseNumBuckets(hashtable);
 
-	/* Account for the buckets in spaceUsed (reported in EXPLAIN ANALYZE) */
-	hashtable->spaceUsed += hashtable->nbuckets * sizeof(HashJoinTuple);
-	if (hashtable->spaceUsed > hashtable->spacePeak)
-		hashtable->spacePeak = hashtable->spaceUsed;
+	/* update info about peak memory usage */
+	ExecHashUpdateSpacePeak(hashtable);
 
 	hashtable->partialTuples = hashtable->totalTuples;
 }
@@ -1036,6 +1037,9 @@ ExecHashIncreaseNumBatches(HashJoinTable hashtable)
 		   hashtable, nfreed, ninmemory, hashtable->spaceUsed);
 #endif
 
+	/* adjust the memory limit for the new nbatches etc. */
+	ExecHashAdjustMemoryLimit(hashtable);
+
 	/*
 	 * If we dumped out either all or none of the tuples in the table, disable
 	 * further expansion of nbatch.  This situation implies that we have
@@ -1673,11 +1677,12 @@ ExecHashTableInsert(HashJoinTable hashtable,
 
 		/* Account for space used, and back off if we've used too much */
 		hashtable->spaceUsed += hashTupleSize;
-		if (hashtable->spaceUsed > hashtable->spacePeak)
-			hashtable->spacePeak = hashtable->spaceUsed;
-		if (hashtable->spaceUsed +
-			hashtable->nbuckets_optimal * sizeof(HashJoinTuple)
-			> hashtable->spaceAllowed)
+
+		/* update info about peak memory usage */
+		ExecHashUpdateSpacePeak(hashtable);
+
+		/* Should we add more batches, to enforce memory limit? */
+		if (ExecHashExceededMemoryLimit(hashtable))
 			ExecHashIncreaseNumBatches(hashtable);
 	}
 	else
@@ -1843,6 +1848,120 @@ ExecHashGetBucketAndBatch(HashJoinTable hashtable,
 	}
 }
 
+/*
+ * ExecHashUpdateSpacePeak
+ *		Update information about peak memory usage.
+ *
+ * This considers tuples added to the hash table, buckets of the hash table
+ * itself, and also the bufferer batch files on both the inner and outer side.
+ * Each file has a BLCKSZ buffer, so with enough batches this may actually
+ * represent most of the memory used by the hash join node.
+ */
+static void
+ExecHashUpdateSpacePeak(HashJoinTable hashtable)
+{
+	Size	spaceUsed = hashtable->spaceUsed;
+
+	/* buckets of the hash table */
+	spaceUsed += hashtable->nbuckets * sizeof(HashJoinTuple);
+
+	/* buffered batch files (inner + outer), each has a BLCKSZ buffer */
+	spaceUsed += hashtable->nbatch * sizeof(PGAlignedBlock) * 2;
+
+	/* if we exceeded the current peak, remember the new one */
+	if (spaceUsed > hashtable->spacePeak)
+		hashtable->spacePeak = spaceUsed;
+}
+
+/*
+ * ExecHashMemoryLimitExceeded
+ *		Check if the amount of memory used exceeds spaceAllowed.
+ *
+ * Check if the total amount of space used by the hash join exceeds the
+ * current value of spaceAllowed, and we should try to increase the number
+ * of batches.
+ *
+ * We need to consider both the data added to the hash and the hashtable
+ * itself (i.e. buckets), but also the files used for future batches.
+ * Each batch needs a file for inner/outer side, so we need (2*nbatch)
+ * files in total, and each BufFile has a BLCKSZ buffer. If we ignored
+ * the files and simply doubled the number of batches, we could easily
+ * increase the total amount of memory because while we expect to cut the
+ * batch size in half to, doubling the number of batches also doubles the
+ * amount of memory allocated by BufFile.
+ *
+ * That means doubling the number of batches is pointless when
+ *
+ *		(spaceUsed / 2) < 2 * (nbatches * sizeof(BufFile))
+ *
+ * because it would result in allocating more memory than it saves.
+ *
+ * This is a temporary decision - we can't stop adding batches entirely,
+ * just until the hash table grows enough to make it a win again.
+ */
+static bool
+ExecHashExceededMemoryLimit(HashJoinTable hashtable)
+{
+	return (hashtable->spaceUsed +
+			hashtable->nbuckets_optimal * sizeof(HashJoinTuple) +
+			hashtable->nbatch * sizeof(PGAlignedBlock) * 2
+			> hashtable->spaceAllowed);
+}
+
+/*
+ * ExecHashAdjustMemoryLimit
+ *		Adjust the memory limit after increasing the number of batches.
+ *
+ * We can't keep the same spaceAllowed value, because as we keep adding
+ * batches we're guaranteed to exceed the older values simply thanks to
+ * the BufFile allocations.
+ *
+ * Instead, we consider the "break even" threshold for the current number
+ * of batches, add a bit of slack (so that we don't get into a cycle of
+ * incrementing number of batches), and calculate the new limit from that.
+ *
+ * For well estimated cases this should do nothing, as the batches are
+ * expected to account only for a small fraction of work_mem. But if we
+ * significantly underestimate the number of batches, or if one batch
+ * happens to be very large, this will relax the limit a bit.
+ *
+ * This means we won't enforce the work_mem limit strictly - but without
+ * adjusting the limit that wouldn't be the case either, we'd just use
+ * a lot of memory for the BufFiles without accounting for that. This
+ * way we do our best to minimize the amount of memory used.
+ */
+static void
+ExecHashAdjustMemoryLimit(HashJoinTable hashtable)
+{
+	Size	newSpaceAllowed;
+
+	/*
+	 * The next time increasing the number of batches "breaks even" is when
+	 *
+	 * (spaceUsed / 2) == (2 * nbatches * BLCKSZ)
+	 *
+	 * which means
+	 *
+	 * spaceUsed == (4 * nbatches * BLCKSZ)
+	 *
+	 * However, this is a "break even" threshold, when we shrink the hash
+	 * table just enough to compensate the new batches, and we'd hit the
+	 * new threshold almost immediately again. In practice we want to free
+	 * more memory to allow new data before having to increase the number
+	 * of batches again. So we allow 25% more space.
+	 */
+	newSpaceAllowed
+		= 1.25 * (4 * hashtable->nbatch * sizeof(PGAlignedBlock));
+
+	/* but also account for the buckets, and the current batch files */
+	newSpaceAllowed += hashtable->nbuckets_optimal * sizeof(HashJoinTuple);
+	newSpaceAllowed += (2 * hashtable->nbatch * sizeof(PGAlignedBlock));
+
+	/* shouldn't go down, but use Max() to make sure */
+	hashtable->spaceAllowed = Max(hashtable->spaceAllowed,
+								  newSpaceAllowed);
+}
+
 /*
  * ExecScanHashBucket
  *		scan a hash bucket for matches to the current outer tuple
@@ -2349,8 +2468,9 @@ ExecHashBuildSkewHash(HashState *hashstate, HashJoinTable hashtable,
 			+ mcvsToUse * sizeof(int);
 		hashtable->spaceUsedSkew += nbuckets * sizeof(HashSkewBucket *)
 			+ mcvsToUse * sizeof(int);
-		if (hashtable->spaceUsed > hashtable->spacePeak)
-			hashtable->spacePeak = hashtable->spaceUsed;
+
+		/* refresh info about peak memory usage */
+		ExecHashUpdateSpacePeak(hashtable);
 
 		/*
 		 * Create a skew bucket for each MCV hash value.
@@ -2399,8 +2519,9 @@ ExecHashBuildSkewHash(HashState *hashstate, HashJoinTable hashtable,
 			hashtable->nSkewBuckets++;
 			hashtable->spaceUsed += SKEW_BUCKET_OVERHEAD;
 			hashtable->spaceUsedSkew += SKEW_BUCKET_OVERHEAD;
-			if (hashtable->spaceUsed > hashtable->spacePeak)
-				hashtable->spacePeak = hashtable->spaceUsed;
+
+			/* refresh info about peak memory usage */
+			ExecHashUpdateSpacePeak(hashtable);
 		}
 
 		free_attstatsslot(&sslot);
@@ -2489,8 +2610,10 @@ ExecHashSkewTableInsert(HashJoinTable hashtable,
 	/* Account for space used, and back off if we've used too much */
 	hashtable->spaceUsed += hashTupleSize;
 	hashtable->spaceUsedSkew += hashTupleSize;
-	if (hashtable->spaceUsed > hashtable->spacePeak)
-		hashtable->spacePeak = hashtable->spaceUsed;
+
+	/* refresh info about peak memory usage */
+	ExecHashUpdateSpacePeak(hashtable);
+
 	while (hashtable->spaceUsedSkew > hashtable->spaceAllowedSkew)
 		ExecHashRemoveNextSkewBucket(hashtable);
 
diff --git a/src/test/regress/expected/join_hash.out b/src/test/regress/expected/join_hash.out
index 4fc34a0e72a..8d54822eb8c 100644
--- a/src/test/regress/expected/join_hash.out
+++ b/src/test/regress/expected/join_hash.out
@@ -198,7 +198,7 @@ rollback to settings;
 -- non-parallel
 savepoint settings;
 set local max_parallel_workers_per_gather = 0;
-set local work_mem = '128kB';
+set local work_mem = '512kB';
 set local hash_mem_multiplier = 1.0;
 explain (costs off)
   select count(*) from simple r join simple s using (id);
@@ -232,7 +232,7 @@ rollback to settings;
 -- parallel with parallel-oblivious hash join
 savepoint settings;
 set local max_parallel_workers_per_gather = 2;
-set local work_mem = '128kB';
+set local work_mem = '512kB';
 set local hash_mem_multiplier = 1.0;
 set local enable_parallel_hash = off;
 explain (costs off)
diff --git a/src/test/regress/sql/join_hash.sql b/src/test/regress/sql/join_hash.sql
index 6b0688ab0a6..ca8758900aa 100644
--- a/src/test/regress/sql/join_hash.sql
+++ b/src/test/regress/sql/join_hash.sql
@@ -145,7 +145,7 @@ rollback to settings;
 -- non-parallel
 savepoint settings;
 set local max_parallel_workers_per_gather = 0;
-set local work_mem = '128kB';
+set local work_mem = '512kB';
 set local hash_mem_multiplier = 1.0;
 explain (costs off)
   select count(*) from simple r join simple s using (id);
@@ -160,7 +160,7 @@ rollback to settings;
 -- parallel with parallel-oblivious hash join
 savepoint settings;
 set local max_parallel_workers_per_gather = 2;
-set local work_mem = '128kB';
+set local work_mem = '512kB';
 set local hash_mem_multiplier = 1.0;
 set local enable_parallel_hash = off;
 explain (costs off)
-- 
2.47.1

