From 8743b443d91b51371fe5df17795a88210aa0264a Mon Sep 17 00:00:00 2001
From: Peter Geoghegan <pg@bowt.ie>
Date: Thu, 14 Dec 2017 21:22:41 -0800
Subject: [PATCH] Fix double free of tuple with grouping sets/tuplesort.c.

Have tuplesort_gettupleslot() copy the contents of its current table
slot into caller's memory context if and when caller is obligated to
free memory itself.  Otherwise, continue to allocate the memory in
tuplesort-owned context, and never tell the caller to free it
themselves.  Repeat this for all "fetch tuple" routines, to be
consistent.

This fixes a crash that can occur on at least 9.5 and 9.6, where some
grouping sets queries may call tuplesort_end() at node shutdown time,
and free tuple memory that the executor imagines it is responsible for,
and must free within ExecClearTuple().

In passing, change the memory context switch point within
tuplesort_getdatum() to matchtuplesort_gettupleslot().  It's not clear
that this can cause a crash, but it's still wrong for pass-by-value
Datum tuplesorts.

Only the tuplesort_getdatum() fix is required for v10, because v10
changed the contract that callers have with tuple-fetch tuplesort
routines to something more robust.  Fetch tuple routines create full
copies of tuples iff callers explicitly opt in on v10, accepting the
risk that it won't be safe to reference the tuple in the very next call
to the fetch routine.

Author: Peter Geoghegan
Reported-By: Bernd Helmle
Analyzed-By: Peter Geoghegan, Andreas Seltenreich, Bernd Helmle
Discussion: https://www.postgresql.org/message-id/1512661638.9720.34.camel@oopsware.de
---
 src/backend/utils/sort/tuplesort.c | 191 ++++++++++++++++++++++++-------------
 1 file changed, 124 insertions(+), 67 deletions(-)

diff --git a/src/backend/utils/sort/tuplesort.c b/src/backend/utils/sort/tuplesort.c
index 4dd5407..f7f8b83 100644
--- a/src/backend/utils/sort/tuplesort.c
+++ b/src/backend/utils/sort/tuplesort.c
@@ -264,6 +264,7 @@ struct Tuplesortstate
 	int			tapeRange;		/* maxTapes-1 (Knuth's P) */
 	MemoryContext sortcontext;	/* memory context holding most sort data */
 	MemoryContext tuplecontext; /* sub-context of sortcontext for tuple data */
+	MemoryContext callercontext; /* tuplesort caller's own context */
 	LogicalTapeSet *tapeset;	/* logtape.c object for tapes in a temp file */
 
 	/*
@@ -390,10 +391,11 @@ struct Tuplesortstate
 	 * sorted order).
 	 */
 	int64		spacePerTape;	/* Space (memory) for tuples (not slots) */
+	int64	   *mergeosize;		/* "overflow" allocation size */
 	char	  **mergetuples;	/* Each tape's memory allocation */
 	char	  **mergecurrent;	/* Current offset into each tape's memory */
 	char	  **mergetail;		/* Last item's start point for each tape */
-	char	  **mergeoverflow;	/* Retail palloc() "overflow" for each tape */
+	char	  **mergeoverflow;	/* "overflow" scratch buffer for each tape */
 
 	/*
 	 * Variables for Algorithm D.  Note that destTape is a "logical" tape
@@ -534,7 +536,11 @@ struct Tuplesortstate
  * Note that this places a responsibility on readtup and copytup routines
  * to use the right memory context for these tuples (and to not use the
  * reset context for anything whose lifetime needs to span multiple
- * external sort runs).
+ * external sort runs).  Note also that readtup routines need to make sure
+ * that memory that caller must free is in fact allocated in caller
+ * context, not in either of the two tuplesort memory contexts.  Callers
+ * must either have full ownership of memory for returned tuples, or
+ * leave it all entirely up to us.
  */
 
 /* When using this macro, beware of double evaluation of len */
@@ -691,6 +697,7 @@ tuplesort_begin_common(int workMem, bool randomAccess)
 	state->availMem = state->allowedMem;
 	state->sortcontext = sortcontext;
 	state->tuplecontext = tuplecontext;
+	state->callercontext = NULL;
 	state->tapeset = NULL;
 
 	state->memtupcount = 0;
@@ -721,6 +728,12 @@ tuplesort_begin_common(int workMem, bool randomAccess)
 
 	state->result_tape = -1;	/* flag that result tape has not been formed */
 
