From 030fca4e56132dd516b8aba23ed36c47311b33f5 Mon Sep 17 00:00:00 2001
From: Tomas Vondra <tomas@vondra.me>
Date: Tue, 31 Dec 2024 17:08:17 +0100
Subject: [PATCH v20241231-multi-spill] Hash join with a multiple spill files

We can't use arbitrary number of batches in the hash join, because that
can use substantial memory use, ignored by the memory limit. Instead,
decide how many batches we can keep in memory, and open files only for
this "slice" of batches. For future batches we keep on spill file per
slice.

Then use these spill files after "switching" to the first batch in each
of those non-in-memory slices.

NB: Compared to the "single-spill" patch, this is less strict, as we
need one spill file per slice, and the number of slices may grow during
execution. It's better than the current behavior, because even with
modest work_mem values we can fit 32-128 batches into memory, thus
thus reducing the number of files to 1/32x - 1/128x. But with many
batches (as with batch explosion) this can still use substantial
amounts of memory (certainly more than work_mem).

Ultimately, it behaves similarly to the "adjustment" patch, but with
more complexity. And it can't handle the "oversized batch" really well,
because the limit is not adjusted.
---
 src/backend/commands/explain.c        |   6 +-
 src/backend/executor/nodeHash.c       | 256 +++++++++++++++++++++++---
 src/backend/executor/nodeHashjoin.c   |  75 +++++---
 src/backend/optimizer/path/costsize.c |   2 +
 src/include/executor/hashjoin.h       |   4 +
 src/include/executor/nodeHash.h       |   8 +
 src/include/nodes/execnodes.h         |   1 +
 7 files changed, 300 insertions(+), 52 deletions(-)

diff --git a/src/backend/commands/explain.c b/src/backend/commands/explain.c
index a201ed30824..aa3fbeda3dd 100644
--- a/src/backend/commands/explain.c
+++ b/src/backend/commands/explain.c
@@ -3466,19 +3466,23 @@ show_hash_info(HashState *hashstate, ExplainState *es)
 								   hinstrument.nbatch, es);
 			ExplainPropertyInteger("Original Hash Batches", NULL,
 								   hinstrument.nbatch_original, es);
+			ExplainPropertyInteger("In-Memory Hash Batches", NULL,
+								   hinstrument.nbatch_original, es);
 			ExplainPropertyUInteger("Peak Memory Usage", "kB",
 									spacePeakKb, es);
 		}
 		else if (hinstrument.nbatch_original != hinstrument.nbatch ||
+				 hinstrument.nbatch_inmemory != hinstrument.nbatch ||
 				 hinstrument.nbuckets_original != hinstrument.nbuckets)
 		{
 			ExplainIndentText(es);
 			appendStringInfo(es->str,
-							 "Buckets: %d (originally %d)  Batches: %d (originally %d)  Memory Usage: " UINT64_FORMAT "kB\n",
+							 "Buckets: %d (originally %d)  Batches: %d (originally %d, in-memory %d)  Memory Usage: " UINT64_FORMAT "kB\n",
 							 hinstrument.nbuckets,
 							 hinstrument.nbuckets_original,
 							 hinstrument.nbatch,
 							 hinstrument.nbatch_original,
+							 hinstrument.nbatch_inmemory,
 							 spacePeakKb);
 		}
 		else
diff --git a/src/backend/executor/nodeHash.c b/src/backend/executor/nodeHash.c
index 3e22d50e3a4..e9d482577ac 100644
--- a/src/backend/executor/nodeHash.c
+++ b/src/backend/executor/nodeHash.c
@@ -80,6 +80,7 @@ static bool ExecParallelHashTuplePrealloc(HashJoinTable hashtable,
 static void ExecParallelHashMergeCounters(HashJoinTable hashtable);
 static void ExecParallelHashCloseBatchAccessors(HashJoinTable hashtable);
 
+static void ExecHashUpdateSpacePeak(HashJoinTable hashtable);
 
 /* ----------------------------------------------------------------
  *		ExecHash
@@ -199,10 +200,8 @@ MultiExecPrivateHash(HashState *node)
 	if (hashtable->nbuckets != hashtable->nbuckets_optimal)
 		ExecHashIncreaseNumBuckets(hashtable);
 
-	/* Account for the buckets in spaceUsed (reported in EXPLAIN ANALYZE) */
-	hashtable->spaceUsed += hashtable->nbuckets * sizeof(HashJoinTuple);
-	if (hashtable->spaceUsed > hashtable->spacePeak)
-		hashtable->spacePeak = hashtable->spaceUsed;
+	/* refresh info about peak used memory */
+	ExecHashUpdateSpacePeak(hashtable);
 
 	hashtable->partialTuples = hashtable->totalTuples;
 }
