From 4564bbf46e40834368975e0cea528d6077437576 Mon Sep 17 00:00:00 2001
From: John Naylor <john.naylor@postgresql.org>
Date: Fri, 17 Oct 2025 09:57:43 +0700
Subject: [PATCH v4 1/4] Use radix sort when datum1 is an integer type

For now this only works for signed and unsigned ints
with the usual comparison semantics, the same types
for which we previously had separate qsort
specializations.

Temporary GUC wip_radix_sort for testing
---
 src/backend/utils/misc/guc_parameters.dat |   7 +
 src/backend/utils/sort/tuplesort.c        | 399 ++++++++++++++++++++--
 src/include/utils/guc.h                   |   1 +
 src/include/utils/tuplesort.h             |   1 +
 4 files changed, 389 insertions(+), 19 deletions(-)

diff --git a/src/backend/utils/misc/guc_parameters.dat b/src/backend/utils/misc/guc_parameters.dat
index 1128167c025..c9167eb4bb4 100644
--- a/src/backend/utils/misc/guc_parameters.dat
+++ b/src/backend/utils/misc/guc_parameters.dat
@@ -3469,6 +3469,13 @@
   max => 'INT_MAX',
 },
 
+{ name => 'wip_radix_sort', type => 'bool', context => 'PGC_USERSET', group => 'DEVELOPER_OPTIONS',
+  short_desc => 'Test radix sort for debugging.',
+  flags => 'GUC_NOT_IN_SAMPLE',
+  variable => 'wip_radix_sort',
+  boot_val => 'true',
+},
+
 { name => 'work_mem', type => 'int', context => 'PGC_USERSET', group => 'RESOURCES_MEM',
   short_desc => 'Sets the maximum memory to be used for query workspaces.',
   long_desc => 'This much memory can be used by each internal sort operation and hash table before switching to temporary disk files.',
diff --git a/src/backend/utils/sort/tuplesort.c b/src/backend/utils/sort/tuplesort.c
index 5d4411dc33f..c2b9625ae88 100644
--- a/src/backend/utils/sort/tuplesort.c
+++ b/src/backend/utils/sort/tuplesort.c
@@ -104,6 +104,7 @@
 #include "commands/tablespace.h"
 #include "miscadmin.h"
 #include "pg_trace.h"
+#include "port/pg_bitutils.h"
 #include "storage/shmem.h"
 #include "utils/guc.h"
 #include "utils/memutils.h"
@@ -122,6 +123,7 @@
 
 /* GUC variables */
 bool		trace_sort = false;
+bool		wip_radix_sort = true;	/* FIXME not for commit */
 
 #ifdef DEBUG_BOUNDED_SORT
 bool		optimize_bounded_sort = true;
@@ -615,6 +617,25 @@ qsort_tuple_int32_compare(SortTuple *a, SortTuple *b, Tuplesortstate *state)
 #define ST_DEFINE
 #include "lib/sort_template.h"
 
+
+#ifdef USE_ASSERT_CHECKING
+/* WIP: for now prefer test coverage of radix sort in Assert builds. */
+#define QSORT_THRESHOLD 0
+#else
+/* WIP: low because qsort_tuple() is slow -- we could raise this with a new specialization */
+#define QSORT_THRESHOLD 40
+#endif
+
+typedef struct RadixPartitionInfo
+{
+	union
+	{
+		size_t		count;
+		size_t		offset;
+	};
+	size_t		next_offset;
+}			RadixPartitionInfo;
+
 /*
  *		tuplesort_begin_xxx
  *
@@ -2663,10 +2684,334 @@ sort_bounded_heap(Tuplesortstate *state)
 	state->boundUsed = true;
 }
 
+static inline uint8_t
+extract_byte(Datum key, int level)
+{
+	return (key >> (((SIZEOF_DATUM - 1) - level) * 8)) & 0xFF;
+}
+
+/*
+ * Normalize datum to work with pure unsigned comparison,
+ * taking ASC/DESC into account as well.
+ */
+static inline Datum
+normalize_datum(Datum orig, SortSupport ssup)
+{
+	Datum		norm_datum1;
+
+	if (ssup->comparator == ssup_datum_signed_cmp)
+	{
+		norm_datum1 = orig + ((uint64) PG_INT64_MAX) + 1;
+	}
+	else if (ssup->comparator == ssup_datum_int32_cmp)
+	{
+		/*
+		 * First truncate to uint32. Technically, we don't need to do this,
+		 * but it forces the upper bytes to remain the same regardless of
+		 * sign.
+		 */
+		uint32		u32 = DatumGetUInt32(orig) + ((uint32) PG_INT32_MAX) + 1;
+
+		norm_datum1 = UInt32GetDatum(u32);
+	}
+	else
+	{
+		Assert(ssup->comparator == ssup_datum_unsigned_cmp);
+		norm_datum1 = orig;
+	}
+
+	if (ssup->ssup_reverse)
+		norm_datum1 = ~norm_datum1;
+
+	return norm_datum1;
+}
+
+/*
+ * Based on implementation in https://github.com/skarupke/ska_sort (Boost license),
+ * with the following noncosmetic change:
+ *  - count sorted partitions in every pass, rather than maintaining a
+ *    list of unsorted partitions
+ */
+static void
+radix_sort_tuple(SortTuple *begin, size_t n_elems, int level, Tuplesortstate *state)
+{
+	RadixPartitionInfo partitions[256] = {0};
+	uint8_t		remaining_partitions[256] = {0};
+	size_t		total = 0;
+	int			num_partitions = 0;
+	int			num_remaining;
+	SortSupport ssup = &state->base.sortKeys[0];
+	size_t		start_offset = 0;
+	SortTuple  *partition_begin = begin;
+
+	/* count key chunks */
+	for (SortTuple *tup = begin; tup < begin + n_elems; tup++)
+	{
+		uint8		current_byte;
+
+		/* extract the byte for this level from the normalized datum */
+		current_byte = extract_byte(normalize_datum(tup->datum1, ssup),
+									level);
+
+		/* save it for the permutation step */
+		tup->current_byte = current_byte;
+
+		partitions[current_byte].count++;
+	}
+
+	/* compute partition offsets */
+	for (int i = 0; i < 256; i++)
+	{
+		size_t		count = partitions[i].count;
+
+		if (count)
+		{
+			partitions[i].offset = total;
+			total += count;
+			remaining_partitions[num_partitions] = i;
+			num_partitions++;
+		}
+		partitions[i].next_offset = total;
+	}
+
+	num_remaining = num_partitions;
+
+	/*
+	 * Permute tuples to correct partition. If we started with one partition,
+	 * there is nothing to do. If a permutation from a previous iteration
+	 * results in a single partition that hasn't been marked as sorted, we
+	 * know it's actually sorted.
+	 */
+	while (num_remaining > 1)
+	{
+		/*
+		 * We can only exit the loop when all partitions are sorted, so must
+		 * reset every iteration
+		 */
+		num_remaining = num_partitions;
+
+		for (int i = 0; i < num_partitions; i++)
+		{
+			uint8		idx = remaining_partitions[i];
+
+			RadixPartitionInfo part = partitions[idx];
+
+			for (SortTuple *st = begin + part.offset;
+				 st < begin + part.next_offset;
+				 st++)
+			{
+				size_t		offset = partitions[st->current_byte].offset++;
+				SortTuple	tmp;
+
+				/* swap current tuple with destination position */
+				Assert(offset < n_elems);
+				tmp = *st;
+				*st = begin[offset];
+				begin[offset] = tmp;
+			};
+
+			if (part.offset == part.next_offset)
+			{
+				/* partition is sorted */
+				num_remaining--;
+			}
+		}
+	}
+
+	/* recurse */
+	for (uint8_t *rp = remaining_partitions;
+		 rp < remaining_partitions + num_partitions;
+		 rp++)
+	{
+		size_t		end_offset = partitions[*rp].next_offset;
+		SortTuple  *partition_end = begin + end_offset;
+		ptrdiff_t	num_elements = end_offset - start_offset;
+
+		if (num_elements > 1)
+		{
+			if (level < SIZEOF_DATUM - 1)
+			{
+				if (num_elements < QSORT_THRESHOLD)
+				{
+					qsort_tuple(partition_begin,
+								num_elements,
+								state->base.comparetup,
+								state);
+				}
+				else
+				{
+					radix_sort_tuple(partition_begin,
+									 num_elements,
+									 level + 1,
+									 state);
+				}
+			}
+			else if (state->base.onlyKey == NULL)
+			{
+				/*
+				 * We've finished radix sort on all bytes of the pass-by-value
+				 * datum (possibly abbreviated), now qsort with the tiebreak
+				 * comparator.
+				 */
+				qsort_tuple(partition_begin,
+							num_elements,
+							state->base.comparetup_tiebreak,
+							state);
+			}
+		}
+
+		start_offset = end_offset;
+		partition_begin = partition_end;
+	}
+}
+
 /*
- * Sort all memtuples using specialized qsort() routines.
+ * Partition tuples by NULL and NOT NULL first sort key.
+ * Then dispatch to either radix sort or qsort.
+ */
+static void
+sort_byvalue_datum(Tuplesortstate *state)
+{
+	SortSupportData ssup = state->base.sortKeys[0];
+
+	bool		nulls_first = ssup.ssup_nulls_first;
+	SortTuple  *data = state->memtuples;
+	SortTuple  *null_start;
+	SortTuple  *not_null_start;
+	size_t		d1 = 0,
+				d2,
+				null_count,
+				not_null_count;
+
+	/*
+	 * First, partition by NULL-ness of the leading sort key, since we can
+	 * only radix sort on NOT NULL pass-by-value datums.
+	 */
+
+	/*
+	 * Find the first NOT NULL tuple if NULLS FIRST, or first NULL element if
+	 * NULLS LAST. This is a quick check for the common case where all tuples
+	 * are NOT NULL in the first sort key.
+	 */
+	while (d1 < state->memtupcount && data[d1].isnull1 == nulls_first)
+		d1++;
+
+	/*
+	 * If we have more than one tuple left after the quick check, partition
+	 * the remainder using branchless cyclic permutation, based on
+	 * https://orlp.net/blog/branchless-lomuto-partitioning/
+	 */
+	if (d1 < state->memtupcount - 1)
+	{
+		size_t		j = d1;
+		SortTuple	save = data[d1];	/* create gap at front */
+
+		/* WIP: more comments */
+		while (j < state->memtupcount - 1)
+		{
+			data[j] = data[d1];
+			j += 1;
+			data[d1] = data[j];
+			d1 += (data[d1].isnull1 == nulls_first);
+		}
+
+		data[j] = data[d1];
+		data[d1] = save;
+		d1 += (data[d1].isnull1 == nulls_first);
+	}
+
+	/* d1 is now the number of elements in the left partition */
+	d2 = state->memtupcount - d1;
+
+	/* set pointers and counts for each partition */
+	if (nulls_first)
+	{
+		null_start = state->memtuples;
+		null_count = d1;
+		not_null_start = state->memtuples + d1;
+		not_null_count = d2;
+	}
+	else
+	{
+		not_null_start = state->memtuples;
+		not_null_count = d1;
+		null_start = state->memtuples + d1;
+		null_count = d2;
+	}
+
+	for (SortTuple *tup = null_start;
+		 tup < null_start + null_count;
+		 tup++)
+		Assert(tup->isnull1 == true);
+	for (SortTuple *tup = not_null_start;
+		 tup < not_null_start + not_null_count;
+		 tup++)
+		Assert(tup->isnull1 == false);
+
+	/*
+	 * Sort the NULL partition using tiebreak comparator, if necessary. XXX
+	 * this will repeat the comparison on isnull1 for abbreviated keys.
+	 */
+	if (state->base.onlyKey == NULL && null_count > 1)
+	{
+		qsort_tuple(null_start,
+					null_count,
+					state->base.comparetup_tiebreak,
+					state);
+	}
+
+	/*
+	 * Sort the NOT NULL partition, using radix sort if large enough,
+	 * otherwise fall back to quicksort.
+	 */
+	if (not_null_count > 1)
+	{
+		if (not_null_count < QSORT_THRESHOLD)
+		{
+			/*
+			 * WIP: We could compute the common prefix, save the following
+			 * byte in current_byte, and use a new qsort specialization for
+			 * that. Same for the diversion to qsort while recursing during
+			 * radix sort.
+			 */
+			qsort_tuple(not_null_start,
+						not_null_count,
+						state->base.comparetup,
+						state);
+		}
+		else
+		{
+			radix_sort_tuple(not_null_start,
+							 not_null_count,
+							 0,
+							 state);
+		}
+	}
+}
+
+/* Verify sort using standard comparator. */
+static void
+verify_sorted_memtuples(Tuplesortstate *state)
+{
+#ifdef USE_ASSERT_CHECKING
+	for (SortTuple *tup = state->memtuples + 1;
+		 tup < state->memtuples + state->memtupcount;
+		 tup++)
+	{
+#if 0
+		Assert(COMPARETUP(state, tup - 1, tup) <= 0);
+#else
+		if (COMPARETUP(state, tup - 1, tup) > 0)
+			elog(ERROR, "SORT FAILED");
+#endif
+	}
+#endif
+}
+
+/*
+ * Sort all memtuples using specialized routines.
  *
- * Quicksort is used for small in-memory sorts, and external sort runs.
+ * Quicksort or radix sort is used for small in-memory sorts, and external sort runs.
  */
 static void
 tuplesort_sort_memtuples(Tuplesortstate *state)
@@ -2681,26 +3026,42 @@ tuplesort_sort_memtuples(Tuplesortstate *state)
 		 */
 		if (state->base.haveDatum1 && state->base.sortKeys)
 		{
-			if (state->base.sortKeys[0].comparator == ssup_datum_unsigned_cmp)
-			{
-				qsort_tuple_unsigned(state->memtuples,
-									 state->memtupcount,
-									 state);
-				return;
-			}
-			else if (state->base.sortKeys[0].comparator == ssup_datum_signed_cmp)
+			SortSupportData ssup = state->base.sortKeys[0];
+
+			if (wip_radix_sort)
 			{
-				qsort_tuple_signed(state->memtuples,
-								   state->memtupcount,
-								   state);
-				return;
+				if ((ssup.comparator == ssup_datum_unsigned_cmp ||
+					 ssup.comparator == ssup_datum_signed_cmp ||
+					 ssup.comparator == ssup_datum_int32_cmp))
+				{
+					sort_byvalue_datum(state);
+					verify_sorted_memtuples(state);
+					return;
+				}
 			}
-			else if (state->base.sortKeys[0].comparator == ssup_datum_int32_cmp)
+			else
 			{
-				qsort_tuple_int32(state->memtuples,
-								  state->memtupcount,
-								  state);
-				return;
+				if (state->base.sortKeys[0].comparator == ssup_datum_unsigned_cmp)
+				{
+					qsort_tuple_unsigned(state->memtuples,
+										 state->memtupcount,
+										 state);
+					return;
+				}
+				else if (state->base.sortKeys[0].comparator == ssup_datum_signed_cmp)
+				{
+					qsort_tuple_signed(state->memtuples,
+									   state->memtupcount,
+									   state);
+					return;
+				}
+				else if (state->base.sortKeys[0].comparator == ssup_datum_int32_cmp)
+				{
+					qsort_tuple_int32(state->memtuples,
+									  state->memtupcount,
+									  state);
+					return;
+				}
 			}
 		}
 
diff --git a/src/include/utils/guc.h b/src/include/utils/guc.h
index f21ec37da89..bc6f7fa60f3 100644
--- a/src/include/utils/guc.h
+++ b/src/include/utils/guc.h
@@ -324,6 +324,7 @@ extern PGDLLIMPORT int tcp_user_timeout;
 extern PGDLLIMPORT char *role_string;
 extern PGDLLIMPORT bool in_hot_standby_guc;
 extern PGDLLIMPORT bool trace_sort;
+extern PGDLLIMPORT bool wip_radix_sort;
 
 #ifdef DEBUG_BOUNDED_SORT
 extern PGDLLIMPORT bool optimize_bounded_sort;
diff --git a/src/include/utils/tuplesort.h b/src/include/utils/tuplesort.h
index 0bf55902aa1..e40c6e52f81 100644
--- a/src/include/utils/tuplesort.h
+++ b/src/include/utils/tuplesort.h
@@ -150,6 +150,7 @@ typedef struct
 	void	   *tuple;			/* the tuple itself */
 	Datum		datum1;			/* value of first key column */
 	bool		isnull1;		/* is first key column NULL? */
+	uint8		current_byte;	/* chunk of datum1 conditioned for radix sort */
 	int			srctape;		/* source tape number */
 } SortTuple;
 
-- 
2.51.1

