diff --git a/src/backend/utils/adt/orderedsetaggs.c b/src/backend/utils/adt/orderedsetaggs.c
index d116a63..92fa9f3 100644
--- a/src/backend/utils/adt/orderedsetaggs.c
+++ b/src/backend/utils/adt/orderedsetaggs.c
@@ -47,8 +47,6 @@ typedef struct OSAPerQueryState
 	Aggref	   *aggref;
 	/* Memory context containing this struct and other per-query data: */
 	MemoryContext qcontext;
-	/* Memory context containing per-group data: */
-	MemoryContext gcontext;
 
 	/* These fields are used only when accumulating tuples: */
 
@@ -86,6 +84,8 @@ typedef struct OSAPerGroupState
 {
 	/* Link to the per-query state for this aggregate: */
 	OSAPerQueryState *qstate;
+	/* MemoryContext for per-group data */
+	MemoryContext gcontext;
 	/* Sort object we're accumulating data in: */
 	Tuplesortstate *sortstate;
 	/* Number of normal rows inserted into sortstate: */
@@ -104,6 +104,15 @@ ordered_set_startup(FunctionCallInfo fcinfo, bool use_tuples)
 	OSAPerGroupState *osastate;
 	OSAPerQueryState *qstate;
 	MemoryContext oldcontext;
+	MemoryContext gcontext;
+
+	/*
+	 * Check we're called as aggregate (and not a window function), and
+	 * get the Agg node's group-lifespan context (which might not be the
+	 * same throughout the query, but will be the same for each transval)
+	 */
+	if (AggCheckCallContext(fcinfo, &gcontext) != AGG_CONTEXT_AGGREGATE)
+		elog(ERROR, "ordered-set aggregate called in non-aggregate context");
 
 	/*
 	 * We keep a link to the per-query state in fn_extra; if it's not there,
@@ -114,16 +123,9 @@ ordered_set_startup(FunctionCallInfo fcinfo, bool use_tuples)
 	{
 		Aggref	   *aggref;
 		MemoryContext qcontext;
-		MemoryContext gcontext;
 		List	   *sortlist;
 		int			numSortCols;
 
-		/*
-		 * Check we're called as aggregate (and not a window function), and
-		 * get the Agg node's group-lifespan context
-		 */
-		if (AggCheckCallContext(fcinfo, &gcontext) != AGG_CONTEXT_AGGREGATE)
-			elog(ERROR, "ordered-set aggregate called in non-aggregate context");
 		/* Need the Aggref as well */
 		aggref = AggGetAggref(fcinfo);
 		if (!aggref)
@@ -142,7 +144,6 @@ ordered_set_startup(FunctionCallInfo fcinfo, bool use_tuples)
 		qstate = (OSAPerQueryState *) palloc0(sizeof(OSAPerQueryState));
 		qstate->aggref = aggref;
 		qstate->qcontext = qcontext;
-		qstate->gcontext = gcontext;
 
 		/* Extract the sort information */
 		sortlist = aggref->aggorder;
@@ -259,10 +260,11 @@ ordered_set_startup(FunctionCallInfo fcinfo, bool use_tuples)
 	}
 
 	/* Now build the stuff we need in group-lifespan context */
-	oldcontext = MemoryContextSwitchTo(qstate->gcontext);
+	oldcontext = MemoryContextSwitchTo(gcontext);
 
 	osastate = (OSAPerGroupState *) palloc(sizeof(OSAPerGroupState));
 	osastate->qstate = qstate;
+	osastate->gcontext = gcontext;
 
 	/* Initialize tuplesort object */
 	if (use_tuples)
