From 2461ef9b3711e5e72368deb18ad92a5989be4f19 Mon Sep 17 00:00:00 2001
From: Kyotaro Horiguchi <horiguchi.kyotaro@lab.ntt.co.jp>
Date: Wed, 5 Apr 2017 20:01:01 +0900
Subject: [PATCH] Fix overflow during interval comparison.

The values in interval are compared by TimeOffset results of
interval_cmp_value but it is so narrow that overflows quite easily.
This patch widen the output of the function to 128 bit. For platforms
without 128 bit arithmetic, a pair of 64 bit intergers is used
instead.
---
 src/backend/utils/adt/timestamp.c      | 153 +++++++++++++++++++++++++++++++--
 src/include/datatype/timestamp.h       |  13 +++
 src/test/regress/expected/interval.out |  36 ++++++++
 src/test/regress/sql/interval.sql      |  17 ++++
 4 files changed, 210 insertions(+), 9 deletions(-)

diff --git a/src/backend/utils/adt/timestamp.c b/src/backend/utils/adt/timestamp.c
index 4be1999..dfef4d2 100644
--- a/src/backend/utils/adt/timestamp.c
+++ b/src/backend/utils/adt/timestamp.c
@@ -2289,25 +2289,148 @@ timestamptz_cmp_timestamp(PG_FUNCTION_ARGS)
 /*
  *		interval_relop	- is interval1 relop interval2
  */
-static inline TimeOffset
+#ifdef HAVE_INT128
+static inline LinearInterval
 interval_cmp_value(const Interval *interval)
 {
-	TimeOffset	span;
+	LinearInterval	span;
 
-	span = interval->time;
-	span += interval->month * INT64CONST(30) * USECS_PER_DAY;
-	span += interval->day * INT64CONST(24) * USECS_PER_HOUR;
+	span = (int128)interval->time;
+	span += (int128)interval->month * INT64CONST(30) * USECS_PER_DAY;
+	span += (int128)interval->day * INT64CONST(24) * USECS_PER_HOUR;
 
 	return span;
 }
