From 2fccf780a517416283ad739acb6df4cbcf2f2b28 Mon Sep 17 00:00:00 2001
From: Andres Freund <andres@anarazel.de>
Date: Mon, 9 Jan 2023 11:25:43 -0800
Subject: [PATCH v2 3/3] wip: stricter validation for 64bit xids

64bit xids can't represent xids before xid 0, but 32bit xids can. This has
lead to a number of bugs, some data corrupting. To make it easier to catch
such bugs, disallow 64bit where the upper 32bit are all set, which is what one
gets when naively trying to convert such 32bit xids. Additionally explicitly
disallow 64bit xids that are already aren't allowed because they'd map to
special 32bit xids.

This changes the input routines for 64bit xids to error out in more cases.

Discussion: https://postgr.es/m/20230108002923.cyoser3ttmt63bfn@awork3.anarazel.de
---
 src/include/access/transam.h      | 136 ++++++++++++++++++++++++++++--
 src/backend/utils/adt/xid.c       |   7 ++
 src/backend/utils/adt/xid8funcs.c |  17 ++--
 src/test/regress/expected/xid.out |  44 ++++++++--
 src/test/regress/sql/xid.sql      |  12 ++-
 5 files changed, 192 insertions(+), 24 deletions(-)

diff --git a/src/include/access/transam.h b/src/include/access/transam.h
index f5af6d30556..0d4bf57e63a 100644
--- a/src/include/access/transam.h
+++ b/src/include/access/transam.h
@@ -44,15 +44,6 @@
 #define TransactionIdStore(xid, dest)	(*(dest) = (xid))
 #define StoreInvalidTransactionId(dest) (*(dest) = InvalidTransactionId)
 
-#define EpochFromFullTransactionId(x)	((uint32) ((x).value >> 32))
-#define XidFromFullTransactionId(x)		((uint32) (x).value)
-#define U64FromFullTransactionId(x)		((x).value)
-#define FullTransactionIdEquals(a, b)	((a).value == (b).value)
-#define FullTransactionIdPrecedes(a, b)	((a).value < (b).value)
-#define FullTransactionIdPrecedesOrEquals(a, b) ((a).value <= (b).value)
-#define FullTransactionIdFollows(a, b) ((a).value > (b).value)
-#define FullTransactionIdFollowsOrEquals(a, b) ((a).value >= (b).value)
-#define FullTransactionIdIsValid(x)		TransactionIdIsValid(XidFromFullTransactionId(x))
 #define InvalidFullTransactionId		FullTransactionIdFromEpochAndXid(0, InvalidTransactionId)
 #define FirstNormalFullTransactionId	FullTransactionIdFromEpochAndXid(0, FirstNormalTransactionId)
 #define FullTransactionIdIsNormal(x)	FullTransactionIdFollowsOrEquals(x, FirstNormalFullTransactionId)
@@ -61,12 +52,127 @@
  * A 64 bit value that contains an epoch and a TransactionId.  This is
  * wrapped in a struct to prevent implicit conversion to/from TransactionId.
  * Not all values represent valid normal XIDs.
+ *
+ * 64bit xids that would map to a non-normal 32bit xid are not allowed, for
+ * obvious reasons.
+
+ * In addition, no xids with all the upper 32bit sets are allowed to exist,
+ * this is to make it easier to catch conversion errors from 32bit xids (which
+ * can point to "before" xid 0).
  */
 typedef struct FullTransactionId
 {
 	uint64		value;
 } FullTransactionId;
 
