From 5b9641514ffa7d545b6932a27d3f193ceecfd564 Mon Sep 17 00:00:00 2001
From: Dean Rasheed <dean.a.rasheed@gmail.com>
Date: Thu, 18 Jul 2024 18:32:56 +0100
Subject: [PATCH v4 2/2] Optimise numeric multiplication using base-NBASE^2
 arithmetic.

Currently mul_var() uses the schoolbook multiplication algorithm,
which is O(n^2) in the number of NBASE digits. To improve performance
for large inputs, convert the inputs to base NBASE^2 before
multiplying, which effectively halves the number of digits in each
input, theoretically speeding up the computation by a factor of 4. In
practice, the actual speedup for large inputs varies between around 3
and 6 times, depending on the system and compiler used. In turn, this
significantly reduces the runtime of the numeric_big regression test.

For this to work, 64-bit integers are required for the products of
base-NBASE^2 digits, so this works best on 64-bit machines, for which
it is faster whenever the shorter input has more than 4 or 5 NBASE
digits. On 32-bit machines, the additional overheads, especially
during carry propagation and the final conversion back to base-NBASE,
are significantly higher, and it is only faster when the shorter input
has more than around 50 NBASE digits. When the shorter input has more
than 6 NBASE digits (so that mul_var_short() cannot be used), but
fewer than around 50 NBASE digits, there may be a noticeable slowdown
on 32-bit machines. That seems to be an acceptable tradeoff, given the
performance gains for other inputs, and the effort that would be
required to maintain code specifically targeting 32-bit machines.
---
 src/backend/utils/adt/numeric.c | 227 +++++++++++++++++++++-----------
 1 file changed, 153 insertions(+), 74 deletions(-)

diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index ca28d0e3b3..fc9caaa8c7 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -101,6 +101,8 @@ typedef signed char NumericDigit;
 typedef int16 NumericDigit;
 #endif
 
+#define NBASE_SQR	(NBASE * NBASE)
+
 /*
  * The Numeric type as stored on disk.
  *
@@ -8674,21 +8676,30 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 		int rscale)
 {
 	int			res_ndigits;
+	int			res_ndigitpairs;
 	int			res_sign;
 	int			res_weight;
+	int			pair_offset;
 	int			maxdigits;
-	int		   *dig;
-	int			carry;
-	int			maxdig;
-	int			newdig;
+	int			maxdigitpairs;
+	uint64	   *dig,
+			   *dig_i1_off;
+	uint64		maxdig;
+	uint64		carry;
+	uint64		newdig;
 	int			var1ndigits;
 	int			var2ndigits;
+	int			var1ndigitpairs;
+	int			var2ndigitpairs;
 	NumericDigit *var1digits;
 	NumericDigit *var2digits;
+	uint32		var1digitpair;
+	uint32	   *var2digitpairs;
 	NumericDigit *res_digits;
 	int			i,
 				i1,
-				i2;
+				i2,
+				i2limit;
 
 	/*
 	 * Arrange for var1 to be the shorter of the two numbers.  This improves
@@ -8729,86 +8740,161 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 		return;
 	}
 
-	/* Determine result sign and (maximum possible) weight */
+	/* Determine result sign */
 	if (var1->sign == var2->sign)
 		res_sign = NUMERIC_POS;
 	else
 		res_sign = NUMERIC_NEG;
-	res_weight = var1->weight + var2->weight + 2;
 
 	/*
-	 * Determine the number of result digits to compute.  If the exact result
-	 * would have more than rscale fractional digits, truncate the computation
-	 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
-	 * would only contribute to the right of that.  (This will give the exact
+	 * Determine the number of result digits to compute and the (maximum
+	 * possible) result weight.  If the exact result would have more than
+	 * rscale fractional digits, truncate the computation with
+	 * MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that would
+	 * only contribute to the right of that.  (This will give the exact
 	 * rounded-to-rscale answer unless carries out of the ignored positions
 	 * would have propagated through more than MUL_GUARD_DIGITS digits.)
 	 *
 	 * Note: an exact computation could not produce more than var1ndigits +
-	 * var2ndigits digits, but we allocate one extra output digit in case
-	 * rscale-driven rounding produces a carry out of the highest exact digit.
+	 * var2ndigits digits, but we allocate at least one extra output digit in
+	 * case rscale-driven rounding produces a carry out of the highest exact
+	 * digit.
+	 *
+	 * The computation itself is done using base-NBASE^2 arithmetic, so we
+	 * actually process the input digits in pairs, producing a base-NBASE^2
+	 * intermediate result.  This significantly improves performance, since
+	 * schoolbook multiplication is O(N^2) in the number of input digits, and
+	 * working in base NBASE^2 effectively halves "N".
 	 */
