From eefd65d1d05111cd12a93902e8acf009d2f4c39f Mon Sep 17 00:00:00 2001
From: Jeff Davis <jeff@j-davis.com>
Date: Thu, 26 Sep 2024 11:27:29 -0700
Subject: [PATCH v8 2/7] Control collation behavior with a method table.

Previously, behavior branched based on the provider.

A method table is less error prone and easier to hook.
---
 src/backend/utils/adt/pg_locale.c      | 123 +++------------------
 src/backend/utils/adt/pg_locale_icu.c  | 147 +++++++++++++++----------
 src/backend/utils/adt/pg_locale_libc.c |  40 +++++--
 src/include/utils/pg_locale.h          |  33 ++++++
 4 files changed, 167 insertions(+), 176 deletions(-)

diff --git a/src/backend/utils/adt/pg_locale.c b/src/backend/utils/adt/pg_locale.c
index ec5f509c4e..00eca68717 100644
--- a/src/backend/utils/adt/pg_locale.c
+++ b/src/backend/utils/adt/pg_locale.c
@@ -92,27 +92,12 @@
 /* pg_locale_icu.c */
 #ifdef USE_ICU
 extern UCollator *pg_ucol_open(const char *loc_str);
-extern int	strncoll_icu(const char *arg1, ssize_t len1,
-						 const char *arg2, ssize_t len2,
-						 pg_locale_t locale);
-extern size_t strnxfrm_icu(char *dest, size_t destsize,
-						   const char *src, ssize_t srclen,
-						   pg_locale_t locale);
-extern size_t strnxfrm_prefix_icu(char *dest, size_t destsize,
-								  const char *src, ssize_t srclen,
-								  pg_locale_t locale);
 #endif
 
 /* pg_locale_libc.c */
 extern pg_locale_t create_pg_locale_builtin(Oid collid, MemoryContext context);
 extern pg_locale_t create_pg_locale_icu(Oid collid, MemoryContext context);
 extern pg_locale_t create_pg_locale_libc(Oid collid, MemoryContext context);
-extern int	strncoll_libc(const char *arg1, ssize_t len1,
-						  const char *arg2, ssize_t len2,
-						  pg_locale_t locale);
-extern size_t strnxfrm_libc(char *dest, size_t destsize,
-							const char *src, ssize_t srclen,
-							pg_locale_t locale);
 
 /* GUC settings */
 char	   *locale_messages;
@@ -1239,6 +1224,9 @@ create_pg_locale(Oid collid, MemoryContext context)
 		/* shouldn't happen */
 		PGLOCALE_SUPPORT_ERROR(collform->collprovider);
 
+	Assert((result->collate_is_c && result->collate == NULL) ||
+		   (!result->collate_is_c && result->collate != NULL));
+
 	datum = SysCacheGetAttr(COLLOID, tp, Anum_pg_collation_collversion,
 							&isnull);
 	if (!isnull)
