diff --git a/src/backend/executor/nodeTableFuncscan.c b/src/backend/executor/nodeTableFuncscan.c
index fed6f2b3a5..421510264b 100644
--- a/src/backend/executor/nodeTableFuncscan.c
+++ b/src/backend/executor/nodeTableFuncscan.c
@@ -282,6 +282,8 @@ tfuncFetchRows(TableFuncScanState *tstate, ExprContext *econtext)
 	oldcxt = MemoryContextSwitchTo(econtext->ecxt_per_query_memory);
 	tstate->tupstore = tuplestore_begin_heap(false, false, work_mem);
 
+	MemoryContextSwitchTo(tstate->perValueCxt);
+
 	PG_TRY();
 	{
 		routine->InitOpaque(tstate,
@@ -313,15 +315,16 @@ tfuncFetchRows(TableFuncScanState *tstate, ExprContext *econtext)
 	}
 	PG_END_TRY();
 
-	/* return to original memory context, and clean up */
-	MemoryContextSwitchTo(oldcxt);
-
+	/* clean up and return to original memory context */
 	if (tstate->opaque != NULL)
 	{
 		routine->DestroyOpaque(tstate);
 		tstate->opaque = NULL;
 	}
 
+	MemoryContextSwitchTo(oldcxt);
+	MemoryContextReset(tstate->perValueCxt);
+
 	return;
 }
 
@@ -423,12 +426,10 @@ tfuncLoadRows(TableFuncScanState *tstate, ExprContext *econtext)
 	Datum	   *values = slot->tts_values;
 	bool	   *nulls = slot->tts_isnull;
 	int			natts = tupdesc->natts;
-	MemoryContext oldcxt;
 	int			ordinalitycol;
 
 	ordinalitycol =
 		((TableFuncScan *) (tstate->ss.ps.plan))->tablefunc->ordinalitycol;
-	oldcxt = MemoryContextSwitchTo(tstate->perValueCxt);
 
 	/*
 	 * Keep requesting rows from the table builder until there aren't any.
@@ -492,9 +493,5 @@ tfuncLoadRows(TableFuncScanState *tstate, ExprContext *econtext)
 		}
 
 		tuplestore_putvalues(tstate->tupstore, tupdesc, values, nulls);
-
-		MemoryContextReset(tstate->perValueCxt);
 	}
-
-	MemoryContextSwitchTo(oldcxt);
 }