+#else
+/*
+ * INT64_AU32 extracts the most significant 32 bits of int64 as int64, while
+ * INT64_AL32 extracts the least significant 32 bits as uint64. The input
+ * value is restored by (INT64_AU32 << 32) + INT64_AL32 for both positive and
+ * negative cases.
+ */
+#define INT64_AU32(i64) ((i64) >> 32)
+#define INT64_AL32(i64) ((uint64)(i64) & UINT64CONST(0xFFFFFFFF))
+
+/*
+ * Add an unsigned int64 value into a LinearInterval variable.
+ * First add the value to the .lo part, then check to see if a carry
+ * needs to be propagated into the .hi part.  A carry is needed if both
+ * inputs have high bits set, or if just one input has high bit set
+ * but the new .lo part doesn't.  Remember that .lo part is unsigned;
+ * we cast to signed here just as a cheap way to check the high bit.
+ */
+#define LINEARINTERVAL_ADD_UINT64(li, v) \
+do { \
+	bool	msb_1 = (int64) (li).lo < 0; \
+	bool	msb_2 = (int64) (v) < 0; \
+	(li).lo += (uint64) v; \
+	if ((msb_1 && msb_2) ||	(msb_1 != msb_2 && (int64) (li).lo >= 0)) \
+		(li).hi++; \
+} while (0)
+
+static inline LinearInterval
+interval_cmp_value(const Interval *interval)
+{
+	LinearInterval	span = {0, 0};
+	int64	dayfraction;
+	int64	days;
+
+	/* INT64_AU32 must use arithmetic right shift */
+	StaticAssertStmt(((int64) -1 >> 1) == (int64) -1,
+					 "arithmetic right shift is needed");
+
+	/*
+	 * days cannot overflow. dayfraction shoud be positive since it is added
+	 * later using LINEARINTERVAL_ADD_UINT64.
+	 */
+	dayfraction = interval->time % USECS_PER_DAY;
+	days = interval->time / USECS_PER_DAY;
+	if (dayfraction < 0)
+	{
+		days--;
+		dayfraction += USECS_PER_DAY;
+	}
+	days += interval->month * INT64CONST(30);
+	days += interval->day;
+
+	/*
+	 * days cannot overflow. dayfraction should be unsigned int so correct
+	 * days if it is negative.
+	 */
+
+	/*----------
+	 * Form the 128-bit product x * y using 64-bit arithmetic.
+	 * Considering each 64-bit input as having 32-bit high and low parts,
+	 * we can compute
+	 *
+	 *	 x * y = ((x.hi << 32) + x.lo) * (((y.hi << 32) + y.lo)
+	 *		   = (x.hi * y.hi) << 64 +
+	 *			 (x.hi * y.lo) << 32 +
+	 *			 (x.lo * y.hi) << 32 +
+	 *			 x.lo * y.lo
+	 *
+	 * Each individual product is of 32-bit terms so it won't overflow when
+	 * computed in 64-bit arithmetic.  Then we just have to shift it to the
+	 * correct position while adding into the 128-bit result.  We must also
+	 * keep in mind that the "lo" parts must be treated as unsigned.
+	 *----------
+	 */
+
+	/* No need to work hard if product must be zero */
+	if (days != 0)
+	{
+		int64	x_u32 = INT64_AU32(days);
+		uint64	x_l32 = INT64_AL32(days);
+		int64	y_u32 = INT64_AU32(USECS_PER_DAY);
+		uint64	y_l32 = INT64_AL32(USECS_PER_DAY);
+		int64	tmp;
+
+		/* the first term */
+		span.hi = x_u32 * y_u32;
+
+		/* the second term: sign-extend it only if x is negative */
+		tmp = x_u32 * y_l32;
+		span.hi += INT64_AU32(tmp);
+		LINEARINTERVAL_ADD_UINT64(span, INT64_AL32(tmp) << 32);
+
+		/* the third term: sign-extend it only if y is negative */
+		tmp = x_l32 * y_u32;
+		span.hi += INT64_AU32(tmp);
+		LINEARINTERVAL_ADD_UINT64(span, INT64_AL32(tmp) << 32);
+
+		/* the fourth term: always unsigned */
+		LINEARINTERVAL_ADD_UINT64(span, x_l32 * y_l32);
+	}
+
+	/* add dayfraction */
+	LINEARINTERVAL_ADD_UINT64(span, dayfraction);
+
+	return span;
+}
+#endif
 
 static int
 interval_cmp_internal(Interval *interval1, Interval *interval2)
 {
-	TimeOffset	span1 = interval_cmp_value(interval1);
-	TimeOffset	span2 = interval_cmp_value(interval2);
+	LinearInterval	span1 = interval_cmp_value(interval1);
+	LinearInterval	span2 = interval_cmp_value(interval2);
 
+#ifdef HAVE_INT128
 	return ((span1 < span2) ? -1 : (span1 > span2) ? 1 : 0);
+#else
+	if (span1.hi < span2.hi)
+		return -1;
+	if (span1.hi > span2.hi)
+		return 1;
+
+	/*
+	 * the result of comparation of lo differs according to the sign of hi
+	 */
+	if (span1.hi >= 0)
+		return (span1.lo < span2.lo ? -1 : (span1.lo > span2.lo) ? 1: 0);
+	else
+		return (span1.lo > span2.lo ? -1 : (span1.lo < span2.lo) ? 1: 0);
+#endif
 }
 
 Datum
@@ -2384,9 +2507,21 @@ Datum
 interval_hash(PG_FUNCTION_ARGS)
 {
 	Interval   *interval = PG_GETARG_INTERVAL_P(0);
-	TimeOffset	span = interval_cmp_value(interval);
+	LinearInterval	span = interval_cmp_value(interval);
+	int64 span64;
+
+	/*
+	 * taking non-siginificant 64 bits for hashing. it is not necessary to
+	 * fold the significant half of the bits into the rest for this
+	 * usage. Anyway it is zero for most cases.
+	 */
+#ifdef HAVE_INT128
+	span64 = span & 0xffffffffffffffff;
+#else
+	span64 = span.lo;
+#endif
 
-	return DirectFunctionCall1(hashint8, Int64GetDatumFast(span));
+	return DirectFunctionCall1(hashint8, Int64GetDatumFast(span64));
 }
 
 /* overlaps_timestamp() --- implements the SQL OVERLAPS operator.
diff --git a/src/include/datatype/timestamp.h b/src/include/datatype/timestamp.h
index 16e1b4c..fd24987 100644
--- a/src/include/datatype/timestamp.h
+++ b/src/include/datatype/timestamp.h
@@ -48,6 +48,19 @@ typedef struct
 	int32		month;			/* months and years, after time for alignment */
 } Interval;
 
+#ifdef HAVE_INT128
+typedef int128	LinearInterval;
+#else
+/*
+ * LinearInterval's alternative defeinition for the environments without
+ * int128 arithmetics. See interval_cmp_value for datails.
+ */
+typedef struct
+{
+	int64	hi;	/* holds significant 64 bits including a sign bit */
+	uint64	lo; /* holds the lower 64 bits without sign */
+} LinearInterval;
+#endif
 
 /* Limits on the "precision" option (typmod) for these data types */
 #define MAX_TIMESTAMP_PRECISION 6