@@ -1490,19 +1478,7 @@ get_collation_actual_version(char collprovider, const char *collcollate)
 int
 pg_strcoll(const char *arg1, const char *arg2, pg_locale_t locale)
 {
-	int			result;
-
-	if (locale->provider == COLLPROVIDER_LIBC)
-		result = strncoll_libc(arg1, -1, arg2, -1, locale);
-#ifdef USE_ICU
-	else if (locale->provider == COLLPROVIDER_ICU)
-		result = strncoll_icu(arg1, -1, arg2, -1, locale);
-#endif
-	else
-		/* shouldn't happen */
-		PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-	return result;
+	return locale->collate->strncoll(arg1, -1, arg2, -1, locale);
 }
 
 /*
@@ -1523,51 +1499,25 @@ int
 pg_strncoll(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
 			pg_locale_t locale)
 {
-	int			result;
-
-	if (locale->provider == COLLPROVIDER_LIBC)
-		result = strncoll_libc(arg1, len1, arg2, len2, locale);
-#ifdef USE_ICU
-	else if (locale->provider == COLLPROVIDER_ICU)
-		result = strncoll_icu(arg1, len1, arg2, len2, locale);
-#endif
-	else
-		/* shouldn't happen */
-		PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-	return result;
+	return locale->collate->strncoll(arg1, len1, arg2, len2, locale);
 }
 
 /*
  * Return true if the collation provider supports pg_strxfrm() and
  * pg_strnxfrm(); otherwise false.
  *
- * Unfortunately, it seems that strxfrm() for non-C collations is broken on
- * many common platforms; testing of multiple versions of glibc reveals that,
- * for many locales, strcoll() and strxfrm() do not return consistent
- * results. While no other libc other than Cygwin has so far been shown to
- * have a problem, we take the conservative course of action for right now and
- * disable this categorically.  (Users who are certain this isn't a problem on
- * their system can define TRUST_STRXFRM.)
  *
  * No similar problem is known for the ICU provider.
  */
 bool
 pg_strxfrm_enabled(pg_locale_t locale)
 {
-	if (locale->provider == COLLPROVIDER_LIBC)
-#ifdef TRUST_STRXFRM
-		return true;
-#else
-		return false;
-#endif
-	else if (locale->provider == COLLPROVIDER_ICU)
-		return true;
-	else
-		/* shouldn't happen */
-		PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-	return false;				/* keep compiler quiet */
+	/*
+	 * locale->collate->strnxfrm is still a required method, even if it may
+	 * have the wrong behavior, because the planner uses it for estimates in
+	 * some cases.
+	 */
+	return locale->collate->strxfrm_is_safe;
 }
 
 /*
@@ -1578,19 +1528,7 @@ pg_strxfrm_enabled(pg_locale_t locale)
 size_t
 pg_strxfrm(char *dest, const char *src, size_t destsize, pg_locale_t locale)
 {
-	size_t		result = 0;		/* keep compiler quiet */
-
-	if (locale->provider == COLLPROVIDER_LIBC)
-		result = strnxfrm_libc(dest, destsize, src, -1, locale);
-#ifdef USE_ICU
-	else if (locale->provider == COLLPROVIDER_ICU)
-		result = strnxfrm_icu(dest, destsize, src, -1, locale);
-#endif
-	else
-		/* shouldn't happen */
-		PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-	return result;
+	return locale->collate->strnxfrm(dest, destsize, src, -1, locale);
 }
 
 /*
@@ -1616,19 +1554,7 @@ size_t
 pg_strnxfrm(char *dest, size_t destsize, const char *src, ssize_t srclen,
 			pg_locale_t locale)
 {
-	size_t		result = 0;		/* keep compiler quiet */
-
-	if (locale->provider == COLLPROVIDER_LIBC)
-		result = strnxfrm_libc(dest, destsize, src, srclen, locale);
-#ifdef USE_ICU
-	else if (locale->provider == COLLPROVIDER_ICU)
-		result = strnxfrm_icu(dest, destsize, src, srclen, locale);
-#endif
-	else
-		/* shouldn't happen */
-		PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-	return result;
+	return locale->collate->strnxfrm(dest, destsize, src, srclen, locale);
 }
 
 /*
@@ -1638,15 +1564,7 @@ pg_strnxfrm(char *dest, size_t destsize, const char *src, ssize_t srclen,
 bool
 pg_strxfrm_prefix_enabled(pg_locale_t locale)
 {
-	if (locale->provider == COLLPROVIDER_LIBC)
-		return false;
-	else if (locale->provider == COLLPROVIDER_ICU)
-		return true;
-	else
-		/* shouldn't happen */
-		PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-	return false;				/* keep compiler quiet */
+	return (locale->collate->strnxfrm_prefix != NULL);
 }
 
 /*
@@ -1658,7 +1576,7 @@ size_t
 pg_strxfrm_prefix(char *dest, const char *src, size_t destsize,
 				  pg_locale_t locale)
 {
-	return pg_strnxfrm_prefix(dest, destsize, src, -1, locale);
+	return locale->collate->strnxfrm_prefix(dest, destsize, src, -1, locale);
 }
 
 /*
@@ -1683,16 +1601,7 @@ size_t
 pg_strnxfrm_prefix(char *dest, size_t destsize, const char *src,
 				   ssize_t srclen, pg_locale_t locale)
 {
-	size_t		result = 0;		/* keep compiler quiet */
-
-#ifdef USE_ICU
-	if (locale->provider == COLLPROVIDER_ICU)
-		result = strnxfrm_prefix_icu(dest, destsize, src, -1, locale);
-	else
-#endif
-		PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-	return result;
+	return locale->collate->strnxfrm_prefix(dest, destsize, src, srclen, locale);
 }
 
 /*
diff --git a/src/backend/utils/adt/pg_locale_icu.c b/src/backend/utils/adt/pg_locale_icu.c
index 73eb430d75..11ec9d4e4b 100644
--- a/src/backend/utils/adt/pg_locale_icu.c
+++ b/src/backend/utils/adt/pg_locale_icu.c
@@ -40,13 +40,14 @@ extern pg_locale_t create_pg_locale_icu(Oid collid, MemoryContext context);
 #ifdef USE_ICU
 
 extern UCollator *pg_ucol_open(const char *loc_str);
-extern int	strncoll_icu(const char *arg1, ssize_t len1,
+
+static int	strncoll_icu(const char *arg1, ssize_t len1,
 						 const char *arg2, ssize_t len2,
 						 pg_locale_t locale);
-extern size_t strnxfrm_icu(char *dest, size_t destsize,
+static size_t strnxfrm_icu(char *dest, size_t destsize,
 						   const char *src, ssize_t srclen,
 						   pg_locale_t locale);
-extern size_t strnxfrm_prefix_icu(char *dest, size_t destsize,
+static size_t strnxfrm_prefix_icu(char *dest, size_t destsize,
 								  const char *src, ssize_t srclen,
 								  pg_locale_t locale);
 
@@ -59,12 +60,20 @@ static UConverter *icu_converter = NULL;
 
 static UCollator *make_icu_collator(const char *iculocstr,
 									const char *icurules);
-static int	strncoll_icu_no_utf8(const char *arg1, ssize_t len1,
-								 const char *arg2, ssize_t len2,
-								 pg_locale_t locale);
-static size_t strnxfrm_prefix_icu_no_utf8(char *dest, size_t destsize,
-										  const char *src, ssize_t srclen,
-										  pg_locale_t locale);
+static int	strncoll_icu(const char *arg1, ssize_t len1,
+						 const char *arg2, ssize_t len2,
+						 pg_locale_t locale);
+static size_t strnxfrm_prefix_icu(char *dest, size_t destsize,
+								  const char *src, ssize_t srclen,
+								  pg_locale_t locale);
+#ifdef HAVE_UCOL_STRCOLLUTF8
+static int	strncoll_icu_utf8(const char *arg1, ssize_t len1,
+							  const char *arg2, ssize_t len2,
+							  pg_locale_t locale);
+#endif
+static size_t strnxfrm_prefix_icu_utf8(char *dest, size_t destsize,
+									   const char *src, ssize_t srclen,
+									   pg_locale_t locale);
 static void init_icu_converter(void);
 static size_t uchar_length(UConverter *converter,
 						   const char *str, int32_t len);
@@ -73,6 +82,25 @@ static int32_t uchar_convert(UConverter *converter,
 							 const char *src, int32_t srclen);
 static void icu_set_collation_attributes(UCollator *collator, const char *loc,
 										 UErrorCode *status);
+
+static const struct collate_methods collate_methods_icu = {
+	.strncoll = strncoll_icu,
+	.strnxfrm = strnxfrm_icu,
+	.strnxfrm_prefix = strnxfrm_prefix_icu,
+	.strxfrm_is_safe = true,
+};
+
+static const struct collate_methods collate_methods_icu_utf8 = {
+#ifdef HAVE_UCOL_STRCOLLUTF8
+	.strncoll = strncoll_icu_utf8,
+#else
+	.strncoll = strncoll_icu,
+#endif
+	.strnxfrm = strnxfrm_icu,
+	.strnxfrm_prefix = strnxfrm_prefix_icu_utf8,
+	.strxfrm_is_safe = true,
+};
+
 #endif
 
 pg_locale_t
@@ -139,6 +167,10 @@ create_pg_locale_icu(Oid collid, MemoryContext context)
 	result->deterministic = deterministic;
 	result->collate_is_c = false;
 	result->ctype_is_c = false;
+	if (GetDatabaseEncoding() == PG_UTF8)
+		result->collate = &collate_methods_icu_utf8;
+	else
+		result->collate = &collate_methods_icu;
 
 	return result;
 #else
@@ -313,42 +345,36 @@ make_icu_collator(const char *iculocstr, const char *icurules)
 }
 
 /*
- * strncoll_icu
+ * strncoll_icu_utf8
  *
  * Call ucol_strcollUTF8() or ucol_strcoll() as appropriate for the given
  * database encoding. An argument length of -1 means the string is
  * NUL-terminated.
  */