@@ -451,6 +450,7 @@ ExecHashTableCreate(HashState *state)
 	size_t		space_allowed;
 	int			nbuckets;
 	int			nbatch;
+	int			nbatch_inmemory;
 	double		rows;
 	int			num_skew_mcvs;
 	int			log2_nbuckets;
@@ -477,7 +477,8 @@ ExecHashTableCreate(HashState *state)
 							state->parallel_state != NULL ?
 							state->parallel_state->nparticipants - 1 : 0,
 							&space_allowed,
-							&nbuckets, &nbatch, &num_skew_mcvs);
+							&nbuckets, &nbatch, &nbatch_inmemory,
+							&num_skew_mcvs);
 
 	/* nbuckets must be a power of 2 */
 	log2_nbuckets = my_log2(nbuckets);
@@ -503,6 +504,7 @@ ExecHashTableCreate(HashState *state)
 	hashtable->nSkewBuckets = 0;
 	hashtable->skewBucketNums = NULL;
 	hashtable->nbatch = nbatch;
+	hashtable->nbatch_inmemory = nbatch_inmemory;
 	hashtable->curbatch = 0;
 	hashtable->nbatch_original = nbatch;
 	hashtable->nbatch_outstart = nbatch;
@@ -512,6 +514,8 @@ ExecHashTableCreate(HashState *state)
 	hashtable->skewTuples = 0;
 	hashtable->innerBatchFile = NULL;
 	hashtable->outerBatchFile = NULL;
+	hashtable->innerOverflowFiles = NULL;
+	hashtable->outerOverflowFiles = NULL;
 	hashtable->spaceUsed = 0;
 	hashtable->spacePeak = 0;
 	hashtable->spaceAllowed = space_allowed;
@@ -552,6 +556,7 @@ ExecHashTableCreate(HashState *state)
 	if (nbatch > 1 && hashtable->parallel_state == NULL)
 	{
 		MemoryContext oldctx;
+		int	cnt = Min(nbatch, nbatch_inmemory);
 
 		/*
 		 * allocate and initialize the file arrays in hashCxt (not needed for
@@ -559,8 +564,19 @@ ExecHashTableCreate(HashState *state)
 		 */
 		oldctx = MemoryContextSwitchTo(hashtable->spillCxt);
 
-		hashtable->innerBatchFile = palloc0_array(BufFile *, nbatch);
-		hashtable->outerBatchFile = palloc0_array(BufFile *, nbatch);
+		hashtable->innerBatchFile = palloc0_array(BufFile *, cnt);
+		hashtable->outerBatchFile = palloc0_array(BufFile *, cnt);
+
+		/* also allocate files for overflow batches */
+		if (nbatch > nbatch_inmemory)
+		{
+			int nslices = (nbatch / nbatch_inmemory);
+
+			Assert(nslices % 2 == 0);
+
+			hashtable->innerOverflowFiles = palloc0_array(BufFile *, nslices + 1);
+			hashtable->outerOverflowFiles = palloc0_array(BufFile *, nslices + 1);
+		}
 
 		MemoryContextSwitchTo(oldctx);
 
@@ -661,6 +677,7 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 						size_t *space_allowed,
 						int *numbuckets,
 						int *numbatches,
+						int *numbatches_inmemory,
 						int *num_skew_mcvs)
 {
 	int			tupsize;
@@ -669,6 +686,7 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 	size_t		bucket_bytes;
 	size_t		max_pointers;
 	int			nbatch = 1;
+	int			nbatch_inmemory = 1;
 	int			nbuckets;
 	double		dbuckets;
 
@@ -811,6 +829,7 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 									space_allowed,
 									numbuckets,
 									numbatches,
+									numbatches_inmemory,
 									num_skew_mcvs);
 			return;
 		}
@@ -848,11 +867,24 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 		nbatch = pg_nextpower2_32(Max(2, minbatch));
 	}
 
