From 37b5414b74c873b787505c5f7d42b8b22cded3de Mon Sep 17 00:00:00 2001
From: Jeff Davis <jeff@j-davis.com>
Date: Thu, 14 Aug 2025 12:56:29 -0700
Subject: [PATCH v1] Use per-PlanState memory context for work mem.

Allocate working memory in a per-PlanState memory context (or
subcontext), making it easier to consistently track the total memory
used by a node. Does not change the enforcement mechanism.
---
 src/backend/executor/execProcnode.c        |   6 +
 src/backend/executor/execUtils.c           | 257 ++++++++++++++++++---
 src/backend/executor/nodeAgg.c             |  58 ++---
 src/backend/executor/nodeHash.c            |   6 +-
 src/backend/executor/nodeIncrementalSort.c |  13 ++
 src/backend/executor/nodeMaterial.c        |   5 +
 src/backend/executor/nodeSort.c            |   9 +
 src/include/executor/executor.h            |  28 ++-
 src/include/nodes/execnodes.h              |   3 +-
 9 files changed, 311 insertions(+), 74 deletions(-)

diff --git a/src/backend/executor/execProcnode.c b/src/backend/executor/execProcnode.c
index f5f9cfbeead..821dc087eba 100644
--- a/src/backend/executor/execProcnode.c
+++ b/src/backend/executor/execProcnode.c
@@ -388,6 +388,9 @@ ExecInitNode(Plan *node, EState *estate, int eflags)
 			break;
 	}
 
+	if (GetWorkMemLimit(result) == 0)
+		SetWorkMemLimit(result, work_mem * 1024L);
+
 	ExecSetExecProcNode(result, result->ExecProcNode);
 
 	/*
@@ -580,6 +583,9 @@ ExecEndNode(PlanState *node)
 		node->chgParam = NULL;
 	}
 
+	if (node->ps_WorkMem != NULL)
+		MemoryContextStats(node->ps_WorkMem);
+
 	switch (nodeTag(node))
 	{
 			/*
diff --git a/src/backend/executor/execUtils.c b/src/backend/executor/execUtils.c
index fdc65c2b42b..71727da7c84 100644
--- a/src/backend/executor/execUtils.c
+++ b/src/backend/executor/execUtils.c
@@ -229,13 +229,21 @@ FreeExecutorState(EState *estate)
 	MemoryContextDelete(estate->es_query_cxt);
 }
 
-/*
- * Internal implementation for CreateExprContext() and CreateWorkExprContext()
- * that allows control over the AllocSet parameters.
+/* ----------------
+ *		CreateExprContext
+ *
+ *		Create a context for expression evaluation within an EState.
+ *
+ * An executor run may require multiple ExprContexts (we usually make one
+ * for each Plan node, and a separate one for per-output-tuple processing
+ * such as constraint checking).  Each ExprContext has its own "per-tuple"
+ * memory context.
+ *
+ * Note we make no assumption about the caller's memory context.
+ * ----------------
  */
