From a66444d7e0f59ecc9a9112671a24391344bb5eec Mon Sep 17 00:00:00 2001
From: David Geier <geidav.pg@gmail.com>
Date: Mon, 8 Sep 2025 12:06:44 +0200
Subject: [PATCH] Optimize eqjoinsel_inner() and eqjoinsel_semi()

Previously an O(N^2) algorithm was used to look for matching MCV
values between two tables. Significantly increasing
default_statistics_target could result in planning taking seconds.

The O(N^2) algorithm got replaced with an O(N) algorithm which is
based on a hash table. If the column type is not hashable, we
fallback to the O(N^2) algorithm. As follow-on work, We can add a
second fast path based on sorting and merging for types that are
not hashable.
---
 src/backend/utils/adt/selfuncs.c | 230 +++++++++++++++++++++++++------
 1 file changed, 190 insertions(+), 40 deletions(-)

diff --git a/src/backend/utils/adt/selfuncs.c b/src/backend/utils/adt/selfuncs.c
index 1c480cfaaf7..8837cda0016 100644
--- a/src/backend/utils/adt/selfuncs.c
+++ b/src/backend/utils/adt/selfuncs.c
@@ -143,6 +143,8 @@
 
 #define DEFAULT_PAGE_CPU_MULTIPLIER 50.0
 
+struct McvHashTable_hash;
+
 /* Hooks for plugins to get control when we ask for stats */
 get_relation_stats_hook_type get_relation_stats_hook = NULL;
 get_index_stats_hook_type get_index_stats_hook = NULL;