+	/*
+	 * See how many batches we can fit into memory (driven mostly by size
+	 * of BufFile, with PGAlignedBlock being the largest part of that).
+	 * We need one BufFile for inner and outer side, so we count it twice
+	 * for each batch, and we stop once we exceed (work_mem/2).
+	 */
+	while ((nbatch_inmemory * 2) * sizeof(PGAlignedBlock) * 2
+			<= (work_mem * 1024L / 2))
+		nbatch_inmemory *= 2;
+
+	// nbatch_inmemory = nbatch;
+
 	Assert(nbuckets > 0);
 	Assert(nbatch > 0);
 
 	*numbuckets = nbuckets;
 	*numbatches = nbatch;
+	*numbatches_inmemory = nbatch_inmemory;
 }
 
 
@@ -874,13 +906,27 @@ ExecHashTableDestroy(HashJoinTable hashtable)
 	 */
 	if (hashtable->innerBatchFile != NULL)
 	{
-		for (i = 1; i < hashtable->nbatch; i++)
+		int n = Min(hashtable->nbatch, hashtable->nbatch_inmemory);
+
+		for (i = 1; i < n; i++)
 		{
 			if (hashtable->innerBatchFile[i])
 				BufFileClose(hashtable->innerBatchFile[i]);
 			if (hashtable->outerBatchFile[i])
 				BufFileClose(hashtable->outerBatchFile[i]);
 		}
+
+		/* number of batch slices */
+		n = (hashtable->nbatch / hashtable->nbatch_inmemory) + 1;
+
+		for (i = 1; i < n; i++)
+		{
+			if (hashtable->innerOverflowFiles[i])
+				BufFileClose(hashtable->innerOverflowFiles[i]);
+
+			if (hashtable->outerOverflowFiles[i])
+				BufFileClose(hashtable->outerOverflowFiles[i]);
+		}
 	}
 
 	/* Release working memory (batchCxt is a child, so it goes away too) */
@@ -923,11 +969,14 @@ ExecHashIncreaseNumBatches(HashJoinTable hashtable)
 
 	if (hashtable->innerBatchFile == NULL)
 	{
+		/* XXX nbatch=1, no need to deal with nbatch_inmemory here */
+		int nbatch_tmp = Min(nbatch, hashtable->nbatch_inmemory);
+
 		MemoryContext oldcxt = MemoryContextSwitchTo(hashtable->spillCxt);
 
 		/* we had no file arrays before */
-		hashtable->innerBatchFile = palloc0_array(BufFile *, nbatch);
-		hashtable->outerBatchFile = palloc0_array(BufFile *, nbatch);
+		hashtable->innerBatchFile = palloc0_array(BufFile *, nbatch_tmp);
+		hashtable->outerBatchFile = palloc0_array(BufFile *, nbatch_tmp);
 
 		MemoryContextSwitchTo(oldcxt);
 
@@ -936,9 +985,35 @@ ExecHashIncreaseNumBatches(HashJoinTable hashtable)
 	}
 	else
 	{
+		int nbatch_tmp = Min(nbatch, hashtable->nbatch_inmemory);
+		int	oldnbatch_tmp = Min(oldnbatch, hashtable->nbatch_inmemory);
+
 		/* enlarge arrays and zero out added entries */
-		hashtable->innerBatchFile = repalloc0_array(hashtable->innerBatchFile, BufFile *, oldnbatch, nbatch);
-		hashtable->outerBatchFile = repalloc0_array(hashtable->outerBatchFile, BufFile *, oldnbatch, nbatch);
+		hashtable->innerBatchFile = repalloc0_array(hashtable->innerBatchFile, BufFile *, oldnbatch_tmp, nbatch_tmp);
+		hashtable->outerBatchFile = repalloc0_array(hashtable->outerBatchFile, BufFile *, oldnbatch_tmp, nbatch_tmp);
+
+		if (nbatch > hashtable->nbatch_inmemory)
+		{
+			int nslices = (nbatch / hashtable->nbatch_inmemory);
+			int oldnslices = (oldnbatch / hashtable->nbatch_inmemory);
+
+			Assert(nslices > 1);
+			Assert(nslices % 2 == 0);
+			Assert((oldnslices == 1) || (oldnslices % 2 == 0));
+			Assert(oldnslices <= nslices);
+
+			if (hashtable->innerOverflowFiles == NULL)
+			{
+				hashtable->innerOverflowFiles = palloc0_array(BufFile *, nslices + 1);
+				hashtable->outerOverflowFiles = palloc0_array(BufFile *, nslices + 1);
+			}
+			else
+			{
+				hashtable->innerOverflowFiles = repalloc0_array(hashtable->innerOverflowFiles, BufFile *, oldnslices + 1, nslices + 1);
+				hashtable->outerOverflowFiles = repalloc0_array(hashtable->outerOverflowFiles, BufFile *, oldnslices + 1, nslices + 1);
+			}
+		}
+
 	}
 
 	hashtable->nbatch = nbatch;
