Some improvements to numeric sqrt() and ln()

Started by Dean Rasheedalmost 6 years ago10 messages
#1Dean Rasheed
dean.a.rasheed@gmail.com
1 attachment(s)

Attached is a WIP patch to improve the performance of numeric sqrt()
and ln(), which also makes a couple of related improvements to
div_var_fast(), all of which have knock-on benefits for other numeric
functions. The actual impact varies greatly depending on the inputs,
but the overall effect is to reduce the run time of the numeric_big
regression test by about 20%.

Additionally it improves the accuracy of sqrt() -- currently sqrt()
sometimes rounds the last digit of the result the wrong way, for
example sqrt(100000000000000010000000000000000) returns
10000000000000001, when the correct answer should be 10000000000000000
to zero decimal places. With this patch, sqrt() guarantees to return
the result correctly rounded to the last digit for all inputs.

The main change is to sqrt_var(), which now uses a different algorithm
[1]: https://hal.inria.fr/inria-00072854/document
I've re-cast the algorithm from [1]https://hal.inria.fr/inria-00072854/document into an iterative form, rather
than doing it recursively, as it's presented in that paper. This
improves performance further, by avoiding overheads from function
calls and copying numeric variables around. Also, IMO, the iterative
form of the algorithm is much more elegant, since it works by making a
single pass over the input digits, consuming them one at a time from
most significant to least, producing a succession of increasingly more
accurate approximations to the square root, until the desired
precision is reached.

For inputs with a handful of digits, this is typically 3-5 times
faster, and for inputs with more digits the performance improvement is
larger (e.g. sqrt(2e131071) is around 10 times faster). If the input
is a perfect square, with a result having a lot of trailing zeros, the
new algorithm is much faster because it basically has nothing to do in
later iterations (e.g., sqrt(64e13070) is about 600 times faster).

Another change to sqrt_var() is that it now explicitly supports a
negative rscale, i.e., rounding before the decimal point. This is
exploited by ln_var() in its argument reduction stage -- ln_var()
reduces all inputs to the range (0.9, 1.1) by repeatedly taking the
square root. For very large inputs this can have an enormous impact,
for example log(1e131071) currently takes about 6.5 seconds on my
machine, whereas with this patch I can run it 1000 times in a plpgsql
loop in about 90ms, so its around 70,000 times faster in that case. Of
course, that's an extreme example, and for most inputs it's a much
more modest difference (e.g., ln(2) is about 1.5 times faster).

In passing, I also made a couple of optimisations to div_var_fast(),
discovered while comparing it's performace with div_var() for various
inputs.

It's possible that there are further gains to be had in the sqrt()
algorithm on platforms that support 128-bit integers, but I haven't
had a chance to investigate that yet.

Regards,
Dean

[1]: https://hal.inria.fr/inria-00072854/document

Attachments:

numeric-sqrt-ln.patchapplication/octet-stream; name=numeric-sqrt-ln.patchDownload
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index c92ad5a..0debcdc
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -393,16 +393,6 @@ static const NumericVar const_ten =
 #endif
 
 #if DEC_DIGITS == 4
-static const NumericDigit const_zero_point_five_data[1] = {5000};
-#elif DEC_DIGITS == 2
-static const NumericDigit const_zero_point_five_data[1] = {50};
-#elif DEC_DIGITS == 1
-static const NumericDigit const_zero_point_five_data[1] = {5};
-#endif
-static const NumericVar const_zero_point_five =
-{1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
-
-#if DEC_DIGITS == 4
 static const NumericDigit const_zero_point_nine_data[1] = {9000};
 #elif DEC_DIGITS == 2
 static const NumericDigit const_zero_point_nine_data[1] = {90};
@@ -518,6 +508,8 @@ static void div_var_fast(const NumericVa
 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
 static void mod_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result);
+static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
+						NumericVar *quot, NumericVar *rem);
 static void ceil_var(const NumericVar *var, NumericVar *result);
 static void floor_var(const NumericVar *var, NumericVar *result);
 
