From ad030d722028f90ef2afe6e7d7a696bc1a8ae2eb Mon Sep 17 00:00:00 2001
From: John Naylor <john.naylor@postgresql.org>
Date: Mon, 25 Nov 2024 12:37:33 +0700
Subject: [PATCH v2] Branchless Lomuto partitioning

POC that only works on datatypes that use ssup_datum_signed_cmp.
Dev-only GUC needed because it can't yet handle NULLs or DESC.
---
 src/backend/utils/misc/guc_tables.c       |  12 ++
 src/backend/utils/sort/tuple_partition.h  |  38 +++++
 src/backend/utils/sort/tuple_small_sort.h |  30 ++++
 src/backend/utils/sort/tuplesort.c        | 192 +++++++++++++++++++++-
 src/include/utils/guc.h                   |   1 +
 src/include/utils/tuplesort.h             |   4 +
 6 files changed, 274 insertions(+), 3 deletions(-)
 create mode 100644 src/backend/utils/sort/tuple_partition.h
 create mode 100644 src/backend/utils/sort/tuple_small_sort.h

diff --git a/src/backend/utils/misc/guc_tables.c b/src/backend/utils/misc/guc_tables.c
index 71448bb4fd..b53ebfca09 100644
--- a/src/backend/utils/misc/guc_tables.c
+++ b/src/backend/utils/misc/guc_tables.c
@@ -1720,6 +1720,18 @@ struct config_bool ConfigureNamesBool[] =
 		NULL, NULL, NULL
 	},
 