@@ -1008,11 +1083,18 @@ ExecHashIncreaseNumBatches(HashJoinTable hashtable)
 			}
 			else
 			{
+				BufFile **batchFile;
+
 				/* dump it out */
 				Assert(batchno > curbatch);
+
+				batchFile = ExecHashGetBatchFile(hashtable, batchno,
+												 hashtable->innerBatchFile,
+												 hashtable->innerOverflowFiles);
+
 				ExecHashJoinSaveTuple(HJTUPLE_MINTUPLE(hashTuple),
 									  hashTuple->hashvalue,
-									  &hashtable->innerBatchFile[batchno],
+									  batchFile,
 									  hashtable);
 
 				hashtable->spaceUsed -= hashTupleSize;
@@ -1673,22 +1755,33 @@ ExecHashTableInsert(HashJoinTable hashtable,
 
 		/* Account for space used, and back off if we've used too much */
 		hashtable->spaceUsed += hashTupleSize;
-		if (hashtable->spaceUsed > hashtable->spacePeak)
-			hashtable->spacePeak = hashtable->spaceUsed;
+
+		/* refresh info about peak used memory */
+		ExecHashUpdateSpacePeak(hashtable);
+
+		/* Consider increasing number of batches if we filled work_mem. */
 		if (hashtable->spaceUsed +
-			hashtable->nbuckets_optimal * sizeof(HashJoinTuple)
+			hashtable->nbuckets_optimal * sizeof(HashJoinTuple) +
+			Min(hashtable->nbatch, hashtable->nbatch_inmemory) * sizeof(PGAlignedBlock) * 2	/* inner + outer */
 			> hashtable->spaceAllowed)
 			ExecHashIncreaseNumBatches(hashtable);
 	}
 	else
 	{
+		BufFile **batchFile;
+
 		/*
 		 * put the tuple into a temp file for later batches
 		 */
 		Assert(batchno > hashtable->curbatch);
+
+		batchFile = ExecHashGetBatchFile(hashtable, batchno,
+										 hashtable->innerBatchFile,
+										 hashtable->innerOverflowFiles);
+
 		ExecHashJoinSaveTuple(tuple,
 							  hashvalue,
-							  &hashtable->innerBatchFile[batchno],
+							  batchFile,
 							  hashtable);
 	}
 
@@ -1843,6 +1936,108 @@ ExecHashGetBucketAndBatch(HashJoinTable hashtable,
 	}
 }
 
