diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 2a6f44a6274..1ba62879cf6 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -1637,16 +1637,49 @@ find_hash_columns(AggState *aggstate)
 
 /*
  * Estimate per-hash-table-entry overhead.
+ *
+ * It's important to account for the overhead of individual memory
+ * allocations, which can be significant. If the caller specifies a memory
+ * context, use that to estimate the total allocation sizes; otherwise, use
+ * CurrentMemoryContext.
  */
 Size
-hash_agg_entry_size(int numAggs, Size tupleWidth, Size transitionSpace)
+hash_agg_entry_size(int numTrans, Size tupleWidth, Size transitionSpace,
+					MemoryContext context)
 {
+	Size    tupleChunkSize;
+	Size    pergroupChunkSize;
+	Size    transitionChunkSize;
+	Size    tupleSize	 = (MAXALIGN(SizeofMinimalTupleHeader) +
+							tupleWidth);
+	Size    pergroupSize = numTrans * sizeof(AggStatePerGroupData);
+
+	if (context == NULL)
+		context = CurrentMemoryContext;
+
+	tupleChunkSize = EstimateMemoryChunkSpace(context, tupleSize);
+
+	if (pergroupSize > 0)
+	{
+		pergroupChunkSize = EstimateMemoryChunkSpace(
+			context, pergroupSize);
+	}
+	else
+		pergroupChunkSize = 0;
+
+	if (transitionSpace > 0)
+	{
+		transitionChunkSize = EstimateMemoryChunkSpace(
+			context, transitionSpace);
+	}
+	else
+		transitionChunkSize = 0;
+
 	return
-		MAXALIGN(SizeofMinimalTupleHeader) +
-		MAXALIGN(tupleWidth) +
-		MAXALIGN(sizeof(TupleHashEntryData) +
-				 numAggs * sizeof(AggStatePerGroupData)) +
-		transitionSpace;
+		sizeof(TupleHashEntryData) +
+		tupleChunkSize +
+		pergroupChunkSize +
+		transitionChunkSize;
 }
 
 /*
@@ -3549,7 +3582,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 		aggstate->hash_pergroup = pergroups;
 
 		aggstate->hashentrysize = hash_agg_entry_size(
-			aggstate->numtrans, outerplan->plan_width, node->transitionSpace);
+			aggstate->numtrans, outerplan->plan_width, node->transitionSpace,
+			aggstate->hashcontext->ecxt_per_tuple_memory);
 
 		/*
 		 * Consider all of the grouping sets together when setting the limits
diff --git a/src/backend/optimizer/path/costsize.c b/src/backend/optimizer/path/costsize.c
index 8cf694b61dc..adc191704c2 100644
--- a/src/backend/optimizer/path/costsize.c
+++ b/src/backend/optimizer/path/costsize.c
@@ -2272,7 +2272,7 @@ cost_agg(Path *path, PlannerInfo *root,
 		 * otherwise we expect to spill.
 		 */
 		hashentrysize = hash_agg_entry_size(
-			aggcosts->numAggs, input_width, aggcosts->transitionSpace);
+			aggcosts->numAggs, input_width, aggcosts->transitionSpace, NULL);
 		hash_agg_set_limits(hashentrysize, numGroups, 0, &mem_limit,
 							&ngroups_limit, &num_partitions);
 
diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index b65abf6046d..5c423a3d14c 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -4866,7 +4866,7 @@ create_distinct_paths(PlannerInfo *root,
 	else
 	{
 		Size		hashentrysize = hash_agg_entry_size(
-			0, cheapest_input_path->pathtarget->width, 0);
+			0, cheapest_input_path->pathtarget->width, 0, NULL);
 
 		allow_hash = enable_hashagg_disk ||
 			(hashentrysize * numDistinctRows <= work_mem * 1024L);
diff --git a/src/backend/utils/adt/selfuncs.c b/src/backend/utils/adt/selfuncs.c
index 8339f4cd7a2..cc5788fc954 100644
--- a/src/backend/utils/adt/selfuncs.c
+++ b/src/backend/utils/adt/selfuncs.c
@@ -3527,8 +3527,10 @@ double
 estimate_hashagg_tablesize(Path *path, const AggClauseCosts *agg_costs,
 						   double dNumGroups)
 {
-	Size		hashentrysize = hash_agg_entry_size(
-		agg_costs->numAggs, path->pathtarget->width, agg_costs->transitionSpace);
+	Size	hashentrysize = hash_agg_entry_size(agg_costs->numAggs,
+												path->pathtarget->width,
+												agg_costs->transitionSpace,
+												NULL);
 
 	/*
 	 * Note that this disregards the effect of fill-factor and growth policy
diff --git a/src/backend/utils/mmgr/aset.c b/src/backend/utils/mmgr/aset.c
index c0623f106d2..107afce953b 100644
--- a/src/backend/utils/mmgr/aset.c
+++ b/src/backend/utils/mmgr/aset.c
@@ -272,6 +272,7 @@ static void *AllocSetRealloc(MemoryContext context, void *pointer, Size size);
 static void AllocSetReset(MemoryContext context);
 static void AllocSetDelete(MemoryContext context);
 static Size AllocSetGetChunkSpace(MemoryContext context, void *pointer);
+static Size AllocSetEstimateChunkSpace(MemoryContext context, Size size);
 static bool AllocSetIsEmpty(MemoryContext context);
 static void AllocSetStats(MemoryContext context,
 						  MemoryStatsPrintFunc printfunc, void *passthru,
@@ -291,6 +292,7 @@ static const MemoryContextMethods AllocSetMethods = {
 	AllocSetReset,
 	AllocSetDelete,
 	AllocSetGetChunkSpace,
+	AllocSetEstimateChunkSpace,
 	AllocSetIsEmpty,
 	AllocSetStats
 #ifdef MEMORY_CONTEXT_CHECKING
@@ -1337,6 +1339,31 @@ AllocSetGetChunkSpace(MemoryContext context, void *pointer)
 	return result;
 }
 
+/*
+ * Estimate total memory consumed for a chunk of the requested size.
+ */
+static Size
+AllocSetEstimateChunkSpace(MemoryContext context, Size size)
+{
+	AllocSet	set = (AllocSet) context;
+	Size		chunk_size;
+
+	if (size > set->allocChunkLimit)
+	{
+		chunk_size = MAXALIGN(size);
+
+		return chunk_size + ALLOC_BLOCKHDRSZ + ALLOC_CHUNKHDRSZ;
+	}
+	else
+	{
+		int fidx = AllocSetFreeIndex(size);
+
+		chunk_size = (1 << ALLOC_MINBITS) << fidx;
+
+		return chunk_size + ALLOC_CHUNKHDRSZ;
+	}
+}
+
 /*
  * AllocSetIsEmpty
  *		Is an allocset empty of any allocated space?
diff --git a/src/backend/utils/mmgr/generation.c b/src/backend/utils/mmgr/generation.c
index 56651d06931..b7257bf351c 100644
--- a/src/backend/utils/mmgr/generation.c
+++ b/src/backend/utils/mmgr/generation.c
@@ -152,6 +152,7 @@ static void *GenerationRealloc(MemoryContext context, void *pointer, Size size);
 static void GenerationReset(MemoryContext context);
 static void GenerationDelete(MemoryContext context);
 static Size GenerationGetChunkSpace(MemoryContext context, void *pointer);
+static Size GenerationEstimateChunkSpace(MemoryContext context, Size size);
 static bool GenerationIsEmpty(MemoryContext context);
 static void GenerationStats(MemoryContext context,
 							MemoryStatsPrintFunc printfunc, void *passthru,
@@ -171,6 +172,7 @@ static const MemoryContextMethods GenerationMethods = {
 	GenerationReset,
 	GenerationDelete,
 	GenerationGetChunkSpace,
+	GenerationEstimateChunkSpace,
 	GenerationIsEmpty,
 	GenerationStats
 #ifdef MEMORY_CONTEXT_CHECKING
@@ -666,6 +668,22 @@ GenerationGetChunkSpace(MemoryContext context, void *pointer)
 	return result;
 }
 
+/*
+ * Estimate total memory consumed for a chunk of the requested size.
+ */
+static Size
+GenerationEstimateChunkSpace(MemoryContext context, Size size)
+{
+	GenerationContext	*set = (GenerationContext *) context;
+	Size				 chunk_size = MAXALIGN(size);
+
+	/* over-sized chunk will allocate special block */
+	if (chunk_size > set->blockSize / 8)
+		return chunk_size + Generation_CHUNKHDRSZ + Generation_BLOCKHDRSZ;
+	else
+		return chunk_size + Generation_CHUNKHDRSZ;
+}
+
 /*
  * GenerationIsEmpty
  *		Is a GenerationContext empty of any allocated space?
diff --git a/src/backend/utils/mmgr/mcxt.c b/src/backend/utils/mmgr/mcxt.c
index 9e24fec72d6..db7a758db53 100644
--- a/src/backend/utils/mmgr/mcxt.c
+++ b/src/backend/utils/mmgr/mcxt.c
@@ -431,6 +431,15 @@ GetMemoryChunkSpace(void *pointer)
 	return context->methods->get_chunk_space(context, pointer);
 }
 
+/*
+ * Estimate total memory consumed for a chunk of the requested size.
+ */
+Size
+EstimateMemoryChunkSpace(MemoryContext context, Size size)
+{
+	return context->methods->estimate_chunk_space(context, size);
+}
+
 /*
  * MemoryContextGetParent
  *		Get the parent context (if any) of the specified context
diff --git a/src/backend/utils/mmgr/slab.c b/src/backend/utils/mmgr/slab.c
index c928476c479..d9e40b7b17b 100644
--- a/src/backend/utils/mmgr/slab.c
+++ b/src/backend/utils/mmgr/slab.c
@@ -132,6 +132,7 @@ static void *SlabRealloc(MemoryContext context, void *pointer, Size size);
 static void SlabReset(MemoryContext context);
 static void SlabDelete(MemoryContext context);
 static Size SlabGetChunkSpace(MemoryContext context, void *pointer);
+static Size SlabEstimateChunkSpace(MemoryContext context, Size size);
 static bool SlabIsEmpty(MemoryContext context);
 static void SlabStats(MemoryContext context,
 					  MemoryStatsPrintFunc printfunc, void *passthru,
@@ -150,6 +151,7 @@ static const MemoryContextMethods SlabMethods = {
 	SlabReset,
 	SlabDelete,
 	SlabGetChunkSpace,
+	SlabEstimateChunkSpace,
 	SlabIsEmpty,
 	SlabStats
 #ifdef MEMORY_CONTEXT_CHECKING
@@ -630,6 +632,19 @@ SlabGetChunkSpace(MemoryContext context, void *pointer)
 	return slab->fullChunkSize;
 }
 
+/*
+ * Estimate total memory consumed for a chunk of the requested size.
+ */
+static Size
+SlabEstimateChunkSpace(MemoryContext context, Size size)
+{
+	SlabContext *slab = castNode(SlabContext, context);
+
+	Assert(slab);
+
+	return slab->fullChunkSize;
+}
+
 /*
  * SlabIsEmpty
  *		Is an Slab empty of any allocated space?
diff --git a/src/include/executor/nodeAgg.h b/src/include/executor/nodeAgg.h
index a5b8a004d1e..29a6bd947ed 100644
--- a/src/include/executor/nodeAgg.h
+++ b/src/include/executor/nodeAgg.h
@@ -315,7 +315,7 @@ extern void ExecEndAgg(AggState *node);
 extern void ExecReScanAgg(AggState *node);
 
 extern Size hash_agg_entry_size(int numAggs, Size tupleWidth,
-								Size transitionSpace);
+								Size transitionSpace, MemoryContext context);
 extern void hash_agg_set_limits(double hashentrysize, uint64 input_groups,
 								int used_bits, Size *mem_limit,
 								uint64 *ngroups_limit, int *num_partitions);
diff --git a/src/include/nodes/memnodes.h b/src/include/nodes/memnodes.h
index c9f2bbcb367..896db9b54d8 100644
--- a/src/include/nodes/memnodes.h
+++ b/src/include/nodes/memnodes.h
@@ -63,6 +63,7 @@ typedef struct MemoryContextMethods
 	void		(*reset) (MemoryContext context);
 	void		(*delete_context) (MemoryContext context);
 	Size		(*get_chunk_space) (MemoryContext context, void *pointer);
+	Size		(*estimate_chunk_space) (MemoryContext context, Size size);
 	bool		(*is_empty) (MemoryContext context);
 	void		(*stats) (MemoryContext context,
 						  MemoryStatsPrintFunc printfunc, void *passthru,
diff --git a/src/include/utils/memutils.h b/src/include/utils/memutils.h
index 909bc2e9888..e3571efaaa7 100644
--- a/src/include/utils/memutils.h
+++ b/src/include/utils/memutils.h
@@ -80,6 +80,7 @@ extern void MemoryContextSetIdentifier(MemoryContext context, const char *id);
 extern void MemoryContextSetParent(MemoryContext context,
 								   MemoryContext new_parent);
 extern Size GetMemoryChunkSpace(void *pointer);
+extern Size EstimateMemoryChunkSpace(MemoryContext context, Size size);
 extern MemoryContext MemoryContextGetParent(MemoryContext context);
 extern bool MemoryContextIsEmpty(MemoryContext context);
 extern Size MemoryContextMemAllocated(MemoryContext context, bool recurse);
