From d275b329aaed6c6ef2403e4c313725a1ae88fa40 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Fri, 4 Mar 2022 08:48:47 -0800
Subject: [PATCH v3 9/9] contrib/oauth: switch to pluggable auth API

Move the core server implementation to contrib/oauth as a pluggable
provider, using the RegisterAuthProvider() API. oauth_validator_command
has been moved from core to a custom GUC. Tests have been updated to
handle the new implementations.

Some server modifications remain:
- Adding new HBA options for custom providers
- Registering support for usermaps
- Allowing custom SASL mechanisms to declare their own maximum message
  length

This patch is optional; you can apply/revert it to compare the two
approaches.
---
 contrib/oauth/Makefile                        | 16 +++++++
 .../auth-oauth.c => contrib/oauth/oauth.c     | 47 +++++++++++++++----
 src/backend/libpq/Makefile                    |  1 -
 src/backend/libpq/auth.c                      |  7 ---
 src/backend/libpq/hba.c                       | 21 +++++----
 src/backend/utils/misc/guc.c                  | 12 -----
 src/include/libpq/hba.h                       |  3 +-
 src/include/libpq/oauth.h                     | 24 ----------
 src/test/python/server/test_oauth.py          | 20 ++++----
 9 files changed, 78 insertions(+), 73 deletions(-)
 create mode 100644 contrib/oauth/Makefile
 rename src/backend/libpq/auth-oauth.c => contrib/oauth/oauth.c (95%)
 delete mode 100644 src/include/libpq/oauth.h

diff --git a/contrib/oauth/Makefile b/contrib/oauth/Makefile
new file mode 100644
index 0000000000..880bc1fef3
--- /dev/null
+++ b/contrib/oauth/Makefile
@@ -0,0 +1,16 @@
+# contrib/oauth/Makefile
+
+MODULE_big = oauth
+OBJS = oauth.o
+PGFILEDESC = "oauth - auth provider supporting OAuth 2.0/OIDC"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = contrib/oauth
+top_builddir = ../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/src/backend/libpq/auth-oauth.c b/contrib/oauth/oauth.c
similarity index 95%
rename from src/backend/libpq/auth-oauth.c
rename to contrib/oauth/oauth.c
index c1232a31a0..3a6dab19d9 100644
--- a/src/backend/libpq/auth-oauth.c
+++ b/contrib/oauth/oauth.c
@@ -1,33 +1,39 @@
-/*-------------------------------------------------------------------------
+/* -------------------------------------------------------------------------
  *
- * auth-oauth.c
+ * oauth.c
  *	  Server-side implementation of the SASL OAUTHBEARER mechanism.
  *
  * See the following RFC for more details:
  * - RFC 7628: https://tools.ietf.org/html/rfc7628
  *
- * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1996-2022, PostgreSQL Global Development Group
  * Portions Copyright (c) 1994, Regents of the University of California
  *
- * src/backend/libpq/auth-oauth.c
+ * contrib/oauth/oauth.c
  *
- *-------------------------------------------------------------------------
+ * -------------------------------------------------------------------------
  */
+
 #include "postgres.h"
 
 #include <unistd.h>
 #include <fcntl.h>
 
 #include "common/oauth-common.h"
+#include "fmgr.h"
 #include "lib/stringinfo.h"
 #include "libpq/auth.h"
 #include "libpq/hba.h"
-#include "libpq/oauth.h"
 #include "libpq/sasl.h"
 #include "storage/fd.h"
+#include "utils/guc.h"
+
+PG_MODULE_MAGIC;
+
+void _PG_init(void);
 
 /* GUC */
-char *oauth_validator_command;
+static char *oauth_validator_command;
 
 static void  oauth_get_mechanisms(Port *port, StringInfo buf);
 static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
@@ -35,7 +41,7 @@ static int   oauth_exchange(void *opaq, const char *input, int inputlen,
 							char **output, int *outputlen, const char **logdetail);
 
 /* Mechanism declaration */
