diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index d01fc4f52e..b061162961 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -2522,8 +2522,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) int existing_aggno; int existing_transno; List *same_input_transnos; - Oid inputTypes[FUNC_MAX_ARGS]; + Oid transFnInputTypes[FUNC_MAX_ARGS]; int numArguments; + int numTransFnArgs; int numDirectArgs; HeapTuple aggTuple; Form_pg_aggregate aggform; @@ -2701,14 +2702,23 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) * could be different from the agg's declared input types, when the * agg accepts ANY or a polymorphic type. */ - numArguments = get_aggregate_argtypes(aggref, inputTypes); + numTransFnArgs = get_aggregate_argtypes(aggref, transFnInputTypes); /* Count the "direct" arguments, if any */ numDirectArgs = list_length(aggref->aggdirectargs); + /* + * Combine functions always have a 2 trans state type input params, so + * this is always set to 1 (we don't count the first trans state). + */ + if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit)) + numArguments = 1; + else + numArguments = numTransFnArgs; + /* Detect how many arguments to pass to the finalfn */ if (aggform->aggfinalextra) - peragg->numFinalArgs = numArguments + 1; + peragg->numFinalArgs = numTransFnArgs + 1; else peragg->numFinalArgs = numDirectArgs + 1; @@ -2722,7 +2732,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) */ if (OidIsValid(finalfn_oid)) { - build_aggregate_finalfn_expr(inputTypes, + build_aggregate_finalfn_expr(transFnInputTypes, peragg->numFinalArgs, aggtranstype, aggref->aggtype, @@ -2781,7 +2791,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggref, transfn_oid, aggtranstype, serialfn_oid, deserialfn_oid, initValue, initValueIsNull, - inputTypes, numArguments); + transFnInputTypes, numArguments); peragg->transno = transno; } ReleaseSysCache(aggTuple); @@ -2872,6 +2882,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) * to initialize the state for. 'aggtransfn', 'aggtranstype', and the rest * of the arguments could be calculated from 'aggref', but the caller has * calculated them already, so might as well pass them. + * When performing DO_AGGSPLIT_COMBINE, aggtranfn really is the Oid of the + * aggcombinefn. + *'transFnInputTypes' must always be the input types from the true aggtransfn. + * 'numArguments' is the number of arguments (minus the state type from the + * given 'aggtransfn'. */ static void build_pertrans_for_aggref(AggStatePerTrans pertrans, @@ -2880,7 +2895,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, Oid aggtransfn, Oid aggtranstype, Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, - Oid *inputTypes, int numArguments) + Oid *transFnInputTypes, int numArguments) { int numGroupingSets = Max(aggstate->maxsets, 1); Expr *serialfnexpr = NULL; @@ -2911,12 +2926,6 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, pertrans->aggtranstype = aggtranstype; - /* Detect how many arguments to pass to the transfn */ - if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) - pertrans->numTransInputs = numInputs; - else - pertrans->numTransInputs = numArguments; - /* * When combining states, we have no use at all for the aggregate * function's transfn. Instead we use the combinefn. In this case, the @@ -2927,6 +2936,14 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, { Expr *combinefnexpr; + pertrans->numTransInputs = numArguments; + + /* + * combinefn should always just have two trans type args (we don't + * count the initial one here). + */ + Assert(numArguments == 1); + build_aggregate_combinefn_expr(aggtranstype, aggref->inputcollid, aggtransfn, @@ -2956,13 +2973,21 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, else { Expr *transfnexpr; - size_t numInputs = pertrans->numTransInputs + 1; + size_t numTransInputs; + + /* Detect how many arguments to pass to the transfn */ + if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) + pertrans->numTransInputs = numInputs; + else + pertrans->numTransInputs = numArguments; + + numTransInputs = pertrans->numTransInputs + 1; /* * Set up infrastructure for calling the transfn. Note that invtrans * is not needed here. */ - build_aggregate_transfn_expr(inputTypes, + build_aggregate_transfn_expr(transFnInputTypes, numArguments, numDirectArgs, aggref->aggvariadic, @@ -2976,10 +3001,10 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn); pertrans->transfn_fcinfo = - (FunctionCallInfo) palloc(SizeForFunctionCallInfo(numInputs)); + (FunctionCallInfo) palloc(SizeForFunctionCallInfo(numTransInputs)); InitFunctionCallInfoData(*pertrans->transfn_fcinfo, &pertrans->transfn, - numInputs, + numTransInputs, pertrans->aggCollation, (void *) aggstate, NULL); @@ -2994,7 +3019,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, if (pertrans->transfn.fn_strict && pertrans->initValueIsNull) { if (numArguments <= numDirectArgs || - !IsBinaryCoercible(inputTypes[numDirectArgs], + !IsBinaryCoercible(transFnInputTypes[numDirectArgs], aggtranstype)) ereport(ERROR, (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), @@ -3088,14 +3113,16 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, { /* * We don't implement DISTINCT or ORDER BY aggs in the HASHED case - * (yet) + * (yet). We also don't support partial aggregation in this case + * either. */ Assert(aggstate->aggstrategy != AGG_HASHED && aggstate->aggstrategy != AGG_MIXED); + Assert(!DO_AGGSPLIT_COMBINE(aggstate->aggsplit)); /* If we have only one input, we need its len/byval info. */ if (numInputs == 1) { - get_typlenbyval(inputTypes[numDirectArgs], + get_typlenbyval(transFnInputTypes[numDirectArgs], &pertrans->inputtypeLen, &pertrans->inputtypeByVal); } diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out index 129c1e5075..572f7f62b0 100644 --- a/src/test/regress/expected/aggregates.out +++ b/src/test/regress/expected/aggregates.out @@ -2204,8 +2204,9 @@ SET max_parallel_workers_per_gather = 4; SET enable_indexonlyscan = off; -- variance(int4) covers numeric_poly_combine -- sum(int8) covers int8_avg_combine +-- regr_cocunt(float8, float8) covers int8inc_float8_float8 and aggregates with > 1 arg EXPLAIN (COSTS OFF) - SELECT variance(unique1::int4), sum(unique1::int8) FROM tenk1; + SELECT variance(unique1::int4), sum(unique1::int8),regr_count(unique1::float8, unique1::float8) FROM tenk1; QUERY PLAN ---------------------------------------------- Finalize Aggregate @@ -2215,10 +2216,10 @@ EXPLAIN (COSTS OFF) -> Parallel Seq Scan on tenk1 (5 rows) -SELECT variance(unique1::int4), sum(unique1::int8) FROM tenk1; - variance | sum -----------------------+---------- - 8334166.666666666667 | 49995000 +SELECT variance(unique1::int4), sum(unique1::int8),regr_count(unique1::float8, unique1::float8) FROM tenk1; + variance | sum | regr_count +----------------------+----------+------------ + 8334166.666666666667 | 49995000 | 10000 (1 row) ROLLBACK; diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql index d4fd657188..bd8b9e8b4f 100644 --- a/src/test/regress/sql/aggregates.sql +++ b/src/test/regress/sql/aggregates.sql @@ -963,10 +963,11 @@ SET enable_indexonlyscan = off; -- variance(int4) covers numeric_poly_combine -- sum(int8) covers int8_avg_combine +-- regr_cocunt(float8, float8) covers int8inc_float8_float8 and aggregates with > 1 arg EXPLAIN (COSTS OFF) - SELECT variance(unique1::int4), sum(unique1::int8) FROM tenk1; + SELECT variance(unique1::int4), sum(unique1::int8),regr_count(unique1::float8, unique1::float8) FROM tenk1; -SELECT variance(unique1::int4), sum(unique1::int8) FROM tenk1; +SELECT variance(unique1::int4), sum(unique1::int8),regr_count(unique1::float8, unique1::float8) FROM tenk1; ROLLBACK;