+#define MAX_FULL_TRANSACTION_ID ((((uint64)~(uint32)0)<<32) - 1)
+
+/*
+ * This is separate from FullTransactionIdValidRange() so that input routines
+ * can check for invalid values without triggering an assert.
+ */
+static inline bool
+InFullTransactionIdRange(uint64 val)
+{
+	/* 64bit xid above the upper bound */
+	if (val > MAX_FULL_TRANSACTION_ID)
+		return false;
+
+	/*
+	 * normal 64bit xids that'd map to special 32 xids aren't allowed.
+	 */
+	if (val >= (uint64) FirstNormalTransactionId &&
+		(uint32) val < FirstNormalTransactionId)
+		return false;
+	return true;
+}
+
+/*
+ * Validation routine, typically used in asserts.
+ */
+static inline bool
+FullTransactionIdValidRange(FullTransactionId x)
+{
+	return InFullTransactionIdRange(x.value);
+}
+
+static inline uint64
+U64FromFullTransactionId(FullTransactionId x)
+{
+	Assert(FullTransactionIdValidRange(x));
+
+	return x.value;
+}
+
+static inline TransactionId
+XidFromFullTransactionId(FullTransactionId x)
+{
+	Assert(FullTransactionIdValidRange(x));
+
+	return (TransactionId) x.value;
+}
+
+static inline uint32
+EpochFromFullTransactionId(FullTransactionId x)
+{
+	Assert(FullTransactionIdValidRange(x));
+
+	return (uint32) (x.value >> 32);
+}
+
+static inline bool
+FullTransactionIdEquals(FullTransactionId a, FullTransactionId b)
+{
+	Assert(FullTransactionIdValidRange(a));
+	Assert(FullTransactionIdValidRange(b));
+
+	return a.value == b.value;
+}
+
+static inline bool
+FullTransactionIdPrecedes(FullTransactionId a, FullTransactionId b)
+{
+	Assert(FullTransactionIdValidRange(a));
+	Assert(FullTransactionIdValidRange(b));
+
+	return a.value < b.value;
+}
+
+static inline bool
+FullTransactionIdPrecedesOrEquals(FullTransactionId a, FullTransactionId b)
+{
+	Assert(FullTransactionIdValidRange(a));
+	Assert(FullTransactionIdValidRange(b));
+
+	return a.value <= b.value;
+}
+
+static inline bool
+FullTransactionIdFollows(FullTransactionId a, FullTransactionId b)
+{
+	Assert(FullTransactionIdValidRange(a));
+	Assert(FullTransactionIdValidRange(b));
+
+	return a.value > b.value;
+}
+
+static inline bool
+FullTransactionIdFollowsOrEquals(FullTransactionId a, FullTransactionId b)
+{
+	Assert(FullTransactionIdValidRange(a));
+	Assert(FullTransactionIdValidRange(b));
+
+	return a.value >= b.value;
+}
+
+static inline bool
+FullTransactionIdIsValid(FullTransactionId x)
+{
+	Assert(FullTransactionIdValidRange(x));
+
+	return x.value != (uint64) InvalidTransactionId;
+}
+
 static inline FullTransactionId
 FullTransactionIdFromEpochAndXid(uint32 epoch, TransactionId xid)
 {
@@ -74,6 +180,8 @@ FullTransactionIdFromEpochAndXid(uint32 epoch, TransactionId xid)
 
 	result.value = ((uint64) epoch) << 32 | xid;
 
+	Assert(FullTransactionIdValidRange(result));
+
 	return result;
 }
 
@@ -84,6 +192,8 @@ FullTransactionIdFromU64(uint64 value)
 
 	result.value = value;
 
+	Assert(FullTransactionIdValidRange(result));
+
 	return result;
 }
 
@@ -102,6 +212,8 @@ FullTransactionIdFromU64(uint64 value)
 static inline void
 FullTransactionIdRetreat(FullTransactionId *dest)
 {
+	Assert(FullTransactionIdValidRange(*dest));
+
 	dest->value--;
 
 	/*
@@ -118,6 +230,8 @@ FullTransactionIdRetreat(FullTransactionId *dest)
 	 */
 	while (XidFromFullTransactionId(*dest) < FirstNormalTransactionId)
 		dest->value--;
+
+	Assert(FullTransactionIdValidRange(*dest));
 }
 
 /*
@@ -127,6 +241,8 @@ FullTransactionIdRetreat(FullTransactionId *dest)
 static inline void
 FullTransactionIdAdvance(FullTransactionId *dest)
 {
+	Assert(FullTransactionIdValidRange(*dest));
+
 	dest->value++;
 
 	/* see FullTransactionIdAdvance() */
@@ -135,6 +251,8 @@ FullTransactionIdAdvance(FullTransactionId *dest)
 
 	while (XidFromFullTransactionId(*dest) < FirstNormalTransactionId)
 		dest->value++;