-const pg_be_sasl_mech pg_be_oauth_mech = {
+static const pg_be_sasl_mech oauth_mech = {
 	oauth_get_mechanisms,
 	oauth_init,
 	oauth_exchange,
@@ -795,3 +801,28 @@ username_ok_for_shell(const char *username)
 
 	return true;
 }
+
+static int CheckOAuth(Port *port)
+{
+	return CheckSASLAuth(&oauth_mech, port, NULL, NULL);
+}
+
+static const char *OAuthError(Port *port)
+{
+	return psprintf("OAuth bearer authentication failed for user \"%s\"",
+					port->user_name);
+}
+
+void
+_PG_init(void)
+{
+	RegisterAuthProvider("oauth", CheckOAuth, OAuthError);
+
+	DefineCustomStringVariable("oauth.validator_command",
+							   gettext_noop("Command to validate OAuth v2 bearer tokens."),
+							   NULL,
+							   &oauth_validator_command,
+							   "",
+							   PGC_SIGHUP, GUC_SUPERUSER_ONLY,
+							   NULL, NULL, NULL);
+}
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 98eb2a8242..6d385fd6a4 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,7 +15,6 @@ include $(top_builddir)/src/Makefile.global
 # be-fsstubs is here for historical reasons, probably belongs elsewhere
 
 OBJS = \
-	auth-oauth.o \
 	auth-sasl.o \
 	auth-scram.o \
 	auth.o \
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 5c30904e2b..3533b0bc50 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -30,7 +30,6 @@
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
 #include "libpq/libpq.h"
-#include "libpq/oauth.h"
 #include "libpq/pqformat.h"
 #include "libpq/sasl.h"
 #include "libpq/scram.h"
@@ -303,9 +302,6 @@ auth_failed(Port *port, int status, const char *logdetail)
 		case uaRADIUS:
 			errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
 			break;
-		case uaOAuth:
-			errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
-			break;
 		case uaCustom:
 			if (CustomAuthenticationError_hook)
 				errstr = CustomAuthenticationError_hook(port);
@@ -631,9 +627,6 @@ ClientAuthentication(Port *port)
 		case uaTrust:
 			status = STATUS_OK;
 			break;
-		case uaOAuth:
-			status = CheckSASLAuth(&pg_be_oauth_mech, port, NULL, NULL);
-			break;
 		case uaCustom:
 			if (CustomAuthenticationCheck_hook)
 				status = CustomAuthenticationCheck_hook(port);
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index f7f3059927..fb51c53cc0 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -136,7 +136,6 @@ static const char *const UserAuthName[] =
 	"radius",
 	"custom",
 	"peer",
-	"oauth",
 };
 
 
@@ -1401,8 +1400,6 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
 #endif
 	else if (strcmp(token->string, "radius") == 0)
 		parsedline->auth_method = uaRADIUS;
-	else if (strcmp(token->string, "oauth") == 0)
-		parsedline->auth_method = uaOAuth;
 	else if (strcmp(token->string, "custom") == 0)
 		parsedline->auth_method = uaCustom;
 	else
@@ -1731,9 +1728,9 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 			hbaline->auth_method != uaPeer &&
 			hbaline->auth_method != uaGSS &&
 			hbaline->auth_method != uaSSPI &&
-			hbaline->auth_method != uaOAuth &&
-			hbaline->auth_method != uaCert)
-			INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, oauth, and cert"));
+			hbaline->auth_method != uaCert &&
+			hbaline->auth_method != uaCustom)
+			INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, and custom"));
 		hbaline->usermap = pstrdup(val);
 	}
 	else if (strcmp(name, "clientcert") == 0)
@@ -2119,19 +2116,25 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 	}
 	else if (strcmp(name, "issuer") == 0)
 	{
-		if (hbaline->auth_method != uaOAuth)
+		if (hbaline->auth_method != uaCustom
+			&& (custom_provider_name != NULL
+				&& strcmp(custom_provider_name, "oauth")))
 			INVALID_AUTH_OPTION("issuer", gettext_noop("oauth"));
 		hbaline->oauth_issuer = pstrdup(val);
 	}
 	else if (strcmp(name, "scope") == 0)
 	{
-		if (hbaline->auth_method != uaOAuth)
+		if (hbaline->auth_method != uaCustom
+			&& (custom_provider_name != NULL
+				&& strcmp(custom_provider_name, "oauth")))
 			INVALID_AUTH_OPTION("scope", gettext_noop("oauth"));
 		hbaline->oauth_scope = pstrdup(val);
 	}
 	else if (strcmp(name, "trust_validator_authz") == 0)
 	{
-		if (hbaline->auth_method != uaOAuth)
+		if (hbaline->auth_method != uaCustom
+			&& (custom_provider_name != NULL
+				&& strcmp(custom_provider_name, "oauth")))
 			INVALID_AUTH_OPTION("trust_validator_authz", gettext_noop("oauth"));
 		if (strcmp(val, "1") == 0)
 			hbaline->oauth_skip_usermap = true;
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index 791c7c83df..1e3650184b 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -58,7 +58,6 @@
 #include "libpq/auth.h"
 #include "libpq/libpq.h"
 #include "libpq/pqformat.h"
-#include "libpq/oauth.h"
 #include "miscadmin.h"
 #include "optimizer/cost.h"
 #include "optimizer/geqo.h"
@@ -4663,17 +4662,6 @@ static struct config_string ConfigureNamesString[] =
 		check_backtrace_functions, assign_backtrace_functions, NULL
 	},
 