@@ -217,7 +219,139 @@ static bool get_actual_variable_endpoint(Relation heapRel,
 static RelOptInfo *find_join_input_rel(PlannerInfo *root, Relids relids);
 static double btcost_correlation(IndexOptInfo *index,
 								 VariableStatData *vardata);
+static uint32 hash_msv(struct McvHashTable_hash *hashTable, Datum key);
+static bool are_mcvs_equal(struct McvHashTable_hash *hashTable, Datum value1, Datum value2);
+
+typedef struct McvHashEntry
+{
+	Datum  value;
+	uint32 index;
+	uint32 hash;
+	char   status;
+} McvHashEntry;
 
+typedef struct McvHashContext
+{
+	FmgrInfo equal_proc;
+	FmgrInfo hash_proc;
+	Oid      collation;
+} McvHashContext;
+
+#define SH_PREFIX                  McvHashTable
+#define SH_ELEMENT_TYPE            McvHashEntry
+#define SH_KEY_TYPE                Datum
+#define SH_KEY                     value
+#define SH_HASH_KEY(mcvs, key)     hash_msv(mcvs, key)
+#define SH_EQUAL(mcvs, key0, key1) are_mcvs_equal(mcvs, key0, key1)
+#define SH_SCOPE                   static inline
+#define SH_STORE_HASH
+#define SH_GET_HASH(mcvs, key)     key->hash
+#define SH_DEFINE
+#define SH_DECLARE
+#include "lib/simplehash.h"
+
+static uint32
+hash_msv(struct McvHashTable_hash *hashTable, Datum key)
+{	
+	McvHashContext *context = (McvHashContext *)hashTable->private_data;
+	return DatumGetUInt32(FunctionCall1Coll(&context->hash_proc, context->collation, key));
+}
+
+static bool
+are_mcvs_equal(struct McvHashTable_hash *hashTable, Datum value1, Datum value2)
+{
+	/*
+	 * We can safely use FunctionCall2Coll() which requires the result to
+	 * never be NULL, because MCV arrays from 'pg_statistic' don't contain
+	 * NULL values
+	 */
+	McvHashContext *context = (McvHashContext *)hashTable->private_data;
+	return DatumGetBool(FunctionCall2Coll(&context->equal_proc, context->collation, value1, value2));
+}
+  
+/*
+ * eqjoinsel_inner_with_hashtable
+ *
+ * Optimizes inner equality join selectivity estimation by using an O(n)
+ * algorithm based on hashing. Returns whether or not all prerequisites are
+ * met and the operation was successful. The result is used to know if to
+ * fallback to the default implementation.
+ */
+static bool
+try_eqjoinsel_with_hashtable(Oid operatorOid, Oid collation,
+							AttStatsSlot *statsSlot1, AttStatsSlot *statsSlot2, int nvaluesSlot2,
+							FunctionCallInfo equalFunctionCallInfo,
+							/* Output parameters: */
+							double *matchprodfreq, int *nmatches, bool *hasmatch1, bool *hasmatch2)
+{
+	AttStatsSlot	*statsInner = statsSlot2;
+	AttStatsSlot	*statsOuter = statsSlot1;
+	bool			*hasMatchInner = hasmatch2;
+	bool			*hasMatchOuter = hasmatch1;
+	int				nvaluesInner = nvaluesSlot2;
+	int				nvaluesOuter = statsSlot1->nvalues;
+	McvHashContext	hashContext;
+	McvHashTable_hash *hashTable;
+	Oid 			hashLeft  = InvalidOid;
+	Oid 			hashRight = InvalidOid;
+
+	/*
+	 * If one MCV array contains only a single value, there's no gain in using a hash table.
+	 * The sweet spot of using hash table lookups instead of iterating is slightly higher
+	 * than 1 but we don't bother here because the gains are neglectable.
+	 */
+	if (Min(statsSlot1->nvalues, nvaluesSlot2) == 1)
+		return false;
+
+	get_op_hash_functions(operatorOid, &hashLeft, &hashRight);
+	if (!OidIsValid(hashLeft) || hashLeft != hashRight)
+		return false;
+
+	/* Make sure we build the hash table on the smaller array. */
+	if (nvaluesOuter < nvaluesInner)
+	{
+		statsInner = statsSlot1;
+		statsOuter = statsSlot2;
+		hasMatchInner = hasmatch1;
+		hasMatchOuter = hasmatch2;
+		nvaluesInner = statsSlot1->nvalues;
+		nvaluesOuter = nvaluesSlot2;
+	}
+
+	/* 1. Create hash table of smaller 'pg_statistic' array. That's O(n). */
+	fmgr_info(get_opcode(operatorOid), &hashContext.equal_proc);
+	fmgr_info(hashLeft, &hashContext.hash_proc); /* hashLeft == hashRight */
+	hashContext.collation = collation;
+
+	hashTable = McvHashTable_create(CurrentMemoryContext, nvaluesInner, &hashContext);
+
+	for (int i = 0; i < nvaluesInner; i++)
+	{
+		bool found = false;
+		McvHashEntry *entry = McvHashTable_insert(hashTable, statsInner->values[i], &found);
+		Assert(!found);
+		entry->index = i;
+	}
+
+	/* 2. Look-up values from other 'pg_statistic' array against hash map to find matches. */
+	for (int i = 0; i < nvaluesOuter; i++)
+	{
+		McvHashEntry *entry = McvHashTable_lookup(hashTable, statsOuter->values[i]);
+		if (entry != NULL)
+		{
+			hasMatchInner[entry->index] = true;
+			hasMatchOuter[i]            = true;
+			(*nmatches)++;
+
+			/* Conditional because not needed by SEMI join selectivity estimation */
+			if (matchprodfreq != NULL)
+				*matchprodfreq += statsInner->numbers[entry->index] * statsOuter->numbers[i];
+		}
+	}
+
+	McvHashTable_destroy(hashTable);
+	return true;
+}
 
 /*
  *		eqsel			- Selectivity of "=" for any data types.
@@ -2350,7 +2484,7 @@ eqjoinsel(PG_FUNCTION_ARGS)
 	}
 
 	/* We need to compute the inner-join selectivity in all cases */
-	selec_inner = eqjoinsel_inner(opfuncoid, collation,
+	selec_inner = eqjoinsel_inner(operator, collation,
 								  &vardata1, &vardata2,
 								  nd1, nd2,
 								  isdefault1, isdefault2,
@@ -2377,7 +2511,7 @@ eqjoinsel(PG_FUNCTION_ARGS)
 			inner_rel = find_join_input_rel(root, sjinfo->min_righthand);
 
 			if (!join_is_reversed)
-				selec = eqjoinsel_semi(opfuncoid, collation,
+				selec = eqjoinsel_semi(operator, collation,
 									   &vardata1, &vardata2,
 									   nd1, nd2,
 									   isdefault1, isdefault2,
@@ -2388,9 +2522,8 @@ eqjoinsel(PG_FUNCTION_ARGS)
 			else
 			{
 				Oid			commop = get_commutator(operator);
-				Oid			commopfuncoid = OidIsValid(commop) ? get_opcode(commop) : InvalidOid;
 
-				selec = eqjoinsel_semi(commopfuncoid, collation,
+				selec = eqjoinsel_semi(commop, collation,
 									   &vardata2, &vardata1,
 									   nd2, nd1,
 									   isdefault2, isdefault1,
@@ -2438,7 +2571,7 @@ eqjoinsel(PG_FUNCTION_ARGS)
  * that it's worth trying to distinguish them here.
  */
 static double
-eqjoinsel_inner(Oid opfuncoid, Oid collation,
+eqjoinsel_inner(Oid operator, Oid collation,
 				VariableStatData *vardata1, VariableStatData *vardata2,
 				double nd1, double nd2,
 				bool isdefault1, bool isdefault2,
@@ -2480,7 +2613,7 @@ eqjoinsel_inner(Oid opfuncoid, Oid collation,
 		int			i,
 					nmatches;
 
-		fmgr_info(opfuncoid, &eqproc);
+		fmgr_info(get_opcode(operator), &eqproc);
 
 		/*
 		 * Save a few cycles by setting up the fcinfo struct just once. Using
@@ -2504,30 +2637,38 @@ eqjoinsel_inner(Oid opfuncoid, Oid collation,
 		 */
 		matchprodfreq = 0.0;
 		nmatches = 0;
-		for (i = 0; i < sslot1->nvalues; i++)
-		{
-			int			j;
 
-			fcinfo->args[0].value = sslot1->values[i];
-
-			for (j = 0; j < sslot2->nvalues; j++)
+		if (!try_eqjoinsel_with_hashtable(operator, collation, sslot1, sslot2, sslot2->nvalues,
+											fcinfo, &matchprodfreq, &nmatches,
+											hasmatch1, hasmatch2))
+		{
+			/* Fallback to O(N^2) algorithm if hash based variant didn't succeed. */
+			for (i = 0; i < sslot1->nvalues; i++)
 			{
-				Datum		fresult;
+				int			j;
 
-				if (hasmatch2[j])
-					continue;
-				fcinfo->args[1].value = sslot2->values[j];
-				fcinfo->isnull = false;
-				fresult = FunctionCallInvoke(fcinfo);
-				if (!fcinfo->isnull && DatumGetBool(fresult))
+				fcinfo->args[0].value = sslot1->values[i];
+
+				for (j = 0; j < sslot2->nvalues; j++)
 				{
-					hasmatch1[i] = hasmatch2[j] = true;
-					matchprodfreq += sslot1->numbers[i] * sslot2->numbers[j];
-					nmatches++;
-					break;
+					Datum		fresult;
+
+					if (hasmatch2[j])
+						continue;
+					fcinfo->args[1].value = sslot2->values[j];
+					fcinfo->isnull = false;
+					fresult = FunctionCallInvoke(fcinfo);
+					if (!fcinfo->isnull && DatumGetBool(fresult))
+					{
+						hasmatch1[i] = hasmatch2[j] = true;
+						matchprodfreq += sslot1->numbers[i] * sslot2->numbers[j];
+						nmatches++;
+						break;
+					}
 				}
 			}
 		}
+
 		CLAMP_PROBABILITY(matchprodfreq);
 		/* Sum up frequencies of matched and unmatched MCVs */
 		matchfreq1 = unmatchfreq1 = 0.0;
@@ -2635,7 +2776,7 @@ eqjoinsel_inner(Oid opfuncoid, Oid collation,
  * Unlike eqjoinsel_inner, we have to cope with opfuncoid being InvalidOid.
  */
 static double
-eqjoinsel_semi(Oid opfuncoid, Oid collation,
+eqjoinsel_semi(Oid operator, Oid collation,
 			   VariableStatData *vardata1, VariableStatData *vardata2,
 			   double nd1, double nd2,
 			   bool isdefault1, bool isdefault2,
@@ -2645,6 +2786,7 @@ eqjoinsel_semi(Oid opfuncoid, Oid collation,
 			   RelOptInfo *inner_rel)
 {
 	double		selec;
+	Oid			opfuncoid = get_opcode(operator);
 
 	/*
 	 * We clamp nd2 to be not more than what we estimate the inner relation's
@@ -2733,29 +2875,37 @@ eqjoinsel_semi(Oid opfuncoid, Oid collation,
 		 * and because the math wouldn't add up...
 		 */
 		nmatches = 0;
-		for (i = 0; i < sslot1->nvalues; i++)
-		{
-			int			j;
 
-			fcinfo->args[0].value = sslot1->values[i];
-
-			for (j = 0; j < clamped_nvalues2; j++)
+		if (!try_eqjoinsel_with_hashtable(operator, collation, sslot1, sslot2,
+										clamped_nvalues2, fcinfo, NULL,
+										&nmatches, hasmatch1, hasmatch2))
+		{
+			/* Fallback to O(N^2) algorithm if hash based variant didn't succeed. */
+			for (i = 0; i < sslot1->nvalues; i++)
 			{
-				Datum		fresult;
+				int			j;
 
-				if (hasmatch2[j])
-					continue;
-				fcinfo->args[1].value = sslot2->values[j];
-				fcinfo->isnull = false;
-				fresult = FunctionCallInvoke(fcinfo);
-				if (!fcinfo->isnull && DatumGetBool(fresult))
+				fcinfo->args[0].value = sslot1->values[i];
+
+				for (j = 0; j < clamped_nvalues2; j++)
 				{
-					hasmatch1[i] = hasmatch2[j] = true;
-					nmatches++;
-					break;
+					Datum		fresult;
+
+					if (hasmatch2[j])
+						continue;
+					fcinfo->args[1].value = sslot2->values[j];
+					fcinfo->isnull = false;
+					fresult = FunctionCallInvoke(fcinfo);
+					if (!fcinfo->isnull && DatumGetBool(fresult))
+					{
+						hasmatch1[i] = hasmatch2[j] = true;
+						nmatches++;
+						break;
+					}
 				}
 			}
 		}
+
 		/* Sum up frequencies of matched MCVs */
 		matchfreq1 = 0.0;
 		for (i = 0; i < sslot1->nvalues; i++)
-- 
2.43.0