+
+	Assert(FullTransactionIdValidRange(*dest));
 }
 
 /* back up a transaction ID variable, handling wraparound correctly */
diff --git a/src/backend/utils/adt/xid.c b/src/backend/utils/adt/xid.c
index 8ac1679c381..21c0f4c67df 100644
--- a/src/backend/utils/adt/xid.c
+++ b/src/backend/utils/adt/xid.c
@@ -188,6 +188,13 @@ xid8in(PG_FUNCTION_ARGS)
 	uint64		result;
 
 	result = uint64in_subr(str, NULL, "xid8", fcinfo->context);
+
+	if (!InFullTransactionIdRange(result))
+		ereturn(fcinfo->context, 0,
+				(errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
+				 errmsg("value \"%s\" is out of range for type %s",
+						str, "xid8")));
+
 	PG_RETURN_FULLTRANSACTIONID(FullTransactionIdFromU64(result));
 }
 
diff --git a/src/backend/utils/adt/xid8funcs.c b/src/backend/utils/adt/xid8funcs.c
index e616303a292..b8d8d92f793 100644
--- a/src/backend/utils/adt/xid8funcs.c
+++ b/src/backend/utils/adt/xid8funcs.c
@@ -302,15 +302,18 @@ parse_snapshot(const char *str, Node *escontext)
 	const char *str_start = str;
 	char	   *endp;
 	StringInfo	buf;
+	uint64 raw_fxid;
 
-	xmin = FullTransactionIdFromU64(strtou64(str, &endp, 10));
-	if (*endp != ':')
+	raw_fxid = strtou64(str, &endp, 10);
+	if (*endp != ':' || !InFullTransactionIdRange(raw_fxid))
 		goto bad_format;
+	xmin = FullTransactionIdFromU64(raw_fxid);
 	str = endp + 1;
 
-	xmax = FullTransactionIdFromU64(strtou64(str, &endp, 10));
-	if (*endp != ':')
+	raw_fxid = strtou64(str, &endp, 10);
+	if (*endp != ':' || !InFullTransactionIdRange(raw_fxid))
 		goto bad_format;
+	xmax = FullTransactionIdFromU64(raw_fxid);
 	str = endp + 1;
 
 	/* it should look sane */