+#ifdef HAVE_UCOL_STRCOLLUTF8
 int
-strncoll_icu(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
-			 pg_locale_t locale)
+strncoll_icu_utf8(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
+				  pg_locale_t locale)
 {
 	int			result;
+	UErrorCode	status;
 
 	Assert(locale->provider == COLLPROVIDER_ICU);
 
-#ifdef HAVE_UCOL_STRCOLLUTF8
-	if (GetDatabaseEncoding() == PG_UTF8)
-	{
-		UErrorCode	status;
+	Assert(GetDatabaseEncoding() == PG_UTF8);
 
-		status = U_ZERO_ERROR;
-		result = ucol_strcollUTF8(locale->info.icu.ucol,
-								  arg1, len1,
-								  arg2, len2,
-								  &status);
-		if (U_FAILURE(status))
-			ereport(ERROR,
-					(errmsg("collation failed: %s", u_errorName(status))));
-	}
-	else
-#endif
-	{
-		result = strncoll_icu_no_utf8(arg1, len1, arg2, len2, locale);
-	}
+	status = U_ZERO_ERROR;
+	result = ucol_strcollUTF8(locale->info.icu.ucol,
+							  arg1, len1,
+							  arg2, len2,
+							  &status);
+	if (U_FAILURE(status))
+		ereport(ERROR,
+				(errmsg("collation failed: %s", u_errorName(status))));
 
 	return result;
 }
