From f520c08e1aee5239051e304c8a8faf5cb25bdbf2 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Fri, 25 Mar 2022 16:35:30 -0700
Subject: [PATCH v4 10/10] contrib/oauth: switch to pluggable auth API

Move the core server implementation to contrib/oauth as a pluggable
provider, using the RegisterAuthProvider() API. oauth_validator_command
has been moved from core to a custom GUC. HBA options are handled
using the new hook. Tests have been updated to handle the new
implementations.

One server modification remains: allowing custom SASL mechanisms to
declare their own maximum message length.

This patch is optional; you can apply/revert it to compare the two
approaches.
---
 contrib/oauth/Makefile                        | 16 ++++
 .../auth-oauth.c => contrib/oauth/oauth.c     | 88 ++++++++++++++++---
 src/backend/libpq/Makefile                    |  1 -
 src/backend/libpq/auth.c                      |  7 --
 src/backend/libpq/hba.c                       | 27 +-----
 src/backend/utils/misc/guc.c                  | 12 ---
 src/include/libpq/hba.h                       |  6 +-
 src/include/libpq/oauth.h                     | 24 -----
 src/test/python/README                        |  3 +-
 src/test/python/server/test_oauth.py          | 20 ++---
 10 files changed, 104 insertions(+), 100 deletions(-)
 create mode 100644 contrib/oauth/Makefile
 rename src/backend/libpq/auth-oauth.c => contrib/oauth/oauth.c (90%)
 delete mode 100644 src/include/libpq/oauth.h

diff --git a/contrib/oauth/Makefile b/contrib/oauth/Makefile
new file mode 100644
index 0000000000..880bc1fef3
--- /dev/null
+++ b/contrib/oauth/Makefile
@@ -0,0 +1,16 @@
+# contrib/oauth/Makefile
+
+MODULE_big = oauth
+OBJS = oauth.o
+PGFILEDESC = "oauth - auth provider supporting OAuth 2.0/OIDC"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = contrib/oauth
+top_builddir = ../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/src/backend/libpq/auth-oauth.c b/contrib/oauth/oauth.c
similarity index 90%
rename from src/backend/libpq/auth-oauth.c
rename to contrib/oauth/oauth.c
index c1232a31a0..e83f3c5d99 100644
--- a/src/backend/libpq/auth-oauth.c
+++ b/contrib/oauth/oauth.c
@@ -1,33 +1,39 @@
-/*-------------------------------------------------------------------------
+/* -------------------------------------------------------------------------
  *
- * auth-oauth.c
+ * oauth.c
  *	  Server-side implementation of the SASL OAUTHBEARER mechanism.
  *
  * See the following RFC for more details:
  * - RFC 7628: https://tools.ietf.org/html/rfc7628
  *
- * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1996-2022, PostgreSQL Global Development Group
  * Portions Copyright (c) 1994, Regents of the University of California
  *
- * src/backend/libpq/auth-oauth.c
+ * contrib/oauth/oauth.c
  *
- *-------------------------------------------------------------------------
+ * -------------------------------------------------------------------------
  */
+
 #include "postgres.h"
 
 #include <unistd.h>
 #include <fcntl.h>
 
 #include "common/oauth-common.h"
+#include "fmgr.h"
 #include "lib/stringinfo.h"
 #include "libpq/auth.h"
 #include "libpq/hba.h"
-#include "libpq/oauth.h"
 #include "libpq/sasl.h"
 #include "storage/fd.h"
+#include "utils/guc.h"
+
+PG_MODULE_MAGIC;
+
+void _PG_init(void);
 
 /* GUC */
-char *oauth_validator_command;
+static char *oauth_validator_command;
 
 static void  oauth_get_mechanisms(Port *port, StringInfo buf);
 static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
@@ -35,7 +41,7 @@ static int   oauth_exchange(void *opaq, const char *input, int inputlen,
 							char **output, int *outputlen, const char **logdetail);
 
 /* Mechanism declaration */