@@ -326,7 +329,11 @@ parse_snapshot(const char *str, Node *escontext)
 	while (*str != '\0')
 	{
 		/* read next value */
-		val = FullTransactionIdFromU64(strtou64(str, &endp, 10));
+		raw_fxid = strtou64(str, &endp, 10);
+
+		val = FullTransactionIdFromU64(raw_fxid);
+		if (!InFullTransactionIdRange(raw_fxid))
+			goto bad_format;
 		str = endp;
 
 		/* require the input to be in order */
diff --git a/src/test/regress/expected/xid.out b/src/test/regress/expected/xid.out
index e62f7019434..4348eb66bd5 100644
--- a/src/test/regress/expected/xid.out
+++ b/src/test/regress/expected/xid.out
@@ -6,11 +6,11 @@ select '010'::xid,
        '-1'::xid,
 	   '010'::xid8,
 	   '42'::xid8,
-	   '0xffffffffffffffff'::xid8,
-	   '-1'::xid8;
- xid | xid |    xid     |    xid     | xid8 | xid8 |         xid8         |         xid8         
------+-----+------------+------------+------+------+----------------------+----------------------
-   8 |  42 | 4294967295 | 4294967295 |    8 |   42 | 18446744073709551615 | 18446744073709551615
+	   '0xefffffffffffffff'::xid8,
+	   '0'::xid8;
+ xid | xid |    xid     |    xid     | xid8 | xid8 |         xid8         | xid8 
+-----+-----+------------+------------+------+------+----------------------+------
+   8 |  42 | 4294967295 | 4294967295 |    8 |   42 | 17293822569102704639 |    0
 (1 row)
 
 -- garbage values
@@ -30,6 +30,18 @@ select 'asdf'::xid8;
 ERROR:  invalid input syntax for type xid8: "asdf"
 LINE 1: select 'asdf'::xid8;
                ^
+select '-1'::xid8;
+ERROR:  value "-1" is out of range for type xid8
+LINE 1: select '-1'::xid8;
+               ^
+select '0xffffffffffffffff'::xid8;
+ERROR:  value "0xffffffffffffffff" is out of range for type xid8
+LINE 1: select '0xffffffffffffffff'::xid8;
+               ^
+select '0x0000300000000001'::xid8;
+ERROR:  value "0x0000300000000001" is out of range for type xid8
+LINE 1: select '0x0000300000000001'::xid8;
+               ^
 -- Also try it with non-error-throwing API
 SELECT pg_input_is_valid('42', 'xid');
  pg_input_is_valid 
@@ -67,6 +79,24 @@ SELECT pg_input_error_message('0xffffffffffffffffffff', 'xid8');
  value "0xffffffffffffffffffff" is out of range for type xid8
 (1 row)
 
+SELECT pg_input_is_valid('-1', 'xid8');
+ pg_input_is_valid 
+-------------------
+ f
+(1 row)
+
+SELECT pg_input_is_valid('0xffffffffffffffff', 'xid8');
+ pg_input_is_valid 
+-------------------
+ f
+(1 row)
+
+SELECT pg_input_is_valid('0x0000300000000001', 'xid8');
+ pg_input_is_valid 
+-------------------
+ f
+(1 row)
+
 -- equality
 select '1'::xid = '1'::xid;
  ?column? 
@@ -160,11 +190,11 @@ select xid8cmp('1', '2'), xid8cmp('2', '2'), xid8cmp('2', '1');
 
 -- min() and max() for xid8
 create table xid8_t1 (x xid8);
-insert into xid8_t1 values ('0'), ('010'), ('42'), ('0xffffffffffffffff'), ('-1');
+insert into xid8_t1 values ('0'), ('010'), ('42'), ('0xefffffffffffffff');
 select min(x), max(x) from xid8_t1;
  min |         max          
 -----+----------------------
-   0 | 18446744073709551615
+   0 | 17293822569102704639
 (1 row)
 
 -- xid8 has btree and hash opclasses
diff --git a/src/test/regress/sql/xid.sql b/src/test/regress/sql/xid.sql
index b6996588ef6..b2104e4ce20 100644
--- a/src/test/regress/sql/xid.sql
+++ b/src/test/regress/sql/xid.sql
@@ -7,14 +7,17 @@ select '010'::xid,
        '-1'::xid,
 	   '010'::xid8,
 	   '42'::xid8,
-	   '0xffffffffffffffff'::xid8,
-	   '-1'::xid8;
+	   '0xefffffffffffffff'::xid8,
+	   '0'::xid8;
 
 -- garbage values
 select ''::xid;
 select 'asdf'::xid;
 select ''::xid8;
 select 'asdf'::xid8;
+select '-1'::xid8;
+select '0xffffffffffffffff'::xid8;
+select '0x0000300000000001'::xid8;
 
 -- Also try it with non-error-throwing API
 SELECT pg_input_is_valid('42', 'xid');
@@ -23,6 +26,9 @@ SELECT pg_input_error_message('0xffffffffff', 'xid');
 SELECT pg_input_is_valid('42', 'xid8');
 SELECT pg_input_is_valid('asdf', 'xid8');
 SELECT pg_input_error_message('0xffffffffffffffffffff', 'xid8');
+SELECT pg_input_is_valid('-1', 'xid8');
+SELECT pg_input_is_valid('0xffffffffffffffff', 'xid8');
+SELECT pg_input_is_valid('0x0000300000000001', 'xid8');
 
 -- equality
 select '1'::xid = '1'::xid;
@@ -51,7 +57,7 @@ select xid8cmp('1', '2'), xid8cmp('2', '2'), xid8cmp('2', '1');
 
 -- min() and max() for xid8
 create table xid8_t1 (x xid8);
-insert into xid8_t1 values ('0'), ('010'), ('42'), ('0xffffffffffffffff'), ('-1');
+insert into xid8_t1 values ('0'), ('010'), ('42'), ('0xefffffffffffffff');
 select min(x), max(x) from xid8_t1;
 
 -- xid8 has btree and hash opclasses
-- 
2.38.0