@@ -7712,6 +7704,7 @@ div_var_fast(const NumericVar *var1, con
 			 NumericVar *result, int rscale, bool round)
 {
 	int			div_ndigits;
+	int			load_ndigits;
 	int			res_sign;
 	int			res_weight;
 	int		   *div;
@@ -7766,9 +7759,6 @@ div_var_fast(const NumericVar *var1, con
 	div_ndigits += DIV_GUARD_DIGITS;
 	if (div_ndigits < DIV_GUARD_DIGITS)
 		div_ndigits = DIV_GUARD_DIGITS;
-	/* Must be at least var1ndigits, too, to simplify data-loading loop */
-	if (div_ndigits < var1ndigits)
-		div_ndigits = var1ndigits;
 
 	/*
 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
@@ -7781,9 +7771,16 @@ div_var_fast(const NumericVar *var1, con
 	 * (approximate) quotient digit and stores it into div[], removing one
 	 * position of dividend space.  A final pass of carry propagation takes
 	 * care of any mistaken quotient digits.
+	 *
+	 * Note that div[] doesn't necessarily contain all of the digits from the
+	 * dividend --- the desired precision plus guard digits might be less than
+	 * the dividend's precision.  This happens, for example, in the square
+	 * root algorithm, where we typically divide a 2N-digit number by an
+	 * N-digit number, and only require a result with N digits of precision.
 	 */
 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
-	for (i = 0; i < var1ndigits; i++)
+	load_ndigits = Min(div_ndigits, var1ndigits);
+	for (i = 0; i < load_ndigits; i++)
 		div[i + 1] = var1digits[i];
 
 	/*
@@ -7844,9 +7841,15 @@ div_var_fast(const NumericVar *var1, con
 			maxdiv += Abs(qdigit);
 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
 			{
-				/* Yes, do it */
+				/*
+				 * Yes, do it.  Note that if var2ndigits is much smaller than
+				 * div_ndigits, we can save a significant amount of effort
+				 * here by noting that we only need to normalise those div[]
+				 * entries touched where prior iterations subtracted multiples
+				 * of the divisor.
+				 */
 				carry = 0;
-				for (i = div_ndigits; i > qi; i--)
+				for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
 				{
 					newdig = div[i] + carry;
 					if (newdig < 0)
@@ -8095,6 +8098,74 @@ mod_var(const NumericVar *var1, const Nu
 
 
 /*
+ * div_mod_var() -
+ *
+ *	Calculate the truncated integer quotient and numeric remainder of two
+ *	numeric variables.
+ */
+static void
+div_mod_var(const NumericVar *var1, const NumericVar *var2,
+			NumericVar *quot, NumericVar *rem)
+{
+	NumericVar	q;
+	NumericVar	r;
+
+	init_var(&q);
+	init_var(&r);
+
+	/*
+	 * Use div_var_fast() to get an initial estimate for the integer quotient.
+	 * In practice, this almost always correct, but it is occasionally off by
+	 * one, which we can easily correct.
+	 */
+	div_var_fast(var1, var2, &q, 0, false);
+	mul_var(var2, &q, &r, var2->dscale);
+	sub_var(var1, &r, &r);
+
+	/*
+	 * Adjust the results if necessary --- the remainder should have the same
+	 * sign as var1, and its absolute value should be less than the absolute
+	 * value of var2.
+	 */
+	while (r.ndigits != 0 && r.sign != var1->sign)
+	{
+		/* The absolute value of the quotient is too large */
+		if (var1->sign == var2->sign)
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+		else
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+	}
+
+	while (cmp_abs(&r, var2) >= 0)
+	{
+		/* The absolute value of the quotient is too small */
+		if (var1->sign == var2->sign)
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+		else
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+	}
+
+	set_var_from_var(&q, quot);
+	set_var_from_var(&r, rem);
+
+	free_var(&q);
+	free_var(&r);
+}
+
+
+/*
  * ceil_var() -
  *
  *	Return the smallest integer greater than or equal to the argument
@@ -8213,18 +8284,30 @@ gcd_var(const NumericVar *var1, const Nu
 /*
  * sqrt_var() -
  *
- *	Compute the square root of x using Newton's algorithm
+ *	Compute the square root of x using the Karatsuba Square Root algorithm.
+ *	NOTE: we allow rscale < 0 here, implying rounding before the decimal
+ *	point.
  */
 static void
 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
 {
-	NumericVar	tmp_arg;
-	NumericVar	tmp_val;
-	NumericVar	last_val;
-	int			local_rscale;
 	int			stat;
-
-	local_rscale = rscale + 8;
+	int			res_weight;
+	int			res_ndigits;
+	int			src_ndigits;
+	int			step;
+	int			ndigits[32];
+	int			blen;
+	int64		arg_int64;
+	int			src_idx;
+	int64		s_int64;
+	int64		r_int64;
+	NumericVar	s_var;
+	NumericVar	r_var;
+	NumericVar	a0_var;
+	NumericVar	a1_var;
+	NumericVar	q_var;
+	NumericVar	u_var;
 
 	stat = cmp_var(arg, &const_zero);
 	if (stat == 0)
@@ -8243,43 +8326,311 @@ sqrt_var(const NumericVar *arg, NumericV
 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
 				 errmsg("cannot take square root of a negative number")));
 
-	init_var(&tmp_arg);
-	init_var(&tmp_val);
-	init_var(&last_val);
+	init_var(&s_var);
+	init_var(&r_var);
+	init_var(&a0_var);
+	init_var(&a1_var);
+	init_var(&q_var);
+	init_var(&u_var);
 
-	/* Copy arg in case it is the same var as result */
-	set_var_from_var(arg, &tmp_arg);
+	/*
+	 * The result weight is half the input weight, rounded towards minus
+	 * infinity.
+	 */
+	res_weight = (int) floor((double) arg->weight / 2);
 
 	/*
-	 * Initialize the result to the first guess
+	 * Number of NBASE digits to compute.  To ensure correct rounding, compute
+	 * at least 1 extra decimal digit.  We explicitly allow rscale to be
+	 * negative here, but must always compute at least 1 NBASE digit.
 	 */
-	alloc_var(result, 1);
-	result->digits[0] = tmp_arg.digits[0] / 2;
-	if (result->digits[0] == 0)
-		result->digits[0] = 1;
-	result->weight = tmp_arg.weight / 2;
-	result->sign = NUMERIC_POS;
+	res_ndigits = res_weight + 1 + (int) ceil((double) (rscale + 1) / DEC_DIGITS);
+	res_ndigits = Max(res_ndigits, 1);
 
-	set_var_from_var(result, &last_val);
+	/*
+	 * Number of source NBASE digits logically required to produce a result
+	 * with this precision --- every digit before the decimal point, plus 2
+	 * for each result digit after the decimal point (or minus 2 for each
+	 * result digit we round before the decimal point).
+	 */
+	src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
+	src_ndigits = Max(src_ndigits, 1);
 
-	for (;;)
+	/* ----------
+	 * From this point on, we treat the input and the result as integers and
+	 * compute the integer square root and remainder using the Karatusba
+	 * Square Root algorithm, which may be written recusively as follows:
+	 *
+	 *	SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
+	 *		[ for some base b, and coefficients a0,a1,a2,a3 chosen so that
+	 *		  0 <= a0,a1,a2 < b and a3 >= b/4 ]
+	 *		Let (s,r) = SqrtRem(a3*b + a2)
+	 *		Let (q,u) = DivRem(r*b + a1, 2*s)
+	 *		Let s = s*b + q
+	 *		Let r = u*b + a0 - q^2
+	 *		If r < 0 Then
+	 *			Let r = r + 2*s - 1
+	 *			Let s = s - 1
+	 *		Return (s,r)
+	 *
+	 * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
+	 * 3805, November 1999.
+	 *
+	 * Note that there is no upper bound on a3, and we allow it to be larger
+	 * than b (by choosing a smaller b) if necessary to ensure that the
+	 * condition a3 >= b/4 is met.  For optimal performance, b should be have
+	 * approximately a quarter the number of digits in the input, so that the
+	 * outer square root computes roughly twice as many digits as the inner
+	 * one.  For simplicity, we choose b = NBASE^blen, an integer power of
+	 * NBASE.
+	 *
+	 * We implement the algorithm iteratively rather than recursively, to
+	 * allow the working variables to be reused.  With this approach, each
+	 * digit of the input is read precisely once --- src_idx tracks the number
+	 * of input digits used so far.
+	 *
+	 * The array ndigits[] holds the number of NBASE digits of the input that
+	 * will have been used at the end of each iteration, which roughly doubles
+	 * each time.  Note that the array elements are stored in reverse order,
+	 * so if the final iteration requires src_ndigits = 37 input digits, the
+	 * array will contain [37,19,11,7,5,3], and we would start by computing
+	 * the square root of the 3 most significant NBASE digits.
+	 * ----------
+	 */
+	step = 0;
+	while ((ndigits[step] = src_ndigits) > 4)
 	{
-		div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
+		/* Choose b so that a3 >= b/4 */
+		blen = src_ndigits / 4;
+		if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
+			blen--;
 
-		add_var(result, &tmp_val, result);
-		mul_var(result, &const_zero_point_five, result, local_rscale);
+		/* Number of digits in the next step (inner square root) */
+		src_ndigits -= 2 * blen;
+		step++;
+	}
 
-		if (cmp_var(&last_val, result) == 0)
-			break;
-		set_var_from_var(result, &last_val);
+	/*
+	 * First iteration (innermost square root and remainder):
+	 *
+	 * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
+	 * has at most 9 decimal digits, so estimate it using double precision
+	 * arithmetic, which will in fact almost certainly return the correct
+	 * result with no further correction required.
+	 */
+	arg_int64 = arg->digits[0];
+	for (src_idx = 1; src_idx < src_ndigits; src_idx++)
+	{
+		arg_int64 *= NBASE;
+		if (src_idx < arg->ndigits)
+			arg_int64 += arg->digits[src_idx];
 	}
 
-	free_var(&last_val);
-	free_var(&tmp_val);
-	free_var(&tmp_arg);
+	s_int64 = (int64) sqrt((double) arg_int64);
+	r_int64 = arg_int64 - s_int64 * s_int64;
 
-	/* Round to requested precision */
+	/* Use Newton's method to correct the result, if necessary */
+	while (r_int64 < 0 || r_int64 > 2 * s_int64)
+	{
+		s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
+		r_int64 = arg_int64 - s_int64 * s_int64;
+	}
+
+	/*
+	 * Iterations with src_ndigits <= 8:
+	 *
+	 * The next 1 or 2 iterations compute larger (outer) square roots with
+	 * src_ndigits <= 8, so the result still fits in an int64 (even though the
+	 * input no longer does) and we can continue to compute using int64
+	 * variables to avoid more expensive numeric computations.
+	 *
+	 * It is fairly easy to see that there is no risk of the intermediate
+	 * values below overflowing 64-bit integers.  In the worst case, the
+	 * previous iteration will have computed a 3-digit square root (of a
+	 * 6-digit input less than NBASE^6 / 4), so at the start of this
+	 * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
+	 * less than 10^12.  In this case, blen will be 1, so numer will be less
+	 * than 10^17, and denom will be less than 10^12 (and hence u will also be
+	 * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
+	 * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
+	 * in 64-bit integers.
+	 */
+	step--;
+	while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
+	{
+		int			b;
+		int			a0;
+		int			a1;
+		int			i;
+		int64		numer;
+		int64		denom;
+		int64		q;
+		int64		u;
+
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0, and compute b */
+		a0 = 0;
+		a1 = 0;
+		b = 1;
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			b *= NBASE;
+			a1 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a1 += arg->digits[src_idx];
+		}
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			a0 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a0 += arg->digits[src_idx];
+		}
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		numer = r_int64 * b + a1;
+		denom = 2 * s_int64;
+		q = numer / denom;
+		u = numer - q * denom;
+
+		/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+		s_int64 = s_int64 * b + q;
+		r_int64 = u * b + a0 - q * q;
+
+		if (r_int64 < 0)
+		{
+			/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+			r_int64 += 2 * s_int64 - 1;
+			s_int64--;
+		}
+
+		Assert(src_idx == src_ndigits);		/* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * Remaining iterations with src_ndigits > 8:
+	 *
+	 * All remaining iterations require numeric variables.  Convert the int64
+	 * values to NumericVar and continue.  Note that in the final iteration we
+	 * don't need the remainder, so we can save a few cycles there by not
+	 * fully computing it.
+	 */
+	int64_to_numericvar(s_int64, &s_var);
+	if (step >= 0)
+		int64_to_numericvar(r_int64, &r_var);
+
+	while (step >= 0)
+	{
+		int			tmp_len;
+
+		src_ndigits = ndigits[step];
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0 */
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a1_var, tmp_len);
+			memcpy(a1_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a1_var.weight = blen - 1;
+			a1_var.sign = NUMERIC_POS;
+			a1_var.dscale = 0;
+			strip_var(&a1_var);
+		}
+		else
+		{
+			zero_var(&a1_var);
+			a1_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a0_var, tmp_len);
+			memcpy(a0_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a0_var.weight = blen - 1;
+			a0_var.sign = NUMERIC_POS;
+			a0_var.dscale = 0;
+			strip_var(&a0_var);
+		}
+		else
+		{
+			zero_var(&a0_var);
+			a0_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		set_var_from_var(&r_var, &q_var);
+		q_var.weight += blen;
+		add_var(&q_var, &a1_var, &q_var);
+		add_var(&s_var, &s_var, &u_var);
+		div_mod_var(&q_var, &u_var, &q_var, &u_var);
+
+		/* Compute s = s*b + q */
+		s_var.weight += blen;
+		add_var(&s_var, &q_var, &s_var);
+
+		/*
+		 * Compute r = u*b + a0 - q^2.
+		 *
+		 * In the final iteration, we don't actually need r, but we do need to
+		 * know whether it would have been negative, so that we know whether
+		 * to adjust s.
+		 */
+		u_var.weight += blen;
+		add_var(&u_var, &a0_var, &u_var);
+		mul_var(&q_var, &q_var, &q_var, 0);
+
+		if (step > 0)
+		{
+			/* Need r for later iterations */
+			sub_var(&u_var, &q_var, &r_var);
+			if (r_var.sign == NUMERIC_NEG)
+			{
+				/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+				add_var(&s_var, &s_var, &q_var);
+				add_var(&r_var, &q_var, &r_var);
+				sub_var(&r_var, &const_one, &r_var);
+				sub_var(&s_var, &const_one, &s_var);
+			}
+		}
+		else
+		{
+			/* Don't need r anymore, except to test if s is too large by 1 */
+			if (cmp_var(&u_var, &q_var) < 0)
+				sub_var(&s_var, &const_one, &s_var);
+		}
+
+		Assert(src_idx == src_ndigits);		/* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * Construct the final result, rounding it to the requested precision.
+	 */
+	set_var_from_var(&s_var, result);
+	result->weight = res_weight;
+	result->sign = NUMERIC_POS;
+
+	/* Round to target rscale (and set result->dscale) */
 	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+
+	free_var(&s_var);
+	free_var(&r_var);
+	free_var(&a0_var);
+	free_var(&a1_var);
+	free_var(&q_var);
+	free_var(&u_var);
 }
 
 
@@ -8529,18 +8880,23 @@ ln_var(const NumericVar *arg, NumericVar
 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
 	 * rscale as we work so that we keep this many significant digits at each
 	 * step (plus a few more for good measure).
+	 *
+	 * Note that we allow local_rscale < 0 during this input reduction
+	 * process, which implies rounding before the decimal point.  sqrt_var()
+	 * explicitly supports this, and it significantly reduces the work
+	 * required to reduce very large inputs to the required range.  Once the
+	 * input reduction is complete, x.weight will be 0 and its display scale
+	 * will be non-negative again.
 	 */
 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 	}
 	while (cmp_var(&x, &const_one_point_one) >= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 	}
diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out
new file mode 100644
index 23a4c6d..c7fe63d
--- a/src/test/regress/expected/numeric.out
+++ b/src/test/regress/expected/numeric.out
@@ -1580,6 +1580,57 @@ select div(12345678901234567890, 123) *
 (1 row)
 
 --
+-- Test some corner cases for square root
+--
+select sqrt(1.000000000000003::numeric);
+       sqrt        
+-------------------
+ 1.000000000000001
+(1 row)
+
+select sqrt(1.000000000000004::numeric);
+       sqrt        
+-------------------
+ 1.000000000000002
+(1 row)
+
+select sqrt(96627521408608.56340355805::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248648
+(1 row)
+
+select sqrt(96627521408608.56340355806::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248649
+(1 row)
+
+select sqrt(515549506212297735.073688290367::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921184
+(1 row)
+
+select sqrt(515549506212297735.073688290368::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921185
+(1 row)
+
+select sqrt(8015491789940783531003294973900306::numeric);
+       sqrt        
+-------------------
+ 89529278953540017
+(1 row)
+
+select sqrt(8015491789940783531003294973900307::numeric);
+       sqrt        
+-------------------
+ 89529278953540018
+(1 row)
+
+--
 -- Test code path for raising to integer powers
 --
 select 10.0 ^ -2147483648 as rounds_to_zero;
diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql
new file mode 100644
index c5c8d76..41475a9
--- a/src/test/regress/sql/numeric.sql
+++ b/src/test/regress/sql/numeric.sql
@@ -883,6 +883,19 @@ select div(12345678901234567890, 123);
 select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;
 
 --
+-- Test some corner cases for square root
+--
+
+select sqrt(1.000000000000003::numeric);
+select sqrt(1.000000000000004::numeric);
+select sqrt(96627521408608.56340355805::numeric);
+select sqrt(96627521408608.56340355806::numeric);
+select sqrt(515549506212297735.073688290367::numeric);
+select sqrt(515549506212297735.073688290368::numeric);
+select sqrt(8015491789940783531003294973900306::numeric);
+select sqrt(8015491789940783531003294973900307::numeric);
+
+--
 -- Test code path for raising to integer powers
 --
 
#2Dean Rasheed
dean.a.rasheed@gmail.com
In reply to: Dean Rasheed (#1)
1 attachment(s)
Re: Some improvements to numeric sqrt() and ln()

On Fri, 28 Feb 2020 at 08:15, Dean Rasheed <dean.a.rasheed@gmail.com> wrote:

It's possible that there are further gains to be had in the sqrt()
algorithm on platforms that support 128-bit integers, but I haven't
had a chance to investigate that yet.

Rebased patch attached, now using 128-bit integers for part of
sqrt_var() on platforms that support them. This turned out to be well
worth it (1.5 to 2 times faster than the previous version if the
result has less than 30 or 40 digits).

Regards,
Dean

Attachments:

numeric-sqrt-v2.patchtext/x-patch; charset=US-ASCII; name=numeric-sqrt-v2.patchDownload
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index 10229eb..9e6bb80
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -393,16 +393,6 @@ static const NumericVar const_ten =
 #endif
 
 #if DEC_DIGITS == 4
-static const NumericDigit const_zero_point_five_data[1] = {5000};
-#elif DEC_DIGITS == 2
-static const NumericDigit const_zero_point_five_data[1] = {50};
-#elif DEC_DIGITS == 1
-static const NumericDigit const_zero_point_five_data[1] = {5};
-#endif
-static const NumericVar const_zero_point_five =
-{1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
-
-#if DEC_DIGITS == 4
 static const NumericDigit const_zero_point_nine_data[1] = {9000};
 #elif DEC_DIGITS == 2
 static const NumericDigit const_zero_point_nine_data[1] = {90};
@@ -518,6 +508,8 @@ static void div_var_fast(const NumericVa
 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
 static void mod_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result);
+static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
+						NumericVar *quot, NumericVar *rem);
 static void ceil_var(const NumericVar *var, NumericVar *result);
 static void floor_var(const NumericVar *var, NumericVar *result);
 
@@ -7712,6 +7704,7 @@ div_var_fast(const NumericVar *var1, con
 			 NumericVar *result, int rscale, bool round)
 {
 	int			div_ndigits;
+	int			load_ndigits;
 	int			res_sign;
 	int			res_weight;
 	int		   *div;
@@ -7766,9 +7759,6 @@ div_var_fast(const NumericVar *var1, con
 	div_ndigits += DIV_GUARD_DIGITS;
 	if (div_ndigits < DIV_GUARD_DIGITS)
 		div_ndigits = DIV_GUARD_DIGITS;
-	/* Must be at least var1ndigits, too, to simplify data-loading loop */
-	if (div_ndigits < var1ndigits)
-		div_ndigits = var1ndigits;
 
 	/*
 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
@@ -7781,9 +7771,16 @@ div_var_fast(const NumericVar *var1, con
 	 * (approximate) quotient digit and stores it into div[], removing one
 	 * position of dividend space.  A final pass of carry propagation takes
 	 * care of any mistaken quotient digits.
+	 *
+	 * Note that div[] doesn't necessarily contain all of the digits from the
+	 * dividend --- the desired precision plus guard digits might be less than
+	 * the dividend's precision.  This happens, for example, in the square
+	 * root algorithm, where we typically divide a 2N-digit number by an
+	 * N-digit number, and only require a result with N digits of precision.
 	 */
 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
-	for (i = 0; i < var1ndigits; i++)
+	load_ndigits = Min(div_ndigits, var1ndigits);
+	for (i = 0; i < load_ndigits; i++)
 		div[i + 1] = var1digits[i];
 
 	/*
@@ -7844,9 +7841,15 @@ div_var_fast(const NumericVar *var1, con
 			maxdiv += Abs(qdigit);
 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
 			{
-				/* Yes, do it */
+				/*
+				 * Yes, do it.  Note that if var2ndigits is much smaller than
+				 * div_ndigits, we can save a significant amount of effort
+				 * here by noting that we only need to normalise those div[]
+				 * entries touched where prior iterations subtracted multiples
+				 * of the divisor.
+				 */
 				carry = 0;
-				for (i = div_ndigits; i > qi; i--)
+				for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
 				{
 					newdig = div[i] + carry;
 					if (newdig < 0)
@@ -8095,6 +8098,74 @@ mod_var(const NumericVar *var1, const Nu
 
 
 /*
+ * div_mod_var() -
+ *
+ *	Calculate the truncated integer quotient and numeric remainder of two
+ *	numeric variables.
+ */
+static void
+div_mod_var(const NumericVar *var1, const NumericVar *var2,
+			NumericVar *quot, NumericVar *rem)
+{
+	NumericVar	q;
+	NumericVar	r;
+
+	init_var(&q);
+	init_var(&r);
+
+	/*
+	 * Use div_var_fast() to get an initial estimate for the integer quotient.
+	 * In practice, this is almost always correct, but it is occasionally off
+	 * by one, which we can easily correct.
+	 */
+	div_var_fast(var1, var2, &q, 0, false);
+	mul_var(var2, &q, &r, var2->dscale);
+	sub_var(var1, &r, &r);
+
+	/*
+	 * Adjust the results if necessary --- the remainder should have the same
+	 * sign as var1, and its absolute value should be less than the absolute
+	 * value of var2.
+	 */
+	while (r.ndigits != 0 && r.sign != var1->sign)
+	{
+		/* The absolute value of the quotient is too large */
+		if (var1->sign == var2->sign)
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+		else
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+	}
+
+	while (cmp_abs(&r, var2) >= 0)
+	{
+		/* The absolute value of the quotient is too small */
+		if (var1->sign == var2->sign)
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+		else
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+	}
+
+	set_var_from_var(&q, quot);
+	set_var_from_var(&r, rem);
+
+	free_var(&q);
+	free_var(&r);
+}
+
+
+/*
  * ceil_var() -
  *
  *	Return the smallest integer greater than or equal to the argument
@@ -8213,18 +8284,30 @@ gcd_var(const NumericVar *var1, const Nu
 /*
  * sqrt_var() -
  *
- *	Compute the square root of x using Newton's algorithm
+ *	Compute the square root of x using the Karatsuba Square Root algorithm.
+ *	NOTE: we allow rscale < 0 here, implying rounding before the decimal
+ *	point.
  */
 static void
 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
 {
-	NumericVar	tmp_arg;
-	NumericVar	tmp_val;
-	NumericVar	last_val;
-	int			local_rscale;
 	int			stat;
-
-	local_rscale = rscale + 8;
+	int			res_weight;
+	int			res_ndigits;
+	int			src_ndigits;
+	int			step;
+	int			ndigits[32];
+	int			blen;
+	int64		arg_int64;
+	int			src_idx;
+	int64		s_int64;
+	int64		r_int64;
+	NumericVar	s_var;
+	NumericVar	r_var;
+	NumericVar	a0_var;
+	NumericVar	a1_var;
+	NumericVar	q_var;
+	NumericVar	u_var;
 
 	stat = cmp_var(arg, &const_zero);
 	if (stat == 0)
@@ -8243,43 +8326,398 @@ sqrt_var(const NumericVar *arg, NumericV
 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
 				 errmsg("cannot take square root of a negative number")));
 
-	init_var(&tmp_arg);
-	init_var(&tmp_val);
-	init_var(&last_val);
+	init_var(&s_var);
+	init_var(&r_var);
+	init_var(&a0_var);
+	init_var(&a1_var);
+	init_var(&q_var);
+	init_var(&u_var);
 
-	/* Copy arg in case it is the same var as result */
-	set_var_from_var(arg, &tmp_arg);
+	/*
+	 * The result weight is half the input weight, rounded towards minus
+	 * infinity.
+	 */
+	res_weight = (int) floor((double) arg->weight / 2);
 
 	/*
-	 * Initialize the result to the first guess
+	 * Number of NBASE digits to compute.  To ensure correct rounding, compute
+	 * at least 1 extra decimal digit.  We explicitly allow rscale to be
+	 * negative here, but must always compute at least 1 NBASE digit.
 	 */
-	alloc_var(result, 1);
-	result->digits[0] = tmp_arg.digits[0] / 2;
-	if (result->digits[0] == 0)
-		result->digits[0] = 1;
-	result->weight = tmp_arg.weight / 2;
-	result->sign = NUMERIC_POS;
+	res_ndigits = res_weight + 1 + (int) ceil((double) (rscale + 1) / DEC_DIGITS);
+	res_ndigits = Max(res_ndigits, 1);
 
-	set_var_from_var(result, &last_val);
+	/*
+	 * Number of source NBASE digits logically required to produce a result
+	 * with this precision --- every digit before the decimal point, plus 2
+	 * for each result digit after the decimal point (or minus 2 for each
+	 * result digit we round before the decimal point).
+	 */
+	src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
+	src_ndigits = Max(src_ndigits, 1);
 
-	for (;;)
+	/* ----------
+	 * From this point on, we treat the input and the result as integers and
+	 * compute the integer square root and remainder using the Karatusba
+	 * Square Root algorithm, which may be written recusively as follows:
+	 *
+	 *	SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
+	 *		[ for some base b, and coefficients a0,a1,a2,a3 chosen so that
+	 *		  0 <= a0,a1,a2 < b and a3 >= b/4 ]
+	 *		Let (s,r) = SqrtRem(a3*b + a2)
+	 *		Let (q,u) = DivRem(r*b + a1, 2*s)
+	 *		Let s = s*b + q
+	 *		Let r = u*b + a0 - q^2
+	 *		If r < 0 Then
+	 *			Let r = r + 2*s - 1
+	 *			Let s = s - 1
+	 *		Return (s,r)
+	 *
+	 * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
+	 * 3805, November 1999.
+	 *
+	 * Note that there is no upper bound on a3, and we allow it to be larger
+	 * than b (by choosing a smaller b) if necessary to ensure that the
+	 * condition a3 >= b/4 is met.  For optimal performance, b should be have
+	 * approximately a quarter the number of digits in the input, so that the
+	 * outer square root computes roughly twice as many digits as the inner
+	 * one.  For simplicity, we choose b = NBASE^blen, an integer power of
+	 * NBASE.
+	 *
+	 * We implement the algorithm iteratively rather than recursively, to
+	 * allow the working variables to be reused.  With this approach, each
+	 * digit of the input is read precisely once --- src_idx tracks the number
+	 * of input digits used so far.
+	 *
+	 * The array ndigits[] holds the number of NBASE digits of the input that
+	 * will have been used at the end of each iteration, which roughly doubles
+	 * each time.  Note that the array elements are stored in reverse order,
+	 * so if the final iteration requires src_ndigits = 37 input digits, the
+	 * array will contain [37,19,11,7,5,3], and we would start by computing
+	 * the square root of the 3 most significant NBASE digits.
+	 * ----------
+	 */
+	step = 0;
+	while ((ndigits[step] = src_ndigits) > 4)
 	{
-		div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
+		/* Choose b so that a3 >= b/4 */
+		blen = src_ndigits / 4;
+		if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
+			blen--;
 
-		add_var(result, &tmp_val, result);
-		mul_var(result, &const_zero_point_five, result, local_rscale);
+		/* Number of digits in the next step (inner square root) */
+		src_ndigits -= 2 * blen;
+		step++;
+	}
 
-		if (cmp_var(&last_val, result) == 0)
-			break;
-		set_var_from_var(result, &last_val);
+	/*
+	 * First iteration (innermost square root and remainder):
+	 *
+	 * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
+	 * has at most 9 decimal digits, so estimate it using double precision
+	 * arithmetic, which will in fact almost certainly return the correct
+	 * result with no further correction required.
+	 */
+	arg_int64 = arg->digits[0];
+	for (src_idx = 1; src_idx < src_ndigits; src_idx++)
+	{
+		arg_int64 *= NBASE;
+		if (src_idx < arg->ndigits)
+			arg_int64 += arg->digits[src_idx];
 	}
 
-	free_var(&last_val);
-	free_var(&tmp_val);
-	free_var(&tmp_arg);
+	s_int64 = (int64) sqrt((double) arg_int64);
+	r_int64 = arg_int64 - s_int64 * s_int64;
 
-	/* Round to requested precision */
+	/* Use Newton's method to correct the result, if necessary */
+	while (r_int64 < 0 || r_int64 > 2 * s_int64)
+	{
+		s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
+		r_int64 = arg_int64 - s_int64 * s_int64;
+	}
+
+	/*
+	 * Iterations with src_ndigits <= 8:
+	 *
+	 * The next 1 or 2 iterations compute larger (outer) square roots with
+	 * src_ndigits <= 8, so the result still fits in an int64 (even though the
+	 * input no longer does) and we can continue to compute using int64
+	 * variables to avoid more expensive numeric computations.
+	 *
+	 * It is fairly easy to see that there is no risk of the intermediate
+	 * values below overflowing 64-bit integers.  In the worst case, the
+	 * previous iteration will have computed a 3-digit square root (of a
+	 * 6-digit input less than NBASE^6 / 4), so at the start of this
+	 * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
+	 * less than 10^12.  In this case, blen will be 1, so numer will be less
+	 * than 10^17, and denom will be less than 10^12 (and hence u will also be
+	 * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
+	 * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
+	 * in 64-bit integers.
+	 */
+	step--;
+	while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
+	{
+		int			b;
+		int			a0;
+		int			a1;
+		int			i;
+		int64		numer;
+		int64		denom;
+		int64		q;
+		int64		u;
+
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0, and compute b */
+		a0 = 0;
+		a1 = 0;
+		b = 1;
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			b *= NBASE;
+			a1 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a1 += arg->digits[src_idx];
+		}
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			a0 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a0 += arg->digits[src_idx];
+		}
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		numer = r_int64 * b + a1;
+		denom = 2 * s_int64;
+		q = numer / denom;
+		u = numer - q * denom;
+
+		/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+		s_int64 = s_int64 * b + q;
+		r_int64 = u * b + a0 - q * q;
+
+		if (r_int64 < 0)
+		{
+			/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+			r_int64 += 2 * s_int64 - 1;
+			s_int64--;
+		}
+
+		Assert(src_idx == src_ndigits);		/* All input digits consumed */
+		step--;
+	}
+
+#ifdef HAVE_INT128
+	/*
+	 * On platforms with 128-bit integer support, we can further delay the
+	 * need to use numeric variables.
+	 */
+	if (step >= 0)
+	{
+		int128		s_int128;
+		int128		r_int128;
+
+		s_int128 = s_int64;
+		r_int128 = r_int64;
+
+		/*
+		 * Iterations with src_ndigits <= 16:
+		 *
+		 * The result fits in an int128 (even though the input doesn't) so we
+		 * use int128 variables to avoid more expensive numeric computations.
+		 */
+		while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
+		{
+			int64		b;
+			int64		a0;
+			int64		a1;
+			int64		i;
+			int128		numer;
+			int128		denom;
+			int128		q;
+			int128		u;
+
+			blen = (src_ndigits - src_idx) / 2;
+
+			/* Extract a1 and a0, and compute b */
+			a0 = 0;
+			a1 = 0;
+			b = 1;
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				b *= NBASE;
+				a1 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a1 += arg->digits[src_idx];
+			}
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				a0 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a0 += arg->digits[src_idx];
+			}
+
+			/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+			numer = r_int128 * b + a1;
+			denom = 2 * s_int128;
+			q = numer / denom;
+			u = numer - q * denom;
+
+			/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+			s_int128 = s_int128 * b + q;
+			r_int128 = u * b + a0 - q * q;
+
+			if (r_int128 < 0)
+			{
+				/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+				r_int128 += 2 * s_int128 - 1;
+				s_int128--;
+			}
+
+			Assert(src_idx == src_ndigits);		/* All input digits consumed */
+			step--;
+		}
+
+		/*
+		 * All remaining iterations require numeric variables.  Convert the
+		 * integer values to NumericVar and continue.  Note that in the final
+		 * iteration we don't need the remainder, so we can save a few cycles
+		 * there by not fully computing it.
+		 */
+		int128_to_numericvar(s_int128, &s_var);
+		if (step >= 0)
+			int128_to_numericvar(r_int128, &r_var);
+	}
+	else
+	{
+		int64_to_numericvar(s_int64, &s_var);
+		if (step >= 0)
+			int64_to_numericvar(r_int64, &r_var);
+	}
+#else
+	int64_to_numericvar(s_int64, &s_var);
+	if (step >= 0)
+		int64_to_numericvar(r_int64, &r_var);
+#endif
+
+	/*
+	 * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
+	 * use numeric variables.
+	 */
+	while (step >= 0)
+	{
+		int			tmp_len;
+
+		src_ndigits = ndigits[step];
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0 */
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a1_var, tmp_len);
+			memcpy(a1_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a1_var.weight = blen - 1;
+			a1_var.sign = NUMERIC_POS;
+			a1_var.dscale = 0;
+			strip_var(&a1_var);
+		}
+		else
+		{
+			zero_var(&a1_var);
+			a1_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a0_var, tmp_len);
+			memcpy(a0_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a0_var.weight = blen - 1;
+			a0_var.sign = NUMERIC_POS;
+			a0_var.dscale = 0;
+			strip_var(&a0_var);
+		}
+		else
+		{
+			zero_var(&a0_var);
+			a0_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		set_var_from_var(&r_var, &q_var);
+		q_var.weight += blen;
+		add_var(&q_var, &a1_var, &q_var);
+		add_var(&s_var, &s_var, &u_var);
+		div_mod_var(&q_var, &u_var, &q_var, &u_var);
+
+		/* Compute s = s*b + q */
+		s_var.weight += blen;
+		add_var(&s_var, &q_var, &s_var);
+
+		/*
+		 * Compute r = u*b + a0 - q^2.
+		 *
+		 * In the final iteration, we don't actually need r, but we do need to
+		 * know whether it would have been negative, so that we know whether
+		 * to adjust s.
+		 */
+		u_var.weight += blen;
+		add_var(&u_var, &a0_var, &u_var);
+		mul_var(&q_var, &q_var, &q_var, 0);
+
+		if (step > 0)
+		{
+			/* Need r for later iterations */
+			sub_var(&u_var, &q_var, &r_var);
+			if (r_var.sign == NUMERIC_NEG)
+			{
+				/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+				add_var(&s_var, &s_var, &q_var);
+				add_var(&r_var, &q_var, &r_var);
+				sub_var(&r_var, &const_one, &r_var);
+				sub_var(&s_var, &const_one, &s_var);
+			}
+		}
+		else
+		{
+			/* Don't need r anymore, except to test if s is too large by 1 */
+			if (cmp_var(&u_var, &q_var) < 0)
+				sub_var(&s_var, &const_one, &s_var);
+		}
+
+		Assert(src_idx == src_ndigits);		/* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * Construct the final result, rounding it to the requested precision.
+	 */
+	set_var_from_var(&s_var, result);
+	result->weight = res_weight;
+	result->sign = NUMERIC_POS;
+
+	/* Round to target rscale (and set result->dscale) */
 	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+
+	free_var(&s_var);
+	free_var(&r_var);
+	free_var(&a0_var);
+	free_var(&a1_var);
+	free_var(&q_var);
+	free_var(&u_var);
 }
 
 
@@ -8530,12 +8968,18 @@ ln_var(const NumericVar *arg, NumericVar
 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
 	 * rscale as we work so that we keep this many significant digits at each
 	 * step (plus a few more for good measure).
+	 *
+	 * Note that we allow local_rscale < 0 during this input reduction
+	 * process, which implies rounding before the decimal point.  sqrt_var()
+	 * explicitly supports this, and it significantly reduces the work
+	 * required to reduce very large inputs to the required range.  Once the
+	 * input reduction is complete, x.weight will be 0 and its display scale
+	 * will be non-negative again.
 	 */
 	nsqrt = 0;
 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
@@ -8543,7 +8987,6 @@ ln_var(const NumericVar *arg, NumericVar
 	while (cmp_var(&x, &const_one_point_one) >= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out
new file mode 100644
index 23a4c6d..c7fe63d
--- a/src/test/regress/expected/numeric.out
+++ b/src/test/regress/expected/numeric.out
@@ -1580,6 +1580,57 @@ select div(12345678901234567890, 123) *
 (1 row)
 
 --
+-- Test some corner cases for square root
+--
+select sqrt(1.000000000000003::numeric);
+       sqrt        
+-------------------
+ 1.000000000000001
+(1 row)
+
+select sqrt(1.000000000000004::numeric);
+       sqrt        
+-------------------
+ 1.000000000000002
+(1 row)
+
+select sqrt(96627521408608.56340355805::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248648
+(1 row)
+
+select sqrt(96627521408608.56340355806::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248649
+(1 row)
+
+select sqrt(515549506212297735.073688290367::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921184
+(1 row)
+
+select sqrt(515549506212297735.073688290368::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921185
+(1 row)
+
+select sqrt(8015491789940783531003294973900306::numeric);
+       sqrt        
+-------------------
+ 89529278953540017
+(1 row)
+
+select sqrt(8015491789940783531003294973900307::numeric);
+       sqrt        
+-------------------
+ 89529278953540018
+(1 row)
+
+--
 -- Test code path for raising to integer powers
 --
 select 10.0 ^ -2147483648 as rounds_to_zero;
diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql
new file mode 100644
index c5c8d76..41475a9
--- a/src/test/regress/sql/numeric.sql
+++ b/src/test/regress/sql/numeric.sql
@@ -883,6 +883,19 @@ select div(12345678901234567890, 123);
 select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;
 
 --
+-- Test some corner cases for square root
+--
+
+select sqrt(1.000000000000003::numeric);
+select sqrt(1.000000000000004::numeric);
+select sqrt(96627521408608.56340355805::numeric);
+select sqrt(96627521408608.56340355806::numeric);
+select sqrt(515549506212297735.073688290367::numeric);
+select sqrt(515549506212297735.073688290368::numeric);
+select sqrt(8015491789940783531003294973900306::numeric);
+select sqrt(8015491789940783531003294973900307::numeric);
+
+--
 -- Test code path for raising to integer powers
 --
 
#3Tels
nospam-pg-abuse@bloodgate.com
In reply to: Dean Rasheed (#2)
1 attachment(s)
Re: Some improvements to numeric sqrt() and ln()

Dear Dean,

On 2020-03-01 20:47, Dean Rasheed wrote:

On Fri, 28 Feb 2020 at 08:15, Dean Rasheed <dean.a.rasheed@gmail.com>
wrote:

It's possible that there are further gains to be had in the sqrt()
algorithm on platforms that support 128-bit integers, but I haven't
had a chance to investigate that yet.

Rebased patch attached, now using 128-bit integers for part of
sqrt_var() on platforms that support them. This turned out to be well
worth it (1.5 to 2 times faster than the previous version if the
result has less than 30 or 40 digits).

Thank you for these patches, these sound like really nice improvements.
One thing can to my mind while reading the patch:

+	 *		If r < 0 Then
+	 *			Let r = r + 2*s - 1
+	 *			Let s = s - 1
+			/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+			r_int64 += 2 * s_int64 - 1;
+			s_int64--;

This can be reformulated as:

+	 *		If r < 0 Then
+	 *			Let r = r + s
+	 *			Let s = s - 1
+	 *			Let r = r + s
+			/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+			r_int64 += s_int64;
+			s_int64--;
+			r_int64 += s_int64;

which would remove one mul/shift and the temp. variable. Mind you, I
have
not benchmarked this, so it might make little difference, but maybe it
is
worth trying it.

Best regards,

Tels

Attachments:

numeric-sqrt-v2.patchtext/x-patch; charset=US-ASCII; name=numeric-sqrt-v2.patchDownload
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index 10229eb..9e6bb80
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -393,16 +393,6 @@ static const NumericVar const_ten =
 #endif
 
 #if DEC_DIGITS == 4
-static const NumericDigit const_zero_point_five_data[1] = {5000};
-#elif DEC_DIGITS == 2
-static const NumericDigit const_zero_point_five_data[1] = {50};
-#elif DEC_DIGITS == 1
-static const NumericDigit const_zero_point_five_data[1] = {5};
-#endif
-static const NumericVar const_zero_point_five =
-{1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
-
-#if DEC_DIGITS == 4
 static const NumericDigit const_zero_point_nine_data[1] = {9000};
 #elif DEC_DIGITS == 2
 static const NumericDigit const_zero_point_nine_data[1] = {90};
@@ -518,6 +508,8 @@ static void div_var_fast(const NumericVa
 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
 static void mod_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result);
+static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
+						NumericVar *quot, NumericVar *rem);
 static void ceil_var(const NumericVar *var, NumericVar *result);
 static void floor_var(const NumericVar *var, NumericVar *result);
 
@@ -7712,6 +7704,7 @@ div_var_fast(const NumericVar *var1, con
 			 NumericVar *result, int rscale, bool round)
 {
 	int			div_ndigits;
+	int			load_ndigits;
 	int			res_sign;
 	int			res_weight;
 	int		   *div;
@@ -7766,9 +7759,6 @@ div_var_fast(const NumericVar *var1, con
 	div_ndigits += DIV_GUARD_DIGITS;
 	if (div_ndigits < DIV_GUARD_DIGITS)
 		div_ndigits = DIV_GUARD_DIGITS;
-	/* Must be at least var1ndigits, too, to simplify data-loading loop */
-	if (div_ndigits < var1ndigits)
-		div_ndigits = var1ndigits;
 
 	/*
 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
@@ -7781,9 +7771,16 @@ div_var_fast(const NumericVar *var1, con
 	 * (approximate) quotient digit and stores it into div[], removing one
 	 * position of dividend space.  A final pass of carry propagation takes
 	 * care of any mistaken quotient digits.
+	 *
+	 * Note that div[] doesn't necessarily contain all of the digits from the
+	 * dividend --- the desired precision plus guard digits might be less than
+	 * the dividend's precision.  This happens, for example, in the square
+	 * root algorithm, where we typically divide a 2N-digit number by an
+	 * N-digit number, and only require a result with N digits of precision.
 	 */
 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
-	for (i = 0; i < var1ndigits; i++)
+	load_ndigits = Min(div_ndigits, var1ndigits);
+	for (i = 0; i < load_ndigits; i++)
 		div[i + 1] = var1digits[i];
 
 	/*
@@ -7844,9 +7841,15 @@ div_var_fast(const NumericVar *var1, con
 			maxdiv += Abs(qdigit);
 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
 			{
-				/* Yes, do it */
+				/*
+				 * Yes, do it.  Note that if var2ndigits is much smaller than
+				 * div_ndigits, we can save a significant amount of effort
+				 * here by noting that we only need to normalise those div[]
+				 * entries touched where prior iterations subtracted multiples
+				 * of the divisor.
+				 */
 				carry = 0;
-				for (i = div_ndigits; i > qi; i--)
+				for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
 				{
 					newdig = div[i] + carry;
 					if (newdig < 0)
@@ -8095,6 +8098,74 @@ mod_var(const NumericVar *var1, const Nu
 
 
 /*
+ * div_mod_var() -
+ *
+ *	Calculate the truncated integer quotient and numeric remainder of two
+ *	numeric variables.
+ */
+static void
+div_mod_var(const NumericVar *var1, const NumericVar *var2,
+			NumericVar *quot, NumericVar *rem)
+{
+	NumericVar	q;
+	NumericVar	r;
+
+	init_var(&q);
+	init_var(&r);
+
+	/*
+	 * Use div_var_fast() to get an initial estimate for the integer quotient.
+	 * In practice, this is almost always correct, but it is occasionally off
+	 * by one, which we can easily correct.
+	 */
+	div_var_fast(var1, var2, &q, 0, false);
+	mul_var(var2, &q, &r, var2->dscale);
+	sub_var(var1, &r, &r);
+
+	/*
+	 * Adjust the results if necessary --- the remainder should have the same
+	 * sign as var1, and its absolute value should be less than the absolute
+	 * value of var2.
+	 */
+	while (r.ndigits != 0 && r.sign != var1->sign)
+	{
+		/* The absolute value of the quotient is too large */
+		if (var1->sign == var2->sign)
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+		else
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+	}
+
+	while (cmp_abs(&r, var2) >= 0)
+	{
+		/* The absolute value of the quotient is too small */
+		if (var1->sign == var2->sign)
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+		else
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+	}
+
+	set_var_from_var(&q, quot);
+	set_var_from_var(&r, rem);
+
+	free_var(&q);
+	free_var(&r);
+}
+
+
+/*
  * ceil_var() -
  *
  *	Return the smallest integer greater than or equal to the argument
@@ -8213,18 +8284,30 @@ gcd_var(const NumericVar *var1, const Nu
 /*
  * sqrt_var() -
  *
- *	Compute the square root of x using Newton's algorithm
+ *	Compute the square root of x using the Karatsuba Square Root algorithm.
+ *	NOTE: we allow rscale < 0 here, implying rounding before the decimal
+ *	point.
  */
 static void
 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
 {
-	NumericVar	tmp_arg;
-	NumericVar	tmp_val;
-	NumericVar	last_val;
-	int			local_rscale;
 	int			stat;
-
-	local_rscale = rscale + 8;
+	int			res_weight;
+	int			res_ndigits;
+	int			src_ndigits;
+	int			step;
+	int			ndigits[32];
+	int			blen;
+	int64		arg_int64;
+	int			src_idx;
+	int64		s_int64;
+	int64		r_int64;
+	NumericVar	s_var;
+	NumericVar	r_var;
+	NumericVar	a0_var;
+	NumericVar	a1_var;
+	NumericVar	q_var;
+	NumericVar	u_var;
 
 	stat = cmp_var(arg, &const_zero);
 	if (stat == 0)
@@ -8243,43 +8326,398 @@ sqrt_var(const NumericVar *arg, NumericV
 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
 				 errmsg("cannot take square root of a negative number")));
 
-	init_var(&tmp_arg);
-	init_var(&tmp_val);
-	init_var(&last_val);
+	init_var(&s_var);
+	init_var(&r_var);
+	init_var(&a0_var);
+	init_var(&a1_var);
+	init_var(&q_var);
+	init_var(&u_var);
 
-	/* Copy arg in case it is the same var as result */
-	set_var_from_var(arg, &tmp_arg);
+	/*
+	 * The result weight is half the input weight, rounded towards minus
+	 * infinity.
+	 */
+	res_weight = (int) floor((double) arg->weight / 2);
 
 	/*
-	 * Initialize the result to the first guess
+	 * Number of NBASE digits to compute.  To ensure correct rounding, compute
+	 * at least 1 extra decimal digit.  We explicitly allow rscale to be
+	 * negative here, but must always compute at least 1 NBASE digit.
 	 */
-	alloc_var(result, 1);
-	result->digits[0] = tmp_arg.digits[0] / 2;
-	if (result->digits[0] == 0)
-		result->digits[0] = 1;
-	result->weight = tmp_arg.weight / 2;
-	result->sign = NUMERIC_POS;
+	res_ndigits = res_weight + 1 + (int) ceil((double) (rscale + 1) / DEC_DIGITS);
+	res_ndigits = Max(res_ndigits, 1);
 
-	set_var_from_var(result, &last_val);
+	/*
+	 * Number of source NBASE digits logically required to produce a result
+	 * with this precision --- every digit before the decimal point, plus 2
+	 * for each result digit after the decimal point (or minus 2 for each
+	 * result digit we round before the decimal point).
+	 */
+	src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
+	src_ndigits = Max(src_ndigits, 1);
 
-	for (;;)
+	/* ----------
+	 * From this point on, we treat the input and the result as integers and
+	 * compute the integer square root and remainder using the Karatusba
+	 * Square Root algorithm, which may be written recusively as follows:
+	 *
+	 *	SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
+	 *		[ for some base b, and coefficients a0,a1,a2,a3 chosen so that
+	 *		  0 <= a0,a1,a2 < b and a3 >= b/4 ]
+	 *		Let (s,r) = SqrtRem(a3*b + a2)
+	 *		Let (q,u) = DivRem(r*b + a1, 2*s)
+	 *		Let s = s*b + q
+	 *		Let r = u*b + a0 - q^2
+	 *		If r < 0 Then
+	 *			Let r = r + 2*s - 1
+	 *			Let s = s - 1
+	 *		Return (s,r)
+	 *
+	 * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
+	 * 3805, November 1999.
+	 *
+	 * Note that there is no upper bound on a3, and we allow it to be larger
+	 * than b (by choosing a smaller b) if necessary to ensure that the
+	 * condition a3 >= b/4 is met.  For optimal performance, b should be have
+	 * approximately a quarter the number of digits in the input, so that the
+	 * outer square root computes roughly twice as many digits as the inner
+	 * one.  For simplicity, we choose b = NBASE^blen, an integer power of
+	 * NBASE.
+	 *
+	 * We implement the algorithm iteratively rather than recursively, to
+	 * allow the working variables to be reused.  With this approach, each
+	 * digit of the input is read precisely once --- src_idx tracks the number
+	 * of input digits used so far.
+	 *
+	 * The array ndigits[] holds the number of NBASE digits of the input that
+	 * will have been used at the end of each iteration, which roughly doubles
+	 * each time.  Note that the array elements are stored in reverse order,
+	 * so if the final iteration requires src_ndigits = 37 input digits, the
+	 * array will contain [37,19,11,7,5,3], and we would start by computing
+	 * the square root of the 3 most significant NBASE digits.
+	 * ----------
+	 */
+	step = 0;
+	while ((ndigits[step] = src_ndigits) > 4)
 	{
-		div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
+		/* Choose b so that a3 >= b/4 */
+		blen = src_ndigits / 4;
+		if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
+			blen--;
 
-		add_var(result, &tmp_val, result);
-		mul_var(result, &const_zero_point_five, result, local_rscale);
+		/* Number of digits in the next step (inner square root) */
+		src_ndigits -= 2 * blen;
+		step++;
+	}
 
-		if (cmp_var(&last_val, result) == 0)
-			break;
-		set_var_from_var(result, &last_val);
+	/*
+	 * First iteration (innermost square root and remainder):
+	 *
+	 * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
+	 * has at most 9 decimal digits, so estimate it using double precision
+	 * arithmetic, which will in fact almost certainly return the correct
+	 * result with no further correction required.
+	 */
+	arg_int64 = arg->digits[0];
+	for (src_idx = 1; src_idx < src_ndigits; src_idx++)
+	{
+		arg_int64 *= NBASE;
+		if (src_idx < arg->ndigits)
+			arg_int64 += arg->digits[src_idx];
 	}
 
-	free_var(&last_val);
-	free_var(&tmp_val);
-	free_var(&tmp_arg);
+	s_int64 = (int64) sqrt((double) arg_int64);
+	r_int64 = arg_int64 - s_int64 * s_int64;
 
-	/* Round to requested precision */
+	/* Use Newton's method to correct the result, if necessary */
+	while (r_int64 < 0 || r_int64 > 2 * s_int64)
+	{
+		s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
+		r_int64 = arg_int64 - s_int64 * s_int64;
+	}
+
+	/*
+	 * Iterations with src_ndigits <= 8:
+	 *
+	 * The next 1 or 2 iterations compute larger (outer) square roots with
+	 * src_ndigits <= 8, so the result still fits in an int64 (even though the
+	 * input no longer does) and we can continue to compute using int64
+	 * variables to avoid more expensive numeric computations.
+	 *
+	 * It is fairly easy to see that there is no risk of the intermediate
+	 * values below overflowing 64-bit integers.  In the worst case, the
+	 * previous iteration will have computed a 3-digit square root (of a
+	 * 6-digit input less than NBASE^6 / 4), so at the start of this
+	 * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
+	 * less than 10^12.  In this case, blen will be 1, so numer will be less
+	 * than 10^17, and denom will be less than 10^12 (and hence u will also be
+	 * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
+	 * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
+	 * in 64-bit integers.
+	 */
+	step--;
+	while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
+	{
+		int			b;
+		int			a0;
+		int			a1;
+		int			i;
+		int64		numer;
+		int64		denom;
+		int64		q;
+		int64		u;
+
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0, and compute b */
+		a0 = 0;
+		a1 = 0;
+		b = 1;
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			b *= NBASE;
+			a1 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a1 += arg->digits[src_idx];
+		}
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			a0 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a0 += arg->digits[src_idx];
+		}
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		numer = r_int64 * b + a1;
+		denom = 2 * s_int64;
+		q = numer / denom;
+		u = numer - q * denom;
+
+		/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+		s_int64 = s_int64 * b + q;
+		r_int64 = u * b + a0 - q * q;
+
+		if (r_int64 < 0)
+		{
+			/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+			r_int64 += 2 * s_int64 - 1;
+			s_int64--;
+		}
+
+		Assert(src_idx == src_ndigits);		/* All input digits consumed */
+		step--;
+	}
+
+#ifdef HAVE_INT128
+	/*
+	 * On platforms with 128-bit integer support, we can further delay the
+	 * need to use numeric variables.
+	 */
+	if (step >= 0)
+	{
+		int128		s_int128;
+		int128		r_int128;
+
+		s_int128 = s_int64;
+		r_int128 = r_int64;
+
+		/*
+		 * Iterations with src_ndigits <= 16:
+		 *
+		 * The result fits in an int128 (even though the input doesn't) so we
+		 * use int128 variables to avoid more expensive numeric computations.
+		 */
+		while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
+		{
+			int64		b;
+			int64		a0;
+			int64		a1;
+			int64		i;
+			int128		numer;
+			int128		denom;
+			int128		q;
+			int128		u;
+
+			blen = (src_ndigits - src_idx) / 2;
+
+			/* Extract a1 and a0, and compute b */
+			a0 = 0;
+			a1 = 0;
+			b = 1;
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				b *= NBASE;
+				a1 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a1 += arg->digits[src_idx];
+			}
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				a0 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a0 += arg->digits[src_idx];
+			}
+
+			/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+			numer = r_int128 * b + a1;
+			denom = 2 * s_int128;
+			q = numer / denom;
+			u = numer - q * denom;
+
+			/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+			s_int128 = s_int128 * b + q;
+			r_int128 = u * b + a0 - q * q;
+
+			if (r_int128 < 0)
+			{
+				/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+				r_int128 += 2 * s_int128 - 1;
+				s_int128--;
+			}
+
+			Assert(src_idx == src_ndigits);		/* All input digits consumed */
+			step--;
+		}
+
+		/*
+		 * All remaining iterations require numeric variables.  Convert the
+		 * integer values to NumericVar and continue.  Note that in the final
+		 * iteration we don't need the remainder, so we can save a few cycles
+		 * there by not fully computing it.
+		 */
+		int128_to_numericvar(s_int128, &s_var);
+		if (step >= 0)
+			int128_to_numericvar(r_int128, &r_var);
+	}
+	else
+	{
+		int64_to_numericvar(s_int64, &s_var);
+		if (step >= 0)
+			int64_to_numericvar(r_int64, &r_var);
+	}
+#else
+	int64_to_numericvar(s_int64, &s_var);
+	if (step >= 0)
+		int64_to_numericvar(r_int64, &r_var);
+#endif
+
+	/*
+	 * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
+	 * use numeric variables.
+	 */
+	while (step >= 0)
+	{
+		int			tmp_len;
+
+		src_ndigits = ndigits[step];
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0 */
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a1_var, tmp_len);
+			memcpy(a1_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a1_var.weight = blen - 1;
+			a1_var.sign = NUMERIC_POS;
+			a1_var.dscale = 0;
+			strip_var(&a1_var);
+		}
+		else
+		{
+			zero_var(&a1_var);
+			a1_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a0_var, tmp_len);
+			memcpy(a0_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a0_var.weight = blen - 1;
+			a0_var.sign = NUMERIC_POS;
+			a0_var.dscale = 0;
+			strip_var(&a0_var);
+		}
+		else
+		{
+			zero_var(&a0_var);
+			a0_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		set_var_from_var(&r_var, &q_var);
+		q_var.weight += blen;
+		add_var(&q_var, &a1_var, &q_var);
+		add_var(&s_var, &s_var, &u_var);
+		div_mod_var(&q_var, &u_var, &q_var, &u_var);
+
+		/* Compute s = s*b + q */
+		s_var.weight += blen;
+		add_var(&s_var, &q_var, &s_var);
+
+		/*
+		 * Compute r = u*b + a0 - q^2.
+		 *
+		 * In the final iteration, we don't actually need r, but we do need to
+		 * know whether it would have been negative, so that we know whether
+		 * to adjust s.
+		 */
+		u_var.weight += blen;
+		add_var(&u_var, &a0_var, &u_var);
+		mul_var(&q_var, &q_var, &q_var, 0);
+
+		if (step > 0)
+		{
+			/* Need r for later iterations */
+			sub_var(&u_var, &q_var, &r_var);
+			if (r_var.sign == NUMERIC_NEG)
+			{
+				/* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+				add_var(&s_var, &s_var, &q_var);
+				add_var(&r_var, &q_var, &r_var);
+				sub_var(&r_var, &const_one, &r_var);
+				sub_var(&s_var, &const_one, &s_var);
+			}
+		}
+		else
+		{
+			/* Don't need r anymore, except to test if s is too large by 1 */
+			if (cmp_var(&u_var, &q_var) < 0)
+				sub_var(&s_var, &const_one, &s_var);
+		}
+
+		Assert(src_idx == src_ndigits);		/* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * Construct the final result, rounding it to the requested precision.
+	 */
+	set_var_from_var(&s_var, result);
+	result->weight = res_weight;
+	result->sign = NUMERIC_POS;
+
+	/* Round to target rscale (and set result->dscale) */
 	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+
+	free_var(&s_var);
+	free_var(&r_var);
+	free_var(&a0_var);
+	free_var(&a1_var);
+	free_var(&q_var);
+	free_var(&u_var);
 }
 
 
@@ -8530,12 +8968,18 @@ ln_var(const NumericVar *arg, NumericVar
 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
 	 * rscale as we work so that we keep this many significant digits at each
 	 * step (plus a few more for good measure).
+	 *
+	 * Note that we allow local_rscale < 0 during this input reduction
+	 * process, which implies rounding before the decimal point.  sqrt_var()
+	 * explicitly supports this, and it significantly reduces the work
+	 * required to reduce very large inputs to the required range.  Once the
+	 * input reduction is complete, x.weight will be 0 and its display scale
+	 * will be non-negative again.
 	 */
 	nsqrt = 0;
 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
@@ -8543,7 +8987,6 @@ ln_var(const NumericVar *arg, NumericVar
 	while (cmp_var(&x, &const_one_point_one) >= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out
new file mode 100644
index 23a4c6d..c7fe63d
--- a/src/test/regress/expected/numeric.out
+++ b/src/test/regress/expected/numeric.out
@@ -1580,6 +1580,57 @@ select div(12345678901234567890, 123) *
 (1 row)
 
 --
+-- Test some corner cases for square root
+--
+select sqrt(1.000000000000003::numeric);
+       sqrt        
+-------------------
+ 1.000000000000001
+(1 row)
+
+select sqrt(1.000000000000004::numeric);
+       sqrt        
+-------------------
+ 1.000000000000002
+(1 row)
+
+select sqrt(96627521408608.56340355805::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248648
+(1 row)
+
+select sqrt(96627521408608.56340355806::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248649
+(1 row)
+
+select sqrt(515549506212297735.073688290367::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921184
+(1 row)
+
+select sqrt(515549506212297735.073688290368::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921185
+(1 row)
+
+select sqrt(8015491789940783531003294973900306::numeric);
+       sqrt        
+-------------------
+ 89529278953540017
+(1 row)
+
+select sqrt(8015491789940783531003294973900307::numeric);
+       sqrt        
+-------------------
+ 89529278953540018
+(1 row)
+
+--
 -- Test code path for raising to integer powers
 --
 select 10.0 ^ -2147483648 as rounds_to_zero;
diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql
new file mode 100644
index c5c8d76..41475a9
--- a/src/test/regress/sql/numeric.sql
+++ b/src/test/regress/sql/numeric.sql
@@ -883,6 +883,19 @@ select div(12345678901234567890, 123);
 select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;
 
 --
+-- Test some corner cases for square root
+--
+
+select sqrt(1.000000000000003::numeric);
+select sqrt(1.000000000000004::numeric);
+select sqrt(96627521408608.56340355805::numeric);
+select sqrt(96627521408608.56340355806::numeric);
+select sqrt(515549506212297735.073688290367::numeric);
+select sqrt(515549506212297735.073688290368::numeric);
+select sqrt(8015491789940783531003294973900306::numeric);
+select sqrt(8015491789940783531003294973900307::numeric);
+
+--
 -- Test code path for raising to integer powers
 --
 
#4Dean Rasheed
dean.a.rasheed@gmail.com
In reply to: Tels (#3)
Re: Some improvements to numeric sqrt() and ln()

On Tue, 3 Mar 2020 at 00:17, Tels <nospam-pg-abuse@bloodgate.com> wrote:

Thank you for these patches, these sound like really nice improvements.

Thanks for looking!

One thing can to my mind while reading the patch:

+        *              If r < 0 Then
+        *                      Let r = r + 2*s - 1
+        *                      Let s = s - 1

This can be reformulated as:

+        *              If r < 0 Then
+        *                      Let r = r + s
+        *                      Let s = s - 1
+        *                      Let r = r + s

which would remove one mul/shift and the temp. variable.

Good point, that's a neat little optimisation.

I wasn't able to detect any difference in performance, because those
corrections are only triggered about 1 time in every 50 or so, but it
looks neater to me, especially in the numeric iterations, where it
saves a sub_var() by const_one as well as not using the temporary
variable.

Regards,
Dean

#5David Steele
david@pgmasters.net
In reply to: Dean Rasheed (#1)
Re: Some improvements to numeric sqrt() and ln()

Hi Dean,

On 2/28/20 3:15 AM, Dean Rasheed wrote:

Attached is a WIP patch to improve the performance of numeric sqrt()
and ln(), which also makes a couple of related improvements to
div_var_fast(), all of which have knock-on benefits for other numeric
functions. The actual impact varies greatly depending on the inputs,
but the overall effect is to reduce the run time of the numeric_big
regression test by about 20%.

Are these improvements targeted at PG13 or PG14? This seems a pretty
big change for the last CF of PG13.

Regards,
--
-David
david@pgmasters.net

#6Dean Rasheed
dean.a.rasheed@gmail.com
In reply to: David Steele (#5)
Re: Some improvements to numeric sqrt() and ln()

On Wed, 4 Mar 2020 at 14:41, David Steele <david@pgmasters.net> wrote:

Are these improvements targeted at PG13 or PG14? This seems a pretty
big change for the last CF of PG13.

Well of course that's not entirely up to me, but I was hoping to
commit it for PG13.

It's very well covered by a large number of regression tests in both
numeric.sql and numeric_big.sql, since nearly anything that calls
ln(), log() or pow() ends up going through sqrt_var(). Also, the
changes are local to functions in numeric.c, which makes them easy to
revert if something proves to be wrong.

Regards,
Dean

#7Tom Lane
tgl@sss.pgh.pa.us
In reply to: Dean Rasheed (#6)
Re: Some improvements to numeric sqrt() and ln()

Dean Rasheed <dean.a.rasheed@gmail.com> writes:

On Wed, 4 Mar 2020 at 14:41, David Steele <david@pgmasters.net> wrote:

Are these improvements targeted at PG13 or PG14? This seems a pretty
big change for the last CF of PG13.

Well of course that's not entirely up to me, but I was hoping to
commit it for PG13.

It's very well covered by a large number of regression tests in both
numeric.sql and numeric_big.sql, since nearly anything that calls
ln(), log() or pow() ends up going through sqrt_var(). Also, the
changes are local to functions in numeric.c, which makes them easy to
revert if something proves to be wrong.

FWIW, I agree that this is a reasonable thing to consider committing
for v13. It's not adding any new user-visible behavior, so there's
no definitional issues to quibble over, which is usually what I worry
about regretting after an overly-hasty commit. And it's only touching
a few functions in one file, so even if the patch is a bit long, the
complexity seems pretty well controlled.

I've not read the patch in detail so this isn't meant as a review,
but from a process standpoint I see no reason not to go forward.

regards, tom lane

#8Tom Lane
tgl@sss.pgh.pa.us
In reply to: Tels (#3)
1 attachment(s)
Re: Some improvements to numeric sqrt() and ln()

Tels <nospam-pg-abuse@bloodgate.com> writes:

This can be reformulated as:
+	 *		If r < 0 Then
+	 *			Let r = r + s
+	 *			Let s = s - 1
+	 *			Let r = r + s

Here's a v3 that

* incorporates Tels' idea;

* improves some of the comments (IMO anyway, though some are clear typos);

* adds some XXX comments about things that could be further improved
and/or need better explanations.

I also ran it through pgindent, just cause I'm like that.

With resolutions of the XXX items, I think this'd be committable.

regards, tom lane

Attachments:

numeric-sqrt-v3.patchtext/x-diff; charset=us-ascii; name=numeric-sqrt-v3.patchDownload
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index 10229eb..afbc2b0 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -393,16 +393,6 @@ static const NumericVar const_ten =
 #endif
 
 #if DEC_DIGITS == 4
-static const NumericDigit const_zero_point_five_data[1] = {5000};
-#elif DEC_DIGITS == 2
-static const NumericDigit const_zero_point_five_data[1] = {50};
-#elif DEC_DIGITS == 1
-static const NumericDigit const_zero_point_five_data[1] = {5};
-#endif
-static const NumericVar const_zero_point_five =
-{1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
-
-#if DEC_DIGITS == 4
 static const NumericDigit const_zero_point_nine_data[1] = {9000};
 #elif DEC_DIGITS == 2
 static const NumericDigit const_zero_point_nine_data[1] = {90};
@@ -518,6 +508,8 @@ static void div_var_fast(const NumericVar *var1, const NumericVar *var2,
 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
 static void mod_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result);
+static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
+						NumericVar *quot, NumericVar *rem);
 static void ceil_var(const NumericVar *var, NumericVar *result);
 static void floor_var(const NumericVar *var, NumericVar *result);
 
@@ -7712,6 +7704,7 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
 			 NumericVar *result, int rscale, bool round)
 {
 	int			div_ndigits;
+	int			load_ndigits;
 	int			res_sign;
 	int			res_weight;
 	int		   *div;
@@ -7766,9 +7759,6 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
 	div_ndigits += DIV_GUARD_DIGITS;
 	if (div_ndigits < DIV_GUARD_DIGITS)
 		div_ndigits = DIV_GUARD_DIGITS;
-	/* Must be at least var1ndigits, too, to simplify data-loading loop */
-	if (div_ndigits < var1ndigits)
-		div_ndigits = var1ndigits;
 
 	/*
 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
@@ -7781,9 +7771,16 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
 	 * (approximate) quotient digit and stores it into div[], removing one
 	 * position of dividend space.  A final pass of carry propagation takes
 	 * care of any mistaken quotient digits.
+	 *
+	 * Note that div[] doesn't necessarily contain all of the digits from the
+	 * dividend --- the desired precision plus guard digits might be less than
+	 * the dividend's precision.  This happens, for example, in the square
+	 * root algorithm, where we typically divide a 2N-digit number by an
+	 * N-digit number, and only require a result with N digits of precision.
 	 */
 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
-	for (i = 0; i < var1ndigits; i++)
+	load_ndigits = Min(div_ndigits, var1ndigits);
+	for (i = 0; i < load_ndigits; i++)
 		div[i + 1] = var1digits[i];
 
 	/*
@@ -7844,9 +7841,15 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
 			maxdiv += Abs(qdigit);
 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
 			{
-				/* Yes, do it */
+				/*
+				 * Yes, do it.  Note that if var2ndigits is much smaller than
+				 * div_ndigits, we can save a significant amount of effort
+				 * here by noting that we only need to normalise those div[]
+				 * entries touched where prior iterations subtracted multiples
+				 * of the divisor.
+				 */
 				carry = 0;
-				for (i = div_ndigits; i > qi; i--)
+				for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
 				{
 					newdig = div[i] + carry;
 					if (newdig < 0)
@@ -8095,6 +8098,76 @@ mod_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
 
 
 /*
+ * div_mod_var() -
+ *
+ *	Calculate the truncated integer quotient and numeric remainder of two
+ *	numeric variables.  The remainder is precise to var2's dscale.
+ */
+static void
+div_mod_var(const NumericVar *var1, const NumericVar *var2,
+			NumericVar *quot, NumericVar *rem)
+{
+	NumericVar	q;
+	NumericVar	r;
+
+	init_var(&q);
+	init_var(&r);
+
+	/*
+	 * Use div_var_fast() to get an initial estimate for the integer quotient.
+	 * This might be inaccurate (per the warning in div_var_fast's comments),
+	 * but we can correct it below.
+	 */
+	div_var_fast(var1, var2, &q, 0, false);
+
+	/* Compute initial estimate of remainder using the quotient estimate. */
+	mul_var(var2, &q, &r, var2->dscale);
+	sub_var(var1, &r, &r);
+
+	/*
+	 * Adjust the results if necessary --- the remainder should have the same
+	 * sign as var1, and its absolute value should be less than the absolute
+	 * value of var2.
+	 */
+	while (r.ndigits != 0 && r.sign != var1->sign)
+	{
+		/* The absolute value of the quotient is too large */
+		if (var1->sign == var2->sign)
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+		else
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+	}
+
+	while (cmp_abs(&r, var2) >= 0)
+	{
+		/* The absolute value of the quotient is too small */
+		if (var1->sign == var2->sign)
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+		else
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+	}
+
+	set_var_from_var(&q, quot);
+	set_var_from_var(&r, rem);
+
+	free_var(&q);
+	free_var(&r);
+}
+
+
+/*
  * ceil_var() -
  *
  *	Return the smallest integer greater than or equal to the argument
@@ -8213,18 +8286,30 @@ gcd_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
 /*
  * sqrt_var() -
  *
- *	Compute the square root of x using Newton's algorithm
+ *	Compute the square root of x using the Karatsuba Square Root algorithm.
+ *	NOTE: we allow rscale < 0 here, implying rounding before the decimal
+ *	point.
  */
 static void
 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
 {
-	NumericVar	tmp_arg;
-	NumericVar	tmp_val;
-	NumericVar	last_val;
-	int			local_rscale;
 	int			stat;
-
-	local_rscale = rscale + 8;
+	int			res_weight;
+	int			res_ndigits;
+	int			src_ndigits;
+	int			step;
+	int			ndigits[32];
+	int			blen;
+	int64		arg_int64;
+	int			src_idx;
+	int64		s_int64;
+	int64		r_int64;
+	NumericVar	s_var;
+	NumericVar	r_var;
+	NumericVar	a0_var;
+	NumericVar	a1_var;
+	NumericVar	q_var;
+	NumericVar	u_var;
 
 	stat = cmp_var(arg, &const_zero);
 	if (stat == 0)
@@ -8243,43 +8328,412 @@ sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
 				 errmsg("cannot take square root of a negative number")));
 
-	init_var(&tmp_arg);
-	init_var(&tmp_val);
-	init_var(&last_val);
+	init_var(&s_var);
+	init_var(&r_var);
+	init_var(&a0_var);
+	init_var(&a1_var);
+	init_var(&q_var);
+	init_var(&u_var);
 
-	/* Copy arg in case it is the same var as result */
-	set_var_from_var(arg, &tmp_arg);
+	/*
+	 * The result weight is half the input weight, rounded towards minus
+	 * infinity.
+	 *
+	 * XXX do we really need floor(double) for that, rather than plain integer
+	 * math?
+	 */
+	res_weight = (int) floor((double) arg->weight / 2);
 
 	/*
-	 * Initialize the result to the first guess
+	 * Number of NBASE digits to compute.  To ensure correct rounding, compute
+	 * at least 1 extra decimal digit.  We explicitly allow rscale to be
+	 * negative here, but must always compute at least 1 NBASE digit.
+	 *
+	 * XXX likewise seems like ceil(double) is unnecessary expense.
 	 */
-	alloc_var(result, 1);
-	result->digits[0] = tmp_arg.digits[0] / 2;
-	if (result->digits[0] == 0)
-		result->digits[0] = 1;
-	result->weight = tmp_arg.weight / 2;
-	result->sign = NUMERIC_POS;
+	res_ndigits = res_weight + 1 + (int) ceil((double) (rscale + 1) / DEC_DIGITS);
+	res_ndigits = Max(res_ndigits, 1);
 
-	set_var_from_var(result, &last_val);
+	/*
+	 * Number of source NBASE digits logically required to produce a result
+	 * with this precision --- every digit before the decimal point, plus 2
+	 * for each result digit after the decimal point (or minus 2 for each
+	 * result digit we round before the decimal point).
+	 */
+	src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
+	src_ndigits = Max(src_ndigits, 1);
 
-	for (;;)
+	/* ----------
+	 * From this point on, we treat the input and the result as integers and
+	 * compute the integer square root and remainder using the Karatsuba
+	 * Square Root algorithm, which may be written recursively as follows:
+	 *
+	 *	SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
+	 *		[ for some base b, and coefficients a0,a1,a2,a3 chosen so that
+	 *		  0 <= a0,a1,a2 < b and a3 >= b/4 ]
+	 *		Let (s,r) = SqrtRem(a3*b + a2)
+	 *		Let (q,u) = DivRem(r*b + a1, 2*s)
+	 *		Let s = s*b + q
+	 *		Let r = u*b + a0 - q^2
+	 *		If r < 0 Then
+	 *			Let r = r + s
+	 *			Let s = s - 1
+	 *			Let r = r + s
+	 *		Return (s,r)
+	 *
+	 * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
+	 * RR-3805, November 1999.  At the time of writing this was available
+	 * on the net at <https://hal.inria.fr/inria-00072854>.
+	 *
+	 * The way to read the assumption "n = a3*b^3 + a2*b^2 + a1*b + a0" is
+	 * "choose a base b such that n requires at least four base-b digits to
+	 * express; then those digits are a3,a2,a1,a0, with a3 possibly larger
+	 * than b".  For optimal performance, b should have approximately a
+	 * quarter the number of digits in the input, so that the outer square
+	 * root computes roughly twice as many digits as the inner one.  For
+	 * simplicity, we choose b = NBASE^blen, an integer power of NBASE.
+	 *
+	 * We implement the algorithm iteratively rather than recursively, to
+	 * allow the working variables to be reused.  With this approach, each
+	 * digit of the input is read precisely once --- src_idx tracks the number
+	 * of input digits used so far.
+	 *
+	 * The array ndigits[] holds the number of NBASE digits of the input that
+	 * will have been used at the end of each iteration, which roughly doubles
+	 * each time.  Note that the array elements are stored in reverse order,
+	 * so if the final iteration requires src_ndigits = 37 input digits, the
+	 * array will contain [37,19,11,7,5,3], and we would start by computing
+	 * the square root of the 3 most significant NBASE digits.
+	 *
+	 * XXX I don't understand how this works.  Why is it correct to consider
+	 * arg->digits[0] at every step?  Can we prove rigorously that the ndigits
+	 * array won't be overrun?  (I can see that src_ndigits is roughly halved
+	 * by each iteration, but only roughly, so it's not entirely clear that
+	 * the worst-case situation couldn't involve more than 31 steps.)
+	 * ----------
+	 */
+	step = 0;
+	while ((ndigits[step] = src_ndigits) > 4)
 	{
-		div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
+		/* Choose b so that a3 >= b/4 */
+		blen = src_ndigits / 4;
+		if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
+			blen--;
 
-		add_var(result, &tmp_val, result);
-		mul_var(result, &const_zero_point_five, result, local_rscale);
+		/* Number of digits in the next step (inner square root) */
+		src_ndigits -= 2 * blen;
+		step++;
+	}
 
-		if (cmp_var(&last_val, result) == 0)
-			break;
-		set_var_from_var(result, &last_val);
+	/*
+	 * First iteration (innermost square root and remainder):
+	 *
+	 * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
+	 * has at most 9 decimal digits, so estimate it using double precision
+	 * arithmetic, which will in fact almost certainly return the correct
+	 * result with no further correction required.
+	 */
+	arg_int64 = arg->digits[0];
+	for (src_idx = 1; src_idx < src_ndigits; src_idx++)
+	{
+		arg_int64 *= NBASE;
+		if (src_idx < arg->ndigits)
+			arg_int64 += arg->digits[src_idx];
 	}
 
-	free_var(&last_val);
-	free_var(&tmp_val);
-	free_var(&tmp_arg);
+	s_int64 = (int64) sqrt((double) arg_int64);
+	r_int64 = arg_int64 - s_int64 * s_int64;
+
+	/* Use Newton's method to correct the result, if necessary */
+	/* XXX is this guaranteed to converge?  integer division truncates... */
+	while (r_int64 < 0 || r_int64 > 2 * s_int64)
+	{
+		s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
+		r_int64 = arg_int64 - s_int64 * s_int64;
+	}
+
+	/*
+	 * Iterations with src_ndigits <= 8:
+	 *
+	 * The next 1 or 2 iterations compute larger (outer) square roots with
+	 * src_ndigits <= 8, so the result still fits in an int64 (even though the
+	 * input no longer does) and we can continue to compute using int64
+	 * variables to avoid more expensive numeric computations.
+	 *
+	 * It is fairly easy to see that there is no risk of the intermediate
+	 * values below overflowing 64-bit integers.  In the worst case, the
+	 * previous iteration will have computed a 3-digit square root (of a
+	 * 6-digit input less than NBASE^6 / 4), so at the start of this
+	 * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
+	 * less than 10^12.  In this case, blen will be 1, so numer will be less
+	 * than 10^17, and denom will be less than 10^12 (and hence u will also be
+	 * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
+	 * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
+	 * in 64-bit integers.
+	 */
+	step--;
+	while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
+	{
+		int			b;
+		int			a0;
+		int			a1;
+		int			i;
+		int64		numer;
+		int64		denom;
+		int64		q;
+		int64		u;
+
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0, and compute b */
+		a0 = 0;
+		a1 = 0;
+		b = 1;
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			b *= NBASE;
+			a1 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a1 += arg->digits[src_idx];
+		}
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			a0 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a0 += arg->digits[src_idx];
+		}
 
-	/* Round to requested precision */
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		numer = r_int64 * b + a1;
+		denom = 2 * s_int64;
+		q = numer / denom;
+		u = numer - q * denom;
+
+		/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+		s_int64 = s_int64 * b + q;
+		r_int64 = u * b + a0 - q * q;
+
+		if (r_int64 < 0)
+		{
+			/* s is too large by 1; set r += s, s--, r += s */
+			r_int64 += s_int64;
+			s_int64--;
+			r_int64 += s_int64;
+		}
+
+		Assert(src_idx == src_ndigits); /* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * On platforms with 128-bit integer support, we can further delay the
+	 * need to use numeric variables.
+	 */
+#ifdef HAVE_INT128
+	if (step >= 0)
+	{
+		int128		s_int128;
+		int128		r_int128;
+
+		s_int128 = s_int64;
+		r_int128 = r_int64;
+
+		/*
+		 * Iterations with src_ndigits <= 16:
+		 *
+		 * The result fits in an int128 (even though the input doesn't) so we
+		 * use int128 variables to avoid more expensive numeric computations.
+		 */
+		while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
+		{
+			int64		b;
+			int64		a0;
+			int64		a1;
+			int64		i;
+			int128		numer;
+			int128		denom;
+			int128		q;
+			int128		u;
+
+			blen = (src_ndigits - src_idx) / 2;
+
+			/* Extract a1 and a0, and compute b */
+			a0 = 0;
+			a1 = 0;
+			b = 1;
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				b *= NBASE;
+				a1 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a1 += arg->digits[src_idx];
+			}
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				a0 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a0 += arg->digits[src_idx];
+			}
+
+			/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+			numer = r_int128 * b + a1;
+			denom = 2 * s_int128;
+			q = numer / denom;
+			u = numer - q * denom;
+
+			/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+			s_int128 = s_int128 * b + q;
+			r_int128 = u * b + a0 - q * q;
+
+			if (r_int128 < 0)
+			{
+				/* s is too large by 1; set r += s, s--, r += s */
+				r_int128 += s_int128;
+				s_int128--;
+				r_int128 += s_int128;
+			}
+
+			Assert(src_idx == src_ndigits); /* All input digits consumed */
+			step--;
+		}
+
+		/*
+		 * All remaining iterations require numeric variables.  Convert the
+		 * integer values to NumericVar and continue.  Note that in the final
+		 * iteration we don't need the remainder, so we can save a few cycles
+		 * there by not fully computing it.
+		 */
+		int128_to_numericvar(s_int128, &s_var);
+		if (step >= 0)
+			int128_to_numericvar(r_int128, &r_var);
+	}
+	else
+	{
+		int64_to_numericvar(s_int64, &s_var);
+		/* step < 0, so we certainly don't need r */
+	}
+#else							/* !HAVE_INT128 */
+	int64_to_numericvar(s_int64, &s_var);
+	if (step >= 0)
+		int64_to_numericvar(r_int64, &r_var);
+#endif							/* HAVE_INT128 */
+
+	/*
+	 * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
+	 * use numeric variables.
+	 */
+	while (step >= 0)
+	{
+		int			tmp_len;
+
+		src_ndigits = ndigits[step];
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0 */
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a1_var, tmp_len);
+			memcpy(a1_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a1_var.weight = blen - 1;
+			a1_var.sign = NUMERIC_POS;
+			a1_var.dscale = 0;
+			strip_var(&a1_var);
+		}
+		else
+		{
+			zero_var(&a1_var);
+			a1_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a0_var, tmp_len);
+			memcpy(a0_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a0_var.weight = blen - 1;
+			a0_var.sign = NUMERIC_POS;
+			a0_var.dscale = 0;
+			strip_var(&a0_var);
+		}
+		else
+		{
+			zero_var(&a0_var);
+			a0_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		set_var_from_var(&r_var, &q_var);
+		q_var.weight += blen;
+		add_var(&q_var, &a1_var, &q_var);
+		add_var(&s_var, &s_var, &u_var);
+		div_mod_var(&q_var, &u_var, &q_var, &u_var);
+
+		/* Compute s = s*b + q */
+		s_var.weight += blen;
+		add_var(&s_var, &q_var, &s_var);
+
+		/*
+		 * Compute r = u*b + a0 - q^2.
+		 *
+		 * In the final iteration, we don't actually need r; we just need to
+		 * know whether it is negative, so that we know whether to adjust s.
+		 * So instead of the final subtraction we can just compare.
+		 */
+		u_var.weight += blen;
+		add_var(&u_var, &a0_var, &u_var);
+		mul_var(&q_var, &q_var, &q_var, 0);
+
+		if (step > 0)
+		{
+			/* Need r for later iterations */
+			sub_var(&u_var, &q_var, &r_var);
+			if (r_var.sign == NUMERIC_NEG)
+			{
+				/* s is too large by 1; set r += s, s--, r += s */
+				add_var(&r_var, &s_var, &r_var);
+				sub_var(&s_var, &const_one, &s_var);
+				add_var(&r_var, &s_var, &r_var);
+			}
+		}
+		else
+		{
+			/* Don't need r anymore, except to test if s is too large by 1 */
+			if (cmp_var(&u_var, &q_var) < 0)
+				sub_var(&s_var, &const_one, &s_var);
+		}
+
+		Assert(src_idx == src_ndigits); /* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * Construct the final result, rounding it to the requested precision.
+	 */
+	set_var_from_var(&s_var, result);
+	result->weight = res_weight;
+	result->sign = NUMERIC_POS;
+
+	/* Round to target rscale (and set result->dscale) */
 	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+
+	free_var(&s_var);
+	free_var(&r_var);
+	free_var(&a0_var);
+	free_var(&a1_var);
+	free_var(&q_var);
+	free_var(&u_var);
 }
 
 
@@ -8530,12 +8984,18 @@ ln_var(const NumericVar *arg, NumericVar *result, int rscale)
 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
 	 * rscale as we work so that we keep this many significant digits at each
 	 * step (plus a few more for good measure).
+	 *
+	 * Note that we allow local_rscale < 0 during this input reduction
+	 * process, which implies rounding before the decimal point.  sqrt_var()
+	 * explicitly supports this, and it significantly reduces the work
+	 * required to reduce very large inputs to the required range.  Once the
+	 * input reduction is complete, x.weight will be 0 and its display scale
+	 * will be non-negative again.
 	 */
 	nsqrt = 0;
 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
@@ -8543,7 +9003,6 @@ ln_var(const NumericVar *arg, NumericVar *result, int rscale)
 	while (cmp_var(&x, &const_one_point_one) >= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out
index 23a4c6d..c7fe63d 100644
--- a/src/test/regress/expected/numeric.out
+++ b/src/test/regress/expected/numeric.out
@@ -1580,6 +1580,57 @@ select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;
 (1 row)
 
 --
+-- Test some corner cases for square root
+--
+select sqrt(1.000000000000003::numeric);
+       sqrt        
+-------------------
+ 1.000000000000001
+(1 row)
+
+select sqrt(1.000000000000004::numeric);
+       sqrt        
+-------------------
+ 1.000000000000002
+(1 row)
+
+select sqrt(96627521408608.56340355805::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248648
+(1 row)
+
+select sqrt(96627521408608.56340355806::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248649
+(1 row)
+
+select sqrt(515549506212297735.073688290367::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921184
+(1 row)
+
+select sqrt(515549506212297735.073688290368::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921185
+(1 row)
+
+select sqrt(8015491789940783531003294973900306::numeric);
+       sqrt        
+-------------------
+ 89529278953540017
+(1 row)
+
+select sqrt(8015491789940783531003294973900307::numeric);
+       sqrt        
+-------------------
+ 89529278953540018
+(1 row)
+
+--
 -- Test code path for raising to integer powers
 --
 select 10.0 ^ -2147483648 as rounds_to_zero;
diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql
index c5c8d76..41475a9 100644
--- a/src/test/regress/sql/numeric.sql
+++ b/src/test/regress/sql/numeric.sql
@@ -883,6 +883,19 @@ select div(12345678901234567890, 123);
 select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;
 
 --
+-- Test some corner cases for square root
+--
+
+select sqrt(1.000000000000003::numeric);
+select sqrt(1.000000000000004::numeric);
+select sqrt(96627521408608.56340355805::numeric);
+select sqrt(96627521408608.56340355806::numeric);
+select sqrt(515549506212297735.073688290367::numeric);
+select sqrt(515549506212297735.073688290368::numeric);
+select sqrt(8015491789940783531003294973900306::numeric);
+select sqrt(8015491789940783531003294973900307::numeric);
+
+--
 -- Test code path for raising to integer powers
 --
 
#9Dean Rasheed
dean.a.rasheed@gmail.com
In reply to: Tom Lane (#8)
1 attachment(s)
Re: Some improvements to numeric sqrt() and ln()

On Sun, 22 Mar 2020 at 22:16, Tom Lane <tgl@sss.pgh.pa.us> wrote:

With resolutions of the XXX items, I think this'd be committable.

Thanks for looking at this!

Here is an updated patch with the following updates based on your comments:

* Now uses integer arithmetic to compute res_weight and res_ndigits,
instead of floor() and ceil().

* New comment giving a more detailed explanation of how blen is
chosen, and why it must sometimes examine the first digit of the input
and reduce blen by 1 (which can occur at any step, as shown in the
example given).

* New comment giving a proof that the number of steps required is
guaranteed to be less than 32.

* New comment explaining why the initial integer square root using
Newton's method is guaranteed to converge. I couldn't find a formal
reference for this, but there's a Wikipedia article on it -
https://en.wikipedia.org/wiki/Integer_square_root and I think it's a
well-known result in the field.

Regards,
Dean

Attachments:

numeric-sqrt-v4.patchtext/x-patch; charset=US-ASCII; name=numeric-sqrt-v4.patchDownload
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index 10229eb..9986132
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -393,16 +393,6 @@ static const NumericVar const_ten =
 #endif
 
 #if DEC_DIGITS == 4
-static const NumericDigit const_zero_point_five_data[1] = {5000};
-#elif DEC_DIGITS == 2
-static const NumericDigit const_zero_point_five_data[1] = {50};
-#elif DEC_DIGITS == 1
-static const NumericDigit const_zero_point_five_data[1] = {5};
-#endif
-static const NumericVar const_zero_point_five =
-{1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
-
-#if DEC_DIGITS == 4
 static const NumericDigit const_zero_point_nine_data[1] = {9000};
 #elif DEC_DIGITS == 2
 static const NumericDigit const_zero_point_nine_data[1] = {90};
@@ -518,6 +508,8 @@ static void div_var_fast(const NumericVa
 static int	select_div_scale(const NumericVar *var1, const NumericVar *var2);
 static void mod_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result);
+static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
+						NumericVar *quot, NumericVar *rem);
 static void ceil_var(const NumericVar *var, NumericVar *result);
 static void floor_var(const NumericVar *var, NumericVar *result);
 
@@ -7712,6 +7704,7 @@ div_var_fast(const NumericVar *var1, con
 			 NumericVar *result, int rscale, bool round)
 {
 	int			div_ndigits;
+	int			load_ndigits;
 	int			res_sign;
 	int			res_weight;
 	int		   *div;
@@ -7766,9 +7759,6 @@ div_var_fast(const NumericVar *var1, con
 	div_ndigits += DIV_GUARD_DIGITS;
 	if (div_ndigits < DIV_GUARD_DIGITS)
 		div_ndigits = DIV_GUARD_DIGITS;
-	/* Must be at least var1ndigits, too, to simplify data-loading loop */
-	if (div_ndigits < var1ndigits)
-		div_ndigits = var1ndigits;
 
 	/*
 	 * We do the arithmetic in an array "div[]" of signed int's.  Since
@@ -7781,9 +7771,16 @@ div_var_fast(const NumericVar *var1, con
 	 * (approximate) quotient digit and stores it into div[], removing one
 	 * position of dividend space.  A final pass of carry propagation takes
 	 * care of any mistaken quotient digits.
+	 *
+	 * Note that div[] doesn't necessarily contain all of the digits from the
+	 * dividend --- the desired precision plus guard digits might be less than
+	 * the dividend's precision.  This happens, for example, in the square
+	 * root algorithm, where we typically divide a 2N-digit number by an
+	 * N-digit number, and only require a result with N digits of precision.
 	 */
 	div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
-	for (i = 0; i < var1ndigits; i++)
+	load_ndigits = Min(div_ndigits, var1ndigits);
+	for (i = 0; i < load_ndigits; i++)
 		div[i + 1] = var1digits[i];
 
 	/*
@@ -7844,9 +7841,15 @@ div_var_fast(const NumericVar *var1, con
 			maxdiv += Abs(qdigit);
 			if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
 			{
-				/* Yes, do it */
+				/*
+				 * Yes, do it.  Note that if var2ndigits is much smaller than
+				 * div_ndigits, we can save a significant amount of effort
+				 * here by noting that we only need to normalise those div[]
+				 * entries touched where prior iterations subtracted multiples
+				 * of the divisor.
+				 */
 				carry = 0;
-				for (i = div_ndigits; i > qi; i--)
+				for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
 				{
 					newdig = div[i] + carry;
 					if (newdig < 0)
@@ -8095,6 +8098,76 @@ mod_var(const NumericVar *var1, const Nu
 
 
 /*
+ * div_mod_var() -
+ *
+ *	Calculate the truncated integer quotient and numeric remainder of two
+ *	numeric variables.  The remainder is precise to var2's dscale.
+ */
+static void
+div_mod_var(const NumericVar *var1, const NumericVar *var2,
+			NumericVar *quot, NumericVar *rem)
+{
+	NumericVar	q;
+	NumericVar	r;
+
+	init_var(&q);
+	init_var(&r);
+
+	/*
+	 * Use div_var_fast() to get an initial estimate for the integer quotient.
+	 * This might be inaccurate (per the warning in div_var_fast's comments),
+	 * but we can correct it below.
+	 */
+	div_var_fast(var1, var2, &q, 0, false);
+
+	/* Compute initial estimate of remainder using the quotient estimate. */
+	mul_var(var2, &q, &r, var2->dscale);
+	sub_var(var1, &r, &r);
+
+	/*
+	 * Adjust the results if necessary --- the remainder should have the same
+	 * sign as var1, and its absolute value should be less than the absolute
+	 * value of var2.
+	 */
+	while (r.ndigits != 0 && r.sign != var1->sign)
+	{
+		/* The absolute value of the quotient is too large */
+		if (var1->sign == var2->sign)
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+		else
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+	}
+
+	while (cmp_abs(&r, var2) >= 0)
+	{
+		/* The absolute value of the quotient is too small */
+		if (var1->sign == var2->sign)
+		{
+			add_var(&q, &const_one, &q);
+			sub_var(&r, var2, &r);
+		}
+		else
+		{
+			sub_var(&q, &const_one, &q);
+			add_var(&r, var2, &r);
+		}
+	}
+
+	set_var_from_var(&q, quot);
+	set_var_from_var(&r, rem);
+
+	free_var(&q);
+	free_var(&r);
+}
+
+
+/*
  * ceil_var() -
  *
  *	Return the smallest integer greater than or equal to the argument
@@ -8213,18 +8286,30 @@ gcd_var(const NumericVar *var1, const Nu
 /*
  * sqrt_var() -
  *
- *	Compute the square root of x using Newton's algorithm
+ *	Compute the square root of x using the Karatsuba Square Root algorithm.
+ *	NOTE: we allow rscale < 0 here, implying rounding before the decimal
+ *	point.
  */
 static void
 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
 {
-	NumericVar	tmp_arg;
-	NumericVar	tmp_val;
-	NumericVar	last_val;
-	int			local_rscale;
 	int			stat;
-
-	local_rscale = rscale + 8;
+	int			res_weight;
+	int			res_ndigits;
+	int			src_ndigits;
+	int			step;
+	int			ndigits[32];
+	int			blen;
+	int64		arg_int64;
+	int			src_idx;
+	int64		s_int64;
+	int64		r_int64;
+	NumericVar	s_var;
+	NumericVar	r_var;
+	NumericVar	a0_var;
+	NumericVar	a1_var;
+	NumericVar	q_var;
+	NumericVar	u_var;
 
 	stat = cmp_var(arg, &const_zero);
 	if (stat == 0)
@@ -8243,43 +8328,440 @@ sqrt_var(const NumericVar *arg, NumericV
 				(errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
 				 errmsg("cannot take square root of a negative number")));
 
-	init_var(&tmp_arg);
-	init_var(&tmp_val);
-	init_var(&last_val);
+	init_var(&s_var);
+	init_var(&r_var);
+	init_var(&a0_var);
+	init_var(&a1_var);
+	init_var(&q_var);
+	init_var(&u_var);
 
-	/* Copy arg in case it is the same var as result */
-	set_var_from_var(arg, &tmp_arg);
+	/*
+	 * The result weight is half the input weight, rounded towards minus
+	 * infinity --- res_weight = floor(arg->weight / 2).
+	 */
+	if (arg->weight >= 0)
+		res_weight = arg->weight / 2;
+	else
+		res_weight = -((-arg->weight - 1) / 2 + 1);
 
 	/*
-	 * Initialize the result to the first guess
+	 * Number of NBASE digits to compute.  To ensure correct rounding, compute
+	 * at least 1 extra decimal digit.  We explicitly allow rscale to be
+	 * negative here, but must always compute at least 1 NBASE digit.  Thus
+	 * res_ndigits = res_weight + 1 + ceil((rscale + 1) / DEC_DIGITS) or 1.
 	 */
-	alloc_var(result, 1);
-	result->digits[0] = tmp_arg.digits[0] / 2;
-	if (result->digits[0] == 0)
-		result->digits[0] = 1;
-	result->weight = tmp_arg.weight / 2;
-	result->sign = NUMERIC_POS;
+	if (rscale + 1 >= 0)
+		res_ndigits = res_weight + 1 + (rscale + DEC_DIGITS) / DEC_DIGITS;
+	else
+		res_ndigits = res_weight + 1 - (-rscale - 1) / DEC_DIGITS;
+	res_ndigits = Max(res_ndigits, 1);
 
-	set_var_from_var(result, &last_val);
+	/*
+	 * Number of source NBASE digits logically required to produce a result
+	 * with this precision --- every digit before the decimal point, plus 2
+	 * for each result digit after the decimal point (or minus 2 for each
+	 * result digit we round before the decimal point).
+	 */
+	src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
+	src_ndigits = Max(src_ndigits, 1);
 
-	for (;;)
+	/* ----------
+	 * From this point on, we treat the input and the result as integers and
+	 * compute the integer square root and remainder using the Karatsuba
+	 * Square Root algorithm, which may be written recursively as follows:
+	 *
+	 *	SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
+	 *		[ for some base b, and coefficients a0,a1,a2,a3 chosen so that
+	 *		  0 <= a0,a1,a2 < b and a3 >= b/4 ]
+	 *		Let (s,r) = SqrtRem(a3*b + a2)
+	 *		Let (q,u) = DivRem(r*b + a1, 2*s)
+	 *		Let s = s*b + q
+	 *		Let r = u*b + a0 - q^2
+	 *		If r < 0 Then
+	 *			Let r = r + s
+	 *			Let s = s - 1
+	 *			Let r = r + s
+	 *		Return (s,r)
+	 *
+	 * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
+	 * RR-3805, November 1999.  At the time of writing this was available
+	 * on the net at <https://hal.inria.fr/inria-00072854>.
+	 *
+	 * The way to read the assumption "n = a3*b^3 + a2*b^2 + a1*b + a0" is
+	 * "choose a base b such that n requires at least four base-b digits to
+	 * express; then those digits are a3,a2,a1,a0, with a3 possibly larger
+	 * than b".  For optimal performance, b should have approximately a
+	 * quarter the number of digits in the input, so that the outer square
+	 * root computes roughly twice as many digits as the inner one.  For
+	 * simplicity, we choose b = NBASE^blen, an integer power of NBASE.
+	 *
+	 * We implement the algorithm iteratively rather than recursively, to
+	 * allow the working variables to be reused.  With this approach, each
+	 * digit of the input is read precisely once --- src_idx tracks the number
+	 * of input digits used so far.
+	 *
+	 * The array ndigits[] holds the number of NBASE digits of the input that
+	 * will have been used at the end of each iteration, which roughly doubles
+	 * each time.  Note that the array elements are stored in reverse order,
+	 * so if the final iteration requires src_ndigits = 37 input digits, the
+	 * array will contain [37,19,11,7,5,3], and we would start by computing
+	 * the square root of the 3 most significant NBASE digits.
+	 *
+	 * In each iteration, we choose blen to be the largest integer for which
+	 * the input number has a3 >= b/4, when written in the form above.  In
+	 * general, this means blen = src_ndigits / 4 (truncated), but if
+	 * src_ndigits is a multiple of 4, that might lead to the coefficient a3
+	 * being less than b/4 (if the first input digit is less than NBASE/4), in
+	 * which case we choose blen = src_ndigits / 4 - 1.  The number of digits
+	 * in the inner square root is then src_ndigits - 2*blen.  So, for
+	 * example, if we have src_ndigits = 26 initially, the array ndigits[]
+	 * will be either [26,14,8,4] or [26,14,8,6,4], depending on the size of
+	 * the first input digit.
+	 *
+	 * Additionally, we can put an upper bound on the number of steps required
+	 * as follows --- suppose that the number of source digits is an n-bit
+	 * number in the range [2^(n-1), 2^n-1], then blen will be in the range
+	 * [2^(n-3)-1, 2^(n-2)-1] and the number of digits in the inner square
+	 * root will be in the range [2^(n-2), 2^(n-1)+1].  In the next step, blen
+	 * will be in the range [2^(n-4)-1, 2^(n-3)] and the number of digits in
+	 * the next inner square root will be in the range [2^(n-3), 2^(n-2)+1].
+	 * This pattern repeats, and in the worst case the array ndigits[] will
+	 * contain [2^n-1, 2^(n-1)+1, 2^(n-2)+1, ... 9, 5, 3], and the computation
+	 * will require n steps.  Therefore, since all digit array sizes are
+	 * signed 32-bit integers, the number of steps required is guaranteed to
+	 * be less than 32.
+	 * ----------
+	 */
+	step = 0;
+	while ((ndigits[step] = src_ndigits) > 4)
 	{
-		div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
+		/* Choose b so that a3 >= b/4, as described above */
+		blen = src_ndigits / 4;
+		if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
+			blen--;
 
-		add_var(result, &tmp_val, result);
-		mul_var(result, &const_zero_point_five, result, local_rscale);
+		/* Number of digits in the next step (inner square root) */
+		src_ndigits -= 2 * blen;
+		step++;
+	}
 
-		if (cmp_var(&last_val, result) == 0)
-			break;
-		set_var_from_var(result, &last_val);
+	/*
+	 * First iteration (innermost square root and remainder):
+	 *
+	 * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
+	 * has at most 9 decimal digits, so estimate it using double precision
+	 * arithmetic, which will in fact almost certainly return the correct
+	 * result with no further correction required.
+	 */
+	arg_int64 = arg->digits[0];
+	for (src_idx = 1; src_idx < src_ndigits; src_idx++)
+	{
+		arg_int64 *= NBASE;
+		if (src_idx < arg->ndigits)
+			arg_int64 += arg->digits[src_idx];
 	}
 
-	free_var(&last_val);
-	free_var(&tmp_val);
-	free_var(&tmp_arg);
+	s_int64 = (int64) sqrt((double) arg_int64);
+	r_int64 = arg_int64 - s_int64 * s_int64;
 
-	/* Round to requested precision */
+	/*
+	 * Use Newton's method to correct the result, if necessary.
+	 *
+	 * This uses integer division with truncation to compute the truncated
+	 * integer square root by iterating using the formula x -> (x + n/x) / 2.
+	 * This is known to converge to isqrt(n), unless n+1 is a perfect square.
+	 * If n+1 is a perfect square, the sequence will oscillate between the two
+	 * values isqrt(n) and isqrt(n)+1, so we can be assured of convergence by
+	 * checking the remainder.
+	 */
+	while (r_int64 < 0 || r_int64 > 2 * s_int64)
+	{
+		s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
+		r_int64 = arg_int64 - s_int64 * s_int64;
+	}
+
+	/*
+	 * Iterations with src_ndigits <= 8:
+	 *
+	 * The next 1 or 2 iterations compute larger (outer) square roots with
+	 * src_ndigits <= 8, so the result still fits in an int64 (even though the
+	 * input no longer does) and we can continue to compute using int64
+	 * variables to avoid more expensive numeric computations.
+	 *
+	 * It is fairly easy to see that there is no risk of the intermediate
+	 * values below overflowing 64-bit integers.  In the worst case, the
+	 * previous iteration will have computed a 3-digit square root (of a
+	 * 6-digit input less than NBASE^6 / 4), so at the start of this
+	 * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
+	 * less than 10^12.  In this case, blen will be 1, so numer will be less
+	 * than 10^17, and denom will be less than 10^12 (and hence u will also be
+	 * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
+	 * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
+	 * in 64-bit integers.
+	 */
+	step--;
+	while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
+	{
+		int			b;
+		int			a0;
+		int			a1;
+		int			i;
+		int64		numer;
+		int64		denom;
+		int64		q;
+		int64		u;
+
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0, and compute b */
+		a0 = 0;
+		a1 = 0;
+		b = 1;
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			b *= NBASE;
+			a1 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a1 += arg->digits[src_idx];
+		}
+
+		for (i = 0; i < blen; i++, src_idx++)
+		{
+			a0 *= NBASE;
+			if (src_idx < arg->ndigits)
+				a0 += arg->digits[src_idx];
+		}
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		numer = r_int64 * b + a1;
+		denom = 2 * s_int64;
+		q = numer / denom;
+		u = numer - q * denom;
+
+		/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+		s_int64 = s_int64 * b + q;
+		r_int64 = u * b + a0 - q * q;
+
+		if (r_int64 < 0)
+		{
+			/* s is too large by 1; set r += s, s--, r += s */
+			r_int64 += s_int64;
+			s_int64--;
+			r_int64 += s_int64;
+		}
+
+		Assert(src_idx == src_ndigits); /* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * On platforms with 128-bit integer support, we can further delay the
+	 * need to use numeric variables.
+	 */
+#ifdef HAVE_INT128
+	if (step >= 0)
+	{
+		int128		s_int128;
+		int128		r_int128;
+
+		s_int128 = s_int64;
+		r_int128 = r_int64;
+
+		/*
+		 * Iterations with src_ndigits <= 16:
+		 *
+		 * The result fits in an int128 (even though the input doesn't) so we
+		 * use int128 variables to avoid more expensive numeric computations.
+		 */
+		while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
+		{
+			int64		b;
+			int64		a0;
+			int64		a1;
+			int64		i;
+			int128		numer;
+			int128		denom;
+			int128		q;
+			int128		u;
+
+			blen = (src_ndigits - src_idx) / 2;
+
+			/* Extract a1 and a0, and compute b */
+			a0 = 0;
+			a1 = 0;
+			b = 1;
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				b *= NBASE;
+				a1 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a1 += arg->digits[src_idx];
+			}
+
+			for (i = 0; i < blen; i++, src_idx++)
+			{
+				a0 *= NBASE;
+				if (src_idx < arg->ndigits)
+					a0 += arg->digits[src_idx];
+			}
+
+			/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+			numer = r_int128 * b + a1;
+			denom = 2 * s_int128;
+			q = numer / denom;
+			u = numer - q * denom;
+
+			/* Compute s = s*b + q and r = u*b + a0 - q^2 */
+			s_int128 = s_int128 * b + q;
+			r_int128 = u * b + a0 - q * q;
+
+			if (r_int128 < 0)
+			{
+				/* s is too large by 1; set r += s, s--, r += s */
+				r_int128 += s_int128;
+				s_int128--;
+				r_int128 += s_int128;
+			}
+
+			Assert(src_idx == src_ndigits); /* All input digits consumed */
+			step--;
+		}
+
+		/*
+		 * All remaining iterations require numeric variables.  Convert the
+		 * integer values to NumericVar and continue.  Note that in the final
+		 * iteration we don't need the remainder, so we can save a few cycles
+		 * there by not fully computing it.
+		 */
+		int128_to_numericvar(s_int128, &s_var);
+		if (step >= 0)
+			int128_to_numericvar(r_int128, &r_var);
+	}
+	else
+	{
+		int64_to_numericvar(s_int64, &s_var);
+		/* step < 0, so we certainly don't need r */
+	}
+#else							/* !HAVE_INT128 */
+	int64_to_numericvar(s_int64, &s_var);
+	if (step >= 0)
+		int64_to_numericvar(r_int64, &r_var);
+#endif							/* HAVE_INT128 */
+
+	/*
+	 * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
+	 * use numeric variables.
+	 */
+	while (step >= 0)
+	{
+		int			tmp_len;
+
+		src_ndigits = ndigits[step];
+		blen = (src_ndigits - src_idx) / 2;
+
+		/* Extract a1 and a0 */
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a1_var, tmp_len);
+			memcpy(a1_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a1_var.weight = blen - 1;
+			a1_var.sign = NUMERIC_POS;
+			a1_var.dscale = 0;
+			strip_var(&a1_var);
+		}
+		else
+		{
+			zero_var(&a1_var);
+			a1_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		if (src_idx < arg->ndigits)
+		{
+			tmp_len = Min(blen, arg->ndigits - src_idx);
+			alloc_var(&a0_var, tmp_len);
+			memcpy(a0_var.digits, arg->digits + src_idx,
+				   tmp_len * sizeof(NumericDigit));
+			a0_var.weight = blen - 1;
+			a0_var.sign = NUMERIC_POS;
+			a0_var.dscale = 0;
+			strip_var(&a0_var);
+		}
+		else
+		{
+			zero_var(&a0_var);
+			a0_var.dscale = 0;
+		}
+		src_idx += blen;
+
+		/* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+		set_var_from_var(&r_var, &q_var);
+		q_var.weight += blen;
+		add_var(&q_var, &a1_var, &q_var);
+		add_var(&s_var, &s_var, &u_var);
+		div_mod_var(&q_var, &u_var, &q_var, &u_var);
+
+		/* Compute s = s*b + q */
+		s_var.weight += blen;
+		add_var(&s_var, &q_var, &s_var);
+
+		/*
+		 * Compute r = u*b + a0 - q^2.
+		 *
+		 * In the final iteration, we don't actually need r; we just need to
+		 * know whether it is negative, so that we know whether to adjust s.
+		 * So instead of the final subtraction we can just compare.
+		 */
+		u_var.weight += blen;
+		add_var(&u_var, &a0_var, &u_var);
+		mul_var(&q_var, &q_var, &q_var, 0);
+
+		if (step > 0)
+		{
+			/* Need r for later iterations */
+			sub_var(&u_var, &q_var, &r_var);
+			if (r_var.sign == NUMERIC_NEG)
+			{
+				/* s is too large by 1; set r += s, s--, r += s */
+				add_var(&r_var, &s_var, &r_var);
+				sub_var(&s_var, &const_one, &s_var);
+				add_var(&r_var, &s_var, &r_var);
+			}
+		}
+		else
+		{
+			/* Don't need r anymore, except to test if s is too large by 1 */
+			if (cmp_var(&u_var, &q_var) < 0)
+				sub_var(&s_var, &const_one, &s_var);
+		}
+
+		Assert(src_idx == src_ndigits); /* All input digits consumed */
+		step--;
+	}
+
+	/*
+	 * Construct the final result, rounding it to the requested precision.
+	 */
+	set_var_from_var(&s_var, result);
+	result->weight = res_weight;
+	result->sign = NUMERIC_POS;
+
+	/* Round to target rscale (and set result->dscale) */
 	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+
+	free_var(&s_var);
+	free_var(&r_var);
+	free_var(&a0_var);
+	free_var(&a1_var);
+	free_var(&q_var);
+	free_var(&u_var);
 }
 
 
@@ -8530,12 +9012,18 @@ ln_var(const NumericVar *arg, NumericVar
 	 * Each sqrt() will roughly halve the weight of x, so adjust the local
 	 * rscale as we work so that we keep this many significant digits at each
 	 * step (plus a few more for good measure).
+	 *
+	 * Note that we allow local_rscale < 0 during this input reduction
+	 * process, which implies rounding before the decimal point.  sqrt_var()
+	 * explicitly supports this, and it significantly reduces the work
+	 * required to reduce very large inputs to the required range.  Once the
+	 * input reduction is complete, x.weight will be 0 and its display scale
+	 * will be non-negative again.
 	 */
 	nsqrt = 0;
 	while (cmp_var(&x, &const_zero_point_nine) <= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
@@ -8543,7 +9031,6 @@ ln_var(const NumericVar *arg, NumericVar
 	while (cmp_var(&x, &const_one_point_one) >= 0)
 	{
 		local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-		local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
 		sqrt_var(&x, &x, local_rscale);
 		mul_var(&fact, &const_two, &fact, 0);
 		nsqrt++;
diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out
new file mode 100644
index 23a4c6d..c7fe63d
--- a/src/test/regress/expected/numeric.out
+++ b/src/test/regress/expected/numeric.out
@@ -1580,6 +1580,57 @@ select div(12345678901234567890, 123) *
 (1 row)
 
 --
+-- Test some corner cases for square root
+--
+select sqrt(1.000000000000003::numeric);
+       sqrt        
+-------------------
+ 1.000000000000001
+(1 row)
+
+select sqrt(1.000000000000004::numeric);
+       sqrt        
+-------------------
+ 1.000000000000002
+(1 row)
+
+select sqrt(96627521408608.56340355805::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248648
+(1 row)
+
+select sqrt(96627521408608.56340355806::numeric);
+        sqrt         
+---------------------
+ 9829929.87811248649
+(1 row)
+
+select sqrt(515549506212297735.073688290367::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921184
+(1 row)
+
+select sqrt(515549506212297735.073688290368::numeric);
+          sqrt          
+------------------------
+ 718017761.766585921185
+(1 row)
+
+select sqrt(8015491789940783531003294973900306::numeric);
+       sqrt        
+-------------------
+ 89529278953540017
+(1 row)
+
+select sqrt(8015491789940783531003294973900307::numeric);
+       sqrt        
+-------------------
+ 89529278953540018
+(1 row)
+
+--
 -- Test code path for raising to integer powers
 --
 select 10.0 ^ -2147483648 as rounds_to_zero;
diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql
new file mode 100644
index c5c8d76..41475a9
--- a/src/test/regress/sql/numeric.sql
+++ b/src/test/regress/sql/numeric.sql
@@ -883,6 +883,19 @@ select div(12345678901234567890, 123);
 select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;
 
 --
+-- Test some corner cases for square root
+--
+
+select sqrt(1.000000000000003::numeric);
+select sqrt(1.000000000000004::numeric);
+select sqrt(96627521408608.56340355805::numeric);
+select sqrt(96627521408608.56340355806::numeric);
+select sqrt(515549506212297735.073688290367::numeric);
+select sqrt(515549506212297735.073688290368::numeric);
+select sqrt(8015491789940783531003294973900306::numeric);
+select sqrt(8015491789940783531003294973900307::numeric);
+
+--
 -- Test code path for raising to integer powers
 --
 
#10Tom Lane
tgl@sss.pgh.pa.us
In reply to: Dean Rasheed (#9)
Re: Some improvements to numeric sqrt() and ln()

Dean Rasheed <dean.a.rasheed@gmail.com> writes:

Here is an updated patch with the following updates based on your comments:

This resolves all my concerns. I've marked it RFC in the CF app.

regards, tom lane