From 995ed2ad31a24cb36e20beae2aa36d3e58fc6298 Mon Sep 17 00:00:00 2001 From: Joel Jakobsson Date: Sun, 7 Jul 2024 19:21:35 +0200 Subject: [PATCH] Optimize mul_var() for var1ndigits >= 8 The idea is to reduce the "n" in O(n^2) by a factor of two. This is achieved by first converting the (ndigits) number of int16 NBASE digits, to (ndigits/2) number of int32 NBASE^2 digits, as well as upgrading the int32 variables to int64-variables so that the products and carry values fit. The existing multiplication algorithm is then executed without change. Finally, the int32 NBASE^2 result digits are converted back to twice the number of int16 NBASE digits. This adds overhead of approximately 4 * O(n), due to the conversion. Benchmark indicates it's a win when var1 is at least 8 ndigits. --- src/backend/utils/adt/numeric.c | 243 ++++++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c index 5510a203b0..ddfc71feda 100644 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -101,6 +101,8 @@ typedef signed char NumericDigit; typedef int16 NumericDigit; #endif +#define SQUARE_NBASE (NBASE * NBASE) + /* * The Numeric type as stored on disk. * @@ -551,6 +553,8 @@ static void sub_var(const NumericVar *var1, const NumericVar *var2, static void mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale); +static void mul_var_large(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, int rscale); static void div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale, bool round); @@ -8715,6 +8719,16 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, return; } + /* + * If var1 has at least 8 digits, delegate to mul_var_large() + * which uses a multiplication algorithm faster for large multiplicands. + */ + if (var1ndigits >= 8) + { + mul_var_large(var1, var2, result, rscale); + return; + } + /* Determine result sign and (maximum possible) weight */ if (var1->sign == var2->sign) res_sign = NUMERIC_POS; @@ -8864,6 +8878,235 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, strip_var(result); } +/* + * mul_var_large() - + * + * Special-case multiplication function used when var1 has at least 8 digits, + * that reduces the "n" in O(n^2) by a factor of two. + * + * This is achieved by first converting the (ndigits) number of int16 NBASE + * digits, to (ndigits/2) number of int32 NBASE^2 digits, as well as upgrading + * the int32 variables to int64-variables so that the products and carry + * values fit. + * + * The existing multiplication algorithm is then executed without change. + * + * Finally, the int32 NBASE^2 result digits are converted back to twice + * the number of int16 NBASE digits. + * + * This adds overhead of approximately 4 * O(n), due to the conversion, + * which seems to be a win when var1 has at least 8 digits. + */ +static void +mul_var_large(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, int rscale) +{ + int res_ndigits; + int res_sign; + int res_weight; + int maxdigits; + int64 *dig; + int64 carry; + int64 maxdig; + int64 newdig; + int var1ndigits = (var1->ndigits + 1) / 2; + int var2ndigits = (var2->ndigits + 1) / 2; + int64 *var1digits; + int64 *var2digits; + int *res_digits; + int i, + i1, + i2; + + /* Check preconditions */ + Assert(var1->ndigits >= 8); + Assert(var2->ndigits >= var1->ndigits); + + /* Determine result sign */ + if (var1->sign == var2->sign) + res_sign = NUMERIC_POS; + else + res_sign = NUMERIC_NEG; + + /* + * Determine the number of result digits to compute. If the exact result + * would have more than rscale fractional digits, truncate the computation + * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that + * would only contribute to the right of that. (This will give the exact + * rounded-to-rscale answer unless carries out of the ignored positions + * would have propagated through more than MUL_GUARD_DIGITS digits.) + * + * Additionally, determine the (maximum possible) weight of the result, + * considering the base conversion and the ceiling division by 2 + * of the number of digits. + * + * Note: an exact computation could not produce more than var1ndigits + + * var2ndigits digits, but we allocate one extra output digit in case + * rscale-driven rounding produces a carry out of the highest exact digit. + */ + res_ndigits = var1ndigits + var2ndigits + 1; + res_weight = var1->weight + var2->weight + 2 + + ((res_ndigits * 2) - (var1->ndigits + var2->ndigits + 1)); + maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS + + MUL_GUARD_DIGITS; + res_ndigits = Min(res_ndigits, maxdigits); + + if (res_ndigits < 3) + { + /* All input digits will be ignored; so result is zero */ + zero_var(result); + result->dscale = rscale; + return; + } + + /* + * We do the arithmetic in an array "dig[]" of signed int64's. Since + * PG_INT64_MAX is noticeably larger than SQUARE_NBASE*SQUARE_NBASE, this + * gives us headroom to avoid normalizing carries immediately. + * + * maxdig tracks the maximum possible value of any dig[] entry; when this + * threatens to exceed PG_INT64_MAX, we take the time to propagate carries. + * Furthermore, we need to ensure that overflow doesn't occur during the + * carry propagation passes either. The carry values could be as much as + * PG_INT64_MAX/SQUARE_NBASE, so really we must normalize when digits + * threaten to exceed PG_INT64_MAX - PG_INT64_MAX/SQUARE_NBASE. + * + * To avoid overflow in maxdig itself, it actually represents the max + * possible value divided by SQUARE_NBASE-1, ie, at the top of the loop it + * is known that no dig[] entry exceeds maxdig * (SQUARE_NBASE-1). + * + * The allocated dig[] array will both be used to write the result, + * as well as the result of the base conversion of var1 and var2. + */ + dig = (int64 *) palloc0((res_ndigits + var1ndigits + var2ndigits) * + sizeof(int64)); + maxdig = 0; + var1digits = dig + res_ndigits; + var2digits = dig + res_ndigits + var1ndigits; + + /* + * Base conversion of var1 and var2 from NBASE to SQUARE_NBASE. + */ + i1 = 0; i2 = 0; + if (var1->ndigits % 2 != 0) + var1digits[i1++] = (int64) var1->digits[i2++]; + for (; i1 < var1ndigits; i1++, i2 += 2) + var1digits[i1] = (int64) var1->digits[i2] * NBASE + var1->digits[i2+1]; + + i1 = 0; i2 = 0; + if (var2->ndigits % 2 != 0) + var2digits[i1++] = (int64) var2->digits[i2++]; + for (; i1 < var2ndigits; i1++, i2 += 2) + var2digits[i1] = (int64) var2->digits[i2] * NBASE + var2->digits[i2+1]; + + /* + * The least significant digits of var1 should be ignored if they don't + * contribute directly to the first res_ndigits digits of the result that + * we are computing. + * + * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit + * i1+i2+2 of the accumulator array, so we need only consider digits of + * var1 for which i1 <= res_ndigits - 3. + */ + for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--) + { + int64 var1digit = var1digits[i1]; + + if (var1digit == 0) + continue; + + /* Time to normalize? */ + maxdig += var1digit; + if (maxdig > (PG_INT64_MAX - PG_INT64_MAX / SQUARE_NBASE) / + (SQUARE_NBASE - 1)) + { + /* Yes, do it */ + carry = 0; + for (i = res_ndigits - 1; i >= 0; i--) + { + newdig = dig[i] + carry; + if (newdig >= SQUARE_NBASE) + { + carry = newdig / SQUARE_NBASE; + newdig -= carry * SQUARE_NBASE; + } + else + carry = 0; + dig[i] = newdig; + } + Assert(carry == 0); + /* Reset maxdig to indicate new worst-case */ + maxdig = 1 + var1digit; + } + + /* + * Add the appropriate multiple of var2 into the accumulator. + * + * As above, digits of var2 can be ignored if they don't contribute, + * so we only include digits for which i1+i2+2 < res_ndigits. + * + * This inner loop is the performance bottleneck for multiplication, + * so we want to keep it simple enough so that it can be + * auto-vectorized. Accordingly, process the digits left-to-right + * even though schoolbook multiplication would suggest right-to-left. + * Since we aren't propagating carries in this loop, the order does + * not matter. + */ + { + int i2limit = Min(var2ndigits, res_ndigits - i1 - 2); + int64 *dig_i1_2 = &dig[i1 + 2]; + + for (i2 = 0; i2 < i2limit; i2++) + dig_i1_2[i2] += var1digit * var2digits[i2]; + } + } + + /* + * Now we do a final carry propagation pass to normalize the result, which + * we combine with storing the result digits into the output. Note that + * this is still done at full precision w/guard digits. + */ + res_digits = (int *) palloc0(res_ndigits * sizeof(int)); + carry = 0; + for (i = res_ndigits - 1; i >= 0; i--) + { + newdig = dig[i] + carry; + if (newdig >= SQUARE_NBASE) + { + carry = newdig / SQUARE_NBASE; + newdig -= carry * SQUARE_NBASE; + } + else + carry = 0; + res_digits[i] = newdig; + } + Assert(carry == 0); + + /* + * Base conversion of res_digits from SQUARE_NBASE to NBASE. + */ + alloc_var(result, res_ndigits * 2); + for (i = 0; i < res_ndigits; i++) + { + int q = res_digits[i]; + result->digits[i*2] = q / NBASE; + result->digits[i*2 + 1] = q % NBASE; + } + + pfree(dig); + + /* + * Finally, round the result to the requested precision. + */ + result->weight = res_weight; + result->sign = res_sign; + + /* Round to target rscale (and set result->dscale) */ + round_var(result, rscale); + + /* Strip leading and trailing zeroes */ + strip_var(result); +} /* * div_var() - -- 2.45.1