diff --git a/src/test/regress/expected/interval.out b/src/test/regress/expected/interval.out
index 946c97a..d4e6a00 100644
--- a/src/test/regress/expected/interval.out
+++ b/src/test/regress/expected/interval.out
@@ -207,6 +207,42 @@ SELECT '' AS fortyfive, r1.*, r2.*
            | 34 years        | 6 years
 (45 rows)
 
+-- intervals out of 64bit linear value in microseconds
+CREATE TABLE INTERVAL_TBL_OF (f1 interval);
+INSERT INTO INTERVAL_TBL_OF (f1) VALUES ('2147483647 days 2147483647 months');
+INSERT INTO INTERVAL_TBL_OF (f1) VALUES ('1 year');
+INSERT INTO INTERVAL_TBL_OF (f1) VALUES ('-2147483648 days -2147483648 months');
+SELECT '' AS "64bits", r1.*, r2.*
+   FROM INTERVAL_TBL_OF r1, INTERVAL_TBL_OF r2
+   WHERE r1.f1 > r2.f1
+   ORDER BY r1.f1, r2.f1;
+ 64bits |                   f1                   |                    f1                     
+--------+----------------------------------------+-------------------------------------------
+        | 1 year                                 | -178956970 years -8 mons -2147483648 days
+        | 178956970 years 7 mons 2147483647 days | -178956970 years -8 mons -2147483648 days
+        | 178956970 years 7 mons 2147483647 days | 1 year
+(3 rows)
+
+CREATE INDEX ON INTERVAL_TBL_OF USING btree (f1);
+SET enable_seqscan to false;
+EXPLAIN SELECT '' AS "64bits_2", f1
+  FROM INTERVAL_TBL_OF r1 ORDER BY f1;
+                                               QUERY PLAN                                               
+--------------------------------------------------------------------------------------------------------
+ Index Only Scan using interval_tbl_of_f1_idx on interval_tbl_of r1  (cost=0.13..12.18 rows=3 width=48)
+(1 row)
+
+SELECT '' AS "64bits_2", f1
+  FROM INTERVAL_TBL_OF r1 ORDER BY f1;
+ 64bits_2 |                    f1                     
+----------+-------------------------------------------
+          | -178956970 years -8 mons -2147483648 days
+          | 1 year
+          | 178956970 years 7 mons 2147483647 days
+(3 rows)
+
+SET enable_seqscan to default;
+DROP TABLE INTERVAL_TBL_OF;
 -- Test multiplication and division with intervals.
 -- Floating point arithmetic rounding errors can lead to unexpected results,
 -- though the code attempts to do the right thing and round up to days and
diff --git a/src/test/regress/sql/interval.sql b/src/test/regress/sql/interval.sql
index cff9ada..aa21673 100644
--- a/src/test/regress/sql/interval.sql
+++ b/src/test/regress/sql/interval.sql
@@ -59,6 +59,23 @@ SELECT '' AS fortyfive, r1.*, r2.*
    WHERE r1.f1 > r2.f1
    ORDER BY r1.f1, r2.f1;
 
+-- intervals out of 64bit linear value in microseconds
+CREATE TABLE INTERVAL_TBL_OF (f1 interval);
+INSERT INTO INTERVAL_TBL_OF (f1) VALUES ('2147483647 days 2147483647 months');
+INSERT INTO INTERVAL_TBL_OF (f1) VALUES ('1 year');
+INSERT INTO INTERVAL_TBL_OF (f1) VALUES ('-2147483648 days -2147483648 months');
+SELECT '' AS "64bits", r1.*, r2.*
+   FROM INTERVAL_TBL_OF r1, INTERVAL_TBL_OF r2
+   WHERE r1.f1 > r2.f1
+   ORDER BY r1.f1, r2.f1;
+CREATE INDEX ON INTERVAL_TBL_OF USING btree (f1);
+SET enable_seqscan to false;
+EXPLAIN SELECT '' AS "64bits_2", f1
+  FROM INTERVAL_TBL_OF r1 ORDER BY f1;
+SELECT '' AS "64bits_2", f1
+  FROM INTERVAL_TBL_OF r1 ORDER BY f1;
+SET enable_seqscan to default;
+DROP TABLE INTERVAL_TBL_OF;
 
 -- Test multiplication and division with intervals.
 -- Floating point arithmetic rounding errors can lead to unexpected results,
-- 
2.9.2