-const pg_be_sasl_mech pg_be_oauth_mech = {
+static const pg_be_sasl_mech oauth_mech = {
 	oauth_get_mechanisms,
 	oauth_init,
 	oauth_exchange,
@@ -57,12 +63,13 @@ struct oauth_ctx
 	Port	   *port;
 	const char *issuer;
 	const char *scope;
+	bool		skip_usermap;
 };
 
 static char *sanitize_char(char c);
 static char *parse_kvpairs_for_auth(char **input);
 static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
-static bool validate(Port *port, const char *auth, const char **logdetail);
+static bool validate(struct oauth_ctx *ctx, const char *auth, const char **logdetail);
 static bool run_validator_command(Port *port, const char *token);
 static bool check_exit(FILE **fh, const char *command);
 static bool unset_cloexec(int fd);
@@ -84,6 +91,7 @@ static void *
 oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
 {
 	struct oauth_ctx *ctx;
+	ListCell	   *lc;
 
 	if (strcmp(selected_mech, OAUTHBEARER_NAME))
 		ereport(ERROR,
@@ -96,8 +104,21 @@ oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
 	ctx->port = port;
 
 	Assert(port->hba);
-	ctx->issuer = port->hba->oauth_issuer;
-	ctx->scope = port->hba->oauth_scope;
+
+	foreach (lc, port->hba->custom_auth_options)
+	{
+		CustomOption *option = lfirst(lc);
+
+		if (strcmp(option->name, "issuer") == 0)
+			ctx->issuer = option->value;
+		else if (strcmp(option->name, "scope") == 0)
+			ctx->scope = option->value;
+		else if (strcmp(option->name, "trust_validator_authz") == 0)
+		{
+			if (strcmp(option->value, "1") == 0)
+				ctx->skip_usermap = true;
+		}
+	}
 
 	return ctx;
 }
@@ -248,7 +269,7 @@ oauth_exchange(void *opaq, const char *input, int inputlen,
 				 errmsg("malformed OAUTHBEARER message"),
 				 errdetail("Message contains additional data after the final terminator.")));
 
-	if (!validate(ctx->port, auth, logdetail))
+	if (!validate(ctx, auth, logdetail))
 	{
 		generate_error_response(ctx, output, outputlen);
 
@@ -415,12 +436,13 @@ generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
 }
 
 static bool
-validate(Port *port, const char *auth, const char **logdetail)
+validate(struct oauth_ctx *ctx, const char *auth, const char **logdetail)
 {
 	static const char * const b64_set = "abcdefghijklmnopqrstuvwxyz"
 										"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
 										"0123456789-._~+/";
 
+	Port	   *port = ctx->port;
 	const char *token;
 	size_t		span;
 	int			ret;
@@ -497,7 +519,7 @@ validate(Port *port, const char *auth, const char **logdetail)
 	if (!run_validator_command(port, token))
 		return false;
 
-	if (port->hba->oauth_skip_usermap)
+	if (ctx->skip_usermap)
 	{
 		/*
 		 * If the validator is our authorization authority, we're done.
@@ -795,3 +817,41 @@ username_ok_for_shell(const char *username)
 
 	return true;
 }
+
+static int CheckOAuth(Port *port)
+{
+	return CheckSASLAuth(&oauth_mech, port, NULL, NULL);
+}
+
+static const char *OAuthError(Port *port)
+{
+	return psprintf("OAuth bearer authentication failed for user \"%s\"",
+					port->user_name);
+}
+
+static bool OAuthCheckOption(char *name, char *val,
+							 struct HbaLine *hbaline, char **errmsg)
+{
+	if (!strcmp(name, "issuer"))
+		return true;
+	if (!strcmp(name, "scope"))
+		return true;
+	if (!strcmp(name, "trust_validator_authz"))
+		return true;
+
+	return false;
+}
+
+void
+_PG_init(void)
+{
+	RegisterAuthProvider("oauth", CheckOAuth, OAuthError, OAuthCheckOption);
+
+	DefineCustomStringVariable("oauth.validator_command",
+							   gettext_noop("Command to validate OAuth v2 bearer tokens."),
+							   NULL,
+							   &oauth_validator_command,
+							   "",
+							   PGC_SIGHUP, GUC_SUPERUSER_ONLY,
+							   NULL, NULL, NULL);
+}
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 98eb2a8242..6d385fd6a4 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,7 +15,6 @@ include $(top_builddir)/src/Makefile.global
 # be-fsstubs is here for historical reasons, probably belongs elsewhere
 
 OBJS = \
-	auth-oauth.o \
 	auth-sasl.o \
 	auth-scram.o \
 	auth.o \
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 17042d84ad..4a8a63922a 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -30,7 +30,6 @@
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
 #include "libpq/libpq.h"
-#include "libpq/oauth.h"
 #include "libpq/pqformat.h"
 #include "libpq/sasl.h"
 #include "libpq/scram.h"
@@ -299,9 +298,6 @@ auth_failed(Port *port, int status, const char *logdetail)
 		case uaRADIUS:
 			errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
 			break;
-		case uaOAuth:
-			errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
-			break;
 		case uaCustom:
 			{
 				CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
@@ -630,9 +626,6 @@ ClientAuthentication(Port *port)
 		case uaTrust:
 			status = STATUS_OK;
 			break;
-		case uaOAuth:
-			status = CheckSASLAuth(&pg_be_oauth_mech, port, NULL, NULL);
-			break;
 		case uaCustom:
 			{
 				CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index cd3b1cc140..6bf986d5b3 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -137,7 +137,6 @@ static const char *const UserAuthName[] =
 	"radius",
 	"custom",
 	"peer",
-	"oauth",
 };
 
 
@@ -1402,8 +1401,6 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
 #endif
 	else if (strcmp(token->string, "radius") == 0)
 		parsedline->auth_method = uaRADIUS;
-	else if (strcmp(token->string, "oauth") == 0)
-		parsedline->auth_method = uaOAuth;
 	else if (strcmp(token->string, "custom") == 0)
 		parsedline->auth_method = uaCustom;
 	else
@@ -1733,9 +1730,8 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 			hbaline->auth_method != uaGSS &&
 			hbaline->auth_method != uaSSPI &&
 			hbaline->auth_method != uaCert &&
-			hbaline->auth_method != uaOAuth &&
 			hbaline->auth_method != uaCustom)
-			INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, oauth, and custom"));
+			INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, and custom"));
 		hbaline->usermap = pstrdup(val);
 	}
 	else if (strcmp(name, "clientcert") == 0)
@@ -2119,27 +2115,6 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 		hbaline->radiusidentifiers = parsed_identifiers;
 		hbaline->radiusidentifiers_s = pstrdup(val);
 	}
-	else if (strcmp(name, "issuer") == 0)
-	{
-		if (hbaline->auth_method != uaOAuth)
-			INVALID_AUTH_OPTION("issuer", gettext_noop("oauth"));
-		hbaline->oauth_issuer = pstrdup(val);
-	}
-	else if (strcmp(name, "scope") == 0)
-	{
-		if (hbaline->auth_method != uaOAuth)
-			INVALID_AUTH_OPTION("scope", gettext_noop("oauth"));
-		hbaline->oauth_scope = pstrdup(val);
-	}
-	else if (strcmp(name, "trust_validator_authz") == 0)
-	{
-		if (hbaline->auth_method != uaOAuth)
-			INVALID_AUTH_OPTION("trust_validator_authz", gettext_noop("oauth"));
-		if (strcmp(val, "1") == 0)
-			hbaline->oauth_skip_usermap = true;
-		else
-			hbaline->oauth_skip_usermap = false;
-	}
 	else if (strcmp(name, "provider") == 0)
 	{
 		REQUIRE_AUTH_OPTION(uaCustom, "provider", "custom");
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index 9a5b2aa496..f70f7f5c01 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -59,7 +59,6 @@
 #include "libpq/auth.h"
 #include "libpq/libpq.h"
 #include "libpq/pqformat.h"
-#include "libpq/oauth.h"
 #include "miscadmin.h"
 #include "optimizer/cost.h"
 #include "optimizer/geqo.h"
@@ -4667,17 +4666,6 @@ static struct config_string ConfigureNamesString[] =
 		check_backtrace_functions, assign_backtrace_functions, NULL
 	},
 
-	{
-		{"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH,
-			gettext_noop("Command to validate OAuth v2 bearer tokens."),
-			NULL,
-			GUC_SUPERUSER_ONLY
-		},
-		&oauth_validator_command,
-		"",
-		NULL, NULL, NULL
-	},
-
 	/* End-of-list marker */
 	{
 		{NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index e405103a2e..bbc94363cb 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -40,8 +40,7 @@ typedef enum UserAuth
 	uaRADIUS,
 	uaCustom,
 	uaPeer,
-	uaOAuth
-#define USER_AUTH_LAST uaOAuth	/* Must be last value of this enum */
+#define USER_AUTH_LAST uaPeer	/* Must be last value of this enum */
 } UserAuth;
 
 /*
@@ -129,9 +128,6 @@ typedef struct HbaLine
 	char	   *radiusidentifiers_s;
 	List	   *radiusports;
 	char	   *radiusports_s;
-	char	   *oauth_issuer;
-	char	   *oauth_scope;
-	bool		oauth_skip_usermap;
 	char	   *custom_provider;
 	List	   *custom_auth_options;
 } HbaLine;
diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h
deleted file mode 100644
index 870e426af1..0000000000
--- a/src/include/libpq/oauth.h
+++ /dev/null
@@ -1,24 +0,0 @@
-/*-------------------------------------------------------------------------
- *
- * oauth.h
- *	  Interface to libpq/auth-oauth.c
- *
- * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
- * Portions Copyright (c) 1994, Regents of the University of California
- *
- * src/include/libpq/oauth.h
- *
- *-------------------------------------------------------------------------
- */
-#ifndef PG_OAUTH_H
-#define PG_OAUTH_H
-
-#include "libpq/libpq-be.h"
-#include "libpq/sasl.h"
-
-extern char *oauth_validator_command;
-
-/* Implementation */
-extern const pg_be_sasl_mech pg_be_oauth_mech;
-
-#endif /* PG_OAUTH_H */
diff --git a/src/test/python/README b/src/test/python/README
index 0bda582c4b..0fbc1046cf 100644
--- a/src/test/python/README
+++ b/src/test/python/README
@@ -13,7 +13,8 @@ but you can adjust as needed for your setup.
 
 ## Requirements
 
-A supported version (3.6+) of Python.
+- A supported version (3.6+) of Python.
+- The oauth extension must be installed and loaded via shared_preload_libraries.
 
 The first run of
 
diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py
index cb5ca7fa23..07fc25edc2 100644
--- a/src/test/python/server/test_oauth.py
+++ b/src/test/python/server/test_oauth.py
@@ -103,9 +103,9 @@ def oauth_ctx():
 
     ctx = Context()
     hba_lines = (
-        f'host {ctx.dbname} {ctx.map_user}   samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
-        f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
-        f'host {ctx.dbname} all              samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
+        f'host {ctx.dbname} {ctx.map_user}   samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
+        f'host {ctx.dbname} {ctx.authz_user} samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
+        f'host {ctx.dbname} all              samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
     )
     ident_lines = (r"oauth /^(.*)@example\.com$ \1",)
 
@@ -126,12 +126,12 @@ def oauth_ctx():
         c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user))
         c.execute(sql.SQL("CREATE DATABASE {};").format(dbname))
 
-        # Make this test script the server's oauth_validator.
+        # Make this test script the server's oauth validator.
         path = pathlib.Path(__file__).parent / "validate_bearer.py"
         path = str(path.absolute())
 
         cmd = f"{shlex.quote(path)} {SHARED_MEM_NAME} <&%f"
-        c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+        c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (cmd,))
 
         # Replace pg_hba and pg_ident.
         c.execute("SHOW hba_file;")
@@ -149,7 +149,7 @@ def oauth_ctx():
         # Put things back the way they were.
         c.execute("SELECT pg_reload_conf();")
 
-        c.execute("ALTER SYSTEM RESET oauth_validator_command;")
+        c.execute("ALTER SYSTEM RESET oauth.validator_command;")
         c.execute(sql.SQL("DROP DATABASE {};").format(dbname))
         c.execute(sql.SQL("DROP ROLE {};").format(authz_user))
         c.execute(sql.SQL("DROP ROLE {};").format(map_user))
@@ -930,7 +930,7 @@ def test_oauth_empty_initial_response(conn, oauth_ctx, bearer_token):
 def set_validator():
     """
     A per-test fixture that allows a test to override the setting of
-    oauth_validator_command for the cluster. The setting will be reverted during
+    oauth.validator_command for the cluster. The setting will be reverted during
     teardown.
 
     Passing None will perform an ALTER SYSTEM RESET.
@@ -942,17 +942,17 @@ def set_validator():
         c = conn.cursor()
 
         # Save the previous value.
-        c.execute("SHOW oauth_validator_command;")
+        c.execute("SHOW oauth.validator_command;")
         prev_cmd = c.fetchone()[0]
 
         def setter(cmd):
-            c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+            c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (cmd,))
             c.execute("SELECT pg_reload_conf();")
 
         yield setter
 
         # Restore the previous value.
-        c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (prev_cmd,))
+        c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (prev_cmd,))
         c.execute("SELECT pg_reload_conf();")
 
 
-- 
2.25.1

