diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index 2bf48c5..27020d8 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -150,11 +150,16 @@ #include "utils/tuplesort.h" #include "utils/datum.h" - /* - * AggStatePerAggData - per-aggregate working state for the Agg scan + * AggStatePerAggStateData + * Stores data about an aggregate state and how the aggregate state must + * be calculated. This struct does not store anything which has any + * concept of how to produce the final aggregate result. In order to + * calculate the final result we must make use of an AggStatePerAggData. + * The reason for this is so that we can share an aggregate state between + * different aggregate functions, in order to save duplicating work. */ -typedef struct AggStatePerAggData +typedef struct AggStatePerAggStateData { /* * These values are set up during ExecInitAgg() and do not change @@ -186,25 +191,14 @@ typedef struct AggStatePerAggData */ int numTransInputs; - /* - * Number of arguments to pass to the finalfn. This is always at least 1 - * (the transition state value) plus any ordered-set direct args. If the - * finalfn wants extra args then we pass nulls corresponding to the - * aggregated input columns. - */ - int numFinalArgs; - - /* Oids of transfer functions */ + /* Oid of transfer function */ Oid transfn_oid; - Oid finalfn_oid; /* may be InvalidOid */ /* - * fmgr lookup data for transfer functions --- only valid when - * corresponding oid is not InvalidOid. Note in particular that fn_strict - * flags are kept here. + * fmgr lookup data for transfer function. + * Note in particular that the fn_strict flag is kept here. */ FmgrInfo transfn; - FmgrInfo finalfn; /* Input collation derived for aggregate */ Oid aggCollation; @@ -288,7 +282,44 @@ typedef struct AggStatePerAggData * worth the extra space consumption. */ FunctionCallInfoData transfn_fcinfo; -} AggStatePerAggData; +} AggStatePerAggStateData; + +/* + * AggStatePerAggData + * Stores required details on how to produce a final aggregate result. + * To be of any use this must make use of an AggStatePerAggStateData + * before any actual result can be produced. Logical separation of the + * state and the final function data stored here makes sense as it allows + * us to re-use an aggregate's state for more than one aggregate function + * providing they share the same transfn and initValue. + */ +typedef struct AggStatePerAggData { + /* + * These values are set up during ExecInitAgg() and do not change + * thereafter: + */ + + /* index to the corresponding state which this agg should use */ + int stateno; + + /* Optional Oid of final function (may be InvalidOid) */ + Oid finalfn_oid; + + /* + * fmgr lookup data for final function --- only valid when + * finalfn_oid oid is not InvalidOid. + */ + FmgrInfo finalfn; + + /* + * Number of arguments to pass to the finalfn. This is always at least 1 + * (the transition state value) plus any ordered-set direct args. If the + * finalfn wants extra args then we pass nulls corresponding to the + * aggregated input columns. + */ + int numFinalArgs; + +} AggStatePerAggData; /* * AggStatePerGroupData - per-aggregate-per-group working state @@ -358,25 +389,35 @@ typedef struct AggHashEntryData AggStatePerGroupData pergroup[FLEXIBLE_ARRAY_MEMBER]; } AggHashEntryData; +/* + * enum states to mark compatibility between aggregate functions. + * These are used to enable various optimizations which are applied to similar + * aggregate functions. See comments for find_compatible_aggref() for details. + */ +typedef enum AggRefCompatibility { + AGGREF_NO_MATCH = 0, /* state is not compatible between aggregates. */ + AGGREF_STATE_MATCH, /* aggregates may share state only. */ + AGGREF_EXACT_MATCH /* aggregates may share state and finalfn. */ +} AggRefCompatibility; static void initialize_phase(AggState *aggstate, int newphase); static TupleTableSlot *fetch_input_tuple(AggState *aggstate); static void initialize_aggregates(AggState *aggstate, - AggStatePerAgg peragg, + AggStatePerAggState peraggstates, AggStatePerGroup pergroup, int numReset); static void advance_transition_function(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAggState peraggstate, AggStatePerGroup pergroupstate); static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup); static void process_ordered_aggregate_single(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAggState peraggstate, AggStatePerGroup pergroupstate); static void process_ordered_aggregate_multi(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAggState peraggstate, AggStatePerGroup pergroupstate); static void finalize_aggregate(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAgg peragg, AggStatePerGroup pergroupstate, Datum *resultVal, bool *resultIsNull); static void prepare_projection_slot(AggState *aggstate, @@ -396,6 +437,10 @@ static TupleTableSlot *agg_retrieve_direct(AggState *aggstate); static void agg_fill_hash_table(AggState *aggstate); static TupleTableSlot *agg_retrieve_hash_table(AggState *aggstate); static Datum GetAggInitVal(Datum textInitVal, Oid transtype); +static AggRefCompatibility find_compatible_aggref(Aggref *newagg, + AggState *aggstate, int lastaggno, int *foundaggno); +static AggRefCompatibility aggref_has_compatible_states(Aggref *newagg, + AggStatePerAgg peragg, AggStatePerAggState peraggstate); /* @@ -498,7 +543,7 @@ fetch_input_tuple(AggState *aggstate) * When called, CurrentMemoryContext should be the per-query context. */ static void -initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, +initialize_aggregate(AggState *aggstate, AggStatePerAggState peraggstate, AggStatePerGroup pergroupstate) { /* @@ -569,7 +614,7 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, } /* - * Initialize all aggregates for a new group of input values. + * Initialize all aggregate states for a new group of input values. * * If there are multiple grouping sets, we initialize only the first numReset * of them (the grouping sets are ordered so that the most specific one, which @@ -580,26 +625,26 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, */ static void initialize_aggregates(AggState *aggstate, - AggStatePerAgg peragg, + AggStatePerAggState peraggstates, AggStatePerGroup pergroup, int numReset) { - int aggno; + int stateno; int numGroupingSets = Max(aggstate->phase->numsets, 1); int setno = 0; if (numReset < 1) numReset = numGroupingSets; - for (aggno = 0; aggno < aggstate->numaggs; aggno++) + for (stateno = 0; stateno < aggstate->numstates; stateno++) { - AggStatePerAgg peraggstate = &peragg[aggno]; + AggStatePerAggState peraggstate = &peraggstates[stateno]; for (setno = 0; setno < numReset; setno++) { AggStatePerGroup pergroupstate; - pergroupstate = &pergroup[aggno + (setno * (aggstate->numaggs))]; + pergroupstate = &pergroup[stateno + (setno * (aggstate->numstates))]; aggstate->current_set = setno; @@ -610,7 +655,7 @@ initialize_aggregates(AggState *aggstate, /* * Given new input value(s), advance the transition function of one aggregate - * within one grouping set only (already set in aggstate->current_set) + * state within one grouping set only (already set in aggstate->current_set) * * The new values (and null flags) have been preloaded into argument positions * 1 and up in peraggstate->transfn_fcinfo, so that we needn't copy them again @@ -621,7 +666,7 @@ initialize_aggregates(AggState *aggstate, */ static void advance_transition_function(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAggState peraggstate, AggStatePerGroup pergroupstate) { FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; @@ -678,8 +723,8 @@ advance_transition_function(AggState *aggstate, /* We run the transition functions in per-input-tuple memory context */ oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory); - /* set up aggstate->curperagg for AggGetAggref() */ - aggstate->curperagg = peraggstate; + /* set up aggstate->curperaggstate for AggGetAggref() */ + aggstate->curperaggstate = peraggstate; /* * OK to call the transition function @@ -690,7 +735,7 @@ advance_transition_function(AggState *aggstate, newVal = FunctionCallInvoke(fcinfo); - aggstate->curperagg = NULL; + aggstate->curperaggstate = NULL; /* * If pass-by-ref datatype, must copy the new value into aggcontext and @@ -718,7 +763,7 @@ advance_transition_function(AggState *aggstate, } /* - * Advance all the aggregates for one input tuple. The input tuple + * Advance each aggregate state for one input tuple. The input tuple * has been stored in tmpcontext->ecxt_outertuple, so that it is accessible * to ExecEvalExpr. pergroup is the array of per-group structs to use * (this might be in a hashtable entry). @@ -728,14 +773,14 @@ advance_transition_function(AggState *aggstate, static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) { - int aggno; + int stateno; int setno = 0; int numGroupingSets = Max(aggstate->phase->numsets, 1); - int numAggs = aggstate->numaggs; + int numStates = aggstate->numstates; - for (aggno = 0; aggno < numAggs; aggno++) + for (stateno = 0; stateno < numStates; stateno++) { - AggStatePerAgg peraggstate = &aggstate->peragg[aggno]; + AggStatePerAggState peraggstate = &aggstate->peraggstate[stateno]; ExprState *filter = peraggstate->aggrefstate->aggfilter; int numTransInputs = peraggstate->numTransInputs; int i; @@ -806,7 +851,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) for (setno = 0; setno < numGroupingSets; setno++) { - AggStatePerGroup pergroupstate = &pergroup[aggno + (setno * numAggs)]; + AggStatePerGroup pergroupstate = &pergroup[stateno + (setno * numStates)]; aggstate->current_set = setno; @@ -841,7 +886,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) */ static void process_ordered_aggregate_single(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAggState peraggstate, AggStatePerGroup pergroupstate) { Datum oldVal = (Datum) 0; @@ -930,7 +975,7 @@ process_ordered_aggregate_single(AggState *aggstate, */ static void process_ordered_aggregate_multi(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAggState peraggstate, AggStatePerGroup pergroupstate) { MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory; @@ -1009,10 +1054,14 @@ process_ordered_aggregate_multi(AggState *aggstate, * * The finalfunction will be run, and the result delivered, in the * output-tuple context; caller's CurrentMemoryContext does not matter. + * + * The finalfn uses the state as set in the stateno. This also might be + * being used by another aggregate function, so it's important that we do + * nothing destructive here. */ static void finalize_aggregate(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAgg peragg, AggStatePerGroup pergroupstate, Datum *resultVal, bool *resultIsNull) { @@ -1021,6 +1070,7 @@ finalize_aggregate(AggState *aggstate, MemoryContext oldContext; int i; ListCell *lc; + AggStatePerAggState peraggstate = &aggstate->peraggstate[peragg->stateno]; oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory); @@ -1046,14 +1096,14 @@ finalize_aggregate(AggState *aggstate, /* * Apply the agg's finalfn if one is provided, else return transValue. */ - if (OidIsValid(peraggstate->finalfn_oid)) + if (OidIsValid(peragg->finalfn_oid)) { - int numFinalArgs = peraggstate->numFinalArgs; + int numFinalArgs = peragg->numFinalArgs; - /* set up aggstate->curperagg for AggGetAggref() */ - aggstate->curperagg = peraggstate; + /* set up aggstate->curperaggstate for AggGetAggref() */ + aggstate->curperaggstate = peraggstate; - InitFunctionCallInfoData(fcinfo, &peraggstate->finalfn, + InitFunctionCallInfoData(fcinfo, &peragg->finalfn, numFinalArgs, peraggstate->aggCollation, (void *) aggstate, NULL); @@ -1082,7 +1132,7 @@ finalize_aggregate(AggState *aggstate, *resultVal = FunctionCallInvoke(&fcinfo); *resultIsNull = fcinfo.isnull; } - aggstate->curperagg = NULL; + aggstate->curperaggstate = NULL; } else { @@ -1173,7 +1223,7 @@ prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet */ static void finalize_aggregates(AggState *aggstate, - AggStatePerAgg peragg, + AggStatePerAgg peraggs, AggStatePerGroup pergroup, int currentSet) { @@ -1189,10 +1239,12 @@ finalize_aggregates(AggState *aggstate, for (aggno = 0; aggno < aggstate->numaggs; aggno++) { - AggStatePerAgg peraggstate = &peragg[aggno]; + AggStatePerAgg peragg = &peraggs[aggno]; + int stateno = peragg->stateno; + AggStatePerAggState peraggstate = &aggstate->peraggstate[stateno]; AggStatePerGroup pergroupstate; - pergroupstate = &pergroup[aggno + (currentSet * (aggstate->numaggs))]; + pergroupstate = &pergroup[stateno + (currentSet * (aggstate->numstates))]; if (peraggstate->numSortCols > 0) { @@ -1208,7 +1260,7 @@ finalize_aggregates(AggState *aggstate, pergroupstate); } - finalize_aggregate(aggstate, peraggstate, pergroupstate, + finalize_aggregate(aggstate, peragg, pergroupstate, &aggvalues[aggno], &aggnulls[aggno]); } } @@ -1428,7 +1480,7 @@ lookup_hash_entry(AggState *aggstate, TupleTableSlot *inputslot) if (isnew) { /* initialize aggregates for new tuple group */ - initialize_aggregates(aggstate, aggstate->peragg, entry->pergroup, 0); + initialize_aggregates(aggstate, aggstate->peraggstate, entry->pergroup, 0); } return entry; @@ -1505,6 +1557,7 @@ agg_retrieve_direct(AggState *aggstate) ExprContext *econtext; ExprContext *tmpcontext; AggStatePerAgg peragg; + AggStatePerAggState peraggstate; AggStatePerGroup pergroup; TupleTableSlot *outerslot; TupleTableSlot *firstSlot; @@ -1527,6 +1580,7 @@ agg_retrieve_direct(AggState *aggstate) tmpcontext = aggstate->tmpcontext; peragg = aggstate->peragg; + peraggstate = aggstate->peraggstate; pergroup = aggstate->pergroup; firstSlot = aggstate->ss.ss_ScanTupleSlot; @@ -1716,7 +1770,7 @@ agg_retrieve_direct(AggState *aggstate) /* * Initialize working state for a new input tuple group. */ - initialize_aggregates(aggstate, peragg, pergroup, numReset); + initialize_aggregates(aggstate, peraggstate, pergroup, numReset); if (aggstate->grp_firstTuple != NULL) { @@ -1945,10 +1999,12 @@ AggState * ExecInitAgg(Agg *node, EState *estate, int eflags) { AggState *aggstate; - AggStatePerAgg peragg; + AggStatePerAgg peraggs; + AggStatePerAggState peraggstates; Plan *outerPlan; ExprContext *econtext; int numaggs, + stateno, aggno; int phase; ListCell *l; @@ -1971,12 +2027,14 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->aggs = NIL; aggstate->numaggs = 0; + aggstate->numstates = 0; aggstate->maxsets = 0; aggstate->hashfunctions = NULL; aggstate->projected_set = -1; aggstate->current_set = 0; aggstate->peragg = NULL; - aggstate->curperagg = NULL; + aggstate->peraggstate = NULL; + aggstate->curperaggstate = NULL; aggstate->agg_done = false; aggstate->input_done = false; aggstate->pergroup = NULL; @@ -2209,8 +2267,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) econtext->ecxt_aggvalues = (Datum *) palloc0(sizeof(Datum) * numaggs); econtext->ecxt_aggnulls = (bool *) palloc0(sizeof(bool) * numaggs); - peragg = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs); - aggstate->peragg = peragg; + peraggs = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData)* numaggs); + peraggstates = (AggStatePerAggState) palloc0(sizeof(AggStatePerAggStateData) * numaggs); + + aggstate->peragg = peraggs; + aggstate->peraggstate = peraggstates; if (node->aggstrategy == AGG_HASHED) { @@ -2232,18 +2293,17 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) /* * Perform lookups of aggregate function info, and initialize the - * unchanging fields of the per-agg data. We also detect duplicate - * aggregates (for example, "SELECT sum(x) ... HAVING sum(x) > 0"). When - * duplicates are detected, we only make an AggStatePerAgg struct for the - * first one. The clones are simply pointed at the same result entry by - * giving them duplicate aggno values. + * unchanging fields of the per-agg data. */ aggno = -1; + stateno = -1; foreach(l, aggstate->aggs) { AggrefExprState *aggrefstate = (AggrefExprState *) lfirst(l); Aggref *aggref = (Aggref *) aggrefstate->xprstate.expr; - AggStatePerAgg peraggstate; + AggStatePerAgg peragg; + AggStatePerAggState peraggstate; + AggRefCompatibility agg_match; Oid inputTypes[FUNC_MAX_ARGS]; int numArguments; int numDirectArgs; @@ -2260,40 +2320,82 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) Expr *transfnexpr, *finalfnexpr; Datum textInitVal; - int i; + int existing_aggno; ListCell *lc; /* Planner should have assigned aggregate to correct level */ Assert(aggref->agglevelsup == 0); - /* Look for a previous duplicate aggregate */ - for (i = 0; i <= aggno; i++) + /* + * For performance reasons we detect duplicate aggregates (for example, + * "SELECT sum(x) ... HAVING sum(x) > 0"). When duplicates are + * detected, we only make an AggStatePerAgg struct for the first one. + * The clones are simply pointed at the same result entry by giving + * them duplicate aggno values. We also do our best to reuse duplicate + * aggregate states. The query may use 2 or more aggregate functions + * which share the same transition function and initial value therefore + * would end up calculating the same state. In this case we can just + * calculate the state once, however if the finalfns do not match then + * we must create a new peragg to store the varying finalfn. + */ + + /* check if we have previous agg or state matches that can be reused */ + agg_match = find_compatible_aggref(aggref, aggstate, aggno, + &existing_aggno); + + if (agg_match == AGGREF_EXACT_MATCH) { - if (equal(aggref, peragg[i].aggref) && - !contain_volatile_functions((Node *) aggref)) - break; + /* Exact match -- this must be using same aggregate function or + * have the same transfn and finalfn. Just reuse the existing agg. + */ + aggrefstate->aggno = existing_aggno; + continue; } - if (i <= aggno) + + else if (agg_match == AGGREF_STATE_MATCH) { - /* Found a match to an existing entry, so just mark it */ - aggrefstate->aggno = i; - continue; + /* + * State only match. The state can be reused, but the finalfn are + * different. We'll need to create a new peragg for the new finalfn + */ + int existing_stateno = peraggs[existing_aggno].stateno; + peragg = &peraggs[++aggno]; + peraggstate = &peraggstates[existing_stateno]; + peragg->stateno = existing_stateno; + } + else /* AGGREF_NO_MATCH */ + { + /* Nothing matches, so assign a new state and a new per agg */ + peraggstate = &peraggstates[++stateno]; + peragg = &peraggs[++aggno]; + peragg->stateno = stateno; } - /* Nope, so assign a new PerAgg record */ - peraggstate = &peragg[++aggno]; + /* + * When we pass through the following code in a AGGREF_STATE_MATCH + * type match, the peraggstate will already have been setup by a + * previous iteration of the loop, so we'll try where possible to + * minimize as much rework of setting up the peraggstate as possible. + * In reality it shouldn't matter as we'll just be setting it up the + * same as it was previously, but for performance reasons we do skip + * over some more expensive parts the 2nd time around. + * - /* Mark Aggref state node with assigned index in the result array */ + /* Mark Aggref state node with the index of which agg it should use */ aggrefstate->aggno = aggno; - /* Begin filling in the peraggstate data */ - peraggstate->aggrefstate = aggrefstate; - peraggstate->aggref = aggref; - peraggstate->sortstates = (Tuplesortstate **) - palloc0(sizeof(Tuplesortstate *) * numGroupingSets); - - for (currentsortno = 0; currentsortno < numGroupingSets; currentsortno++) - peraggstate->sortstates[currentsortno] = NULL; + /* for state matches the peraggstate has already been setup */ + if (agg_match == AGGREF_NO_MATCH) + { + /* Begin filling in the peraggstate data */ + peraggstate->aggrefstate = aggrefstate; + peraggstate->aggref = aggref; + peraggstate->sortstates = (Tuplesortstate **) + palloc0(sizeof(Tuplesortstate *)* numGroupingSets); + + for (currentsortno = 0; currentsortno < numGroupingSets; currentsortno++) + peraggstate->sortstates[currentsortno] = NULL; + } /* Fetch the pg_aggregate row */ aggTuple = SearchSysCache1(AGGFNOID, @@ -2311,8 +2413,12 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) get_func_name(aggref->aggfnoid)); InvokeFunctionExecuteHook(aggref->aggfnoid); + /* when reusing the state the transfns should match! */ + Assert(agg_match == AGGREF_NO_MATCH || + peraggstate->transfn_oid == aggform->aggtransfn); + peraggstate->transfn_oid = transfn_oid = aggform->aggtransfn; - peraggstate->finalfn_oid = finalfn_oid = aggform->aggfinalfn; + peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn; /* Check that aggregate owner has permission to call component fns */ { @@ -2327,12 +2433,20 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggOwner = ((Form_pg_proc) GETSTRUCT(procTuple))->proowner; ReleaseSysCache(procTuple); - aclresult = pg_proc_aclcheck(transfn_oid, aggOwner, - ACL_EXECUTE); - if (aclresult != ACLCHECK_OK) - aclcheck_error(aclresult, ACL_KIND_PROC, - get_func_name(transfn_oid)); - InvokeFunctionExecuteHook(transfn_oid); + /* + * If we're reusing an existing state then the permissions for + * transfn were already checked when we setup that state. + */ + if (agg_match == AGGREF_NO_MATCH) + { + aclresult = pg_proc_aclcheck(transfn_oid, aggOwner, + ACL_EXECUTE); + if (aclresult != ACLCHECK_OK) + aclcheck_error(aclresult, ACL_KIND_PROC, + get_func_name(transfn_oid)); + InvokeFunctionExecuteHook(transfn_oid); + } + if (OidIsValid(finalfn_oid)) { aclresult = pg_proc_aclcheck(finalfn_oid, aggOwner, @@ -2367,9 +2481,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) /* Detect how many arguments to pass to the finalfn */ if (aggform->aggfinalextra) - peraggstate->numFinalArgs = numArguments + 1; + peragg->numFinalArgs = numArguments + 1; else - peraggstate->numFinalArgs = numDirectArgs + 1; + peragg->numFinalArgs = numDirectArgs + 1; /* resolve actual type of transition state, if polymorphic */ aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid, @@ -2377,32 +2491,62 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) inputTypes, numArguments); - /* build expression trees using actual argument & result types */ - build_aggregate_fnexprs(inputTypes, - numArguments, - numDirectArgs, - peraggstate->numFinalArgs, - aggref->aggvariadic, - aggtranstype, - aggref->aggtype, - aggref->inputcollid, - transfn_oid, - InvalidOid, /* invtrans is not needed here */ - finalfn_oid, - &transfnexpr, - NULL, - &finalfnexpr); - - /* set up infrastructure for calling the transfn and finalfn */ - fmgr_info(transfn_oid, &peraggstate->transfn); - fmgr_info_set_expr((Node *) transfnexpr, &peraggstate->transfn); + if (agg_match == AGGREF_NO_MATCH) + { + /* build expression trees using actual argument & result types */ + build_aggregate_fnexprs(inputTypes, + numArguments, + numDirectArgs, + peragg->numFinalArgs, + aggref->aggvariadic, + aggtranstype, + aggref->aggtype, + aggref->inputcollid, + transfn_oid, + InvalidOid, /* invtrans is not needed here */ + finalfn_oid, + &transfnexpr, + NULL, + &finalfnexpr); + + /* set up infrastructure for calling the transfn and finalfn */ + fmgr_info(transfn_oid, &peraggstate->transfn); + fmgr_info_set_expr((Node *) transfnexpr, &peraggstate->transfn); + } + else if (OidIsValid(finalfn_oid)) + { + /* + * AGGREF_STATE_MATCH -- transfn calling infrastructure already + * built for this state + */ + build_aggregate_fnexprs(inputTypes, + numArguments, + numDirectArgs, + peragg->numFinalArgs, + aggref->aggvariadic, + aggtranstype, + aggref->aggtype, + aggref->inputcollid, + transfn_oid, + InvalidOid, /* invtrans is not needed here */ + finalfn_oid, + NULL, /* transfn already done */ + NULL, + &finalfnexpr); + } if (OidIsValid(finalfn_oid)) { - fmgr_info(finalfn_oid, &peraggstate->finalfn); - fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn); + fmgr_info(finalfn_oid, &peragg->finalfn); + fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn); } + /* if it's a state match then everything else has already been done */ + if (agg_match != AGGREF_NO_MATCH) + { + ReleaseSysCache(aggTuple); + continue; + } peraggstate->aggCollation = aggref->inputcollid; InitFunctionCallInfoData(peraggstate->transfn_fcinfo, @@ -2574,8 +2718,12 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ReleaseSysCache(aggTuple); } - /* Update numaggs to match number of unique aggregates found */ + /* + * Update numaggs to match the number of unique aggregates found. + * Also set numstates to the number of unique aggregate states found. + */ aggstate->numaggs = aggno + 1; + aggstate->numstates = stateno + 1; return aggstate; } @@ -2596,11 +2744,195 @@ GetAggInitVal(Datum textInitVal, Oid transtype) return initVal; } +/* + * find_compatible_aggref + * Searches the previously looked at aggregates in order to find a + * compatible aggregate or aggregate state. If a positive match is found + * then foundaggno is set to the aggregate which matches. + * When AGGREF_STATE_MATCH is returned the caller must only use the state + * of the foundaggno, not the actual aggno itself. + * When AGGREF_EXACT_MATCH is returned the caller may use both the aggno + * and the state which that aggno uses. + * + * Scenario 1 -- An aggregate function appears more than once in query: + * + * SELECT SUM(x) FROM ... HAVING SUM(x) > 0 + * + * Since in this case the aggregates are both the same we can optimize by + * only calculating aggregate state and calling the finalfn just once. This + * would be an AGGREF_EXACT_MATCH, meaning both the state and the final + * function call are shared. + * + * Scenario 2 -- Two different aggregate functions appear in the query but + * the two functions happen to share the same transfn, but have + * different finalfn. + * + * SELECT SUM(x), AVG(x) FROM ... + * + * Since in our case these two aggregates both share the same transfn, but + * naturally they have different finalfns. This situation is classed as an + * AGGREF_STATE_MATCH. This means that the same state can be shared by both + * aggregates. Since the finalfn call is not the same this cannot be reused. + * For this case to be valid the INITCOND of the aggregate, if one exists, must + * also match. + * + * Scenario 3 -- The same aggregate function is called with different + * parameters. + * + * SELECT SUM(x),SUM(DISTINCT x) FROM ... + * SELECT SUM(x),SUM(y) FROM ... + * SELECT SUM(x),SUM(x) FILTER(WHERE x > 0) FROM ... + * + * All three of the above queries cannot share the same state and have to be + * calculated independently. + * + * Scenario 4 -- Different aggregates with the same parameters and the same + * transfn and finalfn. + * + * SELECT SUM(x),SUM2(x) FROM ... + * + * A perhaps unlikely scenario where two aggregate functions exist which have, + * both the same transfn and the same finalfn. In this case we can report an + * AGGREF_EXACT_MATCH, providing the INITCOND of both aggregates are the same. + */ +static AggRefCompatibility +find_compatible_aggref(Aggref *newagg, AggState *aggstate, + int lastaggno, int *foundaggno) +{ + int aggno; + int statematchaggno; + AggStatePerAggState peraggstates; + AggStatePerAgg peraggs; + + /* we mustn't reuse the aggref if it contains volatile function calls */ + if (contain_volatile_functions((Node *)newagg)) + return AGGREF_NO_MATCH; + + statematchaggno = -1; + peraggstates = aggstate->peraggstate; + peraggs = aggstate->peragg; + + /* + * Search through the list of already seen aggregates. We'll stop when we + * find an exact match, but until then we'll note any state matches that + * we find. We may have to fall back on these should we fail to find an + * exact match. + */ + for (aggno = 0; aggno <= lastaggno; aggno++) + { + AggRefCompatibility matchtype; + AggStatePerAgg peragg; + AggStatePerAggState peraggstate; + + peragg = &peraggs[aggno]; + peraggstate = &peraggstates[peragg->stateno]; + + /* lookup the match type of this agg */ + matchtype = aggref_has_compatible_states(newagg, peragg, peraggstate); + + /* if it's an exact match then we're done. */ + if (matchtype == AGGREF_EXACT_MATCH) + { + *foundaggno = aggno; + return AGGREF_EXACT_MATCH; + } + + /* remember any state matches, but keep on looking... */ + else if (matchtype == AGGREF_STATE_MATCH) + statematchaggno = aggno; + } + + /* no exact match found, but did we find a state match? */ + if (statematchaggno >= 0) + { + *foundaggno = statematchaggno; + return AGGREF_STATE_MATCH; + } + + return AGGREF_NO_MATCH; +} + +/* + * aggref_has_compatible_states + * Determines match type of this aggregate. See comments in + * find_compatible_aggref() for details. + */ +static AggRefCompatibility +aggref_has_compatible_states(Aggref *newagg, AggStatePerAgg peragg, + AggStatePerAggState peraggstate) +{ + Aggref *existingRef = peraggstate->aggref; + + /* all of the following must be the same or it's no match */ + if (newagg->aggtype != existingRef->aggtype || + newagg->aggcollid != existingRef->aggcollid || + newagg->inputcollid != existingRef->inputcollid || + newagg->aggstar != existingRef->aggstar || + newagg->aggvariadic != existingRef->aggvariadic || + newagg->aggkind != existingRef->aggkind || + !equal(newagg->aggdirectargs, existingRef->aggdirectargs) || + !equal(newagg->args, existingRef->args) || + !equal(newagg->aggorder, existingRef->aggorder) || + !equal(newagg->aggdistinct, existingRef->aggdistinct) || + !equal(newagg->aggfilter, existingRef->aggfilter)) + return AGGREF_NO_MATCH; + + /* if it's the same aggregate function then report exact match */ + if (newagg->aggfnoid == existingRef->aggfnoid) + return AGGREF_EXACT_MATCH; + else + { + /* + * Aggregate functions differ. We'll need to do some more analysis + * before we can know what the match type will be. + * If the transfn match and the initvalue is the same then we can at + * least let the newagg share the state, but if the finalfn also + * happens to match then we can actually still report an exact match. + */ + + HeapTuple aggTuple; + Form_pg_aggregate aggform; + bool initValueIsNull; + + /* Fetch the pg_aggregate row */ + aggTuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(newagg->aggfnoid)); + if (!HeapTupleIsValid(aggTuple)) + elog(ERROR, "cache lookup failed for aggregate %u", newagg->aggfnoid); + aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); + ReleaseSysCache(aggTuple); + + /* if the transfns are not the same then the state can't be shared */ + if (aggform->aggtransfn != peraggstate->transfn_oid) + return AGGREF_NO_MATCH; + + SysCacheGetAttr(AGGFNOID, aggTuple, + Anum_pg_aggregate_agginitval, &initValueIsNull); + + /* + * If both INITCONDs are null then the outcome depends + * on if the finalfns match. + */ + if (initValueIsNull && peraggstate->initValueIsNull) + { + if (aggform->aggfinalfn != peragg->finalfn_oid) + return AGGREF_STATE_MATCH; + else + return AGGREF_EXACT_MATCH; + } + + /* + * XXX perhaps we should check the value of the initValue to see if + * they match? + */ + return AGGREF_NO_MATCH; + } +} + void ExecEndAgg(AggState *node) { PlanState *outerPlan; - int aggno; + int stateno; int numGroupingSets = Max(node->maxsets, 1); int setno; @@ -2611,9 +2943,9 @@ ExecEndAgg(AggState *node) if (node->sort_out) tuplesort_end(node->sort_out); - for (aggno = 0; aggno < node->numaggs; aggno++) + for (stateno = 0; stateno < node->numstates; stateno++) { - AggStatePerAgg peraggstate = &node->peragg[aggno]; + AggStatePerAggState peraggstate = &node->peraggstate[stateno]; for (setno = 0; setno < numGroupingSets; setno++) { @@ -2646,7 +2978,7 @@ ExecReScanAgg(AggState *node) ExprContext *econtext = node->ss.ps.ps_ExprContext; PlanState *outerPlan = outerPlanState(node); Agg *aggnode = (Agg *) node->ss.ps.plan; - int aggno; + int stateno; int numGroupingSets = Max(node->maxsets, 1); int setno; @@ -2678,11 +3010,11 @@ ExecReScanAgg(AggState *node) } /* Make sure we have closed any open tuplesorts */ - for (aggno = 0; aggno < node->numaggs; aggno++) + for (stateno = 0; stateno < node->numstates; stateno++) { for (setno = 0; setno < numGroupingSets; setno++) { - AggStatePerAgg peraggstate = &node->peragg[aggno]; + AggStatePerAggState peraggstate = &node->peraggstate[stateno]; if (peraggstate->sortstates[setno]) { @@ -2811,10 +3143,12 @@ AggGetAggref(FunctionCallInfo fcinfo) { if (fcinfo->context && IsA(fcinfo->context, AggState)) { - AggStatePerAgg curperagg = ((AggState *) fcinfo->context)->curperagg; + AggStatePerAggState curperaggstate; + + curperaggstate = ((AggState *)fcinfo->context)->curperaggstate; - if (curperagg) - return curperagg->aggref; + if (curperaggstate) + return curperaggstate->aggref; } return NULL; } diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index 478d8ca..123cccb 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -1863,42 +1863,45 @@ build_aggregate_fnexprs(Oid *agg_input_types, FuncExpr *fexpr; int i; - /* - * Build arg list to use in the transfn FuncExpr node. We really only care - * that transfn can discover the actual argument types at runtime using - * get_fn_expr_argtype(), so it's okay to use Param nodes that don't - * correspond to any real Param. - */ - argp = makeNode(Param); - argp->paramkind = PARAM_EXEC; - argp->paramid = -1; - argp->paramtype = agg_state_type; - argp->paramtypmod = -1; - argp->paramcollid = agg_input_collation; - argp->location = -1; - - args = list_make1(argp); - - for (i = agg_num_direct_inputs; i < agg_num_inputs; i++) + if (transfnexpr != NULL) { + /* + * Build arg list to use in the transfn FuncExpr node. We really only care + * that transfn can discover the actual argument types at runtime using + * get_fn_expr_argtype(), so it's okay to use Param nodes that don't + * correspond to any real Param. + */ argp = makeNode(Param); argp->paramkind = PARAM_EXEC; argp->paramid = -1; - argp->paramtype = agg_input_types[i]; + argp->paramtype = agg_state_type; argp->paramtypmod = -1; argp->paramcollid = agg_input_collation; argp->location = -1; - args = lappend(args, argp); - } - fexpr = makeFuncExpr(transfn_oid, - agg_state_type, - args, - InvalidOid, - agg_input_collation, - COERCE_EXPLICIT_CALL); - fexpr->funcvariadic = agg_variadic; - *transfnexpr = (Expr *) fexpr; + args = list_make1(argp); + + for (i = agg_num_direct_inputs; i < agg_num_inputs; i++) + { + argp = makeNode(Param); + argp->paramkind = PARAM_EXEC; + argp->paramid = -1; + argp->paramtype = agg_input_types[i]; + argp->paramtypmod = -1; + argp->paramcollid = agg_input_collation; + argp->location = -1; + args = lappend(args, argp); + } + + fexpr = makeFuncExpr(transfn_oid, + agg_state_type, + args, + InvalidOid, + agg_input_collation, + COERCE_EXPLICIT_CALL); + fexpr->funcvariadic = agg_variadic; + *transfnexpr = (Expr *) fexpr; + } /* * Build invtransfn expression if requested, with same args as transfn diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h index db5bd7f..af03214 100644 --- a/src/include/nodes/execnodes.h +++ b/src/include/nodes/execnodes.h @@ -1815,6 +1815,7 @@ typedef struct GroupState */ /* these structs are private in nodeAgg.c: */ typedef struct AggStatePerAggData *AggStatePerAgg; +typedef struct AggStatePerAggStateData *AggStatePerAggState; typedef struct AggStatePerGroupData *AggStatePerGroup; typedef struct AggStatePerPhaseData *AggStatePerPhase; @@ -1823,14 +1824,16 @@ typedef struct AggState ScanState ss; /* its first field is NodeTag */ List *aggs; /* all Aggref nodes in targetlist & quals */ int numaggs; /* length of list (could be zero!) */ + int numstates; /* number of peraggstate items */ AggStatePerPhase phase; /* pointer to current phase data */ int numphases; /* number of phases */ int current_phase; /* current phase number */ FmgrInfo *hashfunctions; /* per-grouping-field hash fns */ AggStatePerAgg peragg; /* per-Aggref information */ + AggStatePerAggState peraggstate; /* per-Agg State information */ ExprContext **aggcontexts; /* econtexts for long-lived data (per GS) */ ExprContext *tmpcontext; /* econtext for input expressions */ - AggStatePerAgg curperagg; /* identifies currently active aggregate */ + AggStatePerAggState curperaggstate; /* identifies currently active aggregate */ bool input_done; /* indicates end of input */ bool agg_done; /* indicates completion of Agg scan */ int projected_set; /* The last projected grouping set */ diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out index 8852051..4dad4fe 100644 --- a/src/test/regress/expected/aggregates.out +++ b/src/test/regress/expected/aggregates.out @@ -1580,3 +1580,171 @@ select least_agg(variadic array[q1,q2]) from int8_tbl; -4567890123456789 (1 row) +-- test aggregates with common transition functions share the same states +begin work; +create type avg_state as (total bigint, count bigint); +create or replace function avg_transfn(state avg_state, n int) returns avg_state as +$$ +declare new_state avg_state; +begin + raise notice 'avg_transfn called with %', n; + if state is null then + if n is not null then + new_state.total := n; + new_state.count := 1; + return new_state; + end if; + return null; + elsif n is not null then + state.total := state.total + n; + state.count := state.count + 1; + return state; + end if; + + return null; +end +$$ language plpgsql; +create function avg_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total / state.count; + end if; +end +$$ language plpgsql; +create function sum_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total; + end if; +end +$$ language plpgsql; +create aggregate my_avg(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn +); +create aggregate my_sum(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn +); +-- aggregate state should be shared as transfn is the same for both aggs. +select my_avg(one),my_sum(one) from (values(1,2),(3,4)) t(one,two); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 2 | 4 +(1 row) + +-- shouldn't share states due to the distinctness not matching. +select my_avg(distinct one),my_sum(one) from (values(1,2),(3,4)) t(one,two); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 2 | 4 +(1 row) + +-- this should not share the state due to different input columns. +select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two); +NOTICE: avg_transfn called with 2 +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 4 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 2 | 6 +(1 row) + +create aggregate my_sum_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn, + initcond = '(10,0)' +); +create aggregate my_avg_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(5,0)' +); +-- Varying INITCONDs should cause the states not to be shared. +select my_avg_init(one),my_sum_init(one) from (values(1,2),(3,4)) t(one,two); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 +NOTICE: avg_transfn called with 3 + my_avg_init | my_sum_init +-------------+------------- + 4 | 14 +(1 row) + +rollback; +-- test aggregate state sharing to ensure it works if one aggregate has a +-- finalfn and the other one has none. +begin work; +create or replace function sum_transfn(state int4, n int4) returns int4 as +$$ +declare new_state int4; +begin + raise notice 'sum_transfn called with %', n; + if state is null then + if n is not null then + new_state := n; + return new_state; + end if; + return null; + elsif n is not null then + state := state + n; + return state; + end if; + + return null; +end +$$ language plpgsql; +create function halfsum_finalfn(state int4) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state / 2; + end if; +end +$$ language plpgsql; +create aggregate my_sum(int4) +( + stype = int4, + sfunc = sum_transfn +); +create aggregate my_half_sum(int4) +( + stype = int4, + sfunc = sum_transfn, + finalfunc = halfsum_finalfn +); +-- Agg state should be shared even though my_sum has no finalfn +select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one); +NOTICE: sum_transfn called with 1 +NOTICE: sum_transfn called with 2 +NOTICE: sum_transfn called with 3 +NOTICE: sum_transfn called with 4 + my_sum | my_half_sum +--------+------------- + 10 | 5 +(1 row) + +rollback; diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql index a84327d..42c3b3c 100644 --- a/src/test/regress/sql/aggregates.sql +++ b/src/test/regress/sql/aggregates.sql @@ -590,3 +590,151 @@ drop view aggordview1; -- variadic aggregates select least_agg(q1,q2) from int8_tbl; select least_agg(variadic array[q1,q2]) from int8_tbl; + + +-- test aggregates with common transition functions share the same states +begin work; + +create type avg_state as (total bigint, count bigint); + +create or replace function avg_transfn(state avg_state, n int) returns avg_state as +$$ +declare new_state avg_state; +begin + raise notice 'avg_transfn called with %', n; + if state is null then + if n is not null then + new_state.total := n; + new_state.count := 1; + return new_state; + end if; + return null; + elsif n is not null then + state.total := state.total + n; + state.count := state.count + 1; + return state; + end if; + + return null; +end +$$ language plpgsql; + +create function avg_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total / state.count; + end if; +end +$$ language plpgsql; + +create function sum_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total; + end if; +end +$$ language plpgsql; + +create aggregate my_avg(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn +); + +create aggregate my_sum(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn +); + +-- aggregate state should be shared as transfn is the same for both aggs. +select my_avg(one),my_sum(one) from (values(1,2),(3,4)) t(one,two); + +-- shouldn't share states due to the distinctness not matching. +select my_avg(distinct one),my_sum(one) from (values(1,2),(3,4)) t(one,two); + +-- this should not share the state due to different input columns. +select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two); + + +create aggregate my_sum_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn, + initcond = '(10,0)' +); + +create aggregate my_avg_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(5,0)' +); + +-- Varying INITCONDs should cause the states not to be shared. +select my_avg_init(one),my_sum_init(one) from (values(1,2),(3,4)) t(one,two); + +rollback; + +-- test aggregate state sharing to ensure it works if one aggregate has a +-- finalfn and the other one has none. +begin work; + +create or replace function sum_transfn(state int4, n int4) returns int4 as +$$ +declare new_state int4; +begin + raise notice 'sum_transfn called with %', n; + if state is null then + if n is not null then + new_state := n; + return new_state; + end if; + return null; + elsif n is not null then + state := state + n; + return state; + end if; + + return null; +end +$$ language plpgsql; + +create function halfsum_finalfn(state int4) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state / 2; + end if; +end +$$ language plpgsql; + +create aggregate my_sum(int4) +( + stype = int4, + sfunc = sum_transfn +); + +create aggregate my_half_sum(int4) +( + stype = int4, + sfunc = sum_transfn, + finalfunc = halfsum_finalfn +); + +-- Agg state should be shared even though my_sum has no finalfn +select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one); + +rollback;