From 31f06781bf6a53de34c20fbaaed0d81367e0e7f4 Mon Sep 17 00:00:00 2001
From: Andres Freund <andres@anarazel.de>
Date: Thu, 3 Aug 2017 15:23:40 -0700
Subject: [PATCH 09/16] Simplify aggregate code a bit.

---
 src/backend/executor/nodeAgg.c | 94 ++++++++++++++++++++----------------------
 src/include/nodes/execnodes.h  |  6 ++-
 2 files changed, 48 insertions(+), 52 deletions(-)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 1783f38f14..7e521459d6 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -522,13 +522,13 @@ static void select_current_set(AggState *aggstate, int setno, bool is_hash);
 static void initialize_phase(AggState *aggstate, int newphase);
 static TupleTableSlot *fetch_input_tuple(AggState *aggstate);
 static void initialize_aggregates(AggState *aggstate,
-					  AggStatePerGroup pergroup,
-					  int numReset);
+					  AggStatePerGroup *pergroups,
+					  bool isHash, int numReset);
 static void advance_transition_function(AggState *aggstate,
 							AggStatePerTrans pertrans,
 							AggStatePerGroup pergroupstate);
-static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup,
-				   AggStatePerGroup *pergroups);
+static void advance_aggregates(AggState *aggstate, AggStatePerGroup *sort_pergroups,
+				   AggStatePerGroup *hash_pergroups);
 static void advance_combine_function(AggState *aggstate,
 						 AggStatePerTrans pertrans,
 						 AggStatePerGroup pergroupstate);
@@ -782,15 +782,14 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
  * If there are multiple grouping sets, we initialize only the first numReset
  * of them (the grouping sets are ordered so that the most specific one, which
  * is reset most often, is first). As a convenience, if numReset is 0, we