+int
+ExecHashGetBatchIndex(HashJoinTable hashtable, int batchno)
+{
+	int	slice,
+		curslice;
+
+	if (hashtable->nbatch < hashtable->nbatch_inmemory)
+		return batchno;
+
+	slice = batchno / hashtable->nbatch_inmemory;
+	curslice = hashtable->curbatch / hashtable->nbatch_inmemory;
+
+	/* slices can't go backwards */
+	Assert(slice >= curslice);
+
+	/* overflow slice */
+	if (slice > curslice)
+		return -1;
+
+	/* current slice, compute index in the current array */
+	return (batchno % hashtable->nbatch_inmemory);
+}
+
+BufFile **
+ExecHashGetBatchFile(HashJoinTable hashtable, int batchno,
+					 BufFile **batchFiles, BufFile **overflowFiles)
+{
+	int		idx = ExecHashGetBatchIndex(hashtable, batchno);
+
+	/* get the right overflow file */
+	if (idx == -1)
+	{
+		int slice = (batchno / hashtable->nbatch_inmemory);
+
+		return &overflowFiles[slice];
+	}
+
+	/* batch file in the current slice */
+	return &batchFiles[idx];
+}
+
+void
+ExecHashSwitchToNextBatchSlice(HashJoinTable hashtable)
+{
+	int	slice = (hashtable->curbatch / hashtable->nbatch_inmemory);
+
+	memset(hashtable->innerBatchFile, 0,
+		   hashtable->nbatch_inmemory * sizeof(BufFile *));
+
+	hashtable->innerBatchFile[0] = hashtable->innerOverflowFiles[slice];
+	hashtable->innerOverflowFiles[slice] = NULL;
+
+	memset(hashtable->outerBatchFile, 0,
+		   hashtable->nbatch_inmemory * sizeof(BufFile *));
+
+	hashtable->outerBatchFile[0] = hashtable->outerOverflowFiles[slice];
+	hashtable->outerOverflowFiles[slice] = NULL;
+}
+
+int
+ExecHashSwitchToNextBatch(HashJoinTable hashtable)
+{
+	int		batchidx;
+
+	hashtable->curbatch++;
+
+	/* see if we skipped to the next batch slice */
+	batchidx = ExecHashGetBatchIndex(hashtable, hashtable->curbatch);
+
+	/* Can't be -1, current batch is in the current slice by definition. */
+	Assert(batchidx >= 0 && batchidx < hashtable->nbatch_inmemory);
+
+	/*
+	 * If we skipped to the next slice of batches, reset the array of files
+	 * and use the overflow file as the first batch.
+	 */
+	if (batchidx == 0)
+		ExecHashSwitchToNextBatchSlice(hashtable);
+
+	return hashtable->curbatch;
+}
+
+static void
+ExecHashUpdateSpacePeak(HashJoinTable hashtable)
+{
+	Size	spaceUsed = hashtable->spaceUsed;
+
+	/* Account for the buckets in spaceUsed (reported in EXPLAIN ANALYZE) */
+	spaceUsed += hashtable->nbuckets * sizeof(HashJoinTuple);
+
+	/* Account for memory used for batch files (inner + outer) */
+	spaceUsed += Min(hashtable->nbatch, hashtable->nbatch_inmemory) *
+				 sizeof(PGAlignedBlock) * 2;
+
+	/* Account for slice files (inner + outer) */
+	spaceUsed += (hashtable->nbatch / hashtable->nbatch_inmemory) *
+				 sizeof(PGAlignedBlock) * 2;
+
+	if (spaceUsed > hashtable->spacePeak)
+		hashtable->spacePeak = spaceUsed;
+}
+
 /*
  * ExecScanHashBucket
  *		scan a hash bucket for matches to the current outer tuple
@@ -2349,8 +2544,9 @@ ExecHashBuildSkewHash(HashState *hashstate, HashJoinTable hashtable,
 			+ mcvsToUse * sizeof(int);
 		hashtable->spaceUsedSkew += nbuckets * sizeof(HashSkewBucket *)
 			+ mcvsToUse * sizeof(int);
-		if (hashtable->spaceUsed > hashtable->spacePeak)
-			hashtable->spacePeak = hashtable->spaceUsed;
+
+		/* refresh info about peak used memory */
+		ExecHashUpdateSpacePeak(hashtable);
 
 		/*
 		 * Create a skew bucket for each MCV hash value.
@@ -2399,8 +2595,9 @@ ExecHashBuildSkewHash(HashState *hashstate, HashJoinTable hashtable,
 			hashtable->nSkewBuckets++;
 			hashtable->spaceUsed += SKEW_BUCKET_OVERHEAD;
 			hashtable->spaceUsedSkew += SKEW_BUCKET_OVERHEAD;
-			if (hashtable->spaceUsed > hashtable->spacePeak)
-				hashtable->spacePeak = hashtable->spaceUsed;
+
+			/* refresh info about peak used memory */
+			ExecHashUpdateSpacePeak(hashtable);
 		}
 
 		free_attstatsslot(&sslot);
@@ -2489,8 +2686,10 @@ ExecHashSkewTableInsert(HashJoinTable hashtable,
 	/* Account for space used, and back off if we've used too much */
 	hashtable->spaceUsed += hashTupleSize;
 	hashtable->spaceUsedSkew += hashTupleSize;
-	if (hashtable->spaceUsed > hashtable->spacePeak)
-		hashtable->spacePeak = hashtable->spaceUsed;
+
+	/* refresh info about peak used memory */
+	ExecHashUpdateSpacePeak(hashtable);
+
 	while (hashtable->spaceUsedSkew > hashtable->spaceAllowedSkew)
 		ExecHashRemoveNextSkewBucket(hashtable);
 
@@ -2569,10 +2768,17 @@ ExecHashRemoveNextSkewBucket(HashJoinTable hashtable)
 		}
 		else
 		{
+			BufFile **batchFile;
+
 			/* Put the tuple into a temp file for later batches */
 			Assert(batchno > hashtable->curbatch);
+
+			batchFile = ExecHashGetBatchFile(hashtable, batchno,
+											 hashtable->innerBatchFile,
+											 hashtable->innerOverflowFiles);
+
 			ExecHashJoinSaveTuple(tuple, hashvalue,
-								  &hashtable->innerBatchFile[batchno],
+								  batchFile,
 								  hashtable);
 			pfree(hashTuple);
 			hashtable->spaceUsed -= tupleSize;
@@ -2750,6 +2956,8 @@ ExecHashAccumInstrumentation(HashInstrumentation *instrument,
 							 hashtable->nbatch);
 	instrument->nbatch_original = Max(instrument->nbatch_original,
 									  hashtable->nbatch_original);
+	instrument->nbatch_inmemory = Min(hashtable->nbatch,
+									  hashtable->nbatch_inmemory);
 	instrument->space_peak = Max(instrument->space_peak,
 								 hashtable->spacePeak);
 }