+#endif
 
 /* 'srclen' of -1 means the strings are NUL-terminated */
 size_t
@@ -399,37 +425,32 @@ strnxfrm_icu(char *dest, size_t destsize, const char *src, ssize_t srclen,
 
 /* 'srclen' of -1 means the strings are NUL-terminated */
 size_t
-strnxfrm_prefix_icu(char *dest, size_t destsize,
-					const char *src, ssize_t srclen,
-					pg_locale_t locale)
+strnxfrm_prefix_icu_utf8(char *dest, size_t destsize,
+						 const char *src, ssize_t srclen,
+						 pg_locale_t locale)
 {
 	size_t		result;
+	UCharIterator iter;
+	uint32_t	state[2];
+	UErrorCode	status;
 
 	Assert(locale->provider == COLLPROVIDER_ICU);
 
-	if (GetDatabaseEncoding() == PG_UTF8)
-	{
-		UCharIterator iter;
-		uint32_t	state[2];
-		UErrorCode	status;
+	Assert(GetDatabaseEncoding() == PG_UTF8);
 
-		uiter_setUTF8(&iter, src, srclen);
-		state[0] = state[1] = 0;	/* won't need that again */
-		status = U_ZERO_ERROR;
-		result = ucol_nextSortKeyPart(locale->info.icu.ucol,
-									  &iter,
-									  state,
-									  (uint8_t *) dest,
-									  destsize,
-									  &status);
-		if (U_FAILURE(status))
-			ereport(ERROR,
-					(errmsg("sort key generation failed: %s",
-							u_errorName(status))));
-	}
-	else
-		result = strnxfrm_prefix_icu_no_utf8(dest, destsize, src, srclen,
-											 locale);
+	uiter_setUTF8(&iter, src, srclen);
+	state[0] = state[1] = 0;	/* won't need that again */
+	status = U_ZERO_ERROR;
+	result = ucol_nextSortKeyPart(locale->info.icu.ucol,
+								  &iter,
+								  state,
+								  (uint8_t *) dest,
+								  destsize,
+								  &status);
+	if (U_FAILURE(status))
+		ereport(ERROR,
+				(errmsg("sort key generation failed: %s",
+						u_errorName(status))));
 
 	return result;
 }
@@ -504,7 +525,7 @@ icu_from_uchar(char **result, const UChar *buff_uchar, int32_t len_uchar)
 }
 
 /*
- * strncoll_icu_no_utf8
+ * strncoll_icu
  *
  * Convert the arguments from the database encoding to UChar strings, then
  * call ucol_strcoll(). An argument length of -1 means that the string is
@@ -514,8 +535,8 @@ icu_from_uchar(char **result, const UChar *buff_uchar, int32_t len_uchar)
  * caller should call that instead.
  */
 static int