-	res_ndigits = var1ndigits + var2ndigits + 1;
+	/* digit pairs in each input */
+	var1ndigitpairs = (var1ndigits + 1) / 2;
+	var2ndigitpairs = (var2ndigits + 1) / 2;
+
+	/* digits in exact result */
+	res_ndigits = var1ndigits + var2ndigits;
+
+	/* digit pairs in exact result with at least one extra output digit */
+	res_ndigitpairs = res_ndigits / 2 + 1;
+
+	/* pair offset to align result to end of dig[] */
+	pair_offset = res_ndigitpairs - var1ndigitpairs - var2ndigitpairs + 1;
+
+	/* maximum possible result weight */
+	res_weight = var1->weight + var2->weight + 1 + 2 * res_ndigitpairs -
+		res_ndigits;
+
+	/* truncate computation based on requested rscale */
 	maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
 		MUL_GUARD_DIGITS;
-	res_ndigits = Min(res_ndigits, maxdigits);
+	maxdigitpairs = (maxdigits + 1) / 2;
 
-	if (res_ndigits < 3)
+	res_ndigitpairs = Min(res_ndigitpairs, maxdigitpairs);
+	res_ndigits = 2 * res_ndigitpairs;
+
+	/*
+	 * In the computation below, digit pair i1 of var1 and digit pair i2 of
+	 * var2 are multiplied and added to digit i1+i2+pair_offset of dig[]. Thus
+	 * input digit pairs with index >= res_ndigitpairs - pair_offset don't
+	 * contribute to the result, and can be ignored.
+	 */
+	if (res_ndigitpairs <= pair_offset)
 	{
 		/* All input digits will be ignored; so result is zero */
 		zero_var(result);
 		result->dscale = rscale;
 		return;
 	}
+	var1ndigitpairs = Min(var1ndigitpairs, res_ndigitpairs - pair_offset);
+	var2ndigitpairs = Min(var2ndigitpairs, res_ndigitpairs - pair_offset);
 
 	/*
-	 * We do the arithmetic in an array "dig[]" of signed int's.  Since
-	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
-	 * to avoid normalizing carries immediately.
+	 * We do the arithmetic in an array "dig[]" of unsigned 64-bit integers.
+	 * Since PG_UINT64_MAX is much larger than NBASE^4, this gives us a lot of
+	 * headroom to avoid normalizing carries immediately.
 	 *
 	 * maxdig tracks the maximum possible value of any dig[] entry; when this
-	 * threatens to exceed INT_MAX, we take the time to propagate carries.
-	 * Furthermore, we need to ensure that overflow doesn't occur during the
-	 * carry propagation passes either.  The carry values could be as much as
-	 * INT_MAX/NBASE, so really we must normalize when digits threaten to
-	 * exceed INT_MAX - INT_MAX/NBASE.
+	 * threatens to exceed PG_UINT64_MAX, we take the time to propagate
+	 * carries.  Furthermore, we need to ensure that overflow doesn't occur
+	 * during the carry propagation passes either.  The carry values could be
+	 * as much as PG_UINT64_MAX / NBASE^2, so really we must normalize when
+	 * digits threaten to exceed PG_UINT64_MAX - PG_UINT64_MAX / NBASE^2.
+	 *
+	 * To avoid overflow in maxdig itself, it actually represents the maximum
+	 * possible value divided by NBASE^2-1, i.e., at the top of the loop it is
+	 * known that no dig[] entry exceeds maxdig * (NBASE^2-1).
 	 *
-	 * To avoid overflow in maxdig itself, it actually represents the max
-	 * possible value divided by NBASE-1, ie, at the top of the loop it is
-	 * known that no dig[] entry exceeds maxdig * (NBASE-1).
+	 * The conversion of var1 to base NBASE^2 is done on the fly, as each new
+	 * digit is required.  The digits of var2 are converted upfront, and
+	 * stored at the end of dig[].  To avoid loss of precision, the input
+	 * digits are aligned with the start of digit pair array, effectively
+	 * shifting them up (multiplying by NBASE) if the inputs have an odd
+	 * number of NBASE digits.
 	 */
-	dig = (int *) palloc0(res_ndigits * sizeof(int));
-	maxdig = 0;
+	dig = (uint64 *) palloc(res_ndigitpairs * sizeof(uint64) +
+							var2ndigitpairs * sizeof(uint32));
+
+	/* convert var2 to base NBASE^2, shifting up if its length is odd */
+	var2digitpairs = (uint32 *) (dig + res_ndigitpairs);
+
+	for (i2 = 0; i2 < var2ndigitpairs - 1; i2++)
+		var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
+
+	if (2 * i2 + 1 < var2ndigits)
+		var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
+	else
+		var2digitpairs[i2] = var2digits[2 * i2] * NBASE;
 
 	/*
-	 * The least significant digits of var1 should be ignored if they don't
-	 * contribute directly to the first res_ndigits digits of the result that
-	 * we are computing.
+	 * Start by multiplying var2 by the least significant contributing digit
+	 * pair from var1, storing the results at the end of dig[], and filling
+	 * the leading digits with zeros.
 	 *
-	 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
-	 * i1+i2+2 of the accumulator array, so we need only consider digits of
-	 * var1 for which i1 <= res_ndigits - 3.
+	 * The loop here is the same as the inner loop below, except that we set
+	 * the results in dig[], rather than adding to them.  This is the
+	 * performance bottleneck for multiplication, so we want to keep it simple
+	 * enough so that it can be auto-vectorized.  Accordingly, process the
+	 * digits left-to-right even though schoolbook multiplication would
+	 * suggest right-to-left.  Since we aren't propagating carries in this
+	 * loop, the order does not matter.
+	 */
+	i1 = var1ndigitpairs - 1;
+	if (2 * i1 + 1 < var1ndigits)
+		var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
+	else
+		var1digitpair = var1digits[2 * i1] * NBASE;
+	maxdig = var1digitpair;
+
+	i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
+	dig_i1_off = &dig[i1 + pair_offset];
+
+	memset(dig, 0, (i1 + pair_offset) * sizeof(uint64));
+	for (i2 = 0; i2 < i2limit; i2++)
+		dig_i1_off[i2] = (uint64) var1digitpair * var2digitpairs[i2];
+
+	/*
+	 * Next, multiply var2 by the remaining digit pairs from var1, adding the
+	 * results to dig[] at the appropriate offsets, and normalizing whenever
+	 * there is a risk of any dig[] entry overflowing.
 	 */
