From 384845bea72d28952d88e58e55f81aaa5addd930 Mon Sep 17 00:00:00 2001
From: Andres Freund <andres@anarazel.de>
Date: Tue, 12 Jul 2016 01:01:28 -0700
Subject: [PATCH] WIP: Only perform one projection in aggregation.

---
 src/backend/executor/nodeAgg.c | 112 ++++++++++++++++++++++++++++++++---------
 1 file changed, 88 insertions(+), 24 deletions(-)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index f655aec..4499d5f 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -210,6 +210,9 @@ typedef struct AggStatePerTransData
 	 */
 	int			numInputs;
 
+	/* offset of input columns in Aggstate->evalslot */
+	int			inputoff;
+
 	/*
 	 * Number of aggregated input columns to pass to the transfn.  This
 	 * includes the ORDER BY columns for ordered-set aggs, but not for plain
@@ -836,14 +839,20 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 	int			setno = 0;
 	int			numGroupingSets = Max(aggstate->phase->numsets, 1);
 	int			numTrans = aggstate->numtrans;
+	TupleTableSlot *slot = aggstate->evalslot;
+	AggStatePerTrans pertrans;
 
-	for (transno = 0; transno < numTrans; transno++)
+	/* compute input for all aggregates */
+	if (aggstate->evalproj)
+		ExecProjectIntoSlot(aggstate->evalproj, aggstate->evalslot);
+
+	for (transno = 0, pertrans = aggstate->pertrans; transno < numTrans;
+		 transno++, pertrans++)
 	{
-		AggStatePerTrans pertrans = &aggstate->pertrans[transno];
 		ExprState  *filter = pertrans->aggfilter;
 		int			numTransInputs = pertrans->numTransInputs;
 		int			i;
-		TupleTableSlot *slot;
+		int			inputoff = pertrans->inputoff;
 
 		/* Skip anything FILTERed out */
 		if (filter)
@@ -857,13 +866,10 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 				continue;
 		}
 
-		/* Evaluate the current input expressions for this aggregate */
-		slot = ExecProject(pertrans->evalproj, NULL);
-
 		if (pertrans->numSortCols > 0)
 		{
 			/* DISTINCT and/or ORDER BY case */
-			Assert(slot->tts_nvalid == pertrans->numInputs);
+			Assert(slot->tts_nvalid >= pertrans->numInputs);
 
 			/*
 			 * If the transfn is strict, we want to check for nullity before
@@ -876,7 +882,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 			{
 				for (i = 0; i < numTransInputs; i++)
 				{
-					if (slot->tts_isnull[i])
+					if (slot->tts_isnull[i + inputoff])
 						break;
 				}
 				if (i < numTransInputs)
@@ -888,10 +894,22 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 				/* OK, put the tuple into the tuplesort object */
 				if (pertrans->numInputs == 1)
 					tuplesort_putdatum(pertrans->sortstates[setno],
-									   slot->tts_values[0],
-									   slot->tts_isnull[0]);
+									   slot->tts_values[inputoff],
+									   slot->tts_isnull[inputoff]);
 				else
-					tuplesort_puttupleslot(pertrans->sortstates[setno], slot);
+				{
+					/* copy slot contents starting from inputoff, into sort slot */
+					ExecClearTuple(pertrans->evalslot);
+					memcpy(pertrans->evalslot->tts_values,
+						   &slot->tts_values[inputoff],
+						   pertrans->numInputs * sizeof(Datum));
+					memcpy(pertrans->evalslot->tts_isnull,
+						   &slot->tts_isnull[inputoff],
+						   pertrans->numInputs * sizeof(bool));
+					pertrans->evalslot->tts_nvalid = pertrans->numInputs;
+					ExecStoreVirtualTuple(pertrans->evalslot);
+					tuplesort_puttupleslot(pertrans->sortstates[setno], pertrans->evalslot);
+				}
 			}
 		}
 		else
@@ -904,8 +922,8 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 			Assert(slot->tts_nvalid >= numTransInputs);
 			for (i = 0; i < numTransInputs; i++)
 			{
-				fcinfo->arg[i + 1] = slot->tts_values[i];
-				fcinfo->argnull[i + 1] = slot->tts_isnull[i];
+				fcinfo->arg[i + 1] = slot->tts_values[i + inputoff];
+				fcinfo->argnull[i + 1] = slot->tts_isnull[i + inputoff];
 			}
 
 			for (setno = 0; setno < numGroupingSets; setno++)