+	/*
+	 * Let caller initialize state->callercontext.  The usage pattern for
+	 * state->callercontext is that it must be set and used in externally
+	 * callable functions.  Functions with local linkage do a conventional
+	 * switch-and-unswitch, to avoid clobbering state->callercontext.
+	 */
 	MemoryContextSwitchTo(oldcontext);
 
 	return state;
@@ -733,11 +746,13 @@ tuplesort_begin_heap(TupleDesc tupDesc,
 					 bool *nullsFirstFlags,
 					 int workMem, bool randomAccess)
 {
-	Tuplesortstate *state = tuplesort_begin_common(workMem, randomAccess);
-	MemoryContext oldcontext;
+	Tuplesortstate *state;
+	MemoryContext oldcontext = CurrentMemoryContext;
 	int			i;
 
-	oldcontext = MemoryContextSwitchTo(state->sortcontext);
+	state = tuplesort_begin_common(workMem, randomAccess);
+	MemoryContextSwitchTo(state->sortcontext);
+	state->callercontext = oldcontext;
 
 	AssertArg(nkeys > 0);
 
@@ -804,14 +819,16 @@ tuplesort_begin_cluster(TupleDesc tupDesc,
 						Relation indexRel,
 						int workMem, bool randomAccess)
 {
-	Tuplesortstate *state = tuplesort_begin_common(workMem, randomAccess);
+	Tuplesortstate *state;
 	ScanKey		indexScanKey;
-	MemoryContext oldcontext;
+	MemoryContext oldcontext = CurrentMemoryContext;
 	int			i;
 
 	Assert(indexRel->rd_rel->relam == BTREE_AM_OID);
 
-	oldcontext = MemoryContextSwitchTo(state->sortcontext);
+	state = tuplesort_begin_common(workMem, randomAccess);
+	MemoryContextSwitchTo(state->sortcontext);
+	state->callercontext = oldcontext;
 
 #ifdef TRACE_SORT
 	if (trace_sort)
@@ -898,12 +915,14 @@ tuplesort_begin_index_btree(Relation heapRel,
 							bool enforceUnique,
 							int workMem, bool randomAccess)
 {
-	Tuplesortstate *state = tuplesort_begin_common(workMem, randomAccess);
+	Tuplesortstate *state;
 	ScanKey		indexScanKey;
-	MemoryContext oldcontext;
+	MemoryContext oldcontext = CurrentMemoryContext;
 	int			i;
 
-	oldcontext = MemoryContextSwitchTo(state->sortcontext);
+	state = tuplesort_begin_common(workMem, randomAccess);
+	MemoryContextSwitchTo(state->sortcontext);
+	state->callercontext = oldcontext;
 
 #ifdef TRACE_SORT
 	if (trace_sort)
@@ -974,10 +993,12 @@ tuplesort_begin_index_hash(Relation heapRel,
 						   uint32 hash_mask,
 						   int workMem, bool randomAccess)
 {
-	Tuplesortstate *state = tuplesort_begin_common(workMem, randomAccess);
-	MemoryContext oldcontext;
+	Tuplesortstate *state;
+	MemoryContext oldcontext = CurrentMemoryContext;
 
-	oldcontext = MemoryContextSwitchTo(state->sortcontext);
+	state = tuplesort_begin_common(workMem, randomAccess);
+	MemoryContextSwitchTo(state->sortcontext);
+	state->callercontext = oldcontext;
 
 #ifdef TRACE_SORT
 	if (trace_sort)
@@ -1010,12 +1031,14 @@ tuplesort_begin_datum(Oid datumType, Oid sortOperator, Oid sortCollation,
 					  bool nullsFirstFlag,
 					  int workMem, bool randomAccess)
 {
-	Tuplesortstate *state = tuplesort_begin_common(workMem, randomAccess);
-	MemoryContext oldcontext;
+	Tuplesortstate *state;
+	MemoryContext oldcontext = CurrentMemoryContext;
 	int16		typlen;
 	bool		typbyval;
 
-	oldcontext = MemoryContextSwitchTo(state->sortcontext);
+	state = tuplesort_begin_common(workMem, randomAccess);
+	MemoryContextSwitchTo(state->sortcontext);
+	state->callercontext = oldcontext;
 
 #ifdef TRACE_SORT
 	if (trace_sort)
@@ -1193,7 +1216,8 @@ tuplesort_end(Tuplesortstate *state)
 
 	/*
 	 * Free the per-sort memory context, thereby releasing all working memory,
-	 * including the Tuplesortstate struct itself.
+	 * including the Tuplesortstate struct itself.  Caller tuples where
+	 * should_free was set to TRUE remain allocated, however.
 	 */
 	MemoryContextDelete(state->sortcontext);
 }
@@ -1335,9 +1359,10 @@ noalloc:
 void
 tuplesort_puttupleslot(Tuplesortstate *state, TupleTableSlot *slot)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
 	SortTuple	stup;
 
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
+
 	/*
 	 * Copy the given tuple into memory we control, and decrease availMem.
 	 * Then call the common code.
@@ -1346,7 +1371,7 @@ tuplesort_puttupleslot(Tuplesortstate *state, TupleTableSlot *slot)
 
 	puttuple_common(state, &stup);
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 }
 
 /*
@@ -1357,18 +1382,18 @@ tuplesort_puttupleslot(Tuplesortstate *state, TupleTableSlot *slot)
 void
 tuplesort_putheaptuple(Tuplesortstate *state, HeapTuple tup)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
 	SortTuple	stup;
 
 	/*
 	 * Copy the given tuple into memory we control, and decrease availMem.
 	 * Then call the common code.
 	 */
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
 	COPYTUP(state, &stup, (void *) tup);
 
 	puttuple_common(state, &stup);
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 }
 
 /*
@@ -1380,11 +1405,11 @@ tuplesort_putindextuplevalues(Tuplesortstate *state, Relation rel,
 							  ItemPointer self, Datum *values,
 							  bool *isnull)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->tuplecontext);
 	SortTuple	stup;
 	Datum		original;
 	IndexTuple	tuple;
 
+	state->callercontext = MemoryContextSwitchTo(state->tuplecontext);
 	stup.tuple = index_form_tuple(RelationGetDescr(rel), values, isnull);
 	tuple = ((IndexTuple) stup.tuple);
 	tuple->t_tid = *self;
@@ -1445,7 +1470,7 @@ tuplesort_putindextuplevalues(Tuplesortstate *state, Relation rel,
 
 	puttuple_common(state, &stup);
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 }
 
 /*
@@ -1456,7 +1481,6 @@ tuplesort_putindextuplevalues(Tuplesortstate *state, Relation rel,
 void
 tuplesort_putdatum(Tuplesortstate *state, Datum val, bool isNull)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->tuplecontext);
 	SortTuple	stup;
 
 	/*
@@ -1470,6 +1494,7 @@ tuplesort_putdatum(Tuplesortstate *state, Datum val, bool isNull)
 	 * abbreviated value if abbreviation is happening, otherwise it's
 	 * identical to stup.tuple.
 	 */
+	state->callercontext = MemoryContextSwitchTo(state->tuplecontext);
 
 	if (isNull || !state->tuples)
 	{
@@ -1529,7 +1554,7 @@ tuplesort_putdatum(Tuplesortstate *state, Datum val, bool isNull)
 
 	puttuple_common(state, &stup);
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 }
 
 /*
@@ -1743,7 +1768,7 @@ consider_abort_common(Tuplesortstate *state)
 void
 tuplesort_performsort(Tuplesortstate *state)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
 
 #ifdef TRACE_SORT
 	if (trace_sort)
@@ -1816,7 +1841,7 @@ tuplesort_performsort(Tuplesortstate *state)
 	}
 #endif
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 }
 
 /*
@@ -1824,6 +1849,9 @@ tuplesort_performsort(Tuplesortstate *state)
  * direction into *stup.  Returns FALSE if no more tuples.
  * If *should_free is set, the caller must pfree stup.tuple when done with it.
  * Otherwise, caller should not use tuple following next call here.
+ *
+ * Note:  Caller can expect tuple to be allocated in their own memory context
+ * when should_free is TRUE.
  */
 static bool
 tuplesort_gettuple_common(Tuplesortstate *state, bool forward,
@@ -2046,22 +2074,24 @@ tuplesort_gettuple_common(Tuplesortstate *state, bool forward,
  * NULL value in leading attribute will set abbreviated value to zeroed
  * representation, which caller may rely on in abbreviated inequality check.
  *
- * The slot receives a copied tuple (sometimes allocated in caller memory
- * context) that will stay valid regardless of future manipulations of the
- * tuplesort's state.
+ * The slot receives a copied tuple that will stay valid regardless of future
+ * manipulations of the tuplesort's state.  This includes calls to
+ * tuplesort_end() -- tuple will be allocated in caller's own context to make
+ * this work.
  */
 bool
 tuplesort_gettupleslot(Tuplesortstate *state, bool forward,
 					   TupleTableSlot *slot, Datum *abbrev)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
 	SortTuple	stup;
 	bool		should_free;
 
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
+
 	if (!tuplesort_gettuple_common(state, forward, &stup, &should_free))
 		stup.tuple = NULL;
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 
 	if (stup.tuple)
 	{
@@ -2089,18 +2119,20 @@ tuplesort_gettupleslot(Tuplesortstate *state, bool forward,
  * Returns NULL if no more tuples.  If *should_free is set, the
  * caller must pfree the returned tuple when done with it.
  * If it is not set, caller should not use tuple following next
- * call here.
+ * call here, and can expect tuple to be in their own memory
+ * context.
  */
 HeapTuple
 tuplesort_getheaptuple(Tuplesortstate *state, bool forward, bool *should_free)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
 	SortTuple	stup;
 
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
+
 	if (!tuplesort_gettuple_common(state, forward, &stup, should_free))
 		stup.tuple = NULL;
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 
 	return stup.tuple;
 }
@@ -2116,13 +2148,14 @@ IndexTuple
 tuplesort_getindextuple(Tuplesortstate *state, bool forward,
 						bool *should_free)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
 	SortTuple	stup;
 
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
+
 	if (!tuplesort_gettuple_common(state, forward, &stup, should_free))
 		stup.tuple = NULL;
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 
 	return (IndexTuple) stup.tuple;
 }
@@ -2132,7 +2165,7 @@ tuplesort_getindextuple(Tuplesortstate *state, bool forward,
  * Returns FALSE if no more datums.
  *
  * If the Datum is pass-by-ref type, the returned value is freshly palloc'd
- * and is now owned by the caller.
+ * in caller's context, and is now owned by the caller.
  *
  * Caller may optionally be passed back abbreviated value (on TRUE return
  * value) when abbreviation was used, which can be used to cheaply avoid
@@ -2145,16 +2178,19 @@ bool
 tuplesort_getdatum(Tuplesortstate *state, bool forward,
 				   Datum *val, bool *isNull, Datum *abbrev)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
 	SortTuple	stup;
 	bool		should_free;
 
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
+
 	if (!tuplesort_gettuple_common(state, forward, &stup, &should_free))
 	{
-		MemoryContextSwitchTo(oldcontext);
+		MemoryContextSwitchTo(state->callercontext);
 		return false;
 	}
 
+	MemoryContextSwitchTo(state->callercontext);
+
 	/* Record abbreviated key for caller */
 	if (state->sortKeys->abbrev_converter && abbrev)
 		*abbrev = stup.datum1;
@@ -2175,8 +2211,6 @@ tuplesort_getdatum(Tuplesortstate *state, bool forward,
 		*isNull = false;
 	}
 
-	MemoryContextSwitchTo(oldcontext);
-
 	return true;
 }
 
@@ -2188,8 +2222,6 @@ tuplesort_getdatum(Tuplesortstate *state, bool forward,
 bool
 tuplesort_skiptuples(Tuplesortstate *state, int64 ntuples, bool forward)
 {
-	MemoryContext oldcontext;
-
 	/*
 	 * We don't actually support backwards skip yet, because no callers need
 	 * it.  The API is designed to allow for that later, though.
@@ -2225,7 +2257,7 @@ tuplesort_skiptuples(Tuplesortstate *state, int64 ntuples, bool forward)
 			 * We could probably optimize these cases better, but for now it's
 			 * not worth the trouble.
 			 */
-			oldcontext = MemoryContextSwitchTo(state->sortcontext);
+			state->callercontext = MemoryContextSwitchTo(state->sortcontext);
 			while (ntuples-- > 0)
 			{
 				SortTuple	stup;
@@ -2234,14 +2266,14 @@ tuplesort_skiptuples(Tuplesortstate *state, int64 ntuples, bool forward)
 				if (!tuplesort_gettuple_common(state, forward,
 											   &stup, &should_free))
 				{
-					MemoryContextSwitchTo(oldcontext);
+					MemoryContextSwitchTo(state->callercontext);
 					return false;
 				}
 				if (should_free && stup.tuple)
 					pfree(stup.tuple);
 				CHECK_FOR_INTERRUPTS();
 			}
-			MemoryContextSwitchTo(oldcontext);
+			MemoryContextSwitchTo(state->callercontext);
 			return true;
 
 		default:
@@ -2363,6 +2395,7 @@ inittapes(Tuplesortstate *state)
 	state->mergelast = (int *) palloc0(maxTapes * sizeof(int));
 	state->mergeavailslots = (int *) palloc0(maxTapes * sizeof(int));
 	state->mergeavailmem = (int64 *) palloc0(maxTapes * sizeof(int64));
+	state->mergeosize = (int64 *) palloc0(maxTapes * sizeof(int64 *));
 	state->mergetuples = (char **) palloc0(maxTapes * sizeof(char *));
 	state->mergecurrent = (char **) palloc0(maxTapes * sizeof(char *));
 	state->mergetail = (char **) palloc0(maxTapes * sizeof(char *));
@@ -2978,6 +3011,7 @@ mergebatch(Tuplesortstate *state, int64 spacePerTape)
 		state->mergetuples[srcTape] = mergetuples;
 		state->mergecurrent[srcTape] = mergetuples;
 		state->mergetail[srcTape] = mergetuples;
+		state->mergeosize[srcTape] = 0;
 		state->mergeoverflow[srcTape] = NULL;
 	}
 
@@ -2993,7 +3027,8 @@ mergebatch(Tuplesortstate *state, int64 spacePerTape)
  * reset to indicate that all memory may be reused.
  *
  * This routine must deal with fixing up the tuple that is about to be returned
- * to the client, due to "overflow" allocations.
+ * to the client, due to "overflow" allocations, while making sure that tuple
+ * memory ends up in caller's own context.
  */
 static void
 mergebatchone(Tuplesortstate *state, int srcTape, SortTuple *rtup,
@@ -3033,17 +3068,28 @@ mergebatchone(Tuplesortstate *state, int srcTape, SortTuple *rtup,
 		/*
 		 * Handle an "overflow" retail palloc.
 		 *
-		 * This is needed when we run out of tuple memory for the tape.
+		 * This is needed when we run out of tuple memory for the tape.  Note
+		 * that we actually free the memory used for the overflow allocation
+		 * directly, and provide a copy allocated in caller's context to
+		 * caller.
 		 */
 		state->mergecurrent[srcTape] = state->mergetuples[srcTape];
 		state->mergetail[srcTape] = state->mergetuples[srcTape];
 
 		if (rtup->tuple)
 		{
+			Assert(state->mergeosize[srcTape] > 0);
 			Assert(rtup->tuple == (void *) state->mergeoverflow[srcTape]);
+			rtup->tuple = MemoryContextAlloc(state->callercontext,
+											 state->mergeosize[srcTape]);
+			MOVETUP(rtup->tuple, state->mergeoverflow[srcTape],
+					state->mergeosize[srcTape]);
 			/* Caller should free palloc'd tuple */
 			*should_free = true;
+			/* Free scratch buffer */
+			pfree(state->mergeoverflow[srcTape]);
 		}
+		state->mergeosize[srcTape] = 0;
 		state->mergeoverflow[srcTape] = NULL;
 	}
 }
@@ -3064,7 +3110,7 @@ mergebatchfreetape(Tuplesortstate *state, int srcTape, SortTuple *rtup,
 	Assert(state->status == TSS_FINALMERGE);
 
 	/*
-	 * Tuple may or may not already be an overflow allocation from
+	 * Tuple may or may not already be caller-owned allocation from
 	 * mergebatchone()
 	 */
 	if (!*should_free && rtup->tuple)
@@ -3072,17 +3118,15 @@ mergebatchfreetape(Tuplesortstate *state, int srcTape, SortTuple *rtup,
 		/*
 		 * Final tuple still in tape's batch allocation.
 		 *
-		 * Return palloc()'d copy to caller, and have it freed in a similar
-		 * manner to overflow allocation.  Otherwise, we'd free batch memory
-		 * and pass back a pointer to garbage.  Note that we deliberately
-		 * allocate this in the parent tuplesort context, to be on the safe
-		 * side.
+		 * Return palloc()'d copy to caller, in caller's own memory context.
+		 * Otherwise, we'd free batch memory and pass back a pointer to
+		 * garbage.
 		 */
 		Size		tuplen;
 		void	   *oldTuple = rtup->tuple;
 
 		tuplen = state->mergecurrent[srcTape] - state->mergetail[srcTape];
-		rtup->tuple = MemoryContextAlloc(state->sortcontext, tuplen);
+		rtup->tuple = MemoryContextAlloc(state->callercontext, tuplen);
 		MOVETUP(rtup->tuple, oldTuple, tuplen);
 		*should_free = true;
 	}
@@ -3138,11 +3182,14 @@ mergebatchalloc(Tuplesortstate *state, int tapenum, Size tuplen)
 		 * will be detected quickly, in a similar fashion to a LACKMEM()
 		 * condition, and should not happen again before a new round of
 		 * preloading for caller's tape.  Note that we deliberately allocate
-		 * this in the parent tuplesort context, to be on the safe side.
+		 * this in the parent tuplesort context, to be on the safe side.  Note
+		 * also that tuplesort caller will ultimately get their own copy of
+		 * this, in their own memory context.
 		 *
 		 * Sometimes, this does not happen because merging runs out of slots
 		 * before running out of memory.
 		 */
+		state->mergeosize[tapenum] = tuplen;
 		ret = state->mergeoverflow[tapenum] =
 			MemoryContextAlloc(state->sortcontext, tuplen);
 	}
@@ -3460,7 +3507,7 @@ dumpbatch(Tuplesortstate *state, bool alltuples)
 void
 tuplesort_rescan(Tuplesortstate *state)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
 
 	Assert(state->randomAccess);
 
@@ -3486,7 +3533,7 @@ tuplesort_rescan(Tuplesortstate *state)
 			break;
 	}
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 }
 
 /*
@@ -3495,7 +3542,7 @@ tuplesort_rescan(Tuplesortstate *state)
 void
 tuplesort_markpos(Tuplesortstate *state)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
 
 	Assert(state->randomAccess);
 
@@ -3517,7 +3564,7 @@ tuplesort_markpos(Tuplesortstate *state)
 			break;
 	}
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 }
 
 /*
@@ -3527,7 +3574,7 @@ tuplesort_markpos(Tuplesortstate *state)
 void
 tuplesort_restorepos(Tuplesortstate *state)
 {
-	MemoryContext oldcontext = MemoryContextSwitchTo(state->sortcontext);
+	state->callercontext = MemoryContextSwitchTo(state->sortcontext);
 
 	Assert(state->randomAccess);
 
@@ -3550,7 +3597,7 @@ tuplesort_restorepos(Tuplesortstate *state)
 			break;
 	}
 
-	MemoryContextSwitchTo(oldcontext);
+	MemoryContextSwitchTo(state->callercontext);
 }
 
 /*
@@ -3887,7 +3934,8 @@ markrunend(Tuplesortstate *state, int tapenum)
  * from tape's batch allocation.  Otherwise, callers must pfree() or
  * reset tuple child memory context, and account for that with a
  * FREEMEM().  Currently, this only ever needs to happen in WRITETUP()
- * routines.
+ * routines, or when memory is allocated in caller's own context, and
+ * so must be freed on caller's own terms.
  */
 static void *
 readtup_alloc(Tuplesortstate *state, int tapenum, Size tuplen)
@@ -3903,15 +3951,24 @@ readtup_alloc(Tuplesortstate *state, int tapenum, Size tuplen)
 		 */
 		return mergebatchalloc(state, tapenum, tuplen);
 	}
-	else
+	else if (state->status == TSS_BUILDRUNS)
 	{
 		char	   *ret;
 
-		/* Batch allocation yet to be performed */
+		/*
+		 * Batch allocation yet to be performed (non-final merge).  Use
+		 * tuplecontext to avoid fragmentation, much like tuplesort_putXXX and
+		 * copytup_XXX routines during initial run generation.
+		 */
 		ret = MemoryContextAlloc(state->tuplecontext, tuplen);
 		USEMEM(state, GetMemoryChunkSpace(ret));
 		return ret;
 	}
+	else
+	{
+		/* Must be allocated within tuplesort caller's context */
+		return MemoryContextAlloc(state->callercontext, tuplen);
+	}
 }
 
 
-- 
2.7.4