- * reinitialize all sets. numReset is -1 to initialize a hashtable entry, in
- * which case the caller must have used select_current_set appropriately.
+ * reinitialize all sets.
  *
  * When called, CurrentMemoryContext should be the per-query context.
  */
 static void
 initialize_aggregates(AggState *aggstate,
-					  AggStatePerGroup pergroup,
-					  int numReset)
+					  AggStatePerGroup *pergroups,
+					  bool isHash, int numReset)
 {
 	int			transno;
 	int			numGroupingSets = Max(aggstate->phase->numsets, 1);
@@ -801,31 +800,19 @@ initialize_aggregates(AggState *aggstate,
 	if (numReset == 0)
 		numReset = numGroupingSets;
 
-	for (transno = 0; transno < numTrans; transno++)
+	for (setno = 0; setno < numReset; setno++)
 	{
-		AggStatePerTrans pertrans = &transstates[transno];
+		AggStatePerGroup pergroup = pergroups[setno];
 
-		if (numReset < 0)
+		select_current_set(aggstate, setno, isHash);
+
+		for (transno = 0; transno < numTrans; transno++)
 		{
-			AggStatePerGroup pergroupstate;
-
-			pergroupstate = &pergroup[transno];
+			AggStatePerTrans pertrans = &transstates[transno];
+			AggStatePerGroup pergroupstate = &pergroup[transno];
 
 			initialize_aggregate(aggstate, pertrans, pergroupstate);
 		}
-		else
-		{
-			for (setno = 0; setno < numReset; setno++)
-			{
-				AggStatePerGroup pergroupstate;
-
-				pergroupstate = &pergroup[transno + (setno * numTrans)];
-
-				select_current_set(aggstate, setno, false);
-
-				initialize_aggregate(aggstate, pertrans, pergroupstate);
-			}
-		}
 	}
 }
 
@@ -965,7 +952,7 @@ advance_transition_function(AggState *aggstate,
  * When called, CurrentMemoryContext should be the per-query context.
  */
 static void
-advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGroup *pergroups)
+advance_aggregates(AggState *aggstate, AggStatePerGroup *sort_pergroups, AggStatePerGroup *hash_pergroups)
 {
 	int			transno;
 	int			setno = 0;
@@ -1002,7 +989,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
 		{
 			/* DISTINCT and/or ORDER BY case */
 			Assert(slot->tts_nvalid >= (pertrans->numInputs + inputoff));
-			Assert(!pergroups);
+			Assert(!hash_pergroups);
 
 			/*
 			 * If the transfn is strict, we want to check for nullity before
@@ -1063,9 +1050,9 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
 				fcinfo->argnull[i + 1] = slot->tts_isnull[i + inputoff];
 			}
 
-			if (pergroup)
+			if (sort_pergroups)
 			{
-				/* advance transition states for ordered grouping */
+				/* advance transition states for ordered grouping  */
 
 				for (setno = 0; setno < numGroupingSets; setno++)
 				{
@@ -1073,13 +1060,13 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
 
 					select_current_set(aggstate, setno, false);
 
-					pergroupstate = &pergroup[transno + (setno * numTrans)];
+					pergroupstate = &sort_pergroups[setno][transno];
 
 					advance_transition_function(aggstate, pertrans, pergroupstate);
 				}
 			}
 
-			if (pergroups)
+			if (hash_pergroups)
 			{
 				/* advance transition states for hashed grouping */
 
@@ -1089,7 +1076,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
 
 					select_current_set(aggstate, setno, true);
 
-					pergroupstate = &pergroups[setno][transno];
+					pergroupstate = &hash_pergroups[setno][transno];
 
 					advance_transition_function(aggstate, pertrans, pergroupstate);
 				}
@@ -2061,8 +2048,8 @@ lookup_hash_entry(AggState *aggstate)
 			MemoryContextAlloc(perhash->hashtable->tablecxt,
 							   sizeof(AggStatePerGroupData) * aggstate->numtrans);
 		/* initialize aggregates for new tuple group */
-		initialize_aggregates(aggstate, (AggStatePerGroup) entry->additional,
-							  -1);
+		initialize_aggregates(aggstate, (AggStatePerGroup*) &entry->additional,
+							  true, 1);
 	}
 
 	return entry;
@@ -2146,7 +2133,7 @@ agg_retrieve_direct(AggState *aggstate)
 	ExprContext *econtext;
 	ExprContext *tmpcontext;
 	AggStatePerAgg peragg;
-	AggStatePerGroup pergroup;
+	AggStatePerGroup *pergroups;
 	AggStatePerGroup *hash_pergroups = NULL;
 	TupleTableSlot *outerslot;
 	TupleTableSlot *firstSlot;
@@ -2169,7 +2156,7 @@ agg_retrieve_direct(AggState *aggstate)
 	tmpcontext = aggstate->tmpcontext;
 
 	peragg = aggstate->peragg;
-	pergroup = aggstate->pergroup;
+	pergroups = aggstate->pergroups;
 	firstSlot = aggstate->ss.ss_ScanTupleSlot;
 
 	/*
@@ -2371,7 +2358,7 @@ agg_retrieve_direct(AggState *aggstate)
 			/*
 			 * Initialize working state for a new input tuple group.
 			 */
-			initialize_aggregates(aggstate, pergroup, numReset);
+			initialize_aggregates(aggstate, pergroups, false, numReset);
 
 			if (aggstate->grp_firstTuple != NULL)
 			{
@@ -2408,9 +2395,9 @@ agg_retrieve_direct(AggState *aggstate)
 						hash_pergroups = NULL;
 
 					if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
-						combine_aggregates(aggstate, pergroup);
+						combine_aggregates(aggstate, pergroups[0]);
 					else
-						advance_aggregates(aggstate, pergroup, hash_pergroups);
+						advance_aggregates(aggstate, pergroups, hash_pergroups);
 
 					/* Reset per-input-tuple context after each tuple */
 					ResetExprContext(tmpcontext);
@@ -2474,7 +2461,7 @@ agg_retrieve_direct(AggState *aggstate)
 
 		finalize_aggregates(aggstate,
 							peragg,
-							pergroup + (currentSet * aggstate->numtrans));
+							pergroups[currentSet]);
 
 		/*
 		 * If there's no row to project right now, we must continue rather
@@ -2715,7 +2702,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 	aggstate->curpertrans = NULL;
 	aggstate->input_done = false;
 	aggstate->agg_done = false;
-	aggstate->pergroup = NULL;
+	aggstate->pergroups = NULL;
 	aggstate->grp_firstTuple = NULL;
 	aggstate->sort_in = NULL;
 	aggstate->sort_out = NULL;
@@ -3019,13 +3006,17 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 
 	if (node->aggstrategy != AGG_HASHED)
 	{
-		AggStatePerGroup pergroup;
+		AggStatePerGroup *pergroups =
+			(AggStatePerGroup*) palloc0(sizeof(AggStatePerGroup)
+										* numGroupingSets);
 
-		pergroup = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData)
-											  * numaggs
-											  * numGroupingSets);
+		for (i = 0; i < numGroupingSets; i++)
+		{
+			pergroups[i] = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData)
+													 * numaggs);
+		}
 
-		aggstate->pergroup = pergroup;
+		aggstate->pergroups = pergroups;
 	}
 
 	/*
@@ -3988,8 +3979,11 @@ ExecReScanAgg(AggState *node)
 		/*
 		 * Reset the per-group state (in particular, mark transvalues null)
 		 */
-		MemSet(node->pergroup, 0,
-			   sizeof(AggStatePerGroupData) * node->numaggs * numGroupingSets);
+		for (setno = 0; setno < numGroupingSets; setno++)
+		{
+			MemSet(node->pergroups[setno], 0,
+				   sizeof(AggStatePerGroupData) * node->numaggs);
+		}
 
 		/* reset to phase 1 */
 		initialize_phase(node, 1);
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index 8ae8179ee7..bc5874f1ee 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -1823,13 +1823,15 @@ typedef struct AggState
 	Tuplesortstate *sort_out;	/* input is copied here for next phase */
 	TupleTableSlot *sort_slot;	/* slot for sort results */
 	/* these fields are used in AGG_PLAIN and AGG_SORTED modes: */
-	AggStatePerGroup pergroup;	/* per-Aggref-per-group working state */
+	AggStatePerGroup *pergroups;	/* grouping set indexed array of per-group
+									 * pointers */
 	HeapTuple	grp_firstTuple; /* copy of first tuple of current group */
 	/* these fields are used in AGG_HASHED and AGG_MIXED modes: */
 	bool		table_filled;	/* hash table filled yet? */
 	int			num_hashes;
 	AggStatePerHash perhash;
-	AggStatePerGroup *hash_pergroup;	/* array of per-group pointers */
+	AggStatePerGroup *hash_pergroup;	/* grouping set indexed array of
+										 * per-group pointers */
 	/* support for evaluation of agg inputs */
 	TupleTableSlot *evalslot;	/* slot for agg inputs */
 	ProjectionInfo *evalproj;	/* projection machinery */
-- 
2.14.1.2.g4274c698f4.dirty