-strncoll_icu_no_utf8(const char *arg1, ssize_t len1,
-					 const char *arg2, ssize_t len2, pg_locale_t locale)
+strncoll_icu(const char *arg1, ssize_t len1,
+			 const char *arg2, ssize_t len2, pg_locale_t locale)
 {
 	char		sbuf[TEXTBUFLEN];
 	char	   *buf = sbuf;
@@ -528,6 +549,8 @@ strncoll_icu_no_utf8(const char *arg1, ssize_t len1,
 	int			result;
 
 	Assert(locale->provider == COLLPROVIDER_ICU);
+
+	/* if encoding is UTF8, use more efficient strncoll_icu_utf8 */
 #ifdef HAVE_UCOL_STRCOLLUTF8
 	Assert(GetDatabaseEncoding() != PG_UTF8);
 #endif
@@ -561,9 +584,9 @@ strncoll_icu_no_utf8(const char *arg1, ssize_t len1,
 
 /* 'srclen' of -1 means the strings are NUL-terminated */
 static size_t
-strnxfrm_prefix_icu_no_utf8(char *dest, size_t destsize,
-							const char *src, ssize_t srclen,
-							pg_locale_t locale)
+strnxfrm_prefix_icu(char *dest, size_t destsize,
+					const char *src, ssize_t srclen,
+					pg_locale_t locale)
 {
 	char		sbuf[TEXTBUFLEN];
 	char	   *buf = sbuf;
@@ -576,6 +599,8 @@ strnxfrm_prefix_icu_no_utf8(char *dest, size_t destsize,
 	Size		result_bsize;
 
 	Assert(locale->provider == COLLPROVIDER_ICU);
+
+	/* if encoding is UTF8, use more efficient strnxfrm_prefix_icu_utf8 */
 	Assert(GetDatabaseEncoding() != PG_UTF8);
 
 	init_icu_converter();
diff --git a/src/backend/utils/adt/pg_locale_libc.c b/src/backend/utils/adt/pg_locale_libc.c
index 374ac37ba0..c7be6dd4f9 100644
--- a/src/backend/utils/adt/pg_locale_libc.c
+++ b/src/backend/utils/adt/pg_locale_libc.c
@@ -32,10 +32,10 @@
 
 extern pg_locale_t create_pg_locale_libc(Oid collid, MemoryContext context);
 
-extern int	strncoll_libc(const char *arg1, ssize_t len1,
+static int	strncoll_libc(const char *arg1, ssize_t len1,
 						  const char *arg2, ssize_t len2,
 						  pg_locale_t locale);
-extern size_t strnxfrm_libc(char *dest, size_t destsize,
+static size_t strnxfrm_libc(char *dest, size_t destsize,
 							const char *src, ssize_t srclen,
 							pg_locale_t locale);
 static locale_t make_libc_collator(const char *collate,
@@ -48,6 +48,27 @@ static int	strncoll_libc_win32_utf8(const char *arg1, ssize_t len1,
 									 pg_locale_t locale);
 #endif
 
+static const struct collate_methods collate_methods_libc = {
+	.strncoll = strncoll_libc,
+	.strnxfrm = strnxfrm_libc,
+	.strnxfrm_prefix = NULL,
+
+	/*
+	 * Unfortunately, it seems that strxfrm() for non-C collations is broken
+	 * on many common platforms; testing of multiple versions of glibc reveals
+	 * that, for many locales, strcoll() and strxfrm() do not return
+	 * consistent results. While no other libc other than Cygwin has so far
+	 * been shown to have a problem, we take the conservative course of action
+	 * for right now and disable this categorically.  (Users who are certain
+	 * this isn't a problem on their system can define TRUST_STRXFRM.)
+	 */
+#ifdef TRUST_STRXFRM
+	.strxfrm_is_safe = true,
+#else
+	.strxfrm_is_safe = false,
+#endif
+};
+
 pg_locale_t
 create_pg_locale_libc(Oid collid, MemoryContext context)
 {
@@ -103,6 +124,15 @@ create_pg_locale_libc(Oid collid, MemoryContext context)
 	result->ctype_is_c = (strcmp(ctype, "C") == 0) ||
 		(strcmp(ctype, "POSIX") == 0);
 	result->info.lt = loc;
+	if (!result->collate_is_c)
+	{
+#ifdef WIN32
+		if (GetDatabaseEncoding() == PG_UTF8)
+			result->collate = &collate_methods_libc_win32_utf8;
+		else
+#endif
+			result->collate = &collate_methods_libc;
+	}
 
 	return result;
 }
@@ -200,12 +230,6 @@ strncoll_libc(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
 
 	Assert(locale->provider == COLLPROVIDER_LIBC);
 
-#ifdef WIN32
-	/* check for this case before doing the work for nul-termination */
-	if (GetDatabaseEncoding() == PG_UTF8)
-		return strncoll_libc_win32_utf8(arg1, len1, arg2, len2, locale);
-#endif							/* WIN32 */
-
 	if (bufsize1 + bufsize2 > TEXTBUFLEN)
 		buf = palloc(bufsize1 + bufsize2);
 
diff --git a/src/include/utils/pg_locale.h b/src/include/utils/pg_locale.h
index 37ecf95193..2f05dffcdd 100644
--- a/src/include/utils/pg_locale.h
+++ b/src/include/utils/pg_locale.h
@@ -60,6 +60,36 @@ extern struct lconv *PGLC_localeconv(void);
 extern void cache_locale_time(void);
 
 
+struct pg_locale_struct;
+typedef struct pg_locale_struct *pg_locale_t;
+
+/* methods that define collation behavior */
+struct collate_methods
+{
+	/* required */
+	int			(*strncoll) (const char *arg1, ssize_t len1,
+							 const char *arg2, ssize_t len2,
+							 pg_locale_t locale);
+
+	/* required */
+	size_t		(*strnxfrm) (char *dest, size_t destsize,
+							 const char *src, ssize_t srclen,
+							 pg_locale_t locale);
+
+	/* optional */
+	size_t		(*strnxfrm_prefix) (char *dest, size_t destsize,
+									const char *src, ssize_t srclen,
+									pg_locale_t locale);
+
+	/*
+	 * If the strnxfrm method is not trusted to return the correct results,
+	 * set strxfrm_is_safe to false. It set to false, the method will not be
+	 * used in most cases, but the planner still expects it to be there for
+	 * estimation purposes (where incorrect results are acceptable).
+	 */
+	bool		strxfrm_is_safe;
+};
+
 /*
  * We use a discriminated union to hold either a locale_t or an ICU collator.
  * pg_locale_t is occasionally checked for truth, so make it a pointer.
@@ -82,6 +112,9 @@ struct pg_locale_struct
 	bool		deterministic;
 	bool		collate_is_c;
 	bool		ctype_is_c;
+
+	const struct collate_methods *collate;	/* NULL if collate_is_c */
+
 	union
 	{
 		struct
-- 
2.34.1