@@ -932,20 +950,23 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 {
 	int			transno;
 	int			numTrans = aggstate->numtrans;
+	TupleTableSlot *slot = NULL;
 
 	/* combine not supported with grouping sets */
 	Assert(aggstate->phase->numsets == 0);
 
+	/* compute input for all aggregates */
+	if (aggstate->evalproj)
+		slot = ExecProject(aggstate->evalproj, NULL);
+
 	for (transno = 0; transno < numTrans; transno++)
 	{
 		AggStatePerTrans pertrans = &aggstate->pertrans[transno];
 		AggStatePerGroup pergroupstate = &pergroup[transno];
-		TupleTableSlot *slot;
 		FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
+		int			inputoff = pertrans->inputoff;
 
-		/* Evaluate the current input expressions for this aggregate */
-		slot = ExecProject(pertrans->evalproj, NULL);
-		Assert(slot->tts_nvalid >= 1);
+		Assert(slot->tts_nvalid + inputoff >= 1);
 
 		/*
 		 * deserialfn_oid will be set if we must deserialize the input state
@@ -954,18 +975,18 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 		if (OidIsValid(pertrans->deserialfn_oid))
 		{
 			/* Don't call a strict deserialization function with NULL input */
-			if (pertrans->deserialfn.fn_strict && slot->tts_isnull[0])
+			if (pertrans->deserialfn.fn_strict && slot->tts_isnull[0 + inputoff])
 			{
-				fcinfo->arg[1] = slot->tts_values[0];
-				fcinfo->argnull[1] = slot->tts_isnull[0];
+				fcinfo->arg[1] = slot->tts_values[0 + inputoff];
+				fcinfo->argnull[1] = slot->tts_isnull[0 + inputoff];
 			}
 			else
 			{
 				FunctionCallInfo dsinfo = &pertrans->deserialfn_fcinfo;
 				MemoryContext oldContext;
 
-				dsinfo->arg[0] = slot->tts_values[0];
-				dsinfo->argnull[0] = slot->tts_isnull[0];
+				dsinfo->arg[0] = slot->tts_values[0 + inputoff];
+				dsinfo->argnull[0] = slot->tts_isnull[0 + inputoff];
 				/* Dummy second argument for type-safety reasons */
 				dsinfo->arg[1] = PointerGetDatum(NULL);
 				dsinfo->argnull[1] = false;
@@ -984,8 +1005,8 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 		}
 		else
 		{
-			fcinfo->arg[1] = slot->tts_values[0];
-			fcinfo->argnull[1] = slot->tts_isnull[0];
+			fcinfo->arg[1] = slot->tts_values[0 + inputoff];
+			fcinfo->argnull[1] = slot->tts_isnull[0 + inputoff];
 		}
 
 		advance_combine_function(aggstate, pertrans, pergroupstate);
@@ -2890,6 +2911,49 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 	aggstate->numaggs = aggno + 1;
 	aggstate->numtrans = transno + 1;
 
+	/*
+	 *
+	 */
+	{
+		List *inputeval = NIL;
+		int offset = 0;
+
+		for (transno = 0; transno < aggstate->numtrans; transno++)
+		{
+			AggStatePerTrans pertrans = &pertransstates[transno];
+			ListCell *arg;
+
+			pertrans->inputoff = offset;
+
+			/*
+			 * Adjust resno in a copied target entry, to point in the combined
+			 * slot.
+			 */
+			foreach(arg, pertrans->aggref->args)
+			{
+				TargetEntry *tle;
+
+				Assert(IsA(lfirst(arg), TargetEntry));
+				tle = copyObject(lfirst(arg));
+				tle->resno += offset;
+
+				inputeval = lappend(inputeval, tle);
+			}
+
+			offset += list_length(pertrans->aggref->args);
+		}
+
+		aggstate->evaldesc = ExecTypeFromTL(inputeval, false);
+
+		aggstate->evalslot = ExecInitExtraTupleSlot(estate);
+
+		aggstate->evalproj = ExecBuildProjectionInfo(inputeval,
+													 aggstate->tmpcontext,
+													 aggstate->evalslot,
+													 (PlanState *) aggstate,
+													 NULL);
+		ExecSetSlotDescriptor(aggstate->evalslot, aggstate->evaldesc);
+	}
 	return aggstate;
 }
 
-- 
2.9.3