diff --git a/src/backend/executor/nodeHashjoin.c b/src/backend/executor/nodeHashjoin.c
index ea0045bc0f3..580b5da93bd 100644
--- a/src/backend/executor/nodeHashjoin.c
+++ b/src/backend/executor/nodeHashjoin.c
@@ -481,10 +481,15 @@ ExecHashJoinImpl(PlanState *pstate, bool parallel)
 				if (batchno != hashtable->curbatch &&
 					node->hj_CurSkewBucketNo == INVALID_SKEW_BUCKET_NO)
 				{
+					BufFile	  **batchFile;
 					bool		shouldFree;
 					MinimalTuple mintuple = ExecFetchSlotMinimalTuple(outerTupleSlot,
 																	  &shouldFree);
 
+					batchFile = ExecHashGetBatchFile(hashtable, batchno,
+													 hashtable->outerBatchFile,
+													 hashtable->outerOverflowFiles);
+
 					/*
 					 * Need to postpone this outer tuple to a later batch.
 					 * Save it in the corresponding outer-batch file.
@@ -492,7 +497,7 @@ ExecHashJoinImpl(PlanState *pstate, bool parallel)
 					Assert(parallel_state == NULL);
 					Assert(batchno > hashtable->curbatch);
 					ExecHashJoinSaveTuple(mintuple, hashvalue,
-										  &hashtable->outerBatchFile[batchno],
+										  batchFile,
 										  hashtable);
 
 					if (shouldFree)
@@ -1030,17 +1035,19 @@ ExecHashJoinOuterGetTuple(PlanState *outerNode,
 	}
 	else if (curbatch < hashtable->nbatch)
 	{
-		BufFile    *file = hashtable->outerBatchFile[curbatch];
+		BufFile    **file = ExecHashGetBatchFile(hashtable, curbatch,
+												 hashtable->outerBatchFile,
+												 hashtable->outerOverflowFiles);
 
 		/*
 		 * In outer-join cases, we could get here even though the batch file
 		 * is empty.
 		 */
-		if (file == NULL)
+		if (*file == NULL)
 			return NULL;
 
 		slot = ExecHashJoinGetSavedTuple(hjstate,
-										 file,
+										 *file,
 										 hashvalue,
 										 hjstate->hj_OuterTupleSlot);
 		if (!TupIsNull(slot))
@@ -1135,9 +1142,18 @@ ExecHashJoinNewBatch(HashJoinState *hjstate)
 	BufFile    *innerFile;
 	TupleTableSlot *slot;
 	uint32		hashvalue;
+	int			batchidx;
+	int			curbatch_old;
 
 	nbatch = hashtable->nbatch;
 	curbatch = hashtable->curbatch;
+	curbatch_old = curbatch;
+
+	/* index of the old batch */
+	batchidx = ExecHashGetBatchIndex(hashtable, curbatch);
+
+	/* has to be in the current slice of batches */
+	Assert(batchidx >= 0 && batchidx < hashtable->nbatch_inmemory);
 
 	if (curbatch > 0)
 	{
@@ -1145,9 +1161,9 @@ ExecHashJoinNewBatch(HashJoinState *hjstate)
 		 * We no longer need the previous outer batch file; close it right
 		 * away to free disk space.
 		 */
-		if (hashtable->outerBatchFile[curbatch])
-			BufFileClose(hashtable->outerBatchFile[curbatch]);
-		hashtable->outerBatchFile[curbatch] = NULL;
+		if (hashtable->outerBatchFile[batchidx])
+			BufFileClose(hashtable->outerBatchFile[batchidx]);
+		hashtable->outerBatchFile[batchidx] = NULL;
 	}
 	else						/* we just finished the first batch */
 	{
@@ -1182,45 +1198,50 @@ ExecHashJoinNewBatch(HashJoinState *hjstate)
 	 * scan, we have to rescan outer batches in case they contain tuples that
 	 * need to be reassigned.
 	 */
-	curbatch++;
+	curbatch = ExecHashSwitchToNextBatch(hashtable);
+	batchidx = ExecHashGetBatchIndex(hashtable, curbatch);
+
 	while (curbatch < nbatch &&
-		   (hashtable->outerBatchFile[curbatch] == NULL ||
-			hashtable->innerBatchFile[curbatch] == NULL))
+		   (hashtable->outerBatchFile[batchidx] == NULL ||
+			hashtable->innerBatchFile[batchidx] == NULL))
 	{
-		if (hashtable->outerBatchFile[curbatch] &&
+		if (hashtable->outerBatchFile[batchidx] &&
 			HJ_FILL_OUTER(hjstate))
 			break;				/* must process due to rule 1 */
-		if (hashtable->innerBatchFile[curbatch] &&
+		if (hashtable->innerBatchFile[batchidx] &&
 			HJ_FILL_INNER(hjstate))
 			break;				/* must process due to rule 1 */
-		if (hashtable->innerBatchFile[curbatch] &&
+		if (hashtable->innerBatchFile[batchidx] &&
 			nbatch != hashtable->nbatch_original)
 			break;				/* must process due to rule 2 */
-		if (hashtable->outerBatchFile[curbatch] &&
+		if (hashtable->outerBatchFile[batchidx] &&
 			nbatch != hashtable->nbatch_outstart)
 			break;				/* must process due to rule 3 */
 		/* We can ignore this batch. */
 		/* Release associated temp files right away. */
-		if (hashtable->innerBatchFile[curbatch])
-			BufFileClose(hashtable->innerBatchFile[curbatch]);
-		hashtable->innerBatchFile[curbatch] = NULL;
-		if (hashtable->outerBatchFile[curbatch])
-			BufFileClose(hashtable->outerBatchFile[curbatch]);
-		hashtable->outerBatchFile[curbatch] = NULL;
-		curbatch++;
+		if (hashtable->innerBatchFile[batchidx])
+			BufFileClose(hashtable->innerBatchFile[batchidx]);
+		hashtable->innerBatchFile[batchidx] = NULL;
+		if (hashtable->outerBatchFile[batchidx])
+			BufFileClose(hashtable->outerBatchFile[batchidx]);
+		hashtable->outerBatchFile[batchidx] = NULL;
+
+		curbatch = ExecHashSwitchToNextBatch(hashtable);
+		batchidx = ExecHashGetBatchIndex(hashtable, curbatch);
 	}
 
 	if (curbatch >= nbatch)
+	{
+		hashtable->curbatch = curbatch_old;
 		return false;			/* no more batches */
-
-	hashtable->curbatch = curbatch;
+	}
 
 	/*
 	 * Reload the hash table with the new inner batch (which could be empty)
 	 */
 	ExecHashTableReset(hashtable);
 
-	innerFile = hashtable->innerBatchFile[curbatch];
+	innerFile = hashtable->innerBatchFile[batchidx];
 
 	if (innerFile != NULL)
 	{
@@ -1246,15 +1267,15 @@ ExecHashJoinNewBatch(HashJoinState *hjstate)
 		 * needed
 		 */
 		BufFileClose(innerFile);
-		hashtable->innerBatchFile[curbatch] = NULL;
+		hashtable->innerBatchFile[batchidx] = NULL;
 	}
 
 	/*
 	 * Rewind outer batch file (if present), so that we can start reading it.
 	 */
-	if (hashtable->outerBatchFile[curbatch] != NULL)
+	if (hashtable->outerBatchFile[batchidx] != NULL)
 	{
-		if (BufFileSeek(hashtable->outerBatchFile[curbatch], 0, 0, SEEK_SET))
+		if (BufFileSeek(hashtable->outerBatchFile[batchidx], 0, 0, SEEK_SET))
 			ereport(ERROR,
 					(errcode_for_file_access(),
 					 errmsg("could not rewind hash-join temporary file")));
diff --git a/src/backend/optimizer/path/costsize.c b/src/backend/optimizer/path/costsize.c
index c36687aa4df..6f40600d10b 100644
--- a/src/backend/optimizer/path/costsize.c
+++ b/src/backend/optimizer/path/costsize.c
@@ -4173,6 +4173,7 @@ initial_cost_hashjoin(PlannerInfo *root, JoinCostWorkspace *workspace,
 	int			num_hashclauses = list_length(hashclauses);
 	int			numbuckets;
 	int			numbatches;
+	int			numbatches_inmemory;
 	int			num_skew_mcvs;
 	size_t		space_allowed;	/* unused */
 
@@ -4227,6 +4228,7 @@ initial_cost_hashjoin(PlannerInfo *root, JoinCostWorkspace *workspace,
 							&space_allowed,
 							&numbuckets,
 							&numbatches,
+							&numbatches_inmemory,
 							&num_skew_mcvs);
 
 	/*
diff --git a/src/include/executor/hashjoin.h b/src/include/executor/hashjoin.h
index 2d8ed8688cd..76c983dbd4d 100644
--- a/src/include/executor/hashjoin.h
+++ b/src/include/executor/hashjoin.h
@@ -320,6 +320,7 @@ typedef struct HashJoinTableData
 	int		   *skewBucketNums; /* array indexes of active skew buckets */
 
 	int			nbatch;			/* number of batches */
+	int			nbatch_inmemory;	/* max number of in-memory batches */
 	int			curbatch;		/* current batch #; 0 during 1st pass */
 
 	int			nbatch_original;	/* nbatch when we started inner scan */
@@ -331,6 +332,9 @@ typedef struct HashJoinTableData
 	double		partialTuples;	/* # tuples obtained from inner plan by me */
 	double		skewTuples;		/* # tuples inserted into skew tuples */
 
+	BufFile	  **innerOverflowFiles;	/* temp file for inner overflow batches */
+	BufFile	  **outerOverflowFiles;	/* temp file for outer overflow batches */
+
 	/*
 	 * These arrays are allocated for the life of the hash join, but only if
 	 * nbatch > 1.  A file is opened only when we first write a tuple into it
diff --git a/src/include/executor/nodeHash.h b/src/include/executor/nodeHash.h
index e4eb7bc6359..ef1d2bec15a 100644
--- a/src/include/executor/nodeHash.h
+++ b/src/include/executor/nodeHash.h
@@ -16,6 +16,7 @@
 
 #include "access/parallel.h"
 #include "nodes/execnodes.h"
+#include "storage/buffile.h"
 
 struct SharedHashJoinBatch;
 
@@ -46,6 +47,12 @@ extern void ExecHashGetBucketAndBatch(HashJoinTable hashtable,
 									  uint32 hashvalue,
 									  int *bucketno,
 									  int *batchno);
+extern int ExecHashGetBatchIndex(HashJoinTable hashtable, int batchno);
+extern BufFile **ExecHashGetBatchFile(HashJoinTable hashtable, int batchno,
+									  BufFile **batchFiles,
+									  BufFile **overflowFiles);
+extern void ExecHashSwitchToNextBatchSlice(HashJoinTable hashtable);
+extern int ExecHashSwitchToNextBatch(HashJoinTable hashtable);
 extern bool ExecScanHashBucket(HashJoinState *hjstate, ExprContext *econtext);
 extern bool ExecParallelScanHashBucket(HashJoinState *hjstate, ExprContext *econtext);
 extern void ExecPrepHashTableForUnmatched(HashJoinState *hjstate);
@@ -62,6 +69,7 @@ extern void ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 									size_t *space_allowed,
 									int *numbuckets,
 									int *numbatches,
+									int *numbatches_inmemory,
 									int *num_skew_mcvs);
 extern int	ExecHashGetSkewBucket(HashJoinTable hashtable, uint32 hashvalue);
 extern void ExecHashEstimate(HashState *node, ParallelContext *pcxt);
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index 1590b643920..eb8eaea89ab 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -2755,6 +2755,7 @@ typedef struct HashInstrumentation
 	int			nbuckets_original;	/* planned number of buckets */
 	int			nbatch;			/* number of batches at end of execution */
 	int			nbatch_original;	/* planned number of batches */
+	int			nbatch_inmemory;	/* number of batches kept in memory */
 	Size		space_peak;		/* peak memory usage in bytes */
 } HashInstrumentation;
 
-- 
2.47.1