-	{
-		{"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH,
-			gettext_noop("Command to validate OAuth v2 bearer tokens."),
-			NULL,
-			GUC_SUPERUSER_ONLY
-		},
-		&oauth_validator_command,
-		"",
-		NULL, NULL, NULL
-	},
-
 	/* End-of-list marker */
 	{
 		{NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index d46c2108eb..0c6a7dd823 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -40,8 +40,7 @@ typedef enum UserAuth
 	uaRADIUS,
 	uaCustom,
 	uaPeer,
-	uaOAuth
-#define USER_AUTH_LAST uaOAuth	/* Must be last value of this enum */
+#define USER_AUTH_LAST uaPeer	/* Must be last value of this enum */
 } UserAuth;
 
 /*
diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h
deleted file mode 100644
index 870e426af1..0000000000
--- a/src/include/libpq/oauth.h
+++ /dev/null
@@ -1,24 +0,0 @@
-/*-------------------------------------------------------------------------
- *
- * oauth.h
- *	  Interface to libpq/auth-oauth.c
- *
- * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
- * Portions Copyright (c) 1994, Regents of the University of California
- *
- * src/include/libpq/oauth.h
- *
- *-------------------------------------------------------------------------
- */
-#ifndef PG_OAUTH_H
-#define PG_OAUTH_H
-
-#include "libpq/libpq-be.h"
-#include "libpq/sasl.h"
-
-extern char *oauth_validator_command;
-
-/* Implementation */
-extern const pg_be_sasl_mech pg_be_oauth_mech;
-
-#endif /* PG_OAUTH_H */
diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py
index cb5ca7fa23..07fc25edc2 100644
--- a/src/test/python/server/test_oauth.py
+++ b/src/test/python/server/test_oauth.py
@@ -103,9 +103,9 @@ def oauth_ctx():
 
     ctx = Context()
     hba_lines = (
-        f'host {ctx.dbname} {ctx.map_user}   samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
-        f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
-        f'host {ctx.dbname} all              samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
+        f'host {ctx.dbname} {ctx.map_user}   samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
+        f'host {ctx.dbname} {ctx.authz_user} samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
+        f'host {ctx.dbname} all              samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
     )
     ident_lines = (r"oauth /^(.*)@example\.com$ \1",)
 
@@ -126,12 +126,12 @@ def oauth_ctx():
         c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user))
         c.execute(sql.SQL("CREATE DATABASE {};").format(dbname))
 
-        # Make this test script the server's oauth_validator.
+        # Make this test script the server's oauth validator.
         path = pathlib.Path(__file__).parent / "validate_bearer.py"
         path = str(path.absolute())
 
         cmd = f"{shlex.quote(path)} {SHARED_MEM_NAME} <&%f"
-        c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+        c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (cmd,))
 
         # Replace pg_hba and pg_ident.
         c.execute("SHOW hba_file;")
@@ -149,7 +149,7 @@ def oauth_ctx():
         # Put things back the way they were.
         c.execute("SELECT pg_reload_conf();")
 
-        c.execute("ALTER SYSTEM RESET oauth_validator_command;")
+        c.execute("ALTER SYSTEM RESET oauth.validator_command;")
         c.execute(sql.SQL("DROP DATABASE {};").format(dbname))
         c.execute(sql.SQL("DROP ROLE {};").format(authz_user))
         c.execute(sql.SQL("DROP ROLE {};").format(map_user))
@@ -930,7 +930,7 @@ def test_oauth_empty_initial_response(conn, oauth_ctx, bearer_token):
 def set_validator():
     """
     A per-test fixture that allows a test to override the setting of
-    oauth_validator_command for the cluster. The setting will be reverted during
+    oauth.validator_command for the cluster. The setting will be reverted during
     teardown.
 
     Passing None will perform an ALTER SYSTEM RESET.
@@ -942,17 +942,17 @@ def set_validator():
         c = conn.cursor()
 
         # Save the previous value.
-        c.execute("SHOW oauth_validator_command;")
+        c.execute("SHOW oauth.validator_command;")
         prev_cmd = c.fetchone()[0]
 
         def setter(cmd):
-            c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+            c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (cmd,))
             c.execute("SELECT pg_reload_conf();")
 
         yield setter
 
         # Restore the previous value.
-        c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (prev_cmd,))
+        c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (prev_cmd,))
         c.execute("SELECT pg_reload_conf();")
 
 
-- 
2.25.1