+	{
+		/* XXX not for commit */
+		{"debug_branchless_sort", PGC_USERSET, QUERY_TUNING_METHOD,
+			gettext_noop("WIP testing of branchless sort techniques"),
+			NULL,
+			GUC_EXPLAIN
+		},
+		&debug_branchless_sort,
+		false,
+		NULL, NULL, NULL
+	},
+
 #ifdef TRACE_SYNCSCAN
 	/* this is undocumented because not exposed in a standard build */
 	{
diff --git a/src/backend/utils/sort/tuple_partition.h b/src/backend/utils/sort/tuple_partition.h
new file mode 100644
index 0000000000..d84c3ae005
--- /dev/null
+++ b/src/backend/utils/sort/tuple_partition.h
@@ -0,0 +1,38 @@
+/*
+ * Based on psuedocode from https://orlp.net/blog/branchless-lomuto-partitioning/
+ *
+ * There is deliberately no include guard here.
+ */
+
+size_t i = 0;
+size_t j = 0;
+SortTuple pivot = *pivot_pos;
+
+Assert(n>0);
+
+
+/* create gap at front */
+*pivot_pos = v[0];
+
+while (j < n - 1)
+{
+	v[j] = v[i];
+	j += 1;
+	v[i] = v[j];
+#ifdef PARTITION_LEFT
+	i += !CMP_2WAY(pivot, v[i]);
+#else
+	i += CMP_2WAY(v[i], pivot);
+#endif
+}
+
+v[j] = v[i];
+v[i] = pivot;
+#ifdef PARTITION_LEFT
+	i += !CMP_2WAY(pivot, v[i]);
+#else
+	i += CMP_2WAY(v[i], pivot);
+#endif
+
+/* i is the number of elements in the left partition */
+return i;
diff --git a/src/backend/utils/sort/tuple_small_sort.h b/src/backend/utils/sort/tuple_small_sort.h
new file mode 100644
index 0000000000..30737a286c
--- /dev/null
+++ b/src/backend/utils/sort/tuple_small_sort.h
@@ -0,0 +1,30 @@
+/*
+ * There is deliberately no include guard here.
+ */
+
+SortTuple	 *pl,
+			 *pm;
+
+for (pm = begin + 1; pm < begin + n; pm++)
+{
+	pl = pm;
+
+	/*
+	 * Compare first so we can avoid 2 moves for an element already
+	 * positioned correctly.
+	 */
+	if (CMP_3WAY(pl - 1, pl) > 0)
+	{
+		SortTuple tmp = *pl;
+
+		do
+		{
+			*pl = *(pl - 1);
+			pl--;
+		}
+		while (pl > begin && CMP_3WAY(pl - 1, &tmp) > 0);
+
+		*pl = tmp;
+	}
+
+}
diff --git a/src/backend/utils/sort/tuplesort.c b/src/backend/utils/sort/tuplesort.c
index 2ef32d53a4..623d278f4b 100644
--- a/src/backend/utils/sort/tuplesort.c
+++ b/src/backend/utils/sort/tuplesort.c
@@ -122,6 +122,7 @@
 
 /* GUC variables */
 bool		trace_sort = false;
+bool		debug_branchless_sort = false;	/* XXX not for commit */
 
 #ifdef DEBUG_BOUNDED_SORT
 bool		optimize_bounded_sort = true;
@@ -619,6 +620,155 @@ qsort_tuple_int32_compare(SortTuple *a, SortTuple *b, Tuplesortstate *state)
 #define ST_DEFINE
 #include "lib/sort_template.h"
 
+
+/*
+ * WIP: Branchless partitioning assumes NULLs have been handled already,
+ * so we don't consider them here.
+ * XXX: only works on first sort key, possibly abbreviated.
+ */
+#define LEADING_DATUM_CMP(a, b) \
+	ApplySortComparator((a)->datum1, false, \
+						(b)->datum1, false, ssup)
+
+static pg_noinline SortTuple *
+datum_med3(SortTuple *a,
+		   SortTuple *b,
+		   SortTuple *c, SortSupport ssup)
+{
+	return LEADING_DATUM_CMP(a, b) < 0 ?
+		(LEADING_DATUM_CMP(b, c) < 0 ? b : (LEADING_DATUM_CMP(a, c) < 0 ? c : a))
+		: (LEADING_DATUM_CMP(b, c) > 0 ? b : (LEADING_DATUM_CMP(a, c) < 0 ? a : c));
+}
+
+#define CMP_2WAY(a, b) (DatumGetInt64((a).datum1) < DatumGetInt64((b).datum1))
+#define CMP_3WAY(a,b) qsort_tuple_signed_compare(a,b,state)
+static size_t
+part_right_datum_signed_asc(SortTuple *v, size_t n, SortTuple *pivot_pos)
+{
+#include "tuple_partition.h"
+}
+
+static size_t
+part_left_datum_signed_asc(SortTuple *v, size_t n, SortTuple *pivot_pos)
+{
+#define PARTITION_LEFT
+#include "tuple_partition.h"
+#undef PARTITION_LEFT
+}
+
+static inline void
+small_sort_datum_signed(SortTuple *begin, size_t n, Tuplesortstate *state)
+{
+#include "tuple_small_sort.h"
+}
+#undef CMP_2WAY
+#undef CMP_3WAY
+
+
+static void
+qsort_tuple_datum(SortTuple *data, size_t n, Tuplesortstate *state, SortTuple *ancestor_pivot)
+{
+	SortTuple  *a = data,
+			   *pl,
+			   *pm,
+			   *pn;
+	size_t		n_left_part;
+	SortSupportData *ssup = state->base.sortKeys;
+
+
+loop:
+	CHECK_FOR_INTERRUPTS();
+
+	if (n < 7)
+	{
+		SortTuple  *begin;
+
+		if (ancestor_pivot != NULL &&
+			state->base.onlyKey == NULL)
+		{
+			/*
+			 * We must inculde the ancestor pivot, because the previous
+			 * partitioning step only compared the first key (possibly
+			 * abbreviated).
+			 */
+			begin = ancestor_pivot;
+			n++;
+		}
+		else
+			begin = a;
+
+		if (state->base.sortKeys[0].comparator == ssup_datum_signed_cmp)
+			small_sort_datum_signed(begin, n, state);
+
+		return;
+	}
+
+	pm = a + (n / 2);
+	if (n > 7)
+	{
+		pl = a;
+		pn = a + (n - 1);
+		if (n > 40)
+		{
+			size_t		d = (n / 8);
+
+			pl = datum_med3(pl, pl + d, pl + 2 * d, ssup);
+			pm = datum_med3(pm - d, pm, pm + d, ssup);
+			pn = datum_med3(pn - 2 * d, pn - d, pn, ssup);
+		}
+		pm = datum_med3(pl, pm, pn, ssup);
+	}
+
+	/*
+	 * Heuristic for when to bucket duplicates: If pivot compares equal to the
+	 * ancestor pivot, then there are likely a large number of duplicates in
+	 * this partition. In this case we "partition left", putting
+	 * elements equal to the pivot into the left partition, and greater elements
+	 * in the right partition.
+	 */
+	if (ancestor_pivot != NULL && LEADING_DATUM_CMP(ancestor_pivot, pm) == 0)
+	{
+		n_left_part = state->base.partition_left(a, n, pm);
+
+		/*
+		 * If the leading datum is authoritative, we are done. If not, we
+		 * recurse with a standard sort using the tiebreak comparator. We must
+		 * inculde both the current pivot and ancestor pivot.
+		 */
+		if (state->base.onlyKey == NULL)
+			qsort_tuple(ancestor_pivot, n_left_part + 1,
+						state->base.comparetup_tiebreak,
+						state);
+
+		/*
+		 * The only time this value must be correctly set to NULL is when we
+		 * enter the root partition. Setting it NULL here is an optimization:
+		 * Since all elements to the right of the current pivot are strictly
+		 * greater than it, we won't include it when we eventually
+		 * recurse to a small sort.
+		 */
+		ancestor_pivot = NULL;
+
+		a += n_left_part;
+		n -= n_left_part;
+		goto loop;
+	}
+	else
+		n_left_part = state->base.partition_right(a, n, pm);
+
+	/* WIP: Keep recursion simple for now. */
+
+	/* Recurse on left partition... */
+	qsort_tuple_datum(a, n_left_part, state, ancestor_pivot);
+
+	/* ..., then iterate on right partition to save stack space */
+	ancestor_pivot = a + n_left_part;
+	a = ancestor_pivot + 1;
+	n -= n_left_part + 1;
+	goto loop;
+}
+
+
 /*
  *		tuplesort_begin_xxx
  *
@@ -2679,6 +2829,23 @@ tuplesort_sort_memtuples(Tuplesortstate *state)
 
 	if (state->memtupcount > 1)
 	{
+		int presorted = 1;
+
+		/* one-time precheck for monotonic input */
+		for (SortTuple* pm = state->memtuples + 1;
+			 pm < state->memtuples + state->memtupcount;
+			 pm++)
+		{
+			CHECK_FOR_INTERRUPTS();
+			if (COMPARETUP(state, pm - 1, pm) > 0)
+			{
+				presorted = 0;
+				break;
+			}
+		}
+		if (presorted)
+			return;
+
 		/*
 		 * Do we have the leading column's value or abbreviation in datum1,
 		 * and is there a specialization for its comparator?
@@ -2695,9 +2862,28 @@ tuplesort_sort_memtuples(Tuplesortstate *state)
 #if SIZEOF_DATUM >= 8
 			else if (state->base.sortKeys[0].comparator == ssup_datum_signed_cmp)
 			{
-				qsort_tuple_signed(state->memtuples,
-								   state->memtupcount,
-								   state);
+				/* WIP: proof of concept for one datum type */
+				SortTuple  *pm;
+
+				if (debug_branchless_sort)
+				{
+					state->base.partition_right = part_right_datum_signed_asc;
+					state->base.partition_left = part_left_datum_signed_asc;
+					qsort_tuple_datum(state->memtuples,
+									  state->memtupcount,
+									  state, NULL);
+				}
+				else
+					qsort_tuple_signed(state->memtuples,
+									   state->memtupcount,
+									   state);
+
+				/* WIP: correctness check */
+				for (pm = state->memtuples + 1;
+					 pm < state->memtuples + state->memtupcount;
+					 pm++)
+					Assert(COMPARETUP(state, pm - 1, pm) <= 0);
+
 				return;
 			}
 #endif