-static ExprContext *
-CreateExprContextInternal(EState *estate, Size minContextSize,
-						  Size initBlockSize, Size maxBlockSize)
+ExprContext *
+CreateExprContext(EState *estate)
 {
 	ExprContext *econtext;
 	MemoryContext oldcontext;
@@ -258,9 +266,7 @@ CreateExprContextInternal(EState *estate, Size minContextSize,
 	econtext->ecxt_per_tuple_memory =
 		AllocSetContextCreate(estate->es_query_cxt,
 							  "ExprContext",
-							  minContextSize,
-							  initBlockSize,
-							  maxBlockSize);
+							  ALLOCSET_DEFAULT_SIZES);
 
 	econtext->ecxt_param_exec_vals = estate->es_param_exec_vals;
 	econtext->ecxt_param_list_info = estate->es_param_list_info;
@@ -290,49 +296,228 @@ CreateExprContextInternal(EState *estate, Size minContextSize,
 	return econtext;
 }
 
-/* ----------------
- *		CreateExprContext
- *
- *		Create a context for expression evaluation within an EState.
- *
- * An executor run may require multiple ExprContexts (we usually make one
- * for each Plan node, and a separate one for per-output-tuple processing
- * such as constraint checking).  Each ExprContext has its own "per-tuple"
- * memory context.
- *
- * Note we make no assumption about the caller's memory context.
- * ----------------
+/*
+ * Based on the working memory limit, compute a reasonable maximum block size
+ * for a memory context.  The size should be small enough that a small
+ * allocation that requires a new block does not cause the allocated working
+ * memory limit to be dramatically exceeded.
  */
-ExprContext *
-CreateExprContext(EState *estate)
+static size_t
+WorkMemClampBlockSize(PlanState *ps, size_t size)
+{
+	size_t maxBlockSize = pg_prevpower2_size_t(ps->ps_WorkMemLimit / 16);
+
+	return Min(maxBlockSize, size);
+}
+
+static MemoryContext
+ValidateWorkMemParent(PlanState *ps, MemoryContext parent)
+{
+	if (parent != NULL)
+	{
+#ifdef USE_ASSERT_CHECKING
+		MemoryContext cur = parent;
+		MemoryContext workmem = GetWorkMem(ps);
+		while (cur != NULL && cur != workmem)
+			cur = MemoryContextGetParent(cur);
+		/* parent must be a subcontext of the working memory context */
+		Assert(cur == workmem);
+#endif
+		return parent;
+	}
+	else
+		return GetWorkMem(ps);
+}
+
+/*
+ * Set limit for working memory in bytes.  Caller must use CheckWorkMemLimit()
+ * to test and enforce the limit.
+ */
+size_t
+GetWorkMemLimit(PlanState *ps)
+{
+	return ps->ps_WorkMemLimit;
+}
+
+/*
+ * Set limit for working memory in bytes.  Caller must use CheckWorkMemLimit()
+ * to test and enforce the limit.
+ */
+void
+SetWorkMemLimit(PlanState *ps, size_t limit)
+{
+	ps->ps_WorkMemLimit = limit;
+}
+
+/*
+ * Return false if working memory limit has been exceeded.
+ */
+bool
+CheckWorkMemLimit(PlanState *ps)
+{
+	size_t allocated = MemoryContextMemAllocated(GetWorkMem(ps), true);
+	return allocated <= ps->ps_WorkMemLimit;
+}
+
+MemoryContext
+GetWorkMem(PlanState *ps)
+{
+	if (ps->ps_WorkMem == NULL)
+	{
+		const char *name;
+
+		switch (nodeTag(ps))
+		{
+			case T_AggState:
+				name = "Aggregate WorkMem";
+				break;
+			case T_HashJoinState:
+				name = "HashJoin WorkMem";
+				break;
+			case T_HashState:
+				name = "Hash WorkMem";
+				break;
+			case T_SortState:
+				name = "Sort WorkMem";
+				break;
+			default:
+				name = "WorkMem";
+				break;
+		}
+
+		/*
+		 * We are sure that "name" points to a compile-time constant, so we
+		 * can call the internal version.
+		 */
+		ps->ps_WorkMem = AllocSetContextCreateInternal(
+			ps->state->es_query_cxt,
+			name,
+			WorkMemClampBlockSize(ps, ALLOCSET_DEFAULT_MINSIZE),
+			WorkMemClampBlockSize(ps, ALLOCSET_DEFAULT_MAXSIZE),
+			WorkMemClampBlockSize(ps, ALLOCSET_DEFAULT_MAXSIZE));
+	}
+
+	return ps->ps_WorkMem;
+}
+
+MemoryContext
+CreateWorkMemAllocSet_Internal(PlanState *ps, MemoryContext parent, const char *name,
+							   size_t minContextSize, size_t initBlockSize,
+							   size_t maxBlockSize)
 {
-	return CreateExprContextInternal(estate, ALLOCSET_DEFAULT_SIZES);
+	MemoryContext context;
+
+	parent = ValidateWorkMemParent(ps, parent);
+
+	/*
+	 * CreateWorkMemAllocSet() macro ensures that "name" points to a
+	 * compile-time constant, so we can call the internal version.
+	 */
+	context = AllocSetContextCreateInternal(parent,
+											name,
+											WorkMemClampBlockSize(ps, minContextSize),
+											WorkMemClampBlockSize(ps, initBlockSize),
+											WorkMemClampBlockSize(ps, maxBlockSize));
+
+	return context;
 }
 
+MemoryContext
+CreateWorkMemBump(PlanState *ps, MemoryContext parent, const char *name,
+				  size_t minContextSize, size_t initBlockSize,
+				  size_t maxBlockSize)
+{
+	MemoryContext context;
+
+	parent = ValidateWorkMemParent(ps, parent);
+
+	/*
+	 * CreateWorkMemAllocSet() macro ensures that "name" points to a
+	 * compile-time constant, so we can call the internal version.
+	 */
+	context = BumpContextCreate(ValidateWorkMemParent(ps, parent),
+								name,
+								WorkMemClampBlockSize(ps, minContextSize),
+								WorkMemClampBlockSize(ps, initBlockSize),
+								WorkMemClampBlockSize(ps, maxBlockSize));
+
+	return context;
+}
+
+/*
+ * Release all working memory including subcontexts.
+ */
+void
+DestroyWorkMem(PlanState *ps)
+{
+	if (ps->ps_WorkMem != NULL)
+	{
+		MemoryContextDelete(ps->ps_WorkMem);
+		ps->ps_WorkMem = NULL;
+	}
+}
 
 /* ----------------
  *		CreateWorkExprContext
  *
- * Like CreateExprContext, but specifies the AllocSet sizes to be reasonable
- * in proportion to work_mem. If the maximum block allocation size is too
- * large, it's easy to skip right past work_mem with a single allocation.
+ * Like CreateExprContext, but intended for operations where work_mem should
+ * be enforced for ecxt_per_tuple_memory.  The caller is responsible for
+ * calling ReScanExprContext() when necessary.
+ *
  * ----------------
  */
 ExprContext *
-CreateWorkExprContext(EState *estate)
+CreateWorkExprContext(PlanState *ps)
 {
-	Size		maxBlockSize = ALLOCSET_DEFAULT_MAXSIZE;
+	ExprContext		*econtext;
+	EState			*estate = ps->state;
+	MemoryContext	 oldcontext;
 
-	maxBlockSize = pg_prevpower2_size_t(work_mem * (Size) 1024 / 16);
+	/* Create the ExprContext node within the per-query memory context */
+	oldcontext = MemoryContextSwitchTo(estate->es_query_cxt);
 
-	/* But no bigger than ALLOCSET_DEFAULT_MAXSIZE */
-	maxBlockSize = Min(maxBlockSize, ALLOCSET_DEFAULT_MAXSIZE);
+	econtext = makeNode(ExprContext);
 
-	/* and no smaller than ALLOCSET_DEFAULT_INITSIZE */
-	maxBlockSize = Max(maxBlockSize, ALLOCSET_DEFAULT_INITSIZE);
+	/* Initialize fields of ExprContext */
+	econtext->ecxt_scantuple = NULL;
+	econtext->ecxt_innertuple = NULL;
+	econtext->ecxt_outertuple = NULL;
 
-	return CreateExprContextInternal(estate, ALLOCSET_DEFAULT_MINSIZE,
-									 ALLOCSET_DEFAULT_INITSIZE, maxBlockSize);
+	econtext->ecxt_per_query_memory = estate->es_query_cxt;
+
+	/*
+	 * Create working memory for expression evaluation in this context.
+	 */
+	econtext->ecxt_per_tuple_memory = CreateWorkMemAllocSet(ps, NULL,
+															"WorkExprContext",
+															ALLOCSET_DEFAULT_SIZES);
+
+	econtext->ecxt_param_exec_vals = estate->es_param_exec_vals;
+	econtext->ecxt_param_list_info = estate->es_param_list_info;
+
+	econtext->ecxt_aggvalues = NULL;
+	econtext->ecxt_aggnulls = NULL;
+
+	econtext->caseValue_datum = (Datum) 0;
+	econtext->caseValue_isNull = true;
+
+	econtext->domainValue_datum = (Datum) 0;
+	econtext->domainValue_isNull = true;
+
+	econtext->ecxt_estate = estate;
+
+	econtext->ecxt_callbacks = NULL;
+
+	/*
+	 * Link the ExprContext into the EState to ensure it is shut down when the
+	 * EState is freed.  Because we use lcons(), shutdowns will occur in
+	 * reverse order of creation, which may not be essential but can't hurt.
+	 */
+	estate->es_exprcontexts = lcons(econtext, estate->es_exprcontexts);
+
+	MemoryContextSwitchTo(oldcontext);
+
+	return econtext;
 }
 
 /* ----------------
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 377e016d732..a8f56e18844 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -522,6 +522,9 @@ initialize_phase(AggState *aggstate, int newphase)
 		Sort	   *sortnode = aggstate->phases[newphase + 1].sortnode;
 		PlanState  *outerNode = outerPlanState(aggstate);
 		TupleDesc	tupDesc = ExecGetResultType(outerNode);
+		MemoryContext oldcontext;
+
+		oldcontext = MemoryContextSwitchTo(GetWorkMem(&aggstate->ss.ps));
 
 		aggstate->sort_out = tuplesort_begin_heap(tupDesc,
 												  sortnode->numCols,
@@ -531,6 +534,8 @@ initialize_phase(AggState *aggstate, int newphase)
 												  sortnode->nullsFirst,
 												  work_mem,
 												  NULL, TUPLESORT_NONE);
+
+		MemoryContextSwitchTo(oldcontext);
 	}
 
 	aggstate->current_phase = newphase;
@@ -585,6 +590,8 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
 	 */
 	if (pertrans->aggsortrequired)
 	{
+		MemoryContext oldcontext;
+
 		/*
 		 * In case of rescan, maybe there could be an uncompleted sort
 		 * operation?  Clean it up if so.
@@ -592,6 +599,7 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
 		if (pertrans->sortstates[aggstate->current_set])
 			tuplesort_end(pertrans->sortstates[aggstate->current_set]);
 
+		oldcontext = MemoryContextSwitchTo(GetWorkMem(&aggstate->ss.ps));
 
 		/*
 		 * We use a plain Datum sorter when there's a single input column;
@@ -618,6 +626,8 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
 									 pertrans->sortCollations,
 									 pertrans->sortNullsFirst,
 									 work_mem, NULL, TUPLESORT_NONE);
+
+		MemoryContextSwitchTo(oldcontext);
 	}
 
 	/*
@@ -1481,7 +1491,7 @@ build_hash_tables(AggState *aggstate)
 
 		Assert(perhash->aggnode->numGroups > 0);
 
-		memory = aggstate->hash_mem_limit / aggstate->num_hashes;
+		memory = GetWorkMemLimit(&aggstate->ss.ps) / aggstate->num_hashes;
 
 		/* choose reasonable number of buckets per hashtable */
 		nbuckets = hash_choose_num_buckets(aggstate->hashentrysize,
@@ -1867,13 +1877,6 @@ static void
 hash_agg_check_limits(AggState *aggstate)
 {
 	uint64		ngroups = aggstate->hash_ngroups_current;
-	Size		meta_mem = MemoryContextMemAllocated(aggstate->hash_metacxt,
-													 true);
-	Size		entry_mem = MemoryContextMemAllocated(aggstate->hash_tablecxt,
-													  true);
-	Size		tval_mem = MemoryContextMemAllocated(aggstate->hashcontext->ecxt_per_tuple_memory,
-													 true);
-	Size		total_mem = meta_mem + entry_mem + tval_mem;
 	bool		do_spill = false;
 
 #ifdef USE_INJECTION_POINTS
@@ -1892,7 +1895,7 @@ hash_agg_check_limits(AggState *aggstate)
 	 * can be sure to make progress even in edge cases.
 	 */
 	if (aggstate->hash_ngroups_current > 0 &&
-		(total_mem > aggstate->hash_mem_limit ||
+		(!CheckWorkMemLimit(&aggstate->ss.ps) ||
 		 ngroups > aggstate->hash_ngroups_limit))
 	{
 		do_spill = true;
@@ -1999,13 +2002,11 @@ hash_agg_update_metrics(AggState *aggstate, bool from_tape, int npartitions)
 static void
 hash_create_memory(AggState *aggstate)
 {
-	Size		maxBlockSize = ALLOCSET_DEFAULT_MAXSIZE;
-
 	/*
 	 * The hashcontext's per-tuple memory will be used for byref transition
 	 * values and returned by AggCheckCallContext().
 	 */
-	aggstate->hashcontext = CreateWorkExprContext(aggstate->ss.ps.state);
+	aggstate->hashcontext = CreateWorkExprContext(&aggstate->ss.ps);
 
 	/*
 	 * The meta context will be used for the bucket array of
@@ -2015,7 +2016,7 @@ hash_create_memory(AggState *aggstate)
 	 * the large allocation path will be used, so it's not worth worrying
 	 * about wasting space due to power-of-two allocations.
 	 */
-	aggstate->hash_metacxt = AllocSetContextCreate(aggstate->ss.ps.state->es_query_cxt,
+	aggstate->hash_metacxt = CreateWorkMemAllocSet(&aggstate->ss.ps, NULL,
 												   "HashAgg meta context",
 												   ALLOCSET_DEFAULT_SIZES);
 
@@ -2030,25 +2031,9 @@ hash_create_memory(AggState *aggstate)
 	 * Like CreateWorkExprContext(), use smaller sizings for smaller work_mem,
 	 * to avoid large jumps in memory usage.
 	 */
-
-	/*
-	 * Like CreateWorkExprContext(), use smaller sizings for smaller work_mem,
-	 * to avoid large jumps in memory usage.
-	 */
-	maxBlockSize = pg_prevpower2_size_t(work_mem * (Size) 1024 / 16);
-
-	/* But no bigger than ALLOCSET_DEFAULT_MAXSIZE */
-	maxBlockSize = Min(maxBlockSize, ALLOCSET_DEFAULT_MAXSIZE);
-
-	/* and no smaller than ALLOCSET_DEFAULT_INITSIZE */
-	maxBlockSize = Max(maxBlockSize, ALLOCSET_DEFAULT_INITSIZE);
-
-	aggstate->hash_tablecxt = BumpContextCreate(aggstate->ss.ps.state->es_query_cxt,
+	aggstate->hash_tablecxt = CreateWorkMemBump(&aggstate->ss.ps, NULL,
 												"HashAgg table context",
-												ALLOCSET_DEFAULT_MINSIZE,
-												ALLOCSET_DEFAULT_INITSIZE,
-												maxBlockSize);
-
+												ALLOCSET_DEFAULT_SIZES);
 }
 
 /*
@@ -2684,6 +2669,7 @@ agg_refill_hash_table(AggState *aggstate)
 	HashAggSpill spill;
 	LogicalTapeSet *tapeset = aggstate->hash_tapeset;
 	bool		spill_initialized = false;
+	size_t		mem_limit;
 
 	if (aggstate->hash_batches == NIL)
 		return false;
@@ -2693,8 +2679,9 @@ agg_refill_hash_table(AggState *aggstate)
 	aggstate->hash_batches = list_delete_last(aggstate->hash_batches);
 
 	hash_agg_set_limits(aggstate->hashentrysize, batch->input_card,
-						batch->used_bits, &aggstate->hash_mem_limit,
+						batch->used_bits, &mem_limit,
 						&aggstate->hash_ngroups_limit, NULL);
+	SetWorkMemLimit(&aggstate->ss.ps, mem_limit);
 
 	/*
 	 * Each batch only processes one grouping set; set the rest to NULL so
@@ -3371,6 +3358,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 	aggstate->aggcontexts = (ExprContext **)
 		palloc0(sizeof(ExprContext *) * numGroupingSets);
 
+	SetWorkMemLimit(&aggstate->ss.ps, get_hash_memory_limit());
+
 	/*
 	 * Create expression contexts.  We need three or more, one for
 	 * per-input-tuple processing, one for per-output-tuple processing, one
@@ -3689,6 +3678,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 	{
 		Plan	   *outerplan = outerPlan(node);
 		uint64		totalGroups = 0;
+		size_t		mem_limit;
 
 		aggstate->hash_spill_rslot = ExecInitExtraTupleSlot(estate, scanDesc,
 															&TTSOpsMinimalTuple);
@@ -3712,9 +3702,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 			totalGroups += aggstate->perhash[k].aggnode->numGroups;
 
 		hash_agg_set_limits(aggstate->hashentrysize, totalGroups, 0,
-							&aggstate->hash_mem_limit,
+							&mem_limit,
 							&aggstate->hash_ngroups_limit,
 							&aggstate->hash_planned_partitions);
+		SetWorkMemLimit(&aggstate->ss.ps, mem_limit);
+
 		find_hash_columns(aggstate);
 
 		/* Skip massive memory allocation if we are just doing EXPLAIN */
diff --git a/src/backend/executor/nodeHash.c b/src/backend/executor/nodeHash.c
index 8d2201ab67f..fd5497b8176 100644
--- a/src/backend/executor/nodeHash.c
+++ b/src/backend/executor/nodeHash.c
@@ -533,15 +533,15 @@ ExecHashTableCreate(HashState *state)
 	 * Create temporary memory contexts in which to keep the hashtable working
 	 * storage.  See notes in executor/hashjoin.h.
 	 */
-	hashtable->hashCxt = AllocSetContextCreate(CurrentMemoryContext,
+	hashtable->hashCxt = CreateWorkMemAllocSet(&state->ps, NULL,
 											   "HashTableContext",
 											   ALLOCSET_DEFAULT_SIZES);
 
-	hashtable->batchCxt = AllocSetContextCreate(hashtable->hashCxt,
+	hashtable->batchCxt = CreateWorkMemAllocSet(&state->ps, hashtable->hashCxt,
 												"HashBatchContext",
 												ALLOCSET_DEFAULT_SIZES);
 
-	hashtable->spillCxt = AllocSetContextCreate(hashtable->hashCxt,
+	hashtable->spillCxt = CreateWorkMemAllocSet(&state->ps, hashtable->hashCxt,
 												"HashSpillContext",
 												ALLOCSET_DEFAULT_SIZES);
 
diff --git a/src/backend/executor/nodeIncrementalSort.c b/src/backend/executor/nodeIncrementalSort.c
index 975b0397e7a..943539d575e 100644
--- a/src/backend/executor/nodeIncrementalSort.c
+++ b/src/backend/executor/nodeIncrementalSort.c
@@ -301,6 +301,9 @@ switchToPresortedPrefixMode(PlanState *pstate)
 	{
 		Tuplesortstate *prefixsort_state;
 		int			nPresortedCols = plannode->nPresortedCols;
+		MemoryContext oldcontext;
+
+		oldcontext = MemoryContextSwitchTo(GetWorkMem(pstate));
 
 		/*
 		 * Optimize the sort by assuming the prefix columns are all equal and
@@ -315,6 +318,9 @@ switchToPresortedPrefixMode(PlanState *pstate)
 												work_mem,
 												NULL,
 												node->bounded ? TUPLESORT_ALLOWBOUNDED : TUPLESORT_NONE);
+
+		MemoryContextSwitchTo(oldcontext);
+
 		node->prefixsort_state = prefixsort_state;
 	}
 	else
@@ -591,6 +597,8 @@ ExecIncrementalSort(PlanState *pstate)
 		 */
 		if (fullsort_state == NULL)
 		{
+			MemoryContext oldcontext;
+
 			/*
 			 * Initialize presorted column support structures for
 			 * isCurrentGroup(). It's correct to do this along with the
@@ -600,6 +608,8 @@ ExecIncrementalSort(PlanState *pstate)
 			 */
 			preparePresortedCols(node);
 
+			oldcontext = MemoryContextSwitchTo(GetWorkMem(pstate));
+
 			/*
 			 * Since we optimize small prefix key groups by accumulating a
 			 * minimum number of tuples before sorting, we can't assume that a
@@ -618,6 +628,9 @@ ExecIncrementalSort(PlanState *pstate)
 												  node->bounded ?
 												  TUPLESORT_ALLOWBOUNDED :
 												  TUPLESORT_NONE);
+
+			MemoryContextSwitchTo(oldcontext);
+
 			node->fullsort_state = fullsort_state;
 		}
 		else
diff --git a/src/backend/executor/nodeMaterial.c b/src/backend/executor/nodeMaterial.c
index 9798bb75365..9df834e8597 100644
--- a/src/backend/executor/nodeMaterial.c
+++ b/src/backend/executor/nodeMaterial.c
@@ -61,7 +61,12 @@ ExecMaterial(PlanState *pstate)
 	 */
 	if (tuplestorestate == NULL && node->eflags != 0)
 	{
+		MemoryContext oldcontext = MemoryContextSwitchTo(GetWorkMem(pstate));
+
 		tuplestorestate = tuplestore_begin_heap(true, false, work_mem);
+
+		MemoryContextSwitchTo(oldcontext);
+
 		tuplestore_set_eflags(tuplestorestate, node->eflags);
 		if (node->eflags & EXEC_FLAG_MARK)
 		{
diff --git a/src/backend/executor/nodeSort.c b/src/backend/executor/nodeSort.c
index f603337ecd3..977580b3792 100644
--- a/src/backend/executor/nodeSort.c
+++ b/src/backend/executor/nodeSort.c
@@ -78,6 +78,7 @@ ExecSort(PlanState *pstate)
 		PlanState  *outerNode;
 		TupleDesc	tupDesc;
 		int			tuplesortopts = TUPLESORT_NONE;
+		MemoryContext oldcontext;
 
 		SO1_printf("ExecSort: %s\n",
 				   "sorting subplan");
@@ -102,6 +103,8 @@ ExecSort(PlanState *pstate)
 		if (node->bounded)
 			tuplesortopts |= TUPLESORT_ALLOWBOUNDED;
 
+		oldcontext = MemoryContextSwitchTo(GetWorkMem(pstate));
+
 		if (node->datumSort)
 			tuplesortstate = tuplesort_begin_datum(TupleDescAttr(tupDesc, 0)->atttypid,
 												   plannode->sortOperators[0],
@@ -120,6 +123,9 @@ ExecSort(PlanState *pstate)
 												  work_mem,
 												  NULL,
 												  tuplesortopts);
+
+		MemoryContextSwitchTo(oldcontext);
+
 		if (node->bounded)
 			tuplesort_set_bound(tuplesortstate, node->bound);
 		node->tuplesortstate = tuplesortstate;
@@ -179,6 +185,7 @@ ExecSort(PlanState *pstate)
 			si = &node->shared_info->sinstrument[ParallelWorkerNumber];
 			tuplesort_get_stats(tuplesortstate, si);
 		}
+
 		SO1_printf("ExecSort: %s\n", "sorting done");
 	}
 
@@ -254,6 +261,8 @@ ExecInitSort(Sort *node, EState *estate, int eflags)
 	 * ExecQual or ExecProject.
 	 */
 
+	SetWorkMemLimit(&sortstate->ss.ps, work_mem * 1024L);
+
 	/*
 	 * initialize child nodes
 	 *
diff --git a/src/include/executor/executor.h b/src/include/executor/executor.h
index 10dcea037c3..29098075689 100644
--- a/src/include/executor/executor.h
+++ b/src/include/executor/executor.h
@@ -638,7 +638,18 @@ extern void end_tup_output(TupOutputState *tstate);
 extern EState *CreateExecutorState(void);
 extern void FreeExecutorState(EState *estate);
 extern ExprContext *CreateExprContext(EState *estate);
-extern ExprContext *CreateWorkExprContext(EState *estate);
+extern MemoryContext GetWorkMem(PlanState *ps);
+extern void DestroyWorkMem(PlanState *ps);
+extern size_t GetWorkMemLimit(PlanState *ps);
+extern void SetWorkMemLimit(PlanState *ps, size_t limit);
+extern bool CheckWorkMemLimit(PlanState *ps);
+extern MemoryContext CreateWorkMemAllocSet_Internal(PlanState *ps, MemoryContext parent,
+													const char *name, size_t minContextSize,
+													size_t initBlockSize, size_t maxBlockSize);
+extern MemoryContext CreateWorkMemBump(PlanState *ps, MemoryContext parent,
+									   const char *name, size_t minContextSize,
+									   size_t initBlockSize, size_t maxBlockSize);
+extern ExprContext *CreateWorkExprContext(PlanState *ps);
 extern ExprContext *CreateStandaloneExprContext(void);
 extern void FreeExprContext(ExprContext *econtext, bool isCommit);
 extern void ReScanExprContext(ExprContext *econtext);
@@ -664,6 +675,21 @@ extern ExprContext *MakePerTupleExprContext(EState *estate);
 			ResetExprContext((estate)->es_per_tuple_exprcontext); \
 	} while (0)
 
+/*
+ * This wrapper macro exists to check for non-constant strings used as context
+ * names; that's no longer supported.  (Use MemoryContextSetIdentifier if you
+ * want to provide a variable identifier.)
+ */
+#ifdef HAVE__BUILTIN_CONSTANT_P
+#define CreateWorkMemAllocSet(ps, parent, name, ...) \
+	(StaticAssertExpr(__builtin_constant_p(name), \
+					  "memory context names must be constant strings"), \
+	 CreateWorkMemAllocSet_Internal(ps, parent, name, __VA_ARGS__))
+#else
+#define AllocSetContextCreate \
+	AllocSetContextCreateInternal
+#endif
+
 extern void ExecAssignExprContext(EState *estate, PlanState *planstate);
 extern TupleDesc ExecGetResultType(PlanState *planstate);
 extern const TupleTableSlotOps *ExecGetResultSlotOps(PlanState *planstate,
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index de782014b2d..6cec630a395 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -1199,6 +1199,8 @@ typedef struct PlanState
 	TupleDesc	ps_ResultTupleDesc; /* node's return type */
 	TupleTableSlot *ps_ResultTupleSlot; /* slot for my result tuples */
 	ExprContext *ps_ExprContext;	/* node's expression-evaluation context */
+	MemoryContext	ps_WorkMem;	/* parent context for all workmem allocations */
+	size_t			ps_WorkMemLimit; /* enforceable work_mem limit in bytes */
 	ProjectionInfo *ps_ProjInfo;	/* info for doing tuple projection */
 
 	bool		async_capable;	/* true if node is async-capable */
@@ -2580,7 +2582,6 @@ typedef struct AggState
 	bool		hash_ever_spilled;	/* ever spilled during this execution? */
 	bool		hash_spill_mode;	/* we hit a limit during the current batch
 									 * and we must not create new groups */
-	Size		hash_mem_limit; /* limit before spilling hash table */
 	uint64		hash_ngroups_limit; /* limit before spilling hash table */
 	int			hash_planned_partitions;	/* number of partitions planned
 											 * for first pass */
-- 
2.43.0