-	for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
+	for (i1 = i1 - 1; i1 >= 0; i1--)
 	{
-		NumericDigit var1digit = var1digits[i1];
-
-		if (var1digit == 0)
+		var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
+		if (var1digitpair == 0)
 			continue;
 
 		/* Time to normalize? */
-		maxdig += var1digit;
-		if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
+		maxdig += var1digitpair;
+		if (maxdig > (PG_UINT64_MAX - PG_UINT64_MAX / NBASE_SQR) / (NBASE_SQR - 1))
 		{
-			/* Yes, do it */
+			/* Yes, do it (to base NBASE^2) */
 			carry = 0;
-			for (i = res_ndigits - 1; i >= 0; i--)
+			for (i = res_ndigitpairs - 1; i >= 0; i--)
 			{
 				newdig = dig[i] + carry;
-				if (newdig >= NBASE)
+				if (newdig >= NBASE_SQR)
 				{
-					carry = newdig / NBASE;
-					newdig -= carry * NBASE;
+					carry = newdig / NBASE_SQR;
+					newdig -= carry * NBASE_SQR;
 				}
 				else
 					carry = 0;
@@ -8816,55 +8902,48 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 			}
 			Assert(carry == 0);
 			/* Reset maxdig to indicate new worst-case */
-			maxdig = 1 + var1digit;
+			maxdig = 1 + var1digitpair;
 		}
 
-		/*
-		 * Add the appropriate multiple of var2 into the accumulator.
-		 *
-		 * As above, digits of var2 can be ignored if they don't contribute,
-		 * so we only include digits for which i1+i2+2 < res_ndigits.
-		 *
-		 * This inner loop is the performance bottleneck for multiplication,
-		 * so we want to keep it simple enough so that it can be
-		 * auto-vectorized.  Accordingly, process the digits left-to-right
-		 * even though schoolbook multiplication would suggest right-to-left.
-		 * Since we aren't propagating carries in this loop, the order does
-		 * not matter.
-		 */
-		{
-			int			i2limit = Min(var2ndigits, res_ndigits - i1 - 2);
-			int		   *dig_i1_2 = &dig[i1 + 2];
+		/* Multiply and add */
+		i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
+		dig_i1_off = &dig[i1 + pair_offset];
 
-			for (i2 = 0; i2 < i2limit; i2++)
-				dig_i1_2[i2] += var1digit * var2digits[i2];
-		}
+		for (i2 = 0; i2 < i2limit; i2++)
+			dig_i1_off[i2] += (uint64) var1digitpair * var2digitpairs[i2];
 	}
 
 	/*
-	 * Now we do a final carry propagation pass to normalize the result, which
-	 * we combine with storing the result digits into the output. Note that
-	 * this is still done at full precision w/guard digits.
+	 * Now we do a final carry propagation pass to normalize back to base
+	 * NBASE^2, and construct the base-NBASE result digits.  Note that this is
+	 * still done at full precision w/guard digits.
 	 */
 	alloc_var(result, res_ndigits);
 	res_digits = result->digits;
 	carry = 0;
-	for (i = res_ndigits - 1; i >= 0; i--)
+	for (i = res_ndigitpairs - 1; i >= 0; i--)
 	{
 		newdig = dig[i] + carry;
-		if (newdig >= NBASE)
+		if (newdig >= NBASE_SQR)
 		{
-			carry = newdig / NBASE;
-			newdig -= carry * NBASE;
+			carry = newdig / NBASE_SQR;
+			newdig -= carry * NBASE_SQR;
 		}
 		else
 			carry = 0;
-		res_digits[i] = newdig;
+		res_digits[2 * i + 1] = (NumericDigit) ((uint32) newdig % NBASE);
+		res_digits[2 * i] = (NumericDigit) ((uint32) newdig / NBASE);
 	}
 	Assert(carry == 0);
 
 	pfree(dig);
 
+	/*
+	 * Adjust the result weight, if the inputs were shifted up during base
+	 * conversion (if they had an odd number of NBASE digits).
+	 */
+	res_weight -= (var1ndigits & 1) + (var2ndigits & 1);
+
 	/*
 	 * Finally, round the result to the requested precision.
 	 */
-- 
2.35.3