diff --git a/src/include/utils/guc.h b/src/include/utils/guc.h
index 1233e07d7d..b049db493d 100644
--- a/src/include/utils/guc.h
+++ b/src/include/utils/guc.h
@@ -300,6 +300,7 @@ extern PGDLLIMPORT int tcp_user_timeout;
 extern PGDLLIMPORT char *role_string;
 extern PGDLLIMPORT bool in_hot_standby_guc;
 extern PGDLLIMPORT bool trace_sort;
+extern PGDLLIMPORT bool debug_branchless_sort; /* XXX not for commit */
 
 #ifdef DEBUG_BOUNDED_SORT
 extern PGDLLIMPORT bool optimize_bounded_sort;
diff --git a/src/include/utils/tuplesort.h b/src/include/utils/tuplesort.h
index c63f1e5d6d..85a511fe7b 100644
--- a/src/include/utils/tuplesort.h
+++ b/src/include/utils/tuplesort.h
@@ -180,6 +180,10 @@ typedef struct
 	 */
 	SortTupleComparator comparetup_tiebreak;
 
+	/* Optional partition functions for single-instruction comparators */
+	size_t		(*partition_right) (SortTuple *begin, size_t n, SortTuple *pivot_pos);
+	size_t		(*partition_left) (SortTuple *begin, size_t n, SortTuple *pivot_pos);
+
 	/*
 	 * Alter datum1 representation in the SortTuple's array back from the
 	 * abbreviated key to the first column value.
-- 
2.48.1

