[PoC] Federated Authn/z with OAUTHBEARER
Hi all,
We've been working on ways to expand the list of third-party auth
methods that Postgres provides. Some example use cases might be "I want
to let anyone with a Google account read this table" or "let anyone who
belongs to this GitHub organization connect as a superuser".
Attached is a proof of concept that implements pieces of OAuth 2.0
federated authorization, via the OAUTHBEARER SASL mechanism from RFC
7628 [1]https://datatracker.ietf.org/doc/html/rfc7628. Currently, only Linux is supported due to some ugly hacks in
the backend.
The architecture can support the following use cases, as long as your
OAuth issuer of choice implements the necessary specs, and you know how
to write a validator for your issuer's bearer tokens:
- Authentication only, where an external validator uses the bearer
token to determine the end user's identity, and Postgres decides
whether that user ID is authorized to connect via the standard pg_ident
user mapping.
- Authorization only, where the validator uses the bearer token to
determine the allowed roles for the end user, and then checks to make
sure that the connection's role is one of those. This bypasses pg_ident
and allows pseudonymous connections, where Postgres doesn't care who
you are as long as the token proves you're allowed to assume the role
you want.
- A combination, where the validator provides both an authn_id (for
later audits of database access) and an authorization decision based on
the bearer token and role provided.
It looks kinda like this during use:
$ psql 'host=example.org oauth_client_id=f02c6361-0635-...'
Visit https://oauth.example.org/login and enter the code: FPQ2-M4BG
= Quickstart =
For anyone who likes building and seeing green tests ASAP.
Prerequisite software:
- iddawc v0.9.9 [2]https://github.com/babelouest/iddawc, library and dev headers, for client support
- Python 3, for the test suite only
(Some newer distributions have dev packages for iddawc, but mine did
not.)
Configure using --with-oauth (and, if you've installed iddawc into a
non-standard location, be sure to use --with-includes and --with-
libraries. Make sure either rpath or LD_LIBRARY_PATH will get you what
you need). Install as usual.
To run the test suite, make sure the contrib/authn_id extension is
installed, then init and start your dev cluster. No other configuration
is required; the test will do it for you. Switch to the src/test/python
directory, point your PG* envvars to a superuser connection on the
cluster (so that a "bare" psql will connect automatically), and run
`make installcheck`.
= Production Setup =
(but don't use this in production, please)
Actually setting up a "real" system requires knowing the specifics of
your third-party issuer of choice. Your issuer MUST implement OpenID
Discovery and the OAuth Device Authorization flow! Seriously, check
this before spending a lot of time writing a validator against an
issuer that can't actually talk to libpq.
The broad strokes are as follows:
1. Register a new public client with your issuer to get an OAuth client
ID for libpq. You'll use this as the oauth_client_id in the connection
string. (If your issuer doesn't support public clients and gives you a
client secret, you can use the oauth_client_secret connection parameter
to provide that too.)
The client you register must be able to use a device authorization
flow; some issuers require additional setup for that.
2. Set up your HBA with the 'oauth' auth method, and set the 'issuer'
and 'scope' options. 'issuer' is the base URL identifying your third-
party issuer (for example, https://accounts.google.com), and 'scope' is
the set of OAuth scopes that the client and server will need to
authenticate and/or authorize the user (e.g. "openid email").
So a sample HBA line might look like
host all all samehost oauth issuer="https://accounts.google.com" scope="openid email"
3. In postgresql.conf, set up an oauth_validator_command that's capable
of verifying bearer tokens and implements the validator protocol. This
is the hardest part. See below.
= Design =
On the client side, I've implemented the Device Authorization flow (RFC
8628, [3]https://datatracker.ietf.org/doc/html/rfc8628). What this means in practice is that libpq reaches out to a
third-party issuer (e.g. Google, Azure, etc.), identifies itself with a
client ID, and requests permission to act on behalf of the end user.
The issuer responds with a login URL and a one-time code, which libpq
presents to the user using the notice hook. The end user then navigates
to that URL, presents their code, authenticates to the issuer, and
grants permission for libpq to retrieve a bearer token. libpq grabs a
token and sends it to the server for verification.
(The bearer token, in this setup, is essentially a plaintext password,
and you must secure it like you would a plaintext password. The token
has an expiration date and can be explicitly revoked, which makes it
slightly better than a password, but this is still a step backwards
from something like SCRAM with channel binding. There are ways to bind
a bearer token to a client certificate [4]https://datatracker.ietf.org/doc/html/rfc8705, which would mitigate the
risk of token theft -- but your issuer has to support that, and I
haven't found much support in the wild.)
The server side is where things get more difficult for the DBA. The
OAUTHBEARER spec has this to say about the server side implementation:
The server validates the response according to the specification for
the OAuth Access Token Types used.
And here's what the Bearer Token specification [5]https://datatracker.ietf.org/doc/html/rfc6750#section-5.2 says:
This document does not specify the encoding or the contents of the
token; hence, detailed recommendations about the means of
guaranteeing token integrity protection are outside the scope of
this document.
It's the Wild West. Every issuer does their own thing in their own
special way. Some don't really give you a way to introspect information
about a bearer token at all, because they assume that the issuer of the
token and the consumer of the token are essentially the same service.
Some major players provide their own custom libraries, implemented in
your-language-of-choice, to deal with their particular brand of magic.
So I punted and added the oauth_validator_command GUC. A token
validator command reads the bearer token from a file descriptor that's
passed to it, then does whatever magic is necessary to validate that
token and find out who owns it. Optionally, it can look at the role
that's being connected and make sure that the token authorizes the user
to actually use that role. Then it says yea or nay to Postgres, and
optionally tells the server who the user is so that their ID can be
logged and mapped through pg_ident.
(See the commit message in 0005 for a full description of the protocol.
The test suite also has two toy implementations that illustrate the
protocol, but they provide zero security.)
This is easily the worst part of the patch, not only because my
implementation is a bad hack on OpenPipeStream(), but because it
balances the security of the entire system on the shoulders of a DBA
who does not have time to read umpteen OAuth specifications cover to
cover. More thought and coding effort is needed here, but I didn't want
to gold-plate a bad design. I'm not sure what alternatives there are
within the rules laid out by OAUTHBEARER. And the system is _extremely_
flexible, in the way that only code that's maintained by somebody else
can be.
= Patchset Roadmap =
The seven patches can be grouped into three:
1. Prep
0001 decouples the SASL code from the SCRAM implementation.
0002 makes it possible to use common/jsonapi from the frontend.
0003 lets the json_errdetail() result be freed, to avoid leaks.
2. OAUTHBEARER Implementation
0004 implements the client with libiddawc.
0005 implements server HBA support and oauth_validator_command.
3. Testing
0006 adds a simple test extension to retrieve the authn_id.
0007 adds the Python test suite I've been developing against.
The first three patches are, hopefully, generally useful outside of
this implementation, and I'll plan to register them in the next
commitfest. The middle two patches are the "interesting" pieces, and
I've split them into client and server for ease of understanding,
though neither is particularly useful without the other.
The last two patches grew out of a test suite that I originally built
to be able to exercise NSS corner cases at the protocol/byte level. It
was incredibly helpful during implementation of this new SASL
mechanism, since I could write the client and server independently of
each other and get high coverage of broken/malicious implementations.
It's based on pytest and Construct, and the Python 3 requirement might
turn some away, but I wanted to include it in case anyone else wanted
to hack on the code. src/test/python/README explains more.
= Thoughts/Reflections =
...in no particular order.
I picked OAuth 2.0 as my first experiment in federated auth mostly
because I was already familiar with pieces of it. I think SAML (via the
SAML20 mechanism, RFC 6595) would be a good companion to this proof of
concept, if there is general interest in federated deployments.
I don't really like the OAUTHBEARER spec, but I'm not sure there's a
better alternative. Everything is left as an exercise for the reader.
It's not particularly extensible. Standard OAuth is built for
authorization, not authentication, and from reading the RFC's history,
it feels like it was a hack to just get something working. New
standards like OpenID Connect have begun to fill in the gaps, but the
SASL mechanisms have not kept up. (The OPENID20 mechanism is, to my
understanding, unrelated/obsolete.) And support for helpful OIDC
features seems to be spotty in the real world.
The iddawc dependency for client-side OAuth was extremely helpful to
develop this proof of concept quickly, but I don't think it would be an
appropriate component to build a real feature on. It's extremely
heavyweight -- it incorporates a huge stack of dependencies, including
a logging framework and a web server, to implement features we would
probably never use -- and it's fairly difficult to debug in practice.
If a device authorization flow were the only thing that libpq needed to
support natively, I think we should just depend on a widely used HTTP
client, like libcurl or neon, and implement the minimum spec directly
against the existing test suite.
There are a huge number of other authorization flows besides Device
Authorization; most would involve libpq automatically opening a web
browser for you. I felt like that wasn't an appropriate thing for a
library to do by default, especially when one of the most important
clients is a command-line application. Perhaps there could be a hook
for applications to be able to override the builtin flow and substitute
their own.
Since bearer tokens are essentially plaintext passwords, the relevant
specs require the use of transport-level protection, and I think it'd
be wise for the client to require TLS to be in place before performing
the initial handshake or sending a token.
Not every OAuth issuer is also an OpenID Discovery provider, so it's
frustrating that OAUTHBEARER (which is purportedly an OAuth 2.0
feature) requires OIDD for real-world implementations. Perhaps we could
hack around this with a data: URI or something.
The client currently performs the OAuth login dance every single time a
connection is made, but a proper OAuth client would cache its tokens to
reuse later, and keep an eye on their expiration times. This would make
daily use a little more like that of Kerberos, but we would have to
design a way to create and secure a token cache on disk.
If you've read this far, thank you for your interest, and I hope you
enjoy playing with it!
--Jacob
[1]: https://datatracker.ietf.org/doc/html/rfc7628
[2]: https://github.com/babelouest/iddawc
[3]: https://datatracker.ietf.org/doc/html/rfc8628
[4]: https://datatracker.ietf.org/doc/html/rfc8705
[5]: https://datatracker.ietf.org/doc/html/rfc6750#section-5.2
Attachments:
0001-auth-generalize-SASL-mechanisms.patchtext/x-patch; name=0001-auth-generalize-SASL-mechanisms.patchDownload
From a6a65b66cc3dc5da7219378dbadb090ff10fd42b Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 13 Apr 2021 10:25:48 -0700
Subject: [PATCH 1/7] auth: generalize SASL mechanisms
Split the SASL logic out from the SCRAM implementation, so that it can
be reused by other mechanisms. New implementations will implement both
a pg_sasl_mech and a pg_be_sasl_mech.
---
src/backend/libpq/auth-scram.c | 34 ++++++++++++++---------
src/backend/libpq/auth.c | 34 ++++++++++++++++-------
src/include/libpq/sasl.h | 34 +++++++++++++++++++++++
src/include/libpq/scram.h | 13 +++------
src/interfaces/libpq/fe-auth-scram.c | 40 +++++++++++++++++++---------
src/interfaces/libpq/fe-auth.c | 16 ++++++++---
src/interfaces/libpq/fe-auth.h | 11 ++------
src/interfaces/libpq/fe-connect.c | 6 +----
src/interfaces/libpq/libpq-int.h | 14 ++++++++++
9 files changed, 139 insertions(+), 63 deletions(-)
create mode 100644 src/include/libpq/sasl.h
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index f9e1026a12..db3ca75a60 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -101,11 +101,25 @@
#include "common/sha2.h"
#include "libpq/auth.h"
#include "libpq/crypt.h"
+#include "libpq/sasl.h"
#include "libpq/scram.h"
#include "miscadmin.h"
#include "utils/builtins.h"
#include "utils/timestamp.h"
+static void scram_get_mechanisms(Port *port, StringInfo buf);
+static void *scram_init(Port *port, const char *selected_mech,
+ const char *shadow_pass);
+static int scram_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, char **logdetail);
+
+/* Mechanism declaration */
+const pg_be_sasl_mech pg_be_scram_mech = {
+ scram_get_mechanisms,
+ scram_init,
+ scram_exchange,
+};
+
/*
* Status data for a SCRAM authentication exchange. This should be kept
* internal to this file.
@@ -170,16 +184,14 @@ static char *sanitize_str(const char *s);
static char *scram_mock_salt(const char *username);
/*
- * pg_be_scram_get_mechanisms
- *
* Get a list of SASL mechanisms that this module supports.
*
* For the convenience of building the FE/BE packet that lists the
* mechanisms, the names are appended to the given StringInfo buffer,
* separated by '\0' bytes.
*/
-void
-pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
+static void
+scram_get_mechanisms(Port *port, StringInfo buf)
{
/*
* Advertise the mechanisms in decreasing order of importance. So the
@@ -199,8 +211,6 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
}
/*
- * pg_be_scram_init
- *
* Initialize a new SCRAM authentication exchange status tracker. This
* needs to be called before doing any exchange. It will be filled later
* after the beginning of the exchange with authentication information.
@@ -215,10 +225,8 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
* an authentication exchange, but it will fail, as if an incorrect password
* was given.
*/
-void *
-pg_be_scram_init(Port *port,
- const char *selected_mech,
- const char *shadow_pass)
+static void *
+scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
{
scram_state *state;
bool got_secret;
@@ -325,9 +333,9 @@ pg_be_scram_init(Port *port,
* string at *logdetail that will be sent to the postmaster log (but not
* the client).
*/
-int
-pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
- char **output, int *outputlen, char **logdetail)
+static int
+scram_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, char **logdetail)
{
scram_state *state = (scram_state *) opaq;
int result;
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 68372fcea8..e20740a7c5 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -26,11 +26,11 @@
#include "commands/user.h"
#include "common/ip.h"
#include "common/md5.h"
-#include "common/scram-common.h"
#include "libpq/auth.h"
#include "libpq/crypt.h"
#include "libpq/libpq.h"
#include "libpq/pqformat.h"
+#include "libpq/sasl.h"
#include "libpq/scram.h"
#include "miscadmin.h"
#include "port/pg_bswap.h"
@@ -51,6 +51,13 @@ static void auth_failed(Port *port, int status, char *logdetail);
static char *recv_password_packet(Port *port);
static void set_authn_id(Port *port, const char *id);
+/*----------------------------------------------------------------
+ * SASL common authentication
+ *----------------------------------------------------------------
+ */
+static int SASL_exchange(const pg_be_sasl_mech *mech, Port *port,
+ char *shadow_pass, char **logdetail);
+
/*----------------------------------------------------------------
* Password-based authentication methods (password, md5, and scram-sha-256)
@@ -912,12 +919,13 @@ CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail)
}
static int
-CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
+SASL_exchange(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
+ char **logdetail)
{
StringInfoData sasl_mechs;
int mtype;
StringInfoData buf;
- void *scram_opaq = NULL;
+ void *opaq = NULL;
char *output = NULL;
int outputlen = 0;
const char *input;
@@ -931,7 +939,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
*/
initStringInfo(&sasl_mechs);
- pg_be_scram_get_mechanisms(port, &sasl_mechs);
+ mech->get_mechanisms(port, &sasl_mechs);
/* Put another '\0' to mark that list is finished. */
appendStringInfoChar(&sasl_mechs, '\0');
@@ -998,7 +1006,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
* This is because we don't want to reveal to an attacker what
* usernames are valid, nor which users have a valid password.
*/
- scram_opaq = pg_be_scram_init(port, selected_mech, shadow_pass);
+ opaq = mech->init(port, selected_mech, shadow_pass);
inputlen = pq_getmsgint(&buf, 4);
if (inputlen == -1)
@@ -1022,12 +1030,11 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
Assert(input == NULL || input[inputlen] == '\0');
/*
- * we pass 'logdetail' as NULL when doing a mock authentication,
- * because we should already have a better error message in that case
+ * Hand the incoming message to the mechanism implementation.
*/
- result = pg_be_scram_exchange(scram_opaq, input, inputlen,
- &output, &outputlen,
- logdetail);
+ result = mech->exchange(opaq, input, inputlen,
+ &output, &outputlen,
+ logdetail);
/* input buffer no longer used */
pfree(buf.data);
@@ -1039,6 +1046,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
*/
elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
+ /* TODO: SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
if (result == SASL_EXCHANGE_SUCCESS)
sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
else
@@ -1057,6 +1065,12 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
return STATUS_OK;
}
+static int
+CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
+{
+ return SASL_exchange(&pg_be_scram_mech, port, shadow_pass, logdetail);
+}
+
/*----------------------------------------------------------------
* GSSAPI authentication system
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
new file mode 100644
index 0000000000..8c9c9983d4
--- /dev/null
+++ b/src/include/libpq/sasl.h
@@ -0,0 +1,34 @@
+/*-------------------------------------------------------------------------
+ *
+ * sasl.h
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/libpq/sasl.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef PG_SASL_H
+#define PG_SASL_H
+
+#include "libpq/libpq-be.h"
+
+/* Status codes for message exchange */
+#define SASL_EXCHANGE_CONTINUE 0
+#define SASL_EXCHANGE_SUCCESS 1
+#define SASL_EXCHANGE_FAILURE 2
+
+/* Backend mechanism API */
+typedef void (*pg_be_sasl_mechanism_func)(Port *, StringInfo);
+typedef void *(*pg_be_sasl_init_func)(Port *, const char *, const char *);
+typedef int (*pg_be_sasl_exchange_func)(void *, const char *, int, char **, int *, char **);
+
+typedef struct
+{
+ pg_be_sasl_mechanism_func get_mechanisms;
+ pg_be_sasl_init_func init;
+ pg_be_sasl_exchange_func exchange;
+} pg_be_sasl_mech;
+
+#endif /* PG_SASL_H */
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index 2c879150da..9e4540bde3 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -15,17 +15,10 @@
#include "lib/stringinfo.h"
#include "libpq/libpq-be.h"
+#include "libpq/sasl.h"
-/* Status codes for message exchange */
-#define SASL_EXCHANGE_CONTINUE 0
-#define SASL_EXCHANGE_SUCCESS 1
-#define SASL_EXCHANGE_FAILURE 2
-
-/* Routines dedicated to authentication */
-extern void pg_be_scram_get_mechanisms(Port *port, StringInfo buf);
-extern void *pg_be_scram_init(Port *port, const char *selected_mech, const char *shadow_pass);
-extern int pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
- char **output, int *outputlen, char **logdetail);
+/* Implementation */
+extern const pg_be_sasl_mech pg_be_scram_mech;
/* Routines to handle and check SCRAM-SHA-256 secret */
extern char *pg_be_scram_build_secret(const char *password);
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index 5881386e37..04d5703d89 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -21,6 +21,22 @@
#include "fe-auth.h"
+/* The exported SCRAM callback mechanism. */
+static void *scram_init(PGconn *conn, const char *password,
+ const char *sasl_mechanism);
+static void scram_exchange(void *opaq, char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success);
+static bool scram_channel_bound(void *opaq);
+static void scram_free(void *opaq);
+
+const pg_sasl_mech pg_scram_mech = {
+ scram_init,
+ scram_exchange,
+ scram_channel_bound,
+ scram_free,
+};
+
/*
* Status of exchange messages used for SCRAM authentication via the
* SASL protocol.
@@ -72,10 +88,10 @@ static bool calculate_client_proof(fe_scram_state *state,
/*
* Initialize SCRAM exchange status.
*/
-void *
-pg_fe_scram_init(PGconn *conn,
- const char *password,
- const char *sasl_mechanism)
+static void *
+scram_init(PGconn *conn,
+ const char *password,
+ const char *sasl_mechanism)
{
fe_scram_state *state;
char *prep_password;
@@ -128,8 +144,8 @@ pg_fe_scram_init(PGconn *conn,
* Note that the caller must also ensure that the exchange was actually
* successful.
*/
-bool
-pg_fe_scram_channel_bound(void *opaq)
+static bool
+scram_channel_bound(void *opaq)
{
fe_scram_state *state = (fe_scram_state *) opaq;
@@ -152,8 +168,8 @@ pg_fe_scram_channel_bound(void *opaq)
/*
* Free SCRAM exchange status
*/
-void
-pg_fe_scram_free(void *opaq)
+static void
+scram_free(void *opaq)
{
fe_scram_state *state = (fe_scram_state *) opaq;
@@ -188,10 +204,10 @@ pg_fe_scram_free(void *opaq)
/*
* Exchange a SCRAM message with backend.
*/
-void
-pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
- char **output, int *outputlen,
- bool *done, bool *success)
+static void
+scram_exchange(void *opaq, char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success)
{
fe_scram_state *state = (fe_scram_state *) opaq;
PGconn *conn = state->conn;
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index e8062647e6..d5cbac108e 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -482,7 +482,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
* channel_binding is not disabled.
*/
if (conn->channel_binding[0] != 'd') /* disable */
+ {
selected_mechanism = SCRAM_SHA_256_PLUS_NAME;
+ conn->sasl = &pg_scram_mech;
+ }
#else
/*
* The client does not support channel binding. If it is
@@ -516,7 +519,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
}
else if (strcmp(mechanism_buf.data, SCRAM_SHA_256_NAME) == 0 &&
!selected_mechanism)
+ {
selected_mechanism = SCRAM_SHA_256_NAME;
+ conn->sasl = &pg_scram_mech;
+ }
}
if (!selected_mechanism)
@@ -555,20 +561,22 @@ pg_SASL_init(PGconn *conn, int payloadlen)
goto error;
}
+ Assert(conn->sasl);
+
/*
* Initialize the SASL state information with all the information gathered
* during the initial exchange.
*
* Note: Only tls-unique is supported for the moment.
*/
- conn->sasl_state = pg_fe_scram_init(conn,
+ conn->sasl_state = conn->sasl->init(conn,
password,
selected_mechanism);
if (!conn->sasl_state)
goto oom_error;
/* Get the mechanism-specific Initial Client Response, if any */
- pg_fe_scram_exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state,
NULL, -1,
&initialresponse, &initialresponselen,
&done, &success);
@@ -649,7 +657,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
/* For safety and convenience, ensure the buffer is NULL-terminated. */
challenge[payloadlen] = '\0';
- pg_fe_scram_exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state,
challenge, payloadlen,
&output, &outputlen,
&done, &success);
@@ -830,7 +838,7 @@ check_expected_areq(AuthRequest areq, PGconn *conn)
case AUTH_REQ_SASL_FIN:
break;
case AUTH_REQ_OK:
- if (!pg_fe_scram_channel_bound(conn->sasl_state))
+ if (!conn->sasl || !conn->sasl->channel_bound(conn->sasl_state))
{
appendPQExpBufferStr(&conn->errorMessage,
libpq_gettext("channel binding required, but server authenticated client without channel binding\n"));
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 7877dcbd09..1e4fcbff62 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -22,15 +22,8 @@
extern int pg_fe_sendauth(AuthRequest areq, int payloadlen, PGconn *conn);
extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
-/* Prototypes for functions in fe-auth-scram.c */
-extern void *pg_fe_scram_init(PGconn *conn,
- const char *password,
- const char *sasl_mechanism);
-extern bool pg_fe_scram_channel_bound(void *opaq);
-extern void pg_fe_scram_free(void *opaq);
-extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
- char **output, int *outputlen,
- bool *done, bool *success);
+/* Mechanisms in fe-auth-scram.c */
+extern const pg_sasl_mech pg_scram_mech;
extern char *pg_fe_scram_build_secret(const char *password);
#endif /* FE_AUTH_H */
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 80703698b8..10d007582c 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -517,11 +517,7 @@ pqDropConnection(PGconn *conn, bool flushInput)
#endif
if (conn->sasl_state)
{
- /*
- * XXX: if support for more authentication mechanisms is added, this
- * needs to call the right 'free' function.
- */
- pg_fe_scram_free(conn->sasl_state);
+ conn->sasl->free(conn->sasl_state);
conn->sasl_state = NULL;
}
}
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index e81dc37906..25eaa231c5 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -339,6 +339,19 @@ typedef struct pg_conn_host
* found in password file. */
} pg_conn_host;
+typedef void *(*pg_sasl_init_func)(PGconn *, const char *, const char *);
+typedef void (*pg_sasl_exchange_func)(void *, char *, int, char **, int *, bool *, bool *);
+typedef bool (*pg_sasl_channel_bound_func)(void *);
+typedef void (*pg_sasl_free_func)(void *);
+
+typedef struct
+{
+ pg_sasl_init_func init;
+ pg_sasl_exchange_func exchange;
+ pg_sasl_channel_bound_func channel_bound;
+ pg_sasl_free_func free;
+} pg_sasl_mech;
+
/*
* PGconn stores all the state data associated with a single connection
* to a backend.
@@ -500,6 +513,7 @@ struct pg_conn
PGresult *next_result; /* next result (used in single-row mode) */
/* Assorted state for SASL, SSL, GSS, etc */
+ const pg_sasl_mech *sasl;
void *sasl_state;
/* SSL structures */
--
2.25.1
0002-src-common-remove-logging-from-jsonapi-for-shlib.patchtext/x-patch; name=0002-src-common-remove-logging-from-jsonapi-for-shlib.patchDownload
From 0541598e4f0bad1b9ff41a4640ec69491b393d54 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Mon, 3 May 2021 11:15:15 -0700
Subject: [PATCH 2/7] src/common: remove logging from jsonapi for shlib
The can't-happen code in jsonapi was pulling in logging code, which for
libpq is not included.
---
src/common/Makefile | 4 ++++
src/common/jsonapi.c | 11 ++++++++---
2 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/src/common/Makefile b/src/common/Makefile
index 38a8599337..6f1039bc78 100644
--- a/src/common/Makefile
+++ b/src/common/Makefile
@@ -28,6 +28,10 @@ subdir = src/common
top_builddir = ../..
include $(top_builddir)/src/Makefile.global
+# For use in shared libraries, jsonapi needs to not link in any logging
+# functions.
+override CFLAGS_SL += -DJSONAPI_NO_LOG
+
# don't include subdirectory-path-dependent -I and -L switches
STD_CPPFLAGS := $(filter-out -I$(top_srcdir)/src/include -I$(top_builddir)/src/include,$(CPPFLAGS))
STD_LDFLAGS := $(filter-out -L$(top_builddir)/src/common -L$(top_builddir)/src/port,$(LDFLAGS))
diff --git a/src/common/jsonapi.c b/src/common/jsonapi.c
index 1bf38d7b42..6b6001b118 100644
--- a/src/common/jsonapi.c
+++ b/src/common/jsonapi.c
@@ -27,11 +27,16 @@
#endif
#ifdef FRONTEND
-#define check_stack_depth()
-#define json_log_and_abort(...) \
+# define check_stack_depth()
+# ifdef JSONAPI_NO_LOG
+# define json_log_and_abort(...) \
+ do { fprintf(stderr, __VA_ARGS__); exit(1); } while(0)
+# else
+# define json_log_and_abort(...) \
do { pg_log_fatal(__VA_ARGS__); exit(1); } while(0)
+# endif
#else
-#define json_log_and_abort(...) elog(ERROR, __VA_ARGS__)
+# define json_log_and_abort(...) elog(ERROR, __VA_ARGS__)
#endif
/*
--
2.25.1
0003-common-jsonapi-always-palloc-the-error-strings.patchtext/x-patch; name=0003-common-jsonapi-always-palloc-the-error-strings.patchDownload
From 5ad4b3c7835fe9e0f284702ec7b827c27770854e Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Mon, 3 May 2021 15:38:26 -0700
Subject: [PATCH 3/7] common/jsonapi: always palloc the error strings
...so that client code can pfree() to avoid memory leaks in long-running
operations.
---
src/common/jsonapi.c | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/src/common/jsonapi.c b/src/common/jsonapi.c
index 6b6001b118..f7304f584f 100644
--- a/src/common/jsonapi.c
+++ b/src/common/jsonapi.c
@@ -1089,7 +1089,7 @@ json_errdetail(JsonParseErrorType error, JsonLexContext *lex)
return psprintf(_("Expected JSON value, but found \"%s\"."),
extract_token(lex));
case JSON_EXPECTED_MORE:
- return _("The input string ended unexpectedly.");
+ return pstrdup(_("The input string ended unexpectedly."));
case JSON_EXPECTED_OBJECT_FIRST:
return psprintf(_("Expected string or \"}\", but found \"%s\"."),
extract_token(lex));
@@ -1103,16 +1103,16 @@ json_errdetail(JsonParseErrorType error, JsonLexContext *lex)
return psprintf(_("Token \"%s\" is invalid."),
extract_token(lex));
case JSON_UNICODE_CODE_POINT_ZERO:
- return _("\\u0000 cannot be converted to text.");
+ return pstrdup(_("\\u0000 cannot be converted to text."));
case JSON_UNICODE_ESCAPE_FORMAT:
- return _("\"\\u\" must be followed by four hexadecimal digits.");
+ return pstrdup(_("\"\\u\" must be followed by four hexadecimal digits."));
case JSON_UNICODE_HIGH_ESCAPE:
/* note: this case is only reachable in frontend not backend */
- return _("Unicode escape values cannot be used for code point values above 007F when the encoding is not UTF8.");
+ return pstrdup(_("Unicode escape values cannot be used for code point values above 007F when the encoding is not UTF8."));
case JSON_UNICODE_HIGH_SURROGATE:
- return _("Unicode high surrogate must not follow a high surrogate.");
+ return pstrdup(_("Unicode high surrogate must not follow a high surrogate."));
case JSON_UNICODE_LOW_SURROGATE:
- return _("Unicode low surrogate must follow a high surrogate.");
+ return pstrdup(_("Unicode low surrogate must follow a high surrogate."));
}
/*
--
2.25.1
0004-libpq-add-OAUTHBEARER-SASL-mechanism.patchtext/x-patch; name=0004-libpq-add-OAUTHBEARER-SASL-mechanism.patchDownload
From e3d95709e147ae3670bd8acd0c265493a6116b9a Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 13 Apr 2021 10:27:27 -0700
Subject: [PATCH 4/7] libpq: add OAUTHBEARER SASL mechanism
DO NOT USE THIS PROOF OF CONCEPT IN PRODUCTION.
Implement OAUTHBEARER (RFC 7628) and OAuth 2.0 Device Authorization
Grants (RFC 8628) on the client side. When speaking to a OAuth-enabled
server, it looks a bit like this:
$ psql 'host=example.org oauth_client_id=f02c6361-0635-...'
Visit https://oauth.example.org/login and enter the code: FPQ2-M4BG
The OAuth issuer must support device authorization. No other OAuth flows
are currently implemented.
The client implementation requires libiddawc and its development
headers. Configure --with-oauth (and --with-includes/--with-libraries to
point at the iddawc installation, if it's in a custom location).
Several TODOs:
- don't retry forever if the server won't accept our token
- perform several sanity checks on the OAuth issuer's responses
- handle cases where the client has been set up with an issuer and
scope, but the Postgres server wants to use something different
- improve error debuggability during the OAuth handshake
- ...and more.
---
configure | 100 ++++
configure.ac | 19 +
src/Makefile.global.in | 1 +
src/include/common/oauth-common.h | 19 +
src/include/pg_config.h.in | 6 +
src/interfaces/libpq/Makefile | 7 +-
src/interfaces/libpq/fe-auth-oauth.c | 724 +++++++++++++++++++++++++++
src/interfaces/libpq/fe-auth-scram.c | 6 +-
src/interfaces/libpq/fe-auth.c | 42 +-
src/interfaces/libpq/fe-auth.h | 3 +
src/interfaces/libpq/fe-connect.c | 38 ++
src/interfaces/libpq/libpq-int.h | 10 +-
12 files changed, 956 insertions(+), 19 deletions(-)
create mode 100644 src/include/common/oauth-common.h
create mode 100644 src/interfaces/libpq/fe-auth-oauth.c
diff --git a/configure b/configure
index e9b98f442f..c3b7a89bf0 100755
--- a/configure
+++ b/configure
@@ -713,6 +713,7 @@ with_uuid
with_readline
with_systemd
with_selinux
+with_oauth
with_ldap
with_krb_srvnam
krb_srvtab
@@ -856,6 +857,7 @@ with_krb_srvnam
with_pam
with_bsd_auth
with_ldap
+with_oauth
with_bonjour
with_selinux
with_systemd
@@ -1562,6 +1564,7 @@ Optional Packages:
--with-pam build with PAM support
--with-bsd-auth build with BSD Authentication support
--with-ldap build with LDAP support
+ --with-oauth build with OAuth 2.0 support
--with-bonjour build with Bonjour support
--with-selinux build with SELinux support
--with-systemd build with systemd support
@@ -8046,6 +8049,42 @@ $as_echo "$with_ldap" >&6; }
+#
+# OAuth 2.0
+#
+{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether to build with OAuth support" >&5
+$as_echo_n "checking whether to build with OAuth support... " >&6; }
+
+
+
+# Check whether --with-oauth was given.
+if test "${with_oauth+set}" = set; then :
+ withval=$with_oauth;
+ case $withval in
+ yes)
+
+$as_echo "#define USE_OAUTH 1" >>confdefs.h
+
+ ;;
+ no)
+ :
+ ;;
+ *)
+ as_fn_error $? "no argument expected for --with-oauth option" "$LINENO" 5
+ ;;
+ esac
+
+else
+ with_oauth=no
+
+fi
+
+
+{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $with_oauth" >&5
+$as_echo "$with_oauth" >&6; }
+
+
+
#
# Bonjour
#
@@ -13048,6 +13087,56 @@ fi
+if test "$with_oauth" = yes ; then
+ { $as_echo "$as_me:${as_lineno-$LINENO}: checking for i_init_session in -liddawc" >&5
+$as_echo_n "checking for i_init_session in -liddawc... " >&6; }
+if ${ac_cv_lib_iddawc_i_init_session+:} false; then :
+ $as_echo_n "(cached) " >&6
+else
+ ac_check_lib_save_LIBS=$LIBS
+LIBS="-liddawc $LIBS"
+cat confdefs.h - <<_ACEOF >conftest.$ac_ext
+/* end confdefs.h. */
+
+/* Override any GCC internal prototype to avoid an error.
+ Use char because int might match the return type of a GCC
+ builtin and then its argument prototype would still apply. */
+#ifdef __cplusplus
+extern "C"
+#endif
+char i_init_session ();
+int
+main ()
+{
+return i_init_session ();
+ ;
+ return 0;
+}
+_ACEOF
+if ac_fn_c_try_link "$LINENO"; then :
+ ac_cv_lib_iddawc_i_init_session=yes
+else
+ ac_cv_lib_iddawc_i_init_session=no
+fi
+rm -f core conftest.err conftest.$ac_objext \
+ conftest$ac_exeext conftest.$ac_ext
+LIBS=$ac_check_lib_save_LIBS
+fi
+{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_iddawc_i_init_session" >&5
+$as_echo "$ac_cv_lib_iddawc_i_init_session" >&6; }
+if test "x$ac_cv_lib_iddawc_i_init_session" = xyes; then :
+ cat >>confdefs.h <<_ACEOF
+#define HAVE_LIBIDDAWC 1
+_ACEOF
+
+ LIBS="-liddawc $LIBS"
+
+else
+ as_fn_error $? "library 'iddawc' is required for OAuth support" "$LINENO" 5
+fi
+
+fi
+
# for contrib/sepgsql
if test "$with_selinux" = yes; then
{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for security_compute_create_name in -lselinux" >&5
@@ -13942,6 +14031,17 @@ fi
done
+fi
+
+if test "$with_oauth" != no; then
+ ac_fn_c_check_header_mongrel "$LINENO" "iddawc.h" "ac_cv_header_iddawc_h" "$ac_includes_default"
+if test "x$ac_cv_header_iddawc_h" = xyes; then :
+
+else
+ as_fn_error $? "header file <iddawc.h> is required for OAuth" "$LINENO" 5
+fi
+
+
fi
if test "$PORTNAME" = "win32" ; then
diff --git a/configure.ac b/configure.ac
index 3b42d8bdc9..f15f6f64d5 100644
--- a/configure.ac
+++ b/configure.ac
@@ -842,6 +842,17 @@ AC_MSG_RESULT([$with_ldap])
AC_SUBST(with_ldap)
+#
+# OAuth 2.0
+#
+AC_MSG_CHECKING([whether to build with OAuth support])
+PGAC_ARG_BOOL(with, oauth, no,
+ [build with OAuth 2.0 support],
+ [AC_DEFINE([USE_OAUTH], 1, [Define to 1 to build with OAuth 2.0 support. (--with-oauth)])])
+AC_MSG_RESULT([$with_oauth])
+AC_SUBST(with_oauth)
+
+
#
# Bonjour
#
@@ -1313,6 +1324,10 @@ fi
AC_SUBST(LDAP_LIBS_FE)
AC_SUBST(LDAP_LIBS_BE)
+if test "$with_oauth" = yes ; then
+ AC_CHECK_LIB(iddawc, i_init_session, [], [AC_MSG_ERROR([library 'iddawc' is required for OAuth support])])
+fi
+
# for contrib/sepgsql
if test "$with_selinux" = yes; then
AC_CHECK_LIB(selinux, security_compute_create_name, [],
@@ -1523,6 +1538,10 @@ elif test "$with_uuid" = ossp ; then
[AC_MSG_ERROR([header file <ossp/uuid.h> or <uuid.h> is required for OSSP UUID])])])
fi
+if test "$with_oauth" != no; then
+ AC_CHECK_HEADER(iddawc.h, [], [AC_MSG_ERROR([header file <iddawc.h> is required for OAuth])])
+fi
+
if test "$PORTNAME" = "win32" ; then
AC_CHECK_HEADERS(crtdefs.h)
fi
diff --git a/src/Makefile.global.in b/src/Makefile.global.in
index 8f05840821..3a61dd46d3 100644
--- a/src/Makefile.global.in
+++ b/src/Makefile.global.in
@@ -193,6 +193,7 @@ with_ldap = @with_ldap@
with_libxml = @with_libxml@
with_libxslt = @with_libxslt@
with_llvm = @with_llvm@
+with_oauth = @with_oauth@
with_system_tzdata = @with_system_tzdata@
with_uuid = @with_uuid@
with_zlib = @with_zlib@
diff --git a/src/include/common/oauth-common.h b/src/include/common/oauth-common.h
new file mode 100644
index 0000000000..3fa95ac7e8
--- /dev/null
+++ b/src/include/common/oauth-common.h
@@ -0,0 +1,19 @@
+/*-------------------------------------------------------------------------
+ *
+ * oauth-common.h
+ * Declarations for helper functions used for OAuth/OIDC authentication
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/common/oauth-common.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef OAUTH_COMMON_H
+#define OAUTH_COMMON_H
+
+/* Name of SASL mechanism per IANA */
+#define OAUTHBEARER_NAME "OAUTHBEARER"
+
+#endif /* OAUTH_COMMON_H */
diff --git a/src/include/pg_config.h.in b/src/include/pg_config.h.in
index 783b8fc1ba..db5bc56ac5 100644
--- a/src/include/pg_config.h.in
+++ b/src/include/pg_config.h.in
@@ -319,6 +319,9 @@
/* Define to 1 if you have the `crypto' library (-lcrypto). */
#undef HAVE_LIBCRYPTO
+/* Define to 1 if you have the `iddawc' library (-liddawc). */
+#undef HAVE_LIBIDDAWC
+
/* Define to 1 if you have the `ldap' library (-lldap). */
#undef HAVE_LIBLDAP
@@ -920,6 +923,9 @@
/* Define to select named POSIX semaphores. */
#undef USE_NAMED_POSIX_SEMAPHORES
+/* Define to 1 to build with OAuth 2.0 support. (--with-oauth) */
+#undef USE_OAUTH
+
/* Define to 1 to build with OpenSSL support. (--with-ssl=openssl) */
#undef USE_OPENSSL
diff --git a/src/interfaces/libpq/Makefile b/src/interfaces/libpq/Makefile
index 0c4e55b6ad..8e89d50900 100644
--- a/src/interfaces/libpq/Makefile
+++ b/src/interfaces/libpq/Makefile
@@ -62,6 +62,11 @@ OBJS += \
fe-secure-gssapi.o
endif
+ifeq ($(with_oauth),yes)
+OBJS += \
+ fe-auth-oauth.o
+endif
+
ifeq ($(PORTNAME), cygwin)
override shlib = cyg$(NAME)$(DLSUFFIX)
endif
@@ -83,7 +88,7 @@ endif
# that are built correctly for use in a shlib.
SHLIB_LINK_INTERNAL = -lpgcommon_shlib -lpgport_shlib
ifneq ($(PORTNAME), win32)
-SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi_krb5 -lgss -lgssapi -lssl -lsocket -lnsl -lresolv -lintl -lm, $(LIBS)) $(LDAP_LIBS_FE) $(PTHREAD_LIBS)
+SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi_krb5 -lgss -lgssapi -lssl -liddawc -lsocket -lnsl -lresolv -lintl -lm, $(LIBS)) $(LDAP_LIBS_FE) $(PTHREAD_LIBS)
else
SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi32 -lssl -lsocket -lnsl -lresolv -lintl -lm $(PTHREAD_LIBS), $(LIBS)) $(LDAP_LIBS_FE)
endif
diff --git a/src/interfaces/libpq/fe-auth-oauth.c b/src/interfaces/libpq/fe-auth-oauth.c
new file mode 100644
index 0000000000..a27f974369
--- /dev/null
+++ b/src/interfaces/libpq/fe-auth-oauth.c
@@ -0,0 +1,724 @@
+/*-------------------------------------------------------------------------
+ *
+ * fe-auth-oauth.c
+ * The front-end (client) implementation of OAuth/OIDC authentication.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * IDENTIFICATION
+ * src/interfaces/libpq/fe-auth-oauth.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include <iddawc.h>
+
+#include "postgres_fe.h"
+
+#include "common/base64.h"
+#include "common/hmac.h"
+#include "common/jsonapi.h"
+#include "common/oauth-common.h"
+#include "fe-auth.h"
+#include "mb/pg_wchar.h"
+
+/* The exported OAuth callback mechanism. */
+static void *oauth_init(PGconn *conn, const char *password,
+ const char *sasl_mechanism);
+static void oauth_exchange(void *opaq, bool final,
+ char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success);
+static bool oauth_channel_bound(void *opaq);
+static void oauth_free(void *opaq);
+
+const pg_sasl_mech pg_oauth_mech = {
+ oauth_init,
+ oauth_exchange,
+ oauth_channel_bound,
+ oauth_free,
+};
+
+typedef enum
+{
+ FE_OAUTH_INIT,
+ FE_OAUTH_BEARER_SENT,
+ FE_OAUTH_SERVER_ERROR,
+} fe_oauth_state_enum;
+
+typedef struct
+{
+ fe_oauth_state_enum state;
+
+ PGconn *conn;
+} fe_oauth_state;
+
+static void *
+oauth_init(PGconn *conn, const char *password,
+ const char *sasl_mechanism)
+{
+ fe_oauth_state *state;
+
+ /*
+ * We only support one SASL mechanism here; anything else is programmer
+ * error.
+ */
+ Assert(sasl_mechanism != NULL);
+ Assert(!strcmp(sasl_mechanism, OAUTHBEARER_NAME));
+
+ state = malloc(sizeof(*state));
+ if (!state)
+ return NULL;
+
+ state->state = FE_OAUTH_INIT;
+ state->conn = conn;
+
+ return state;
+}
+
+static const char *
+iddawc_error_string(int errcode)
+{
+ switch (errcode)
+ {
+ case I_OK:
+ return "I_OK";
+
+ case I_ERROR:
+ return "I_ERROR";
+
+ case I_ERROR_PARAM:
+ return "I_ERROR_PARAM";
+
+ case I_ERROR_MEMORY:
+ return "I_ERROR_MEMORY";
+
+ case I_ERROR_UNAUTHORIZED:
+ return "I_ERROR_UNAUTHORIZED";
+
+ case I_ERROR_SERVER:
+ return "I_ERROR_SERVER";
+ }
+
+ return "<unknown>";
+}
+
+static void
+iddawc_error(PGconn *conn, int errcode, const char *msg)
+{
+ appendPQExpBufferStr(&conn->errorMessage, libpq_gettext(msg));
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext(" (iddawc error %s)\n"),
+ iddawc_error_string(errcode));
+}
+
+static void
+iddawc_request_error(PGconn *conn, struct _i_session *i, int err, const char *msg)
+{
+ const char *error_code;
+ const char *desc;
+
+ appendPQExpBuffer(&conn->errorMessage, "%s: ", libpq_gettext(msg));
+
+ error_code = i_get_str_parameter(i, I_OPT_ERROR);
+ if (!error_code)
+ {
+ /*
+ * The server didn't give us any useful information, so just print the
+ * error code.
+ */
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("(iddawc error %s)\n"),
+ iddawc_error_string(err));
+ return;
+ }
+
+ /* If the server gave a string description, print that too. */
+ desc = i_get_str_parameter(i, I_OPT_ERROR_DESCRIPTION);
+ if (desc)
+ appendPQExpBuffer(&conn->errorMessage, "%s ", desc);
+
+ appendPQExpBuffer(&conn->errorMessage, "(%s)\n", error_code);
+}
+
+static char *
+get_auth_token(PGconn *conn)
+{
+ PQExpBuffer token_buf = NULL;
+ struct _i_session session;
+ int err;
+ int auth_method;
+ bool user_prompted = false;
+ const char *verification_uri;
+ const char *user_code;
+ const char *access_token;
+ const char *token_type;
+ char *token = NULL;
+
+ if (!conn->oauth_discovery_uri)
+ return strdup(""); /* ask the server for one */
+
+ i_init_session(&session);
+
+ if (!conn->oauth_client_id)
+ {
+ /* We can't talk to a server without a client identifier. */
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("no oauth_client_id is set for the connection"));
+ goto cleanup;
+ }
+
+ token_buf = createPQExpBuffer();
+
+ if (!token_buf)
+ goto cleanup;
+
+ err = i_set_str_parameter(&session, I_OPT_OPENID_CONFIG_ENDPOINT, conn->oauth_discovery_uri);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set OpenID config endpoint");
+ goto cleanup;
+ }
+
+ err = i_get_openid_config(&session);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to fetch OpenID discovery document");
+ goto cleanup;
+ }
+
+ if (!i_get_str_parameter(&session, I_OPT_TOKEN_ENDPOINT))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer has no token endpoint"));
+ goto cleanup;
+ }
+
+ if (!i_get_str_parameter(&session, I_OPT_DEVICE_AUTHORIZATION_ENDPOINT))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer does not support device authorization"));
+ goto cleanup;
+ }
+
+ err = i_set_response_type(&session, I_RESPONSE_TYPE_DEVICE_CODE);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set device code response type");
+ goto cleanup;
+ }
+
+ auth_method = I_TOKEN_AUTH_METHOD_NONE;
+ if (conn->oauth_client_secret && *conn->oauth_client_secret)
+ auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
+
+ err = i_set_parameter_list(&session,
+ I_OPT_CLIENT_ID, conn->oauth_client_id,
+ I_OPT_CLIENT_SECRET, conn->oauth_client_secret,
+ I_OPT_TOKEN_METHOD, auth_method,
+ I_OPT_SCOPE, conn->oauth_scope,
+ I_OPT_NONE
+ );
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set client identifier");
+ goto cleanup;
+ }
+
+ err = i_run_device_auth_request(&session);
+ if (err)
+ {
+ iddawc_request_error(conn, &session, err,
+ "failed to obtain device authorization");
+ goto cleanup;
+ }
+
+ verification_uri = i_get_str_parameter(&session, I_OPT_DEVICE_AUTH_VERIFICATION_URI);
+ if (!verification_uri)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a verification URI"));
+ goto cleanup;
+ }
+
+ user_code = i_get_str_parameter(&session, I_OPT_DEVICE_AUTH_USER_CODE);
+ if (!user_code)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a user code"));
+ goto cleanup;
+ }
+
+ /*
+ * Poll the token endpoint until either the user logs in and authorizes the
+ * use of a token, or a hard failure occurs. We perform one ping _before_
+ * prompting the user, so that we don't make them do the work of logging in
+ * only to find that the token endpoint is completely unreachable.
+ */
+ err = i_run_token_request(&session);
+ while (err)
+ {
+ const char *error_code;
+ uint interval;
+
+ error_code = i_get_str_parameter(&session, I_OPT_ERROR);
+
+ /*
+ * authorization_pending and slow_down are the only acceptable errors;
+ * anything else and we bail.
+ */
+ if (!error_code || (strcmp(error_code, "authorization_pending")
+ && strcmp(error_code, "slow_down")))
+ {
+ iddawc_request_error(conn, &session, err,
+ "OAuth token retrieval failed");
+ goto cleanup;
+ }
+
+ if (!user_prompted)
+ {
+ /*
+ * Now that we know the token endpoint isn't broken, give the user
+ * the login instructions.
+ */
+ pqInternalNotice(&conn->noticeHooks,
+ "Visit %s and enter the code: %s",
+ verification_uri, user_code);
+
+ user_prompted = true;
+ }
+
+ /*
+ * We are required to wait between polls; the server tells us how long.
+ * TODO: if interval's not set, we need to default to five seconds
+ * TODO: sanity check the interval
+ */
+ interval = i_get_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL);
+
+ /*
+ * A slow_down error requires us to permanently increase our retry
+ * interval by five seconds. RFC 8628, Sec. 3.5.
+ */
+ if (!strcmp(error_code, "slow_down"))
+ {
+ interval += 5;
+ i_set_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL, interval);
+ }
+
+ sleep(interval);
+
+ /*
+ * XXX Reset the error code before every call, because iddawc won't do
+ * that for us. This matters if the server first sends a "pending" error
+ * code, then later hard-fails without sending an error code to
+ * overwrite the first one.
+ *
+ * That we have to do this at all seems like a bug in iddawc.
+ */
+ i_set_str_parameter(&session, I_OPT_ERROR, NULL);
+
+ err = i_run_token_request(&session);
+ }
+
+ access_token = i_get_str_parameter(&session, I_OPT_ACCESS_TOKEN);
+ token_type = i_get_str_parameter(&session, I_OPT_TOKEN_TYPE);
+
+ if (!access_token || !token_type || strcasecmp(token_type, "Bearer"))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a bearer token"));
+ goto cleanup;
+ }
+
+ appendPQExpBufferStr(token_buf, "Bearer ");
+ appendPQExpBufferStr(token_buf, access_token);
+
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+
+ token = strdup(token_buf->data);
+
+cleanup:
+ if (token_buf)
+ destroyPQExpBuffer(token_buf);
+ i_clean_session(&session);
+
+ return token;
+}
+
+#define kvsep "\x01"
+
+static char *
+client_initial_response(PGconn *conn)
+{
+ static const char * const resp_format = "n,," kvsep "auth=%s" kvsep kvsep;
+
+ PQExpBuffer token_buf;
+ PQExpBuffer discovery_buf = NULL;
+ char *token = NULL;
+ char *response = NULL;
+
+ token_buf = createPQExpBuffer();
+ if (!token_buf)
+ goto cleanup;
+
+ /*
+ * If we don't yet have a discovery URI, but the user gave us an explicit
+ * issuer, use the .well-known discovery URI for that issuer.
+ */
+ if (!conn->oauth_discovery_uri && conn->oauth_issuer)
+ {
+ discovery_buf = createPQExpBuffer();
+ if (!discovery_buf)
+ goto cleanup;
+
+ appendPQExpBufferStr(discovery_buf, conn->oauth_issuer);
+ appendPQExpBufferStr(discovery_buf, "/.well-known/openid-configuration");
+
+ if (PQExpBufferBroken(discovery_buf))
+ goto cleanup;
+
+ conn->oauth_discovery_uri = strdup(discovery_buf->data);
+ }
+
+ token = get_auth_token(conn);
+ if (!token)
+ goto cleanup;
+
+ appendPQExpBuffer(token_buf, resp_format, token);
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+
+ response = strdup(token_buf->data);
+
+cleanup:
+ if (token)
+ free(token);
+ if (discovery_buf)
+ destroyPQExpBuffer(discovery_buf);
+ if (token_buf)
+ destroyPQExpBuffer(token_buf);
+
+ return response;
+}
+
+#define ERROR_STATUS_FIELD "status"
+#define ERROR_SCOPE_FIELD "scope"
+#define ERROR_OPENID_CONFIGURATION_FIELD "openid-configuration"
+
+struct json_ctx
+{
+ char *errmsg; /* any non-NULL value stops all processing */
+ int nested; /* nesting level (zero is the top) */
+
+ const char *target_field_name; /* points to a static allocation */
+ char **target_field; /* see below */
+
+ /* target_field, if set, points to one of the following: */
+ char *status;
+ char *scope;
+ char *discovery_uri;
+};
+
+static void
+oauth_json_object_start(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (ctx->errmsg)
+ return; /* short-circuit */
+
+ if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ ctx->errmsg = psprintf(libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+
+ ++ctx->nested;
+}
+
+static void
+oauth_json_object_end(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (ctx->errmsg)
+ return; /* short-circuit */
+
+ --ctx->nested;
+}
+
+static void
+oauth_json_object_field_start(void *state, char *name, bool isnull)
+{
+ struct json_ctx *ctx = state;
+
+ if (ctx->errmsg)
+ {
+ /* short-circuit */
+ pfree(name);
+ return;
+ }
+
+ if (ctx->nested == 1)
+ {
+ if (!strcmp(name, ERROR_STATUS_FIELD))
+ {
+ ctx->target_field_name = ERROR_STATUS_FIELD;
+ ctx->target_field = &ctx->status;
+ }
+ else if (!strcmp(name, ERROR_SCOPE_FIELD))
+ {
+ ctx->target_field_name = ERROR_SCOPE_FIELD;
+ ctx->target_field = &ctx->scope;
+ }
+ else if (!strcmp(name, ERROR_OPENID_CONFIGURATION_FIELD))
+ {
+ ctx->target_field_name = ERROR_OPENID_CONFIGURATION_FIELD;
+ ctx->target_field = &ctx->discovery_uri;
+ }
+ }
+
+ pfree(name);
+}
+
+static void
+oauth_json_array_start(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (ctx->errmsg)
+ return; /* short-circuit */
+
+ if (!ctx->nested)
+ {
+ ctx->errmsg = pstrdup(libpq_gettext("top-level element must be an object"));
+ }
+ else if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ ctx->errmsg = psprintf(libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+}
+
+static void
+oauth_json_scalar(void *state, char *token, JsonTokenType type)
+{
+ struct json_ctx *ctx = state;
+
+ if (ctx->errmsg)
+ {
+ /* short-circuit */
+ pfree(token);
+ return;
+ }
+
+ if (!ctx->nested)
+ {
+ ctx->errmsg = pstrdup(libpq_gettext("top-level element must be an object"));
+ }
+ else if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ if (type == JSON_TOKEN_STRING)
+ {
+ *ctx->target_field = token;
+
+ ctx->target_field = NULL;
+ ctx->target_field_name = NULL;
+
+ return; /* don't pfree the token we're using */
+ }
+
+ ctx->errmsg = psprintf(libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+
+ pfree(token);
+}
+
+static bool
+handle_oauth_sasl_error(PGconn *conn, char *msg, int msglen)
+{
+ JsonLexContext *lex;
+ JsonSemAction sem = {0};
+ JsonParseErrorType err;
+ struct json_ctx ctx = {0};
+ char *errmsg = NULL;
+
+ /* Sanity check. */
+ if (strlen(msg) != msglen)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("server's error message contained an embedded NULL"));
+ return false;
+ }
+
+ lex = makeJsonLexContextCstringLen(msg, msglen, PG_UTF8, true);
+
+ sem.semstate = &ctx;
+
+ sem.object_start = oauth_json_object_start;
+ sem.object_end = oauth_json_object_end;
+ sem.object_field_start = oauth_json_object_field_start;
+ sem.array_start = oauth_json_array_start;
+ sem.scalar = oauth_json_scalar;
+
+ err = pg_parse_json(lex, &sem);
+
+ if (err != JSON_SUCCESS)
+ {
+ errmsg = json_errdetail(err, lex);
+ }
+ else if (ctx.errmsg)
+ {
+ errmsg = ctx.errmsg;
+ }
+
+ if (errmsg)
+ {
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("failed to parse server's error response: %s"),
+ errmsg);
+ pfree(errmsg);
+ return false;
+ }
+
+ /* TODO: what if these override what the user already specified? */
+ if (ctx.discovery_uri)
+ {
+ if (conn->oauth_discovery_uri)
+ free(conn->oauth_discovery_uri);
+
+ conn->oauth_discovery_uri = ctx.discovery_uri;
+ }
+
+ if (ctx.scope)
+ {
+ if (conn->oauth_scope)
+ free(conn->oauth_scope);
+
+ conn->oauth_scope = ctx.scope;
+ }
+ /* TODO: missing error scope should clear any existing connection scope */
+
+ if (!ctx.status)
+ {
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("server sent error response without a status"));
+ return false;
+ }
+
+ if (!strcmp(ctx.status, "invalid_token"))
+ {
+ /*
+ * invalid_token is the only error code we'll automatically retry for,
+ * but only if we have enough information to do so.
+ */
+ if (conn->oauth_discovery_uri)
+ conn->oauth_want_retry = true;
+ }
+ /* TODO: include status in hard failure message */
+
+ return true;
+}
+
+static void
+oauth_exchange(void *opaq, bool final,
+ char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success)
+{
+ fe_oauth_state *state = opaq;
+ PGconn *conn = state->conn;
+
+ *done = false;
+ *success = false;
+ *output = NULL;
+ *outputlen = 0;
+
+ switch (state->state)
+ {
+ case FE_OAUTH_INIT:
+ Assert(inputlen == -1);
+
+ *output = client_initial_response(conn);
+ if (!*output)
+ goto error;
+
+ *outputlen = strlen(*output);
+ state->state = FE_OAUTH_BEARER_SENT;
+
+ break;
+
+ case FE_OAUTH_BEARER_SENT:
+ if (final)
+ {
+ /* TODO: ensure there is no message content here. */
+ *done = true;
+ *success = true;
+
+ break;
+ }
+
+ /*
+ * Error message sent by the server.
+ */
+ if (!handle_oauth_sasl_error(conn, input, inputlen))
+ goto error;
+
+ /*
+ * Respond with the required dummy message (RFC 7628, sec. 3.2.3).
+ */
+ *output = strdup(kvsep);
+ *outputlen = strlen(*output); /* == 1 */
+
+ state->state = FE_OAUTH_SERVER_ERROR;
+ break;
+
+ case FE_OAUTH_SERVER_ERROR:
+ /*
+ * After an error, the server should send an error response to fail
+ * the SASL handshake, which is handled in higher layers.
+ *
+ * If we get here, the server either sent *another* challenge which
+ * isn't defined in the RFC, or completed the handshake successfully
+ * after telling us it was going to fail. Neither is acceptable.
+ */
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("server sent additional OAuth data after error\n"));
+ goto error;
+
+ default:
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("invalid OAuth exchange state\n"));
+ goto error;
+ }
+
+ return;
+
+error:
+ *done = true;
+ *success = false;
+}
+
+static bool
+oauth_channel_bound(void *opaq)
+{
+ /* This mechanism does not support channel binding. */
+ return false;
+}
+
+static void
+oauth_free(void *opaq)
+{
+ fe_oauth_state *state = opaq;
+
+ free(state);
+}
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index 04d5703d89..f2ba3bca37 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -24,7 +24,8 @@
/* The exported SCRAM callback mechanism. */
static void *scram_init(PGconn *conn, const char *password,
const char *sasl_mechanism);
-static void scram_exchange(void *opaq, char *input, int inputlen,
+static void scram_exchange(void *opaq, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success);
static bool scram_channel_bound(void *opaq);
@@ -205,7 +206,8 @@ scram_free(void *opaq)
* Exchange a SCRAM message with backend.
*/
static void
-scram_exchange(void *opaq, char *input, int inputlen,
+scram_exchange(void *opaq, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success)
{
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index d5cbac108e..690b23b9d9 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -39,6 +39,7 @@
#endif
#include "common/md5.h"
+#include "common/oauth-common.h"
#include "common/scram-common.h"
#include "fe-auth.h"
#include "libpq-fe.h"
@@ -422,7 +423,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
bool success;
const char *selected_mechanism;
PQExpBufferData mechanism_buf;
- char *password;
+ char *password = NULL;
initPQExpBuffer(&mechanism_buf);
@@ -444,8 +445,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
/*
* Parse the list of SASL authentication mechanisms in the
* AuthenticationSASL message, and select the best mechanism that we
- * support. SCRAM-SHA-256-PLUS and SCRAM-SHA-256 are the only ones
- * supported at the moment, listed by order of decreasing importance.
+ * support. Mechanisms are listed by order of decreasing importance.
*/
selected_mechanism = NULL;
for (;;)
@@ -485,6 +485,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
{
selected_mechanism = SCRAM_SHA_256_PLUS_NAME;
conn->sasl = &pg_scram_mech;
+ conn->password_needed = true;
}
#else
/*
@@ -522,7 +523,17 @@ pg_SASL_init(PGconn *conn, int payloadlen)
{
selected_mechanism = SCRAM_SHA_256_NAME;
conn->sasl = &pg_scram_mech;
+ conn->password_needed = true;
}
+#ifdef USE_OAUTH
+ else if (strcmp(mechanism_buf.data, OAUTHBEARER_NAME) == 0 &&
+ !selected_mechanism)
+ {
+ selected_mechanism = OAUTHBEARER_NAME;
+ conn->sasl = &pg_oauth_mech;
+ conn->password_needed = false;
+ }
+#endif
}
if (!selected_mechanism)
@@ -547,18 +558,19 @@ pg_SASL_init(PGconn *conn, int payloadlen)
/*
* First, select the password to use for the exchange, complaining if
- * there isn't one. Currently, all supported SASL mechanisms require a
- * password, so we can just go ahead here without further distinction.
+ * there isn't one and the SASL mechanism needs it.
*/
- conn->password_needed = true;
- password = conn->connhost[conn->whichhost].password;
- if (password == NULL)
- password = conn->pgpass;
- if (password == NULL || password[0] == '\0')
+ if (conn->password_needed)
{
- appendPQExpBufferStr(&conn->errorMessage,
- PQnoPasswordSupplied);
- goto error;
+ password = conn->connhost[conn->whichhost].password;
+ if (password == NULL)
+ password = conn->pgpass;
+ if (password == NULL || password[0] == '\0')
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ PQnoPasswordSupplied);
+ goto error;
+ }
}
Assert(conn->sasl);
@@ -576,7 +588,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
goto oom_error;
/* Get the mechanism-specific Initial Client Response, if any */
- conn->sasl->exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state, false,
NULL, -1,
&initialresponse, &initialresponselen,
&done, &success);
@@ -657,7 +669,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
/* For safety and convenience, ensure the buffer is NULL-terminated. */
challenge[payloadlen] = '\0';
- conn->sasl->exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state, final,
challenge, payloadlen,
&output, &outputlen,
&done, &success);
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 1e4fcbff62..edc748fd3a 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -26,4 +26,7 @@ extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
extern const pg_sasl_mech pg_scram_mech;
extern char *pg_fe_scram_build_secret(const char *password);
+/* Mechanisms in fe-auth-oauth.c */
+extern const pg_sasl_mech pg_oauth_mech;
+
#endif /* FE_AUTH_H */
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 10d007582c..1d4bca9194 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -345,6 +345,23 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
"Target-Session-Attrs", "", 15, /* sizeof("prefer-standby") = 15 */
offsetof(struct pg_conn, target_session_attrs)},
+ /* OAuth v2 */
+ {"oauth_issuer", NULL, NULL, NULL,
+ "OAuth-Issuer", "", 40,
+ offsetof(struct pg_conn, oauth_issuer)},
+
+ {"oauth_client_id", NULL, NULL, NULL,
+ "OAuth-Client-ID", "", 40,
+ offsetof(struct pg_conn, oauth_client_id)},
+
+ {"oauth_client_secret", NULL, NULL, NULL,
+ "OAuth-Client-Secret", "", 40,
+ offsetof(struct pg_conn, oauth_client_secret)},
+
+ {"oauth_scope", NULL, NULL, NULL,
+ "OAuth-Scope", "", 15,
+ offsetof(struct pg_conn, oauth_scope)},
+
/* Terminating entry --- MUST BE LAST */
{NULL, NULL, NULL, NULL,
NULL, NULL, 0}
@@ -607,6 +624,7 @@ pqDropServerData(PGconn *conn)
conn->write_err_msg = NULL;
conn->be_pid = 0;
conn->be_key = 0;
+ /* conn->oauth_want_retry = false; TODO */
}
@@ -3356,6 +3374,16 @@ keep_going: /* We will come back to here until there is
/* Check to see if we should mention pgpassfile */
pgpassfileWarning(conn);
+#ifdef USE_OAUTH
+ if (conn->sasl == &pg_oauth_mech
+ && conn->oauth_want_retry)
+ {
+ /* TODO: only allow retry once */
+ need_new_connection = true;
+ goto keep_going;
+ }
+#endif
+
#ifdef ENABLE_GSS
/*
@@ -4130,6 +4158,16 @@ freePGconn(PGconn *conn)
free(conn->rowBuf);
if (conn->target_session_attrs)
free(conn->target_session_attrs);
+ if (conn->oauth_issuer)
+ free(conn->oauth_issuer);
+ if (conn->oauth_discovery_uri)
+ free(conn->oauth_discovery_uri);
+ if (conn->oauth_client_id)
+ free(conn->oauth_client_id);
+ if (conn->oauth_client_secret)
+ free(conn->oauth_client_secret);
+ if (conn->oauth_scope)
+ free(conn->oauth_scope);
termPQExpBuffer(&conn->errorMessage);
termPQExpBuffer(&conn->workBuffer);
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 25eaa231c5..b749c6c05d 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -340,7 +340,7 @@ typedef struct pg_conn_host
} pg_conn_host;
typedef void *(*pg_sasl_init_func)(PGconn *, const char *, const char *);
-typedef void (*pg_sasl_exchange_func)(void *, char *, int, char **, int *, bool *, bool *);
+typedef void (*pg_sasl_exchange_func)(void *, bool, char *, int, char **, int *, bool *, bool *);
typedef bool (*pg_sasl_channel_bound_func)(void *);
typedef void (*pg_sasl_free_func)(void *);
@@ -406,6 +406,14 @@ struct pg_conn
char *ssl_max_protocol_version; /* maximum TLS protocol version */
char *target_session_attrs; /* desired session properties */
+ /* OAuth v2 */
+ char *oauth_issuer; /* token issuer URL */
+ char *oauth_discovery_uri; /* URI of the issuer's discovery document */
+ char *oauth_client_id; /* client identifier */
+ char *oauth_client_secret; /* client secret */
+ char *oauth_scope; /* access token scope */
+ bool oauth_want_retry; /* should we retry on failure? */
+
/* Optional file to write trace info to */
FILE *Pfdebug;
int traceFlags;
--
2.25.1
0005-backend-add-OAUTHBEARER-SASL-mechanism.patchtext/x-patch; name=0005-backend-add-OAUTHBEARER-SASL-mechanism.patchDownload
From ee8e85d3416f381ba9d44f8d4a681e5006bd5b82 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 4 May 2021 16:21:11 -0700
Subject: [PATCH 5/7] backend: add OAUTHBEARER SASL mechanism
DO NOT USE THIS PROOF OF CONCEPT IN PRODUCTION.
Implement OAUTHBEARER (RFC 7628) on the server side. This adds a new
auth method, oauth, to pg_hba.
Because OAuth implementations vary so wildly, and bearer token
validation is heavily dependent on the issuing party, authn/z is done by
communicating with an external program: the oauth_validator_command.
This command must do the following:
1. Receive the bearer token by reading its contents from a file
descriptor passed from the server. (The numeric value of this
descriptor may be inserted into the oauth_validator_command using the
%f specifier.)
This MUST be the first action the command performs. The server will
not begin reading stdout from the command until the token has been
read in full, so if the command tries to print anything and hits a
buffer limit, the backend will deadlock and time out.
2. Validate the bearer token. The correct way to do this depends on the
issuer, but it generally involves either cryptographic operations to
prove that the token was issued by a trusted party, or the
presentation of the bearer token to some other party so that _it_ can
perform validation.
The command MUST maintain confidentiality of the bearer token, since
in most cases it can be used just like a password. (There are ways to
cryptographically bind tokens to client certificates, but they are
way beyond the scope of this commit message.)
If the token cannot be validated, the command must exit with a
non-zero status. Further authentication/authorization is pointless if
the bearer token wasn't issued by someone you trust.
3. Authenticate the user, authorize the user, or both:
a. To authenticate the user, use the bearer token to retrieve some
trusted identifier string for the end user. The exact process for
this is, again, issuer-dependent. The command should print the
authenticated identity string to stdout, followed by a newline.
If the user cannot be authenticated, the validator should not
print anything to stdout. It should also exit with a non-zero
status, unless the token may be used to authorize the connection
through some other means (see below).
On a success, the command may then exit with a zero success code.
By default, the server will then check to make sure the identity
string matches the role that is being used (or matches a usermap
entry, if one is in use).
b. To optionally authorize the user, in combination with the HBA
option trust_validator_authz=1 (see below), the validator simply
returns a zero exit code if the client should be allowed to
connect with its presented role (which can be passed to the
command using the %r specifier), or a non-zero code otherwise.
The hard part is in determining whether the given token truly
authorizes the client to use the given role, which must
unfortunately be left as an exercise to the reader.
This obviously requires some care, as a poorly implemented token
validator may silently open the entire database to anyone with a
bearer token. But it may be a more portable approach, since OAuth
is designed as an authorization framework, not an authentication
framework. For example, the user's bearer token could carry an
"allow_superuser_access" claim, which would authorize pseudonymous
database access as any role. It's then up to the OAuth system
administrators to ensure that allow_superuser_access is doled out
only to the proper users.
c. It's possible that the user can be successfully authenticated but
isn't authorized to connect. In this case, the command may print
the authenticated ID and then fail with a non-zero exit code.
(This makes it easier to see what's going on in the Postgres
logs.)
4. Token validators may optionally log to stderr. This will be printed
verbatim into the Postgres server logs.
The oauth method supports the following HBA options (but note that two
of them are not optional, since we have no way of choosing sensible
defaults):
issuer: Required. The URL of the OAuth issuing party, which the client
must contact to receive a bearer token.
Some real-world examples as of time of writing:
- https://accounts.google.com
- https://login.microsoft.com/[tenant-id]/v2.0
scope: Required. The OAuth scope(s) required for the server to
authenticate and/or authorize the user. This is heavily
deployment-specific, but a simple example is "openid email".
map: Optional. Specify a standard PostgreSQL user map; this works
the same as with other auth methods such as peer. If a map is
not specified, the user ID returned by the token validator
must exactly match the role that's being requested (but see
trust_validator_authz, below).
trust_validator_authz:
Optional. When set to 1, this allows the token validator to
take full control of the authorization process. Standard user
mapping is skipped: if the validator command succeeds, the
client is allowed to connect under its desired role and no
further checks are done.
Unlike the client, servers support OAuth without needing to be built
against libiddawc (since the responsibility for "speaking" OAuth/OIDC
correctly is delegated entirely to the oauth_validator_command).
Several TODOs:
- port to platforms other than "modern Linux"
- overhaul the communication with oauth_validator_command, which is
currently a bad hack on OpenPipeStream()
- implement more sanity checks on the OAUTHBEARER message format and
tokens sent by the client
- implement more helpful handling of HBA misconfigurations
- properly interpolate JSON when generating error responses
- use logdetail during auth failures
- deal with role names that can't be safely passed to system() without
shell-escaping
- allow passing the configured issuer to the oauth_validator_command, to
deal with multi-issuer setups
- ...and more.
---
src/backend/libpq/Makefile | 1 +
src/backend/libpq/auth-oauth.c | 797 +++++++++++++++++++++++++++++++++
src/backend/libpq/auth-scram.c | 2 +
src/backend/libpq/auth.c | 43 +-
src/backend/libpq/hba.c | 29 +-
src/backend/utils/misc/guc.c | 12 +
src/include/libpq/auth.h | 1 +
src/include/libpq/hba.h | 8 +-
src/include/libpq/oauth.h | 24 +
src/include/libpq/sasl.h | 26 ++
10 files changed, 915 insertions(+), 28 deletions(-)
create mode 100644 src/backend/libpq/auth-oauth.c
create mode 100644 src/include/libpq/oauth.h
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 8d1d16b0fc..40f2c50c3c 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,6 +15,7 @@ include $(top_builddir)/src/Makefile.global
# be-fsstubs is here for historical reasons, probably belongs elsewhere
OBJS = \
+ auth-oauth.o \
auth-scram.o \
auth.o \
be-fsstubs.o \
diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c
new file mode 100644
index 0000000000..b2b9d56e7c
--- /dev/null
+++ b/src/backend/libpq/auth-oauth.c
@@ -0,0 +1,797 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth-oauth.c
+ * Server-side implementation of the SASL OAUTHBEARER mechanism.
+ *
+ * See the following RFC for more details:
+ * - RFC 7628: https://tools.ietf.org/html/rfc7628
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/backend/libpq/auth-oauth.c
+ *
+ *-------------------------------------------------------------------------
+ */
+#include "postgres.h"
+
+#include <unistd.h>
+#include <fcntl.h>
+
+#include "common/oauth-common.h"
+#include "lib/stringinfo.h"
+#include "libpq/auth.h"
+#include "libpq/hba.h"
+#include "libpq/oauth.h"
+#include "libpq/sasl.h"
+#include "storage/fd.h"
+
+/* GUC */
+char *oauth_validator_command;
+
+static void oauth_get_mechanisms(Port *port, StringInfo buf);
+static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
+static int oauth_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, char **logdetail);
+
+/* Mechanism declaration */
+const pg_be_sasl_mech pg_be_oauth_mech = {
+ oauth_get_mechanisms,
+ oauth_init,
+ oauth_exchange,
+
+ PG_MAX_AUTH_TOKEN_LENGTH,
+};
+
+
+typedef enum
+{
+ OAUTH_STATE_INIT = 0,
+ OAUTH_STATE_ERROR,
+ OAUTH_STATE_FINISHED,
+} oauth_state;
+
+struct oauth_ctx
+{
+ oauth_state state;
+ Port *port;
+ const char *issuer;
+ const char *scope;
+};
+
+static char *sanitize_char(char c);
+static char *parse_kvpairs_for_auth(char **input);
+static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
+static bool validate(Port *port, const char *auth, char **logdetail);
+static bool run_validator_command(Port *port, const char *token);
+static bool check_exit(FILE **fh, const char *command);
+static bool unset_cloexec(int fd);
+static bool username_ok_for_shell(const char *username);
+
+#define KVSEP 0x01
+#define AUTH_KEY "auth"
+#define BEARER_SCHEME "Bearer "
+
+static void
+oauth_get_mechanisms(Port *port, StringInfo buf)
+{
+ /* Only OAUTHBEARER is supported. */
+ appendStringInfoString(buf, OAUTHBEARER_NAME);
+ appendStringInfoChar(buf, '\0');
+}
+
+static void *
+oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
+{
+ struct oauth_ctx *ctx;
+
+ if (strcmp(selected_mech, OAUTHBEARER_NAME))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("client selected an invalid SASL authentication mechanism")));
+
+ ctx = palloc0(sizeof(*ctx));
+
+ ctx->state = OAUTH_STATE_INIT;
+ ctx->port = port;
+
+ Assert(port->hba);
+ ctx->issuer = port->hba->oauth_issuer;
+ ctx->scope = port->hba->oauth_scope;
+
+ return ctx;
+}
+
+static int
+oauth_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, char **logdetail)
+{
+ char *p;
+ char cbind_flag;
+ char *auth;
+
+ struct oauth_ctx *ctx = opaq;
+
+ *output = NULL;
+ *outputlen = -1;
+
+ /*
+ * If the client didn't include an "Initial Client Response" in the
+ * SASLInitialResponse message, send an empty challenge, to which the
+ * client will respond with the same data that usually comes in the
+ * Initial Client Response.
+ */
+ if (input == NULL)
+ {
+ Assert(ctx->state == OAUTH_STATE_INIT);
+
+ *output = pstrdup("");
+ *outputlen = 0;
+ return SASL_EXCHANGE_CONTINUE;
+ }
+
+ /*
+ * Check that the input length agrees with the string length of the input.
+ */
+ if (inputlen == 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("The message is empty.")));
+ if (inputlen != strlen(input))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message length does not match input length.")));
+
+ switch (ctx->state)
+ {
+ case OAUTH_STATE_INIT:
+ /* Handle this case below. */
+ break;
+
+ case OAUTH_STATE_ERROR:
+ /*
+ * Only one response is valid for the client during authentication
+ * failure: a single kvsep.
+ */
+ if (inputlen != 1 || *input != KVSEP)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Client did not send a kvsep response.")));
+
+ /* The (failed) handshake is now complete. */
+ ctx->state = OAUTH_STATE_FINISHED;
+ return SASL_EXCHANGE_FAILURE;
+
+ default:
+ elog(ERROR, "invalid OAUTHBEARER exchange state");
+ return SASL_EXCHANGE_FAILURE;
+ }
+
+ /* Handle the client's initial message. */
+ p = strdup(input);
+
+ /*
+ * OAUTHBEARER does not currently define a channel binding (so there is no
+ * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a 'y'
+ * specifier purely for the remote chance that a future specification could
+ * define one; then future clients can still interoperate with this server
+ * implementation. 'n' is the expected case.
+ */
+ cbind_flag = *p;
+ switch (cbind_flag)
+ {
+ case 'p':
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data.")));
+ break;
+
+ case 'y': /* fall through */
+ case 'n':
+ p++;
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Comma expected, but found character \"%s\".",
+ sanitize_char(*p))));
+ p++;
+ break;
+
+ default:
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Unexpected channel-binding flag \"%s\".",
+ sanitize_char(cbind_flag))));
+ }
+
+ /*
+ * Forbid optional authzid (authorization identity). We don't support it.
+ */
+ if (*p == 'a')
+ ereport(ERROR,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("client uses authorization identity, but it is not supported")));
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Unexpected attribute \"%s\" in client-first-message.",
+ sanitize_char(*p))));
+ p++;
+
+ /* All remaining fields are separated by the RFC's kvsep (\x01). */
+ if (*p != KVSEP)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Key-value separator expected, but found character \"%s\".",
+ sanitize_char(*p))));
+ p++;
+
+ auth = parse_kvpairs_for_auth(&p);
+ if (!auth)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message does not contain an auth value.")));
+
+ /* We should be at the end of our message. */
+ if (*p)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains additional data after the final terminator.")));
+
+ if (!validate(ctx->port, auth, logdetail))
+ {
+ generate_error_response(ctx, output, outputlen);
+
+ ctx->state = OAUTH_STATE_ERROR;
+ return SASL_EXCHANGE_CONTINUE;
+ }
+
+ ctx->state = OAUTH_STATE_FINISHED;
+ return SASL_EXCHANGE_SUCCESS;
+}
+
+/*
+ * Convert an arbitrary byte to printable form. For error messages.
+ *
+ * If it's a printable ASCII character, print it as a single character.
+ * otherwise, print it in hex.
+ *
+ * The returned pointer points to a static buffer.
+ */
+static char *
+sanitize_char(char c)
+{
+ static char buf[5];
+
+ if (c >= 0x21 && c <= 0x7E)
+ snprintf(buf, sizeof(buf), "'%c'", c);
+ else
+ snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
+ return buf;
+}
+
+/*
+ * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is
+ * found, its value is returned.
+ */
+static char *
+parse_kvpairs_for_auth(char **input)
+{
+ char *pos = *input;
+ char *auth = NULL;
+
+ /*
+ * The relevant ABNF, from Sec. 3.1:
+ *
+ * kvsep = %x01
+ * key = 1*(ALPHA)
+ * value = *(VCHAR / SP / HTAB / CR / LF )
+ * kvpair = key "=" value kvsep
+ * ;;gs2-header = See RFC 5801
+ * client-resp = (gs2-header kvsep *kvpair kvsep) / kvsep
+ *
+ * By the time we reach this code, the gs2-header and initial kvsep have
+ * already been validated. We start at the beginning of the first kvpair.
+ */
+
+ while (*pos)
+ {
+ char *end;
+ char *sep;
+ char *key;
+ char *value;
+
+ /*
+ * Find the end of this kvpair. Note that input is null-terminated by
+ * the SASL code, so the strchr() is bounded.
+ */
+ end = strchr(pos, KVSEP);
+ if (!end)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains an unterminated key/value pair.")));
+ *end = '\0';
+
+ if (pos == end)
+ {
+ /* Empty kvpair, signifying the end of the list. */
+ *input = pos + 1;
+ return auth;
+ }
+
+ /*
+ * Find the end of the key name.
+ *
+ * TODO further validate the key/value grammar? empty keys, bad chars...
+ */
+ sep = strchr(pos, '=');
+ if (!sep)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains a key without a value.")));
+ *sep = '\0';
+
+ /* Both key and value are now safely terminated. */
+ key = pos;
+ value = sep + 1;
+
+ if (!strcmp(key, AUTH_KEY))
+ {
+ if (auth)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains multiple auth values.")));
+
+ auth = value;
+ }
+ else
+ {
+ /*
+ * The RFC also defines the host and port keys, but they are not
+ * required for OAUTHBEARER and we do not use them. Also, per
+ * Sec. 3.1, any key/value pairs we don't recognize must be ignored.
+ */
+ }
+
+ /* Move to the next pair. */
+ pos = end + 1;
+ }
+
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message did not contain a final terminator.")));
+
+ return NULL; /* unreachable */
+}
+
+static void
+generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
+{
+ StringInfoData buf;
+
+ /*
+ * The admin needs to set an issuer and scope for OAuth to work. There's not
+ * really a way to hide this from the user, either, because we can't choose
+ * a "default" issuer, so be honest in the failure message.
+ *
+ * TODO: see if there's a better place to fail, earlier than this.
+ */
+ if (!ctx->issuer || !ctx->scope)
+ ereport(FATAL,
+ (errcode(ERRCODE_INTERNAL_ERROR),
+ errmsg("OAuth is not properly configured for this user"),
+ errdetail_log("The issuer and scope parameters must be set in pg_hba.conf.")));
+
+
+ initStringInfo(&buf);
+
+ /*
+ * TODO: JSON escaping
+ */
+ appendStringInfo(&buf,
+ "{ "
+ "\"status\": \"invalid_token\", "
+ "\"openid-configuration\": \"%s/.well-known/openid-configuration\","
+ "\"scope\": \"%s\" "
+ "}",
+ ctx->issuer, ctx->scope);
+
+ *output = buf.data;
+ *outputlen = buf.len;
+}
+
+static bool
+validate(Port *port, const char *auth, char **logdetail)
+{
+ static const char * const b64_set = "abcdefghijklmnopqrstuvwxyz"
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "0123456789-._~+/";
+
+ const char *token;
+ size_t span;
+ int ret;
+
+ /* TODO: handle logdetail when the test framework can check it */
+
+ /*
+ * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec.
+ * 2.1:
+ *
+ * b64token = 1*( ALPHA / DIGIT /
+ * "-" / "." / "_" / "~" / "+" / "/" ) *"="
+ * credentials = "Bearer" 1*SP b64token
+ *
+ * The "credentials" construction is what we receive in our auth value.
+ *
+ * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization
+ * header format; RFC 7235 Sec. 2), the "Bearer" scheme string must be
+ * compared case-insensitively. (This is not mentioned in RFC 6750, but it's
+ * pointed out in RFC 7628 Sec. 4.)
+ *
+ * TODO: handle the Authorization spec, RFC 7235 Sec. 2.1.
+ */
+ if (strncasecmp(auth, BEARER_SCHEME, strlen(BEARER_SCHEME)))
+ return false;
+
+ /* Pull the bearer token out of the auth value. */
+ token = auth + strlen(BEARER_SCHEME);
+
+ /* Swallow any additional spaces. */
+ while (*token == ' ')
+ token++;
+
+ /*
+ * Before invoking the validator command, sanity-check the token format to
+ * avoid any injection attacks later in the chain. Invalid formats are
+ * technically a protocol violation, but don't reflect any information about
+ * the sensitive Bearer token back to the client; log at COMMERROR instead.
+ */
+
+ /* Tokens must not be empty. */
+ if (!*token)
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Bearer token is empty.")));
+ return false;
+ }
+
+ /*
+ * Make sure the token contains only allowed characters. Tokens may end with
+ * any number of '=' characters.
+ */
+ span = strspn(token, b64_set);
+ while (token[span] == '=')
+ span++;
+
+ if (token[span] != '\0')
+ {
+ /*
+ * This error message could be more helpful by printing the problematic
+ * character(s), but that'd be a bit like printing a piece of someone's
+ * password into the logs.
+ */
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Bearer token is not in the correct format.")));
+ return false;
+ }
+
+ /* Have the validator check the token. */
+ if (!run_validator_command(port, token))
+ return false;
+
+ if (port->hba->oauth_skip_usermap)
+ {
+ /*
+ * If the validator is our authorization authority, we're done.
+ * Authentication may or may not have been performed depending on the
+ * validator implementation; all that matters is that the validator says
+ * the user can log in with the target role.
+ */
+ return true;
+ }
+
+ /* Make sure the validator authenticated the user. */
+ if (!port->authn_id)
+ {
+ /* TODO: use logdetail; reduce message duplication */
+ ereport(LOG,
+ (errmsg("OAuth bearer authentication failed for user \"%s\": validator provided no identity",
+ port->user_name)));
+ return false;
+ }
+
+ /* Finally, check the user map. */
+ ret = check_usermap(port->hba->usermap, port->user_name, port->authn_id,
+ false);
+ return (ret == STATUS_OK);
+}
+
+static bool
+run_validator_command(Port *port, const char *token)
+{
+ bool success = false;
+ int rc;
+ int pipefd[2];
+ int rfd = -1;
+ int wfd = -1;
+
+ StringInfoData command = { 0 };
+ char *p;
+ FILE *fh = NULL;
+
+ ssize_t written;
+ char *line = NULL;
+ size_t size = 0;
+ ssize_t len;
+
+ Assert(oauth_validator_command);
+
+ if (!oauth_validator_command[0])
+ {
+ ereport(COMMERROR,
+ (errmsg("oauth_validator_command is not set"),
+ errhint("To allow OAuth authenticated connections, set "
+ "oauth_validator_command in postgresql.conf.")));
+ return false;
+ }
+
+ /*
+ * Since popen() is unidirectional, open up a pipe for the other direction.
+ * Use CLOEXEC to ensure that our write end doesn't accidentally get copied
+ * into child processes, which would prevent us from closing it cleanly.
+ *
+ * XXX this is ugly. We should just read from the child process's stdout,
+ * but that's a lot more code.
+ * XXX by bypassing the popen API, we open the potential of process
+ * deadlock. Clearly document child process requirements (i.e. the child
+ * MUST read all data off of the pipe before writing anything).
+ * TODO: port to Windows using _pipe().
+ */
+ rc = pipe2(pipefd, O_CLOEXEC);
+ if (rc < 0)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not create child pipe: %m")));
+ return false;
+ }
+
+ rfd = pipefd[0];
+ wfd = pipefd[1];
+
+ /* Allow the read pipe be passed to the child. */
+ if (!unset_cloexec(rfd))
+ {
+ /* error message was already logged */
+ goto cleanup;
+ }
+
+ /*
+ * Construct the command, substituting any recognized %-specifiers:
+ *
+ * %f: the file descriptor of the input pipe
+ * %r: the role that the client wants to assume (port->user_name)
+ * %%: a literal '%'
+ */
+ initStringInfo(&command);
+
+ for (p = oauth_validator_command; *p; p++)
+ {
+ if (p[0] == '%')
+ {
+ switch (p[1])
+ {
+ case 'f':
+ appendStringInfo(&command, "%d", rfd);
+ p++;
+ break;
+ case 'r':
+ /*
+ * TODO: decide how this string should be escaped. The role
+ * is controlled by the client, so if we don't escape it,
+ * command injections are inevitable.
+ *
+ * This is probably an indication that the role name needs
+ * to be communicated to the validator process in some other
+ * way. For this proof of concept, just be incredibly strict
+ * about the characters that are allowed in user names.
+ */
+ if (!username_ok_for_shell(port->user_name))
+ goto cleanup;
+
+ appendStringInfoString(&command, port->user_name);
+ p++;
+ break;
+ case '%':
+ appendStringInfoChar(&command, '%');
+ p++;
+ break;
+ default:
+ appendStringInfoChar(&command, p[0]);
+ }
+ }
+ else
+ appendStringInfoChar(&command, p[0]);
+ }
+
+ /* Execute the command. */
+ fh = OpenPipeStream(command.data, "re");
+ /* TODO: handle failures */
+
+ /* We don't need the read end of the pipe anymore. */
+ close(rfd);
+ rfd = -1;
+
+ /* Give the command the token to validate. */
+ written = write(wfd, token, strlen(token));
+ if (written != strlen(token))
+ {
+ /* TODO must loop for short writes, EINTR et al */
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not write token to child pipe: %m")));
+ goto cleanup;
+ }
+
+ close(wfd);
+ wfd = -1;
+
+ /*
+ * Read the command's response.
+ *
+ * TODO: getline() is probably too new to use, unfortunately.
+ * TODO: loop over all lines
+ */
+ if ((len = getline(&line, &size, fh)) >= 0)
+ {
+ /* TODO: fail if the authn_id doesn't end with a newline */
+ if (len > 0)
+ line[len - 1] = '\0';
+
+ set_authn_id(port, line);
+ }
+ else if (ferror(fh))
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not read from command \"%s\": %m",
+ command.data)));
+ goto cleanup;
+ }
+
+ /* Make sure the command exits cleanly. */
+ if (!check_exit(&fh, command.data))
+ {
+ /* error message already logged */
+ goto cleanup;
+ }
+
+ /* Done. */
+ success = true;
+
+cleanup:
+ if (line)
+ free(line);
+
+ /*
+ * In the successful case, the pipe fds are already closed. For the error
+ * case, always close out the pipe before waiting for the command, to
+ * prevent deadlock.
+ */
+ if (rfd >= 0)
+ close(rfd);
+ if (wfd >= 0)
+ close(wfd);
+
+ if (fh)
+ {
+ Assert(!success);
+ check_exit(&fh, command.data);
+ }
+
+ if (command.data)
+ pfree(command.data);
+
+ return success;
+}
+
+static bool
+check_exit(FILE **fh, const char *command)
+{
+ int rc;
+
+ rc = ClosePipeStream(*fh);
+ *fh = NULL;
+
+ if (rc == -1)
+ {
+ /* pclose() itself failed. */
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not close pipe to command \"%s\": %m",
+ command)));
+ }
+ else if (rc != 0)
+ {
+ char *reason = wait_result_to_str(rc);
+
+ ereport(COMMERROR,
+ (errmsg("failed to execute command \"%s\": %s",
+ command, reason)));
+
+ pfree(reason);
+ }
+
+ return (rc == 0);
+}
+
+static bool
+unset_cloexec(int fd)
+{
+ int flags;
+ int rc;
+
+ flags = fcntl(fd, F_GETFD);
+ if (flags == -1)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not get fd flags for child pipe: %m")));
+ return false;
+ }
+
+ rc = fcntl(fd, F_SETFD, flags & ~FD_CLOEXEC);
+ if (rc < 0)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not unset FD_CLOEXEC for child pipe: %m")));
+ return false;
+ }
+
+ return true;
+}
+
+/*
+ * XXX This should go away eventually and be replaced with either a proper
+ * escape or a different strategy for communication with the validator command.
+ */
+static bool
+username_ok_for_shell(const char *username)
+{
+ /* This set is borrowed from fe_utils' appendShellStringNoError(). */
+ static const char * const allowed = "abcdefghijklmnopqrstuvwxyz"
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "0123456789-_./:";
+ size_t span;
+
+ Assert(username && username[0]); /* should have already been checked */
+
+ span = strspn(username, allowed);
+ if (username[span] != '\0')
+ {
+ ereport(COMMERROR,
+ (errmsg("PostgreSQL user name contains unsafe characters and cannot be passed to the OAuth validator")));
+ return false;
+ }
+
+ return true;
+}
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index db3ca75a60..9e4482dc27 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -118,6 +118,8 @@ const pg_be_sasl_mech pg_be_scram_mech = {
scram_get_mechanisms,
scram_init,
scram_exchange,
+
+ PG_MAX_SASL_MESSAGE_LENGTH,
};
/*
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index e20740a7c5..354c7b0fc8 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -29,6 +29,7 @@
#include "libpq/auth.h"
#include "libpq/crypt.h"
#include "libpq/libpq.h"
+#include "libpq/oauth.h"
#include "libpq/pqformat.h"
#include "libpq/sasl.h"
#include "libpq/scram.h"
@@ -49,7 +50,6 @@ static void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
int extralen);
static void auth_failed(Port *port, int status, char *logdetail);
static char *recv_password_packet(Port *port);
-static void set_authn_id(Port *port, const char *id);
/*----------------------------------------------------------------
* SASL common authentication
@@ -215,29 +215,12 @@ static int CheckRADIUSAuth(Port *port);
static int PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd);
-/*
- * Maximum accepted size of GSS and SSPI authentication tokens.
- * We also use this as a limit on ordinary password packet lengths.
- *
- * Kerberos tickets are usually quite small, but the TGTs issued by Windows
- * domain controllers include an authorization field known as the Privilege
- * Attribute Certificate (PAC), which contains the user's Windows permissions
- * (group memberships etc.). The PAC is copied into all tickets obtained on
- * the basis of this TGT (even those issued by Unix realms which the Windows
- * realm trusts), and can be several kB in size. The maximum token size
- * accepted by Windows systems is determined by the MaxAuthToken Windows
- * registry setting. Microsoft recommends that it is not set higher than
- * 65535 bytes, so that seems like a reasonable limit for us as well.
+/*----------------------------------------------------------------
+ * OAuth v2 Bearer Authentication
+ *----------------------------------------------------------------
*/
-#define PG_MAX_AUTH_TOKEN_LENGTH 65535
+static int CheckOAuthBearer(Port *port);
-/*
- * Maximum accepted size of SASL messages.
- *
- * The messages that the server or libpq generate are much smaller than this,
- * but have some headroom.
- */
-#define PG_MAX_SASL_MESSAGE_LENGTH 1024
/*----------------------------------------------------------------
* Global authentication functions
@@ -327,6 +310,9 @@ auth_failed(Port *port, int status, char *logdetail)
case uaRADIUS:
errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
break;
+ case uaOAuth:
+ errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
+ break;
default:
errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method");
break;
@@ -361,7 +347,7 @@ auth_failed(Port *port, int status, char *logdetail)
* lifetime of the Port, so it is safe to pass a string that is managed by an
* external library.
*/
-static void
+void
set_authn_id(Port *port, const char *id)
{
Assert(id);
@@ -646,6 +632,9 @@ ClientAuthentication(Port *port)
case uaTrust:
status = STATUS_OK;
break;
+ case uaOAuth:
+ status = CheckOAuthBearer(port);
+ break;
}
if ((status == STATUS_OK && port->hba->clientcert == clientCertFull)
@@ -973,7 +962,7 @@ SASL_exchange(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
/* Get the actual SASL message */
initStringInfo(&buf);
- if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
+ if (pq_getmessage(&buf, mech->max_message_length))
{
/* EOF - pq_getmessage already logged error */
pfree(buf.data);
@@ -3495,3 +3484,9 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
}
} /* while (true) */
}
+
+static int
+CheckOAuthBearer(Port *port)
+{
+ return SASL_exchange(&pg_be_oauth_mech, port, NULL, NULL);
+}
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 3be8778d21..98147700dd 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -134,7 +134,8 @@ static const char *const UserAuthName[] =
"ldap",
"cert",
"radius",
- "peer"
+ "peer",
+ "oauth",
};
@@ -1399,6 +1400,8 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
#endif
else if (strcmp(token->string, "radius") == 0)
parsedline->auth_method = uaRADIUS;
+ else if (strcmp(token->string, "oauth") == 0)
+ parsedline->auth_method = uaOAuth;
else
{
ereport(elevel,
@@ -1713,8 +1716,9 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->auth_method != uaPeer &&
hbaline->auth_method != uaGSS &&
hbaline->auth_method != uaSSPI &&
+ hbaline->auth_method != uaOAuth &&
hbaline->auth_method != uaCert)
- INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, and cert"));
+ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, oauth, and cert"));
hbaline->usermap = pstrdup(val);
}
else if (strcmp(name, "clientcert") == 0)
@@ -2098,6 +2102,27 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->radiusidentifiers = parsed_identifiers;
hbaline->radiusidentifiers_s = pstrdup(val);
}
+ else if (strcmp(name, "issuer") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("issuer", gettext_noop("oauth"));
+ hbaline->oauth_issuer = pstrdup(val);
+ }
+ else if (strcmp(name, "scope") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("scope", gettext_noop("oauth"));
+ hbaline->oauth_scope = pstrdup(val);
+ }
+ else if (strcmp(name, "trust_validator_authz") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("trust_validator_authz", gettext_noop("oauth"));
+ if (strcmp(val, "1") == 0)
+ hbaline->oauth_skip_usermap = true;
+ else
+ hbaline->oauth_skip_usermap = false;
+ }
else
{
ereport(elevel,
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index 68b62d523d..1ef6b3c41e 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -56,6 +56,7 @@
#include "libpq/auth.h"
#include "libpq/libpq.h"
#include "libpq/pqformat.h"
+#include "libpq/oauth.h"
#include "miscadmin.h"
#include "optimizer/cost.h"
#include "optimizer/geqo.h"
@@ -4587,6 +4588,17 @@ static struct config_string ConfigureNamesString[] =
check_backtrace_functions, assign_backtrace_functions, NULL
},
+ {
+ {"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH,
+ gettext_noop("Command to validate OAuth v2 bearer tokens."),
+ NULL,
+ GUC_SUPERUSER_ONLY
+ },
+ &oauth_validator_command,
+ "",
+ NULL, NULL, NULL
+ },
+
/* End-of-list marker */
{
{NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 3610fae3ff..785cc5d16f 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -21,6 +21,7 @@ extern bool pg_krb_caseins_users;
extern char *pg_krb_realm;
extern void ClientAuthentication(Port *port);
+extern void set_authn_id(Port *port, const char *id);
/* Hook for plugins to get control in ClientAuthentication() */
typedef void (*ClientAuthentication_hook_type) (Port *, int);
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 8d9f3821b1..441dd5623e 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -38,8 +38,9 @@ typedef enum UserAuth
uaLDAP,
uaCert,
uaRADIUS,
- uaPeer
-#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */
+ uaPeer,
+ uaOAuth
+#define USER_AUTH_LAST uaOAuth /* Must be last value of this enum */
} UserAuth;
/*
@@ -120,6 +121,9 @@ typedef struct HbaLine
char *radiusidentifiers_s;
List *radiusports;
char *radiusports_s;
+ char *oauth_issuer;
+ char *oauth_scope;
+ bool oauth_skip_usermap;
} HbaLine;
typedef struct IdentLine
diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h
new file mode 100644
index 0000000000..870e426af1
--- /dev/null
+++ b/src/include/libpq/oauth.h
@@ -0,0 +1,24 @@
+/*-------------------------------------------------------------------------
+ *
+ * oauth.h
+ * Interface to libpq/auth-oauth.c
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/libpq/oauth.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef PG_OAUTH_H
+#define PG_OAUTH_H
+
+#include "libpq/libpq-be.h"
+#include "libpq/sasl.h"
+
+extern char *oauth_validator_command;
+
+/* Implementation */
+extern const pg_be_sasl_mech pg_be_oauth_mech;
+
+#endif /* PG_OAUTH_H */
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
index 8c9c9983d4..f1341d0c54 100644
--- a/src/include/libpq/sasl.h
+++ b/src/include/libpq/sasl.h
@@ -19,6 +19,30 @@
#define SASL_EXCHANGE_SUCCESS 1
#define SASL_EXCHANGE_FAILURE 2
+/*
+ * Maximum accepted size of GSS and SSPI authentication tokens.
+ * We also use this as a limit on ordinary password packet lengths.
+ *
+ * Kerberos tickets are usually quite small, but the TGTs issued by Windows
+ * domain controllers include an authorization field known as the Privilege
+ * Attribute Certificate (PAC), which contains the user's Windows permissions
+ * (group memberships etc.). The PAC is copied into all tickets obtained on
+ * the basis of this TGT (even those issued by Unix realms which the Windows
+ * realm trusts), and can be several kB in size. The maximum token size
+ * accepted by Windows systems is determined by the MaxAuthToken Windows
+ * registry setting. Microsoft recommends that it is not set higher than
+ * 65535 bytes, so that seems like a reasonable limit for us as well.
+ */
+#define PG_MAX_AUTH_TOKEN_LENGTH 65535
+
+/*
+ * Maximum accepted size of SASL messages.
+ *
+ * The messages that the server or libpq generate are much smaller than this,
+ * but have some headroom.
+ */
+#define PG_MAX_SASL_MESSAGE_LENGTH 1024
+
/* Backend mechanism API */
typedef void (*pg_be_sasl_mechanism_func)(Port *, StringInfo);
typedef void *(*pg_be_sasl_init_func)(Port *, const char *, const char *);
@@ -29,6 +53,8 @@ typedef struct
pg_be_sasl_mechanism_func get_mechanisms;
pg_be_sasl_init_func init;
pg_be_sasl_exchange_func exchange;
+
+ int max_message_length;
} pg_be_sasl_mech;
#endif /* PG_SASL_H */
--
2.25.1
0006-Add-a-very-simple-authn_id-extension.patchtext/x-patch; name=0006-Add-a-very-simple-authn_id-extension.patchDownload
From e468be7ff7d19645aeb77bef21a383960a47731e Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 18 May 2021 15:01:29 -0700
Subject: [PATCH 6/7] Add a very simple authn_id extension
...for retrieving the authn_id from the server in tests.
---
contrib/authn_id/Makefile | 19 +++++++++++++++++++
contrib/authn_id/authn_id--1.0.sql | 8 ++++++++
contrib/authn_id/authn_id.c | 28 ++++++++++++++++++++++++++++
contrib/authn_id/authn_id.control | 5 +++++
4 files changed, 60 insertions(+)
create mode 100644 contrib/authn_id/Makefile
create mode 100644 contrib/authn_id/authn_id--1.0.sql
create mode 100644 contrib/authn_id/authn_id.c
create mode 100644 contrib/authn_id/authn_id.control
diff --git a/contrib/authn_id/Makefile b/contrib/authn_id/Makefile
new file mode 100644
index 0000000000..46026358e0
--- /dev/null
+++ b/contrib/authn_id/Makefile
@@ -0,0 +1,19 @@
+# contrib/authn_id/Makefile
+
+MODULE_big = authn_id
+OBJS = authn_id.o
+
+EXTENSION = authn_id
+DATA = authn_id--1.0.sql
+PGFILEDESC = "authn_id - information about the authenticated user"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = contrib/authn_id
+top_builddir = ../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/contrib/authn_id/authn_id--1.0.sql b/contrib/authn_id/authn_id--1.0.sql
new file mode 100644
index 0000000000..af2a4d3991
--- /dev/null
+++ b/contrib/authn_id/authn_id--1.0.sql
@@ -0,0 +1,8 @@
+/* contrib/authn_id/authn_id--1.0.sql */
+
+-- complain if script is sourced in psql, rather than via CREATE EXTENSION
+\echo Use "CREATE EXTENSION authn_id" to load this file. \quit
+
+CREATE FUNCTION authn_id() RETURNS text
+AS 'MODULE_PATHNAME', 'authn_id'
+LANGUAGE C IMMUTABLE;
diff --git a/contrib/authn_id/authn_id.c b/contrib/authn_id/authn_id.c
new file mode 100644
index 0000000000..0fecac36a8
--- /dev/null
+++ b/contrib/authn_id/authn_id.c
@@ -0,0 +1,28 @@
+/*
+ * Extension to expose the current user's authn_id.
+ *
+ * contrib/authn_id/authn_id.c
+ */
+
+#include "postgres.h"
+
+#include "fmgr.h"
+#include "libpq/libpq-be.h"
+#include "miscadmin.h"
+#include "utils/builtins.h"
+
+PG_MODULE_MAGIC;
+
+PG_FUNCTION_INFO_V1(authn_id);
+
+/*
+ * Returns the current user's authenticated identity.
+ */
+Datum
+authn_id(PG_FUNCTION_ARGS)
+{
+ if (!MyProcPort->authn_id)
+ PG_RETURN_NULL();
+
+ PG_RETURN_TEXT_P(cstring_to_text(MyProcPort->authn_id));
+}
diff --git a/contrib/authn_id/authn_id.control b/contrib/authn_id/authn_id.control
new file mode 100644
index 0000000000..e0f9e06bed
--- /dev/null
+++ b/contrib/authn_id/authn_id.control
@@ -0,0 +1,5 @@
+# authn_id extension
+comment = 'current user identity'
+default_version = '1.0'
+module_pathname = '$libdir/authn_id'
+relocatable = true
--
2.25.1
0007-Add-pytest-suite-for-OAuth.patchtext/x-patch; name=0007-Add-pytest-suite-for-OAuth.patchDownload
From 896da918cfcd16bcb119090914f687b3e905d865 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Fri, 4 Jun 2021 09:06:38 -0700
Subject: [PATCH 7/7] Add pytest suite for OAuth
Requires Python 3; on the first run of `make installcheck` the
dependencies will be installed into ./venv for you. See the README for
more details.
---
src/test/python/.gitignore | 2 +
src/test/python/Makefile | 33 +
src/test/python/README | 49 +
src/test/python/client/__init__.py | 0
src/test/python/client/conftest.py | 126 +++
src/test/python/client/test_client.py | 180 ++++
src/test/python/client/test_oauth.py | 936 ++++++++++++++++++
src/test/python/pq3.py | 727 ++++++++++++++
src/test/python/pytest.ini | 4 +
src/test/python/requirements.txt | 7 +
src/test/python/server/__init__.py | 0
src/test/python/server/conftest.py | 45 +
src/test/python/server/test_oauth.py | 1012 ++++++++++++++++++++
src/test/python/server/test_server.py | 21 +
src/test/python/server/validate_bearer.py | 101 ++
src/test/python/server/validate_reflect.py | 34 +
src/test/python/test_internals.py | 138 +++
src/test/python/test_pq3.py | 558 +++++++++++
src/test/python/tls.py | 195 ++++
19 files changed, 4168 insertions(+)
create mode 100644 src/test/python/.gitignore
create mode 100644 src/test/python/Makefile
create mode 100644 src/test/python/README
create mode 100644 src/test/python/client/__init__.py
create mode 100644 src/test/python/client/conftest.py
create mode 100644 src/test/python/client/test_client.py
create mode 100644 src/test/python/client/test_oauth.py
create mode 100644 src/test/python/pq3.py
create mode 100644 src/test/python/pytest.ini
create mode 100644 src/test/python/requirements.txt
create mode 100644 src/test/python/server/__init__.py
create mode 100644 src/test/python/server/conftest.py
create mode 100644 src/test/python/server/test_oauth.py
create mode 100644 src/test/python/server/test_server.py
create mode 100755 src/test/python/server/validate_bearer.py
create mode 100755 src/test/python/server/validate_reflect.py
create mode 100644 src/test/python/test_internals.py
create mode 100644 src/test/python/test_pq3.py
create mode 100644 src/test/python/tls.py
diff --git a/src/test/python/.gitignore b/src/test/python/.gitignore
new file mode 100644
index 0000000000..0e8f027b2e
--- /dev/null
+++ b/src/test/python/.gitignore
@@ -0,0 +1,2 @@
+__pycache__/
+/venv/
diff --git a/src/test/python/Makefile b/src/test/python/Makefile
new file mode 100644
index 0000000000..515a995106
--- /dev/null
+++ b/src/test/python/Makefile
@@ -0,0 +1,33 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+# Only Python 3 is supported, but if it's named something different on your
+# system you can override it with the PYTHON3 variable.
+PYTHON3 := python3
+
+# All dependencies are placed into this directory. The default is .gitignored
+# for you, but you can override it if you'd like.
+VENV := ./venv
+
+override VBIN := $(VENV)/bin
+override PIP := $(VBIN)/pip
+override PYTEST := $(VBIN)/py.test
+override ISORT := $(VBIN)/isort
+override BLACK := $(VBIN)/black
+
+.PHONY: installcheck indent
+
+installcheck: $(PYTEST)
+ $(PYTEST) -v -rs
+
+indent: $(ISORT) $(BLACK)
+ $(ISORT) --profile black *.py client/*.py server/*.py
+ $(BLACK) *.py client/*.py server/*.py
+
+$(PYTEST) $(ISORT) $(BLACK): requirements.txt | $(PIP)
+ $(PIP) install -r $<
+
+$(PIP):
+ $(PYTHON3) -m venv $(VENV)
diff --git a/src/test/python/README b/src/test/python/README
new file mode 100644
index 0000000000..ceae364e81
--- /dev/null
+++ b/src/test/python/README
@@ -0,0 +1,49 @@
+A test suite for exercising both the libpq client and the server backend at the
+protocol level, based on pytest and Construct.
+
+The test suite currently assumes that the standard PG* environment variables
+point to the database under test and are sufficient to log in a superuser on
+that system. In other words, a bare `psql` needs to Just Work before the test
+suite can do its thing. For a newly built dev cluster, typically all that I need
+to do is a
+
+ export PGDATABASE=postgres
+
+but you can adjust as needed for your setup.
+
+## Requirements
+
+A supported version (3.6+) of Python.
+
+The first run of
+
+ make installcheck
+
+will install a local virtual environment and all needed dependencies.
+
+## Hacking
+
+The code style is enforced by a _very_ opinionated autoformatter. Running the
+
+ make indent
+
+recipe will invoke it for you automatically. Don't fight the tool; part of the
+zen is in knowing that if the formatter makes your code ugly, there's probably a
+cleaner way to write your code.
+
+## Advanced Usage
+
+The Makefile is there for convenience, but you don't have to use it. Activate
+the virtualenv to be able to use pytest directly:
+
+ $ source venv/bin/activate
+ $ py.test -k oauth
+ ...
+ $ py.test ./server/test_server.py
+ ...
+ $ deactivate # puts the PATH et al back the way it was before
+
+To make quick smoke tests possible, slow tests have been marked explicitly. You
+can skip them by saying e.g.
+
+ $ py.test -m 'not slow'
diff --git a/src/test/python/client/__init__.py b/src/test/python/client/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/client/conftest.py b/src/test/python/client/conftest.py
new file mode 100644
index 0000000000..f38da7a138
--- /dev/null
+++ b/src/test/python/client/conftest.py
@@ -0,0 +1,126 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import socket
+import sys
+import threading
+
+import psycopg2
+import pytest
+
+import pq3
+
+BLOCKING_TIMEOUT = 2 # the number of seconds to wait for blocking calls
+
+
+@pytest.fixture
+def server_socket(unused_tcp_port_factory):
+ """
+ Returns a listening socket bound to an ephemeral port.
+ """
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("127.0.0.1", unused_tcp_port_factory()))
+ s.listen(1)
+ s.settimeout(BLOCKING_TIMEOUT)
+ yield s
+
+
+class ClientHandshake(threading.Thread):
+ """
+ A thread that connects to a local Postgres server using psycopg2. Once the
+ opening handshake completes, the connection will be immediately closed.
+ """
+
+ def __init__(self, *, port, **kwargs):
+ super().__init__()
+
+ kwargs["port"] = port
+ self._kwargs = kwargs
+
+ self.exception = None
+
+ def run(self):
+ try:
+ conn = psycopg2.connect(host="127.0.0.1", **self._kwargs)
+ conn.close()
+ except Exception as e:
+ self.exception = e
+
+ def check_completed(self, timeout=BLOCKING_TIMEOUT):
+ """
+ Joins the client thread. Raises an exception if the thread could not be
+ joined, or if it threw an exception itself. (The exception will be
+ cleared, so future calls to check_completed will succeed.)
+ """
+ self.join(timeout)
+
+ if self.is_alive():
+ raise TimeoutError("client thread did not handshake within the timeout")
+ elif self.exception:
+ e = self.exception
+ self.exception = None
+ raise e
+
+
+@pytest.fixture
+def accept(server_socket):
+ """
+ Returns a factory function that, when called, returns a pair (sock, client)
+ where sock is a server socket that has accepted a connection from client,
+ and client is an instance of ClientHandshake. Clients will complete their
+ handshakes and cleanly disconnect.
+
+ The default connstring options may be extended or overridden by passing
+ arbitrary keyword arguments. Keep in mind that you generally should not
+ override the host or port, since they point to the local test server.
+
+ For situations where a client needs to connect more than once to complete a
+ handshake, the accept function may be called more than once. (The client
+ returned for subsequent calls will always be the same client that was
+ returned for the first call.)
+
+ Tests must either complete the handshake so that the client thread can be
+ automatically joined during teardown, or else call client.check_completed()
+ and manually handle any expected errors.
+ """
+ _, port = server_socket.getsockname()
+
+ client = None
+ default_opts = dict(
+ port=port,
+ user=pq3.pguser(),
+ sslmode="disable",
+ )
+
+ def factory(**kwargs):
+ nonlocal client
+
+ if client is None:
+ opts = dict(default_opts)
+ opts.update(kwargs)
+
+ # The server_socket is already listening, so the client thread can
+ # be safely started; it'll block on the connection until we accept.
+ client = ClientHandshake(**opts)
+ client.start()
+
+ sock, _ = server_socket.accept()
+ return sock, client
+
+ yield factory
+ client.check_completed()
+
+
+@pytest.fixture
+def conn(accept):
+ """
+ Returns an accepted, wrapped pq3 connection to a psycopg2 client. The socket
+ will be closed when the test finishes, and the client will be checked for a
+ cleanly completed handshake.
+ """
+ sock, client = accept()
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ yield conn
diff --git a/src/test/python/client/test_client.py b/src/test/python/client/test_client.py
new file mode 100644
index 0000000000..c4c946fda4
--- /dev/null
+++ b/src/test/python/client/test_client.py
@@ -0,0 +1,180 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import sys
+
+import psycopg2
+import pytest
+from cryptography.hazmat.primitives import hashes, hmac
+
+import pq3
+
+
+def finish_handshake(conn):
+ """
+ Sends the AuthenticationOK message and the standard opening salvo of server
+ messages, then asserts that the client immediately sends a Terminate message
+ to close the connection cleanly.
+ """
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+ pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.Terminate
+
+
+def test_handshake(conn):
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ finish_handshake(conn)
+
+
+def test_aborted_connection(accept):
+ """
+ Make sure the client correctly reports an early close during handshakes.
+ """
+ sock, client = accept()
+ sock.close()
+
+ expected = "server closed the connection unexpectedly"
+ with pytest.raises(psycopg2.OperationalError, match=expected):
+ client.check_completed()
+
+
+#
+# SCRAM-SHA-256 (see RFC 5802: https://tools.ietf.org/html/rfc5802)
+#
+
+
+@pytest.fixture
+def password():
+ """
+ Returns a password for use by both client and server.
+ """
+ # TODO: parameterize this with passwords that require SASLprep.
+ return "secret"
+
+
+@pytest.fixture
+def pwconn(accept, password):
+ """
+ Like the conn fixture, but uses a password in the connection.
+ """
+ sock, client = accept(password=password)
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ yield conn
+
+
+def sha256(data):
+ """The H(str) function from Section 2.2."""
+ digest = hashes.Hash(hashes.SHA256())
+ digest.update(data)
+ return digest.finalize()
+
+
+def hmac_256(key, data):
+ """The HMAC(key, str) function from Section 2.2."""
+ h = hmac.HMAC(key, hashes.SHA256())
+ h.update(data)
+ return h.finalize()
+
+
+def xor(a, b):
+ """The XOR operation from Section 2.2."""
+ res = bytearray(a)
+ for i, byte in enumerate(b):
+ res[i] ^= byte
+ return bytes(res)
+
+
+def h_i(data, salt, i):
+ """The Hi(str, salt, i) function from Section 2.2."""
+ assert i > 0
+
+ acc = hmac_256(data, salt + b"\x00\x00\x00\x01")
+ last = acc
+ i -= 1
+
+ while i:
+ u = hmac_256(data, last)
+ acc = xor(acc, u)
+
+ last = u
+ i -= 1
+
+ return acc
+
+
+def test_scram(pwconn, password):
+ startup = pq3.recv1(pwconn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ pwconn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"SCRAM-SHA-256", b""],
+ )
+
+ # Get the client-first-message.
+ pkt = pq3.recv1(pwconn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ initial = pq3.SASLInitialResponse.parse(pkt.payload)
+ assert initial.name == b"SCRAM-SHA-256"
+
+ c_bind, authzid, c_name, c_nonce = initial.data.split(b",")
+ assert c_bind == b"n" # no channel bindings on a plaintext connection
+ assert authzid == b"" # we don't support authzid currently
+ assert c_name == b"n=" # libpq doesn't honor the GS2 username
+ assert c_nonce.startswith(b"r=")
+
+ # Send the server-first-message.
+ salt = b"12345"
+ iterations = 2
+
+ s_nonce = c_nonce + b"somenonce"
+ s_salt = b"s=" + base64.b64encode(salt)
+ s_iterations = b"i=%d" % iterations
+
+ msg = b",".join([s_nonce, s_salt, s_iterations])
+ pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=msg)
+
+ # Get the client-final-message.
+ pkt = pq3.recv1(pwconn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ c_bind_final, c_nonce_final, c_proof = pkt.payload.split(b",")
+ assert c_bind_final == b"c=" + base64.b64encode(c_bind + b"," + authzid + b",")
+ assert c_nonce_final == s_nonce
+
+ # Calculate what the client proof should be.
+ salted_password = h_i(password.encode("ascii"), salt, iterations)
+ client_key = hmac_256(salted_password, b"Client Key")
+ stored_key = sha256(client_key)
+
+ auth_message = b",".join(
+ [c_name, c_nonce, s_nonce, s_salt, s_iterations, c_bind_final, c_nonce_final]
+ )
+ client_signature = hmac_256(stored_key, auth_message)
+ client_proof = xor(client_key, client_signature)
+
+ expected = b"p=" + base64.b64encode(client_proof)
+ assert c_proof == expected
+
+ # Send the correct server signature.
+ server_key = hmac_256(salted_password, b"Server Key")
+ server_signature = hmac_256(server_key, auth_message)
+
+ s_verify = b"v=" + base64.b64encode(server_signature)
+ pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal, body=s_verify)
+
+ # Done!
+ finish_handshake(pwconn)
diff --git a/src/test/python/client/test_oauth.py b/src/test/python/client/test_oauth.py
new file mode 100644
index 0000000000..a754a9c0b6
--- /dev/null
+++ b/src/test/python/client/test_oauth.py
@@ -0,0 +1,936 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import http.server
+import json
+import secrets
+import sys
+import threading
+import time
+import urllib.parse
+
+import psycopg2
+import pytest
+
+import pq3
+
+from .conftest import BLOCKING_TIMEOUT
+
+
+def finish_handshake(conn):
+ """
+ Sends the AuthenticationOK message and the standard opening salvo of server
+ messages, then asserts that the client immediately sends a Terminate message
+ to close the connection cleanly.
+ """
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+ pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.Terminate
+
+
+#
+# OAUTHBEARER (see RFC 7628: https://tools.ietf.org/html/rfc7628)
+#
+
+
+def start_oauth_handshake(conn):
+ """
+ Negotiates an OAUTHBEARER SASL challenge. Returns the client's initial
+ response data.
+ """
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn, pq3.types.AuthnRequest, type=pq3.authn.SASL, body=[b"OAUTHBEARER", b""]
+ )
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ initial = pq3.SASLInitialResponse.parse(pkt.payload)
+ assert initial.name == b"OAUTHBEARER"
+
+ return initial.data
+
+
+def get_auth_value(initial):
+ """
+ Finds the auth value (e.g. "Bearer somedata..." in the client's initial SASL
+ response.
+ """
+ kvpairs = initial.split(b"\x01")
+ assert kvpairs[0] == b"n,," # no channel binding or authzid
+ assert kvpairs[2] == b"" # ends with an empty kvpair
+ assert kvpairs[3] == b"" # ...and there's nothing after it
+ assert len(kvpairs) == 4
+
+ key, value = kvpairs[1].split(b"=", 2)
+ assert key == b"auth"
+
+ return value
+
+
+def xtest_oauth_success(conn): # TODO
+ initial = start_oauth_handshake(conn)
+
+ auth = get_auth_value(initial)
+ assert auth.startswith(b"Bearer ")
+
+ # Accept the token. TODO actually validate
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+class OpenIDProvider(threading.Thread):
+ """
+ A thread that runs a mock OpenID provider server.
+ """
+
+ def __init__(self, *, port):
+ super().__init__()
+
+ self.exception = None
+
+ addr = ("", port)
+ self.server = self._Server(addr, self._Handler)
+
+ # TODO: allow HTTPS only, somehow
+ oauth = self._OAuthState()
+ oauth.host = f"localhost:{port}"
+ oauth.issuer = f"http://localhost:{port}"
+
+ # The following endpoints are required to be advertised by providers,
+ # even though our chosen client implementation does not actually make
+ # use of them.
+ oauth.register_endpoint(
+ "authorization_endpoint", "POST", "/authorize", self._authorization_handler
+ )
+ oauth.register_endpoint("jwks_uri", "GET", "/keys", self._jwks_handler)
+
+ self.server.oauth = oauth
+
+ def run(self):
+ try:
+ self.server.serve_forever()
+ except Exception as e:
+ self.exception = e
+
+ def stop(self, timeout=BLOCKING_TIMEOUT):
+ """
+ Shuts down the server and joins its thread. Raises an exception if the
+ thread could not be joined, or if it threw an exception itself. Must
+ only be called once, after start().
+ """
+ self.server.shutdown()
+ self.join(timeout)
+
+ if self.is_alive():
+ raise TimeoutError("client thread did not handshake within the timeout")
+ elif self.exception:
+ e = self.exception
+ raise e
+
+ class _OAuthState(object):
+ def __init__(self):
+ self.endpoint_paths = {}
+ self._endpoints = {}
+
+ def register_endpoint(self, name, method, path, func):
+ if method not in self._endpoints:
+ self._endpoints[method] = {}
+
+ self._endpoints[method][path] = func
+ self.endpoint_paths[name] = path
+
+ def endpoint(self, method, path):
+ if method not in self._endpoints:
+ return None
+
+ return self._endpoints[method].get(path)
+
+ class _Server(http.server.HTTPServer):
+ def handle_error(self, request, addr):
+ self.shutdown_request(request)
+ raise
+
+ @staticmethod
+ def _jwks_handler(headers, params):
+ return 200, {"keys": []}
+
+ @staticmethod
+ def _authorization_handler(headers, params):
+ # We don't actually want this to be called during these tests -- we
+ # should be using the device authorization endpoint instead.
+ assert (
+ False
+ ), "authorization handler called instead of device authorization handler"
+
+ class _Handler(http.server.BaseHTTPRequestHandler):
+ timeout = BLOCKING_TIMEOUT
+
+ def _discovery_handler(self, headers, params):
+ oauth = self.server.oauth
+
+ doc = {
+ "issuer": oauth.issuer,
+ "response_types_supported": ["token"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+
+ for name, path in oauth.endpoint_paths.items():
+ doc[name] = oauth.issuer + path
+
+ return 200, doc
+
+ def _handle(self, *, params=None, handler=None):
+ oauth = self.server.oauth
+ assert self.headers["Host"] == oauth.host
+
+ if handler is None:
+ handler = oauth.endpoint(self.command, self.path)
+ assert (
+ handler is not None
+ ), f"no registered endpoint for {self.command} {self.path}"
+
+ code, resp = handler(self.headers, params)
+
+ self.send_response(code)
+ self.send_header("Content-Type", "application/json")
+ self.end_headers()
+
+ resp = json.dumps(resp)
+ resp = resp.encode("utf-8")
+ self.wfile.write(resp)
+
+ self.close_connection = True
+
+ def do_GET(self):
+ if self.path == "/.well-known/openid-configuration":
+ self._handle(handler=self._discovery_handler)
+ return
+
+ self._handle()
+
+ def _request_body(self):
+ length = self.headers["Content-Length"]
+
+ # Handle only an explicit content-length.
+ assert length is not None
+ length = int(length)
+
+ return self.rfile.read(length).decode("utf-8")
+
+ def do_POST(self):
+ assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
+
+ body = self._request_body()
+ params = urllib.parse.parse_qs(body)
+
+ self._handle(params=params)
+
+
+@pytest.fixture
+def openid_provider(unused_tcp_port_factory):
+ """
+ A fixture that returns the OAuth state of a running OpenID provider server. The
+ server will be stopped when the fixture is torn down.
+ """
+ thread = OpenIDProvider(port=unused_tcp_port_factory())
+ thread.start()
+
+ try:
+ yield thread.server.oauth
+ finally:
+ thread.stop()
+
+
+@pytest.mark.parametrize("secret", [None, "", "hunter2"])
+@pytest.mark.parametrize("scope", [None, "", "openid email"])
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_with_explicit_issuer(
+ capfd, accept, openid_provider, retries, scope, secret
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ oauth_client_secret=secret,
+ oauth_scope=scope,
+ )
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+ verification_url = "https://example.com/device"
+
+ access_token = secrets.token_urlsafe()
+
+ def check_client_authn(headers, params):
+ if not secret:
+ assert params["client_id"] == [client_id]
+ return
+
+ # Require the client to use Basic authn; request-body credentials are
+ # NOT RECOMMENDED (RFC 6749, Sec. 2.3.1).
+ assert "Authorization" in headers
+
+ method, creds = headers["Authorization"].split()
+ assert method == "Basic"
+
+ expected = f"{client_id}:{secret}"
+ assert base64.b64decode(creds) == expected.encode("ascii")
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ check_client_authn(headers, params)
+
+ if scope:
+ assert params["scope"] == [scope]
+ else:
+ assert "scope" not in params
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": verification_url,
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ attempts = 0
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ check_client_authn(headers, params)
+
+ assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+ assert params["device_code"] == [device_code]
+
+ now = time.monotonic()
+
+ with retry_lock:
+ nonlocal attempts
+
+ # If the test wants to force the client to retry, return an
+ # authorization_pending response and decrement the retry count.
+ if attempts < retries:
+ attempts += 1
+ return 400, {"error": "authorization_pending"}
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+ if retries:
+ # Finally, make sure that the client prompted the user with the expected
+ # authorization URL and user code.
+ expected = f"Visit {verification_url} and enter the code: {user_code}"
+ _, stderr = capfd.readouterr()
+ assert expected in stderr
+
+
+def test_oauth_requires_client_id(accept, openid_provider):
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ # Do not set a client ID; this should cause a client error after the
+ # server asks for OAUTHBEARER and the client tries to contact the
+ # issuer.
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should disconnect at this point.
+ assert not conn.read()
+
+ expected_error = "no oauth_client_id is set"
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("error_code", ["authorization_pending", "slow_down"])
+@pytest.mark.parametrize("retries", [1, 2])
+def test_oauth_retry_interval(accept, openid_provider, retries, error_code):
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id="some-id",
+ )
+
+ expected_retry_interval = 1
+ access_token = secrets.token_urlsafe()
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ resp = {
+ "device_code": "my-device-code",
+ "user_code": "my-user-code",
+ "interval": expected_retry_interval,
+ "verification_uri": "https://example.com",
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ attempts = 0
+ last_retry = None
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ now = time.monotonic()
+
+ with retry_lock:
+ nonlocal attempts, last_retry, expected_retry_interval
+
+ # Make sure the retry interval is being respected by the client.
+ if last_retry is not None:
+ interval = now - last_retry
+ assert interval >= expected_retry_interval
+
+ last_retry = now
+
+ # If the test wants to force the client to retry, return the desired
+ # error response and decrement the retry count.
+ if attempts < retries:
+ attempts += 1
+
+ # A slow_down code requires the client to additionally increase
+ # its interval by five seconds.
+ if error_code == "slow_down":
+ expected_retry_interval += 5
+
+ return 400, {"error": error_code}
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+ "failure_mode, error_pattern",
+ [
+ pytest.param(
+ {
+ "error": "invalid_client",
+ "error_description": "client authentication failed",
+ },
+ r"client authentication failed \(invalid_client\)",
+ id="authentication failure with description",
+ ),
+ pytest.param(
+ {"error": "invalid_request"},
+ r"\(invalid_request\)",
+ id="invalid request without description",
+ ),
+ pytest.param(
+ {},
+ r"failed to obtain device authorization",
+ id="broken error response",
+ ),
+ ],
+)
+def test_oauth_device_authorization_failures(
+ accept, openid_provider, failure_mode, error_pattern
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ )
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ return 400, failure_mode
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ def token_endpoint(headers, params):
+ assert False, "token endpoint was invoked unexpectedly"
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should not continue the connection due to the hardcoded
+ # provider failure; we disconnect here.
+
+ # Now make sure the client correctly failed.
+ with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+ client.check_completed()
+
+
+@pytest.mark.parametrize(
+ "failure_mode, error_pattern",
+ [
+ pytest.param(
+ {
+ "error": "expired_token",
+ "error_description": "the device code has expired",
+ },
+ r"the device code has expired \(expired_token\)",
+ id="expired token with description",
+ ),
+ pytest.param(
+ {"error": "access_denied"},
+ r"\(access_denied\)",
+ id="access denied without description",
+ ),
+ pytest.param(
+ {},
+ r"OAuth token retrieval failed",
+ id="broken error response",
+ ),
+ ],
+)
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_token_failures(
+ accept, openid_provider, retries, failure_mode, error_pattern
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ )
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ assert params["client_id"] == [client_id]
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": "https://example.com/device",
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ with retry_lock:
+ nonlocal retries
+
+ # If the test wants to force the client to retry, return an
+ # authorization_pending response and decrement the retry count.
+ if retries > 0:
+ retries -= 1
+ return 400, {"error": "authorization_pending"}
+
+ return 400, failure_mode
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should not continue the connection due to the hardcoded
+ # provider failure; we disconnect here.
+
+ # Now make sure the client correctly failed.
+ with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+ client.check_completed()
+
+
+@pytest.mark.parametrize("scope", [None, "openid email"])
+@pytest.mark.parametrize(
+ "base_response",
+ [
+ {"status": "invalid_token"},
+ {"extra_object": {"key": "value"}, "status": "invalid_token"},
+ {"extra_object": {"status": 1}, "status": "invalid_token"},
+ ],
+)
+def test_oauth_discovery(accept, openid_provider, base_response, scope):
+ sock, client = accept(oauth_client_id=secrets.token_hex())
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+ verification_url = "https://example.com/device"
+
+ access_token = secrets.token_urlsafe()
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ if scope:
+ assert params["scope"] == [scope]
+ else:
+ assert "scope" not in params
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": verification_url,
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ def token_endpoint(headers, params):
+ assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+ assert params["device_code"] == [device_code]
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # For discovery, the client should send an empty auth header. See
+ # RFC 7628, Sec. 4.3.
+ auth = get_auth_value(initial)
+ assert auth == b""
+
+ # We will fail the first SASL exchange. First return a link to the
+ # discovery document, pointing to the test provider server.
+ resp = dict(base_response)
+
+ discovery_uri = f"{openid_provider.issuer}/.well-known/openid-configuration"
+ resp["openid-configuration"] = discovery_uri
+
+ if scope:
+ resp["scope"] = scope
+
+ resp = json.dumps(resp)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASLContinue,
+ body=resp.encode("ascii"),
+ )
+
+ # Per RFC, the client is required to send a dummy ^A response.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+ assert pkt.payload == b"\x01"
+
+ # Now fail the SASL exchange.
+ pq3.send(
+ conn,
+ pq3.types.ErrorResponse,
+ fields=[
+ b"SFATAL",
+ b"C28000",
+ b"Mdoesn't matter",
+ b"",
+ ],
+ )
+
+ # The client will connect to us a second time, using the parameters we sent
+ # it.
+ sock, _ = accept()
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+ "response,expected_error",
+ [
+ pytest.param(
+ "abcde",
+ 'Token "abcde" is invalid',
+ id="bad JSON: invalid syntax",
+ ),
+ pytest.param(
+ '"abcde"',
+ "top-level element must be an object",
+ id="bad JSON: top-level element is a string",
+ ),
+ pytest.param(
+ "[]",
+ "top-level element must be an object",
+ id="bad JSON: top-level element is an array",
+ ),
+ pytest.param(
+ "{}",
+ "server sent error response without a status",
+ id="bad JSON: no status member",
+ ),
+ pytest.param(
+ '{ "status": null }',
+ 'field "status" must be a string',
+ id="bad JSON: null status member",
+ ),
+ pytest.param(
+ '{ "status": 0 }',
+ 'field "status" must be a string',
+ id="bad JSON: int status member",
+ ),
+ pytest.param(
+ '{ "status": [ "bad" ] }',
+ 'field "status" must be a string',
+ id="bad JSON: array status member",
+ ),
+ pytest.param(
+ '{ "status": { "bad": "bad" } }',
+ 'field "status" must be a string',
+ id="bad JSON: object status member",
+ ),
+ pytest.param(
+ '{ "nested": { "status": "bad" } }',
+ "server sent error response without a status",
+ id="bad JSON: nested status",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token" ',
+ "The input string ended unexpectedly",
+ id="bad JSON: unterminated object",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token" } { }',
+ 'Expected end of input, but found "{"',
+ id="bad JSON: trailing data",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "openid-configuration": 1 }',
+ 'field "openid-configuration" must be a string',
+ id="bad JSON: int openid-configuration member",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "openid-configuration": 1 }',
+ 'field "openid-configuration" must be a string',
+ id="bad JSON: int openid-configuration member",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "scope": 1 }',
+ 'field "scope" must be a string',
+ id="bad JSON: int scope member",
+ ),
+ ],
+)
+def test_oauth_discovery_server_error(accept, response, expected_error):
+ sock, client = accept(oauth_client_id=secrets.token_hex())
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # Fail the SASL exchange with an invalid JSON response.
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASLContinue,
+ body=response.encode("utf-8"),
+ )
+
+ # The client should disconnect, so the socket is closed here. (If
+ # the client doesn't disconnect, it will report a different error
+ # below and the test will fail.)
+
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
+
+
+@pytest.mark.parametrize(
+ "sasl_err,resp_type,resp_payload,expected_error",
+ [
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.ErrorResponse,
+ dict(
+ fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+ ),
+ "expected error message",
+ id="standard server error: invalid_request",
+ ),
+ pytest.param(
+ {"status": "invalid_token"},
+ pq3.types.ErrorResponse,
+ dict(
+ fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+ ),
+ "expected error message",
+ id="standard server error: invalid_token without discovery URI",
+ ),
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.AuthnRequest,
+ dict(type=pq3.authn.SASLContinue, body=b""),
+ "server sent additional OAuth data",
+ id="broken server: additional challenge after error",
+ ),
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.AuthnRequest,
+ dict(type=pq3.authn.SASLFinal),
+ "server sent additional OAuth data",
+ id="broken server: SASL success after error",
+ ),
+ ],
+)
+def test_oauth_server_error(accept, sasl_err, resp_type, resp_payload, expected_error):
+ sock, client = accept()
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ start_oauth_handshake(conn)
+
+ # Ignore the client data. Return an error "challenge".
+ resp = json.dumps(sasl_err)
+ resp = resp.encode("utf-8")
+
+ pq3.send(
+ conn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=resp
+ )
+
+ # Per RFC, the client is required to send a dummy ^A response.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+ assert pkt.payload == b"\x01"
+
+ # Now fail the SASL exchange (in either a valid way, or an invalid
+ # one, depending on the test).
+ pq3.send(conn, resp_type, **resp_payload)
+
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
diff --git a/src/test/python/pq3.py b/src/test/python/pq3.py
new file mode 100644
index 0000000000..3a22dad0b6
--- /dev/null
+++ b/src/test/python/pq3.py
@@ -0,0 +1,727 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import os
+import ssl
+import sys
+import textwrap
+
+from construct import *
+
+import tls
+
+
+def protocol(major, minor):
+ """
+ Returns the protocol version, in integer format, corresponding to the given
+ major and minor version numbers.
+ """
+ return (major << 16) | minor
+
+
+# Startup
+
+StringList = GreedyRange(NullTerminated(GreedyBytes))
+
+
+class KeyValueAdapter(Adapter):
+ """
+ Turns a key-value store into a null-terminated list of null-terminated
+ strings, as presented on the wire in the startup packet.
+ """
+
+ def _encode(self, obj, context, path):
+ if isinstance(obj, list):
+ return obj
+
+ l = []
+
+ for k, v in obj.items():
+ if isinstance(k, str):
+ k = k.encode("utf-8")
+ l.append(k)
+
+ if isinstance(v, str):
+ v = v.encode("utf-8")
+ l.append(v)
+
+ l.append(b"")
+ return l
+
+ def _decode(self, obj, context, path):
+ # TODO: turn a list back into a dict
+ return obj
+
+
+KeyValues = KeyValueAdapter(StringList)
+
+_startup_payload = Switch(
+ this.proto,
+ {
+ protocol(3, 0): KeyValues,
+ },
+ default=GreedyBytes,
+)
+
+
+def _default_protocol(this):
+ try:
+ if isinstance(this.payload, (list, dict)):
+ return protocol(3, 0)
+ except AttributeError:
+ pass # no payload passed during build
+
+ return 0
+
+
+def _startup_payload_len(this):
+ """
+ The payload field has a fixed size based on the length of the packet. But
+ if the caller hasn't supplied an explicit length at build time, we have to
+ build the payload to figure out how long it is, which requires us to know
+ the length first... This function exists solely to break the cycle.
+ """
+ assert this._building, "_startup_payload_len() cannot be called during parsing"
+
+ try:
+ payload = this.payload
+ except AttributeError:
+ return 0 # no payload
+
+ if isinstance(payload, bytes):
+ # already serialized; just use the given length
+ return len(payload)
+
+ try:
+ proto = this.proto
+ except AttributeError:
+ proto = _default_protocol(this)
+
+ data = _startup_payload.build(payload, proto=proto)
+ return len(data)
+
+
+Startup = Struct(
+ "len" / Default(Int32sb, lambda this: _startup_payload_len(this) + 8),
+ "proto" / Default(Hex(Int32sb), _default_protocol),
+ "payload" / FixedSized(this.len - 8, Default(_startup_payload, b"")),
+)
+
+# Pq3
+
+# Adapted from construct.core.EnumIntegerString
+class EnumNamedByte:
+ def __init__(self, val, name):
+ self._val = val
+ self._name = name
+
+ def __int__(self):
+ return ord(self._val)
+
+ def __str__(self):
+ return "(enum) %s %r" % (self._name, self._val)
+
+ def __repr__(self):
+ return "EnumNamedByte(%r)" % self._val
+
+ def __eq__(self, other):
+ if isinstance(other, EnumNamedByte):
+ other = other._val
+ if not isinstance(other, bytes):
+ return NotImplemented
+
+ return self._val == other
+
+ def __hash__(self):
+ return hash(self._val)
+
+
+# Adapted from construct.core.Enum
+class ByteEnum(Adapter):
+ def __init__(self, **mapping):
+ super(ByteEnum, self).__init__(Byte)
+ self.namemapping = {k: EnumNamedByte(v, k) for k, v in mapping.items()}
+ self.decmapping = {v: EnumNamedByte(v, k) for k, v in mapping.items()}
+
+ def __getattr__(self, name):
+ if name in self.namemapping:
+ return self.decmapping[self.namemapping[name]]
+ raise AttributeError
+
+ def _decode(self, obj, context, path):
+ b = bytes([obj])
+ try:
+ return self.decmapping[b]
+ except KeyError:
+ return EnumNamedByte(b, "(unknown)")
+
+ def _encode(self, obj, context, path):
+ if isinstance(obj, int):
+ return obj
+ elif isinstance(obj, bytes):
+ return ord(obj)
+ return int(obj)
+
+
+types = ByteEnum(
+ ErrorResponse=b"E",
+ ReadyForQuery=b"Z",
+ Query=b"Q",
+ EmptyQueryResponse=b"I",
+ AuthnRequest=b"R",
+ PasswordMessage=b"p",
+ BackendKeyData=b"K",
+ CommandComplete=b"C",
+ ParameterStatus=b"S",
+ DataRow=b"D",
+ Terminate=b"X",
+)
+
+
+authn = Enum(
+ Int32ub,
+ OK=0,
+ SASL=10,
+ SASLContinue=11,
+ SASLFinal=12,
+)
+
+
+_authn_body = Switch(
+ this.type,
+ {
+ authn.OK: Terminated,
+ authn.SASL: StringList,
+ },
+ default=GreedyBytes,
+)
+
+
+def _data_len(this):
+ assert this._building, "_data_len() cannot be called during parsing"
+
+ if not hasattr(this, "data") or this.data is None:
+ return -1
+
+ return len(this.data)
+
+
+# The protocol reuses the PasswordMessage for several authentication response
+# types, and there's no good way to figure out which is which without keeping
+# state for the entire stream. So this is a separate Construct that can be
+# explicitly parsed/built by code that knows it's needed.
+SASLInitialResponse = Struct(
+ "name" / NullTerminated(GreedyBytes),
+ "len" / Default(Int32sb, lambda this: _data_len(this)),
+ "data"
+ / IfThenElse(
+ # Allow tests to explicitly pass an incorrect length during testing, by
+ # not enforcing a FixedSized during build. (The len calculation above
+ # defaults to the correct size.)
+ this._building,
+ Optional(GreedyBytes),
+ If(this.len != -1, Default(FixedSized(this.len, GreedyBytes), b"")),
+ ),
+ Terminated, # make sure the entire response is consumed
+)
+
+
+_column = FocusedSeq(
+ "data",
+ "len" / Default(Int32sb, lambda this: _data_len(this)),
+ "data" / If(this.len != -1, FixedSized(this.len, GreedyBytes)),
+)
+
+
+_payload_map = {
+ types.ErrorResponse: Struct("fields" / StringList),
+ types.ReadyForQuery: Struct("status" / Bytes(1)),
+ types.Query: Struct("query" / NullTerminated(GreedyBytes)),
+ types.EmptyQueryResponse: Terminated,
+ types.AuthnRequest: Struct("type" / authn, "body" / Default(_authn_body, b"")),
+ types.BackendKeyData: Struct("pid" / Int32ub, "key" / Hex(Int32ub)),
+ types.CommandComplete: Struct("tag" / NullTerminated(GreedyBytes)),
+ types.ParameterStatus: Struct(
+ "name" / NullTerminated(GreedyBytes), "value" / NullTerminated(GreedyBytes)
+ ),
+ types.DataRow: Struct("columns" / Default(PrefixedArray(Int16sb, _column), b"")),
+ types.Terminate: Terminated,
+}
+
+
+_payload = FocusedSeq(
+ "_payload",
+ "_payload"
+ / Switch(
+ this._.type,
+ _payload_map,
+ default=GreedyBytes,
+ ),
+ Terminated, # make sure every payload consumes the entire packet
+)
+
+
+def _payload_len(this):
+ """
+ See _startup_payload_len() for an explanation.
+ """
+ assert this._building, "_payload_len() cannot be called during parsing"
+
+ try:
+ payload = this.payload
+ except AttributeError:
+ return 0 # no payload
+
+ if isinstance(payload, bytes):
+ # already serialized; just use the given length
+ return len(payload)
+
+ data = _payload.build(payload, type=this.type)
+ return len(data)
+
+
+Pq3 = Struct(
+ "type" / types,
+ "len" / Default(Int32ub, lambda this: _payload_len(this) + 4),
+ "payload" / FixedSized(this.len - 4, Default(_payload, b"")),
+)
+
+
+# Environment
+
+
+def pghost():
+ return os.environ.get("PGHOST", default="localhost")
+
+
+def pgport():
+ return int(os.environ.get("PGPORT", default=5432))
+
+
+def pguser():
+ try:
+ return os.environ["PGUSER"]
+ except KeyError:
+ return getpass.getuser()
+
+
+def pgdatabase():
+ return os.environ.get("PGDATABASE", default="postgres")
+
+
+# Connections
+
+
+def _hexdump_translation_map():
+ """
+ For hexdumps. Translates any unprintable or non-ASCII bytes into '.'.
+ """
+ input = bytearray()
+
+ for i in range(128):
+ c = chr(i)
+
+ if not c.isprintable():
+ input += bytes([i])
+
+ input += bytes(range(128, 256))
+
+ return bytes.maketrans(input, b"." * len(input))
+
+
+class _DebugStream(object):
+ """
+ Wraps a file-like object and adds hexdumps of the read and write data. Call
+ end_packet() on a _DebugStream to write the accumulated hexdumps to the
+ output stream, along with the packet that was sent.
+ """
+
+ _translation_map = _hexdump_translation_map()
+
+ def __init__(self, stream, out=sys.stdout):
+ """
+ Creates a new _DebugStream wrapping the given stream (which must have
+ been created by wrap()). All attributes not provided by the _DebugStream
+ are delegated to the wrapped stream. out is the text stream to which
+ hexdumps are written.
+ """
+ self.raw = stream
+ self._out = out
+ self._rbuf = io.BytesIO()
+ self._wbuf = io.BytesIO()
+
+ def __getattr__(self, name):
+ return getattr(self.raw, name)
+
+ def __setattr__(self, name, value):
+ if name in ("raw", "_out", "_rbuf", "_wbuf"):
+ return object.__setattr__(self, name, value)
+
+ setattr(self.raw, name, value)
+
+ def read(self, *args, **kwargs):
+ buf = self.raw.read(*args, **kwargs)
+
+ self._rbuf.write(buf)
+ return buf
+
+ def write(self, b):
+ self._wbuf.write(b)
+ return self.raw.write(b)
+
+ def recv(self, *args):
+ buf = self.raw.recv(*args)
+
+ self._rbuf.write(buf)
+ return buf
+
+ def _flush(self, buf, prefix):
+ width = 16
+ hexwidth = width * 3 - 1
+
+ count = 0
+ buf.seek(0)
+
+ while True:
+ line = buf.read(16)
+
+ if not line:
+ if count:
+ self._out.write("\n") # separate the output block with a newline
+ return
+
+ self._out.write("%s %04X:\t" % (prefix, count))
+ self._out.write("%*s\t" % (-hexwidth, line.hex(" ")))
+ self._out.write(line.translate(self._translation_map).decode("ascii"))
+ self._out.write("\n")
+
+ count += 16
+
+ def print_debug(self, obj, *, prefix=""):
+ contents = ""
+ if obj is not None:
+ contents = str(obj)
+
+ for line in contents.splitlines():
+ self._out.write("%s%s\n" % (prefix, line))
+
+ self._out.write("\n")
+
+ def flush_debug(self, *, prefix=""):
+ self._flush(self._rbuf, prefix + "<")
+ self._rbuf = io.BytesIO()
+
+ self._flush(self._wbuf, prefix + ">")
+ self._wbuf = io.BytesIO()
+
+ def end_packet(self, pkt, *, read=False, prefix="", indent=" "):
+ """
+ Marks the end of a logical "packet" of data. A string representation of
+ pkt will be printed, and the debug buffers will be flushed with an
+ indent. All lines can be optionally prefixed.
+
+ If read is True, the packet representation is written after the debug
+ buffers; otherwise the default of False (meaning write) causes the
+ packet representation to be dumped first. This is meant to capture the
+ logical flow of layer translation.
+ """
+ write = not read
+
+ if write:
+ self.print_debug(pkt, prefix=prefix + "> ")
+
+ self.flush_debug(prefix=prefix + indent)
+
+ if read:
+ self.print_debug(pkt, prefix=prefix + "< ")
+
+
+@contextlib.contextmanager
+def wrap(socket, *, debug_stream=None):
+ """
+ Transforms a raw socket into a connection that can be used for Construct
+ building and parsing. The return value is a context manager and can be used
+ in a with statement.
+ """
+ # It is critical that buffering be disabled here, so that we can still
+ # manipulate the raw socket without desyncing the stream.
+ with socket.makefile("rwb", buffering=0) as sfile:
+ # Expose the original socket's recv() on the SocketIO object we return.
+ def recv(self, *args):
+ return socket.recv(*args)
+
+ sfile.recv = recv.__get__(sfile)
+
+ conn = sfile
+ if debug_stream:
+ conn = _DebugStream(conn, debug_stream)
+
+ try:
+ yield conn
+ finally:
+ if debug_stream:
+ conn.flush_debug(prefix="? ")
+
+
+def _send(stream, cls, obj):
+ debugging = hasattr(stream, "flush_debug")
+ out = io.BytesIO()
+
+ # Ideally we would build directly to the passed stream, but because we need
+ # to reparse the generated output for the debugging case, build to an
+ # intermediate BytesIO and send it instead.
+ cls.build_stream(obj, out)
+ buf = out.getvalue()
+
+ stream.write(buf)
+ if debugging:
+ pkt = cls.parse(buf)
+ stream.end_packet(pkt)
+
+ stream.flush()
+
+
+def send(stream, packet_type, payload_data=None, **payloadkw):
+ """
+ Sends a packet on the given pq3 connection. type is the pq3.types member
+ that should be assigned to the packet. If payload_data is given, it will be
+ used as the packet payload; otherwise the key/value pairs in payloadkw will
+ be the payload contents.
+ """
+ data = payloadkw
+
+ if payload_data is not None:
+ if payloadkw:
+ raise ValueError(
+ "payload_data and payload keywords may not be used simultaneously"
+ )
+
+ data = payload_data
+
+ _send(stream, Pq3, dict(type=packet_type, payload=data))
+
+
+def send_startup(stream, proto=None, **kwargs):
+ """
+ Sends a startup packet on the given pq3 connection. In most cases you should
+ use the handshake functions instead, which will do this for you.
+
+ By default, a protocol version 3 packet will be sent. This can be overridden
+ with the proto parameter.
+ """
+ pkt = {}
+
+ if proto is not None:
+ pkt["proto"] = proto
+ if kwargs:
+ pkt["payload"] = kwargs
+
+ _send(stream, Startup, pkt)
+
+
+def recv1(stream, *, cls=Pq3):
+ """
+ Receives a single pq3 packet from the given stream and returns it.
+ """
+ resp = cls.parse_stream(stream)
+
+ debugging = hasattr(stream, "flush_debug")
+ if debugging:
+ stream.end_packet(resp, read=True)
+
+ return resp
+
+
+def handshake(stream, **kwargs):
+ """
+ Performs a libpq v3 startup handshake. kwargs should contain the key/value
+ parameters to send to the server in the startup packet.
+ """
+ # Send our startup parameters.
+ send_startup(stream, **kwargs)
+
+ # Receive and dump packets until the server indicates it's ready for our
+ # first query.
+ while True:
+ resp = recv1(stream)
+ if resp is None:
+ raise RuntimeError("server closed connection during handshake")
+
+ if resp.type == types.ReadyForQuery:
+ return
+ elif resp.type == types.ErrorResponse:
+ raise RuntimeError(
+ f"received error response from peer: {resp.payload.fields!r}"
+ )
+
+
+# TLS
+
+
+class _TLSStream(object):
+ """
+ A file-like object that performs TLS encryption/decryption on a wrapped
+ stream. Differs from ssl.SSLSocket in that we have full visibility and
+ control over the TLS layer.
+ """
+
+ def __init__(self, stream, context):
+ self._stream = stream
+ self._debugging = hasattr(stream, "flush_debug")
+
+ self._in = ssl.MemoryBIO()
+ self._out = ssl.MemoryBIO()
+ self._ssl = context.wrap_bio(self._in, self._out)
+
+ def handshake(self):
+ try:
+ self._pump(lambda: self._ssl.do_handshake())
+ finally:
+ self._flush_debug(prefix="? ")
+
+ def read(self, *args):
+ return self._pump(lambda: self._ssl.read(*args))
+
+ def write(self, *args):
+ return self._pump(lambda: self._ssl.write(*args))
+
+ def _decode(self, buf):
+ """
+ Attempts to decode a buffer of TLS data into a packet representation
+ that can be printed.
+
+ TODO: handle buffers (and record fragments) that don't align with packet
+ boundaries.
+ """
+ end = len(buf)
+ bio = io.BytesIO(buf)
+
+ ret = io.StringIO()
+
+ while bio.tell() < end:
+ record = tls.Plaintext.parse_stream(bio)
+
+ if ret.tell() > 0:
+ ret.write("\n")
+ ret.write("[Record] ")
+ ret.write(str(record))
+ ret.write("\n")
+
+ if record.type == tls.ContentType.handshake:
+ record_cls = tls.Handshake
+ else:
+ continue
+
+ innerlen = len(record.fragment)
+ inner = io.BytesIO(record.fragment)
+
+ while inner.tell() < innerlen:
+ msg = record_cls.parse_stream(inner)
+
+ indented = "[Message] " + str(msg)
+ indented = textwrap.indent(indented, " ")
+
+ ret.write("\n")
+ ret.write(indented)
+ ret.write("\n")
+
+ return ret.getvalue()
+
+ def flush(self):
+ if not self._out.pending:
+ self._stream.flush()
+ return
+
+ buf = self._out.read()
+ self._stream.write(buf)
+
+ if self._debugging:
+ pkt = self._decode(buf)
+ self._stream.end_packet(pkt, prefix=" ")
+
+ self._stream.flush()
+
+ def _pump(self, operation):
+ while True:
+ try:
+ return operation()
+ except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
+ want = e
+ self._read_write(want)
+
+ def _recv(self, maxsize):
+ buf = self._stream.recv(4096)
+ if not buf:
+ self._in.write_eof()
+ return
+
+ self._in.write(buf)
+
+ if not self._debugging:
+ return
+
+ pkt = self._decode(buf)
+ self._stream.end_packet(pkt, read=True, prefix=" ")
+
+ def _read_write(self, want):
+ # XXX This needs work. So many corner cases yet to handle. For one,
+ # doing blocking writes in flush may lead to distributed deadlock if the
+ # peer is already blocking on its writes.
+
+ if isinstance(want, ssl.SSLWantWriteError):
+ assert self._out.pending, "SSL backend wants write without data"
+
+ self.flush()
+
+ if isinstance(want, ssl.SSLWantReadError):
+ self._recv(4096)
+
+ def _flush_debug(self, prefix):
+ if not self._debugging:
+ return
+
+ self._stream.flush_debug(prefix=prefix)
+
+
+@contextlib.contextmanager
+def tls_handshake(stream, context):
+ """
+ Performs a TLS handshake over the given stream (which must have been created
+ via a call to wrap()), and returns a new stream which transparently tunnels
+ data over the TLS connection.
+
+ If the passed stream has debugging enabled, the returned stream will also
+ have debugging, using the same output IO.
+ """
+ debugging = hasattr(stream, "flush_debug")
+
+ # Send our startup parameters.
+ send_startup(stream, proto=protocol(1234, 5679))
+
+ # Look at the SSL response.
+ resp = stream.read(1)
+ if debugging:
+ stream.flush_debug(prefix=" ")
+
+ if resp == b"N":
+ raise RuntimeError("server does not support SSLRequest")
+ if resp != b"S":
+ raise RuntimeError(f"unexpected response of type {resp!r} during TLS startup")
+
+ tls = _TLSStream(stream, context)
+ tls.handshake()
+
+ if debugging:
+ tls = _DebugStream(tls, stream._out)
+
+ try:
+ yield tls
+ # TODO: teardown/unwrap the connection?
+ finally:
+ if debugging:
+ tls.flush_debug(prefix="? ")
diff --git a/src/test/python/pytest.ini b/src/test/python/pytest.ini
new file mode 100644
index 0000000000..ab7a6e7fb9
--- /dev/null
+++ b/src/test/python/pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+
+markers =
+ slow: mark test as slow
diff --git a/src/test/python/requirements.txt b/src/test/python/requirements.txt
new file mode 100644
index 0000000000..32f105ea84
--- /dev/null
+++ b/src/test/python/requirements.txt
@@ -0,0 +1,7 @@
+black
+cryptography~=3.4.6
+construct~=2.10.61
+isort~=5.6
+psycopg2~=2.8.6
+pytest~=6.1
+pytest-asyncio~=0.14.0
diff --git a/src/test/python/server/__init__.py b/src/test/python/server/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/server/conftest.py b/src/test/python/server/conftest.py
new file mode 100644
index 0000000000..ba7342a453
--- /dev/null
+++ b/src/test/python/server/conftest.py
@@ -0,0 +1,45 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import socket
+import sys
+
+import pytest
+
+import pq3
+
+
+@pytest.fixture
+def connect():
+ """
+ A factory fixture that, when called, returns a socket connected to a
+ Postgres server, wrapped in a pq3 connection. The calling test will be
+ skipped automatically if a server is not running at PGHOST:PGPORT, so it's
+ best to connect as soon as possible after the test case begins, to avoid
+ doing unnecessary work.
+ """
+ # Set up an ExitStack to handle safe cleanup of all of the moving pieces.
+ with contextlib.ExitStack() as stack:
+
+ def conn_factory():
+ addr = (pq3.pghost(), pq3.pgport())
+
+ try:
+ sock = socket.create_connection(addr, timeout=2)
+ except ConnectionError as e:
+ pytest.skip(f"unable to connect to {addr}: {e}")
+
+ # Have ExitStack close our socket.
+ stack.enter_context(sock)
+
+ # Wrap the connection in a pq3 layer and have ExitStack clean it up
+ # too.
+ wrap_ctx = pq3.wrap(sock, debug_stream=sys.stdout)
+ conn = stack.enter_context(wrap_ctx)
+
+ return conn
+
+ yield conn_factory
diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py
new file mode 100644
index 0000000000..355ef8e4bd
--- /dev/null
+++ b/src/test/python/server/test_oauth.py
@@ -0,0 +1,1012 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import contextlib
+import json
+import os
+import pathlib
+import secrets
+import shlex
+import shutil
+import socket
+import struct
+from multiprocessing import shared_memory
+
+import psycopg2
+import pytest
+from psycopg2 import sql
+
+import pq3
+
+MAX_SASL_MESSAGE_LENGTH = 65535
+
+INVALID_AUTHORIZATION_ERRCODE = b"28000"
+PROTOCOL_VIOLATION_ERRCODE = b"08P01"
+FEATURE_NOT_SUPPORTED_ERRCODE = b"0A000"
+
+SHARED_MEM_NAME = "oauth-pytest"
+MAX_TOKEN_SIZE = 4096
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def skip_if_no_postgres():
+ """
+ Used by the oauth_ctx fixture to skip this test module if no Postgres server
+ is running.
+
+ This logic is nearly duplicated with the conn fixture. Ideally oauth_ctx
+ would depend on that, but a module-scope fixture can't depend on a
+ test-scope fixture, and we haven't reached the rule of three yet.
+ """
+ addr = (pq3.pghost(), pq3.pgport())
+
+ try:
+ with socket.create_connection(addr, timeout=2):
+ pass
+ except ConnectionError as e:
+ pytest.skip(f"unable to connect to {addr}: {e}")
+
+
+@contextlib.contextmanager
+def prepend_file(path, lines):
+ """
+ A context manager that prepends a file on disk with the desired lines of
+ text. When the context manager is exited, the file will be restored to its
+ original contents.
+ """
+ # First make a backup of the original file.
+ bak = path + ".bak"
+ shutil.copy2(path, bak)
+
+ try:
+ # Write the new lines, followed by the original file content.
+ with open(path, "w") as new, open(bak, "r") as orig:
+ new.writelines(lines)
+ shutil.copyfileobj(orig, new)
+
+ # Return control to the calling code.
+ yield
+
+ finally:
+ # Put the backup back into place.
+ os.replace(bak, path)
+
+
+@pytest.fixture(scope="module")
+def oauth_ctx():
+ """
+ Creates a database and user that use the oauth auth method. The context
+ object contains the dbname and user attributes as strings to be used during
+ connection, as well as the issuer and scope that have been set in the HBA
+ configuration.
+
+ This fixture assumes that the standard PG* environment variables point to a
+ server running on a local machine, and that the PGUSER has rights to create
+ databases and roles.
+ """
+ skip_if_no_postgres() # don't bother running these tests without a server
+
+ id = secrets.token_hex(4)
+
+ class Context:
+ dbname = "oauth_test_" + id
+
+ user = "oauth_user_" + id
+ map_user = "oauth_map_user_" + id
+ authz_user = "oauth_authz_user_" + id
+
+ issuer = "https://example.com/" + id
+ scope = "openid " + id
+
+ ctx = Context()
+ hba_lines = (
+ f'host {ctx.dbname} {ctx.map_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
+ f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
+ f'host {ctx.dbname} all samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
+ )
+ ident_lines = (r"oauth /^(.*)@example\.com$ \1",)
+
+ conn = psycopg2.connect("")
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+
+ # Create our roles and database.
+ user = sql.Identifier(ctx.user)
+ map_user = sql.Identifier(ctx.map_user)
+ authz_user = sql.Identifier(ctx.authz_user)
+ dbname = sql.Identifier(ctx.dbname)
+
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(user))
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(map_user))
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user))
+ c.execute(sql.SQL("CREATE DATABASE {};").format(dbname))
+
+ # Make this test script the server's oauth_validator.
+ path = pathlib.Path(__file__).parent / "validate_bearer.py"
+ path = str(path.absolute())
+
+ cmd = f"{shlex.quote(path)} {SHARED_MEM_NAME} <&%f"
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+
+ # Replace pg_hba and pg_ident.
+ c.execute("SHOW hba_file;")
+ hba = c.fetchone()[0]
+
+ c.execute("SHOW ident_file;")
+ ident = c.fetchone()[0]
+
+ with prepend_file(hba, hba_lines), prepend_file(ident, ident_lines):
+ c.execute("SELECT pg_reload_conf();")
+
+ # Use the new database and user.
+ yield ctx
+
+ # Put things back the way they were.
+ c.execute("SELECT pg_reload_conf();")
+
+ c.execute("ALTER SYSTEM RESET oauth_validator_command;")
+ c.execute(sql.SQL("DROP DATABASE {};").format(dbname))
+ c.execute(sql.SQL("DROP ROLE {};").format(authz_user))
+ c.execute(sql.SQL("DROP ROLE {};").format(map_user))
+ c.execute(sql.SQL("DROP ROLE {};").format(user))
+
+
+@pytest.fixture()
+def conn(oauth_ctx, connect):
+ """
+ A convenience wrapper for connect(). The main purpose of this fixture is to
+ make sure oauth_ctx runs its setup code before the connection is made.
+ """
+ return connect()
+
+
+@pytest.fixture(scope="module", autouse=True)
+def authn_id_extension(oauth_ctx):
+ """
+ Performs a `CREATE EXTENSION authn_id` in the test database. This fixture is
+ autoused, so tests don't need to rely on it.
+ """
+ conn = psycopg2.connect(database=oauth_ctx.dbname)
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+ c.execute("CREATE EXTENSION authn_id;")
+
+
+@pytest.fixture(scope="session")
+def shared_mem():
+ """
+ Yields a shared memory segment that can be used for communication between
+ the bearer_token fixture and ./validate_bearer.py.
+ """
+ size = MAX_TOKEN_SIZE + 2 # two byte length prefix
+ mem = shared_memory.SharedMemory(SHARED_MEM_NAME, create=True, size=size)
+
+ try:
+ with contextlib.closing(mem):
+ yield mem
+ finally:
+ mem.unlink()
+
+
+@pytest.fixture()
+def bearer_token(shared_mem):
+ """
+ Returns a factory function that, when called, will store a Bearer token in
+ shared_mem. If token is None (the default), a new token will be generated
+ using secrets.token_urlsafe() and returned; otherwise the passed token will
+ be used as-is.
+
+ When token is None, the generated token size in bytes may be specified as an
+ argument; if unset, a small 16-byte token will be generated. The token size
+ may not exceed MAX_TOKEN_SIZE in any case.
+
+ The return value is the token, converted to a bytes object.
+
+ As a special case for testing failure modes, accept_any may be set to True.
+ This signals to the validator command that any bearer token should be
+ accepted. The returned token in this case may be used or discarded as needed
+ by the test.
+ """
+
+ def set_token(token=None, *, size=16, accept_any=False):
+ if token is not None:
+ size = len(token)
+
+ if size > MAX_TOKEN_SIZE:
+ raise ValueError(f"token size {size} exceeds maximum size {MAX_TOKEN_SIZE}")
+
+ if token is None:
+ if size % 4:
+ raise ValueError(f"requested token size {size} is not a multiple of 4")
+
+ token = secrets.token_urlsafe(size // 4 * 3)
+ assert len(token) == size
+
+ try:
+ token = token.encode("ascii")
+ except AttributeError:
+ pass # already encoded
+
+ if accept_any:
+ # Two-byte magic value.
+ shared_mem.buf[:2] = struct.pack("H", MAX_UINT16)
+ else:
+ # Two-byte length prefix, then the token data.
+ shared_mem.buf[:2] = struct.pack("H", len(token))
+ shared_mem.buf[2 : size + 2] = token
+
+ return token
+
+ return set_token
+
+
+def begin_oauth_handshake(conn, oauth_ctx, *, user=None):
+ if user is None:
+ user = oauth_ctx.authz_user
+
+ pq3.send_startup(conn, user=user, database=oauth_ctx.dbname)
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.AuthnRequest
+
+ # The server should advertise exactly one mechanism.
+ assert resp.payload.type == pq3.authn.SASL
+ assert resp.payload.body == [b"OAUTHBEARER", b""]
+
+
+def send_initial_response(conn, *, auth=None, bearer=None):
+ """
+ Sends the OAUTHBEARER initial response on the connection, using the given
+ bearer token. Alternatively to a bearer token, the initial response's auth
+ field may be explicitly specified to test corner cases.
+ """
+ if bearer is not None and auth is not None:
+ raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+ if bearer is not None:
+ auth = b"Bearer " + bearer
+
+ if auth is None:
+ raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+ initial = pq3.SASLInitialResponse.build(
+ dict(
+ name=b"OAUTHBEARER",
+ data=b"n,,\x01auth=" + auth + b"\x01\x01",
+ )
+ )
+ pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+
+def expect_handshake_success(conn):
+ """
+ Validates that the server responds with an AuthnOK message, and then drains
+ the connection until a ReadyForQuery message is received.
+ """
+ resp = pq3.recv1(conn)
+
+ assert resp.type == pq3.types.AuthnRequest
+ assert resp.payload.type == pq3.authn.OK
+ assert not resp.payload.body
+
+ receive_until(conn, pq3.types.ReadyForQuery)
+
+
+def expect_handshake_failure(conn, oauth_ctx):
+ """
+ Performs the OAUTHBEARER SASL failure "handshake" and validates the server's
+ side of the conversation, including the final ErrorResponse.
+ """
+
+ # We expect a discovery "challenge" back from the server before the authn
+ # failure message.
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.AuthnRequest
+
+ req = resp.payload
+ assert req.type == pq3.authn.SASLContinue
+
+ body = json.loads(req.body)
+ assert body["status"] == "invalid_token"
+ assert body["scope"] == oauth_ctx.scope
+
+ expected_config = oauth_ctx.issuer + "/.well-known/openid-configuration"
+ assert body["openid-configuration"] == expected_config
+
+ # Send the dummy response to complete the failed handshake.
+ pq3.send(conn, pq3.types.PasswordMessage, b"\x01")
+ resp = pq3.recv1(conn)
+
+ err = ExpectedError(INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed")
+ err.match(resp)
+
+
+def receive_until(conn, type):
+ """
+ receive_until pulls packets off the pq3 connection until a packet with the
+ desired type is found, or an error response is received.
+ """
+ while True:
+ pkt = pq3.recv1(conn)
+
+ if pkt.type == type:
+ return pkt
+ elif pkt.type == pq3.types.ErrorResponse:
+ raise RuntimeError(
+ f"received error response from peer: {pkt.payload.fields!r}"
+ )
+
+
+@pytest.mark.parametrize("token_len", [16, 1024, 4096])
+@pytest.mark.parametrize(
+ "auth_prefix",
+ [
+ b"Bearer ",
+ b"bearer ",
+ b"Bearer ",
+ ],
+)
+def test_oauth(conn, oauth_ctx, bearer_token, auth_prefix, token_len):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Generate our bearer token with the desired length.
+ token = bearer_token(size=token_len)
+ auth = auth_prefix + token
+
+ send_initial_response(conn, auth=auth)
+ expect_handshake_success(conn)
+
+ # Make sure that the server has not set an authenticated ID.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ assert row.columns == [None]
+
+
+@pytest.mark.parametrize(
+ "token_value",
+ [
+ "abcdzA==",
+ "123456M=",
+ "x-._~+/x",
+ ],
+)
+def test_oauth_bearer_corner_cases(conn, oauth_ctx, bearer_token, token_value):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ send_initial_response(conn, bearer=bearer_token(token_value))
+
+ expect_handshake_success(conn)
+
+
+@pytest.mark.parametrize(
+ "user,authn_id,should_succeed",
+ [
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: ctx.user,
+ True,
+ id="validator authn: succeeds when authn_id == username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: None,
+ False,
+ id="validator authn: fails when authn_id is not set",
+ ),
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: ctx.authz_user,
+ False,
+ id="validator authn: fails when authn_id != username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: ctx.map_user + "@example.com",
+ True,
+ id="validator with map: succeeds when authn_id matches map",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: None,
+ False,
+ id="validator with map: fails when authn_id is not set",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: ctx.map_user + "@example.net",
+ False,
+ id="validator with map: fails when authn_id doesn't match map",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: None,
+ True,
+ id="validator authz: succeeds with no authn_id",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "",
+ True,
+ id="validator authz: succeeds with empty authn_id",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "postgres",
+ True,
+ id="validator authz: succeeds with basic username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "me@example.com",
+ True,
+ id="validator authz: succeeds with email address",
+ ),
+ ],
+)
+def test_oauth_authn_id(conn, oauth_ctx, bearer_token, user, authn_id, should_succeed):
+ token = None
+
+ authn_id = authn_id(oauth_ctx)
+ if authn_id is not None:
+ authn_id = authn_id.encode("ascii")
+
+ # As a hack to get the validator to reflect arbitrary output from this
+ # test, encode the desired output as a base64 token. The validator will
+ # key on the leading "output=" to differentiate this from the random
+ # tokens generated by secrets.token_urlsafe().
+ output = b"output=" + authn_id + b"\n"
+ token = base64.urlsafe_b64encode(output)
+
+ token = bearer_token(token)
+ username = user(oauth_ctx)
+
+ begin_oauth_handshake(conn, oauth_ctx, user=username)
+ send_initial_response(conn, bearer=token)
+
+ if not should_succeed:
+ expect_handshake_failure(conn, oauth_ctx)
+ return
+
+ expect_handshake_success(conn)
+
+ # Check the reported authn_id.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ assert row.columns == [authn_id]
+
+
+class ExpectedError(object):
+ def __init__(self, code, msg=None, detail=None):
+ self.code = code
+ self.msg = msg
+ self.detail = detail
+
+ # Protect against the footgun of an accidental empty string, which will
+ # "match" anything. If you don't want to match message or detail, just
+ # don't pass them.
+ if self.msg == "":
+ raise ValueError("msg must be non-empty or None")
+ if self.detail == "":
+ raise ValueError("detail must be non-empty or None")
+
+ def _getfield(self, resp, type):
+ """
+ Searches an ErrorResponse for a single field of the given type (e.g.
+ "M", "C", "D") and returns its value. Asserts if it doesn't find exactly
+ one field.
+ """
+ prefix = type.encode("ascii")
+ fields = [f for f in resp.payload.fields if f.startswith(prefix)]
+
+ assert len(fields) == 1
+ return fields[0][1:] # strip off the type byte
+
+ def match(self, resp):
+ """
+ Checks that the given response matches the expected code, message, and
+ detail (if given). The error code must match exactly. The expected
+ message and detail must be contained within the actual strings.
+ """
+ assert resp.type == pq3.types.ErrorResponse
+
+ code = self._getfield(resp, "C")
+ assert code == self.code
+
+ if self.msg:
+ msg = self._getfield(resp, "M")
+ expected = self.msg.encode("utf-8")
+ assert expected in msg
+
+ if self.detail:
+ detail = self._getfield(resp, "D")
+ expected = self.detail.encode("utf-8")
+ assert expected in detail
+
+
+def test_oauth_rejected_bearer(conn, oauth_ctx, bearer_token):
+ # Generate a new bearer token, which we will proceed not to use.
+ _ = bearer_token()
+
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send a bearer token that doesn't match what the validator expects. It
+ # should fail the connection.
+ send_initial_response(conn, bearer=b"xxxxxx")
+
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.parametrize(
+ "bad_bearer",
+ [
+ b"Bearer ",
+ b"Bearer a===b",
+ b"Bearer hello!",
+ b"Bearer me@example.com",
+ b'OAuth realm="Example"',
+ b"",
+ ],
+)
+def test_oauth_invalid_bearer(conn, oauth_ctx, bearer_token, bad_bearer):
+ # Tell the validator to accept any token. This ensures that the invalid
+ # bearer tokens are rejected before the validation step.
+ _ = bearer_token(accept_any=True)
+
+ begin_oauth_handshake(conn, oauth_ctx)
+ send_initial_response(conn, auth=bad_bearer)
+
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "resp_type,resp,err",
+ [
+ pytest.param(
+ None,
+ None,
+ None,
+ marks=pytest.mark.slow,
+ id="no response (expect timeout)",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"hello",
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not send a kvsep response",
+ ),
+ id="bad dummy response",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"\x01\x01",
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not send a kvsep response",
+ ),
+ id="multiple kvseps",
+ ),
+ pytest.param(
+ pq3.types.Query,
+ dict(query=b""),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+ id="bad response message type",
+ ),
+ ],
+)
+def test_oauth_bad_response_to_error_challenge(conn, oauth_ctx, resp_type, resp, err):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send an empty auth initial response, which will force an authn failure.
+ send_initial_response(conn, auth=b"")
+
+ # We expect a discovery "challenge" back from the server before the authn
+ # failure message.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.AuthnRequest
+
+ req = pkt.payload
+ assert req.type == pq3.authn.SASLContinue
+
+ body = json.loads(req.body)
+ assert body["status"] == "invalid_token"
+
+ if resp_type is None:
+ # Do not send the dummy response. We should time out and not get a
+ # response from the server.
+ with pytest.raises(socket.timeout):
+ conn.read(1)
+
+ # Done with the test.
+ return
+
+ # Send the bad response.
+ pq3.send(conn, resp_type, resp)
+
+ # Make sure the server fails the connection correctly.
+ pkt = pq3.recv1(conn)
+ err.match(pkt)
+
+
+@pytest.mark.parametrize(
+ "type,payload,err",
+ [
+ pytest.param(
+ pq3.types.ErrorResponse,
+ dict(fields=[b""]),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+ id="error response in initial message",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"x" * (MAX_SASL_MESSAGE_LENGTH + 1),
+ ExpectedError(
+ INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed"
+ ),
+ id="overlong initial response data",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"SCRAM-SHA-256")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE, "invalid SASL authentication mechanism"
+ ),
+ id="bad SASL mechanism selection",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=2, data=b"x")),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "insufficient data"),
+ id="SASL data underflow",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=0, data=b"x")),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "invalid message format"),
+ id="SASL data overflow",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "message is empty",
+ ),
+ id="empty",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"n,,\x01auth=\x01\x01\0")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "length does not match input length",
+ ),
+ id="contains null byte",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected channel-binding flag", # XXX this is a bit strange
+ ),
+ id="initial error response",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"p=tls-server-end-point,,\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "server does not support channel binding",
+ ),
+ id="uses channel binding",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"x,,\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected channel-binding flag",
+ ),
+ id="invalid channel binding specifier",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Comma expected",
+ ),
+ id="bad GS2 header: missing channel binding terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,a")),
+ ExpectedError(
+ FEATURE_NOT_SUPPORTED_ERRCODE,
+ "client uses authorization identity",
+ ),
+ id="bad GS2 header: authzid in use",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,b,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected attribute",
+ ),
+ id="bad GS2 header: extra attribute",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ 'Unexpected attribute "0x00"', # XXX this is a bit strange
+ ),
+ id="bad GS2 header: missing authzid terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Key-value separator expected",
+ ),
+ id="missing initial kvsep",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Key-value separator expected",
+ ),
+ id="missing initial kvsep",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "does not contain an auth value",
+ ),
+ id="missing auth value: empty key-value list",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "does not contain an auth value",
+ ),
+ id="missing auth value: other keys present",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "unterminated key/value pair",
+ ),
+ id="missing value terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not contain a final terminator",
+ ),
+ id="missing list terminator: empty list",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not contain a final terminator",
+ ),
+ id="missing list terminator: with auth value",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01\x01blah")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "additional data after the final terminator",
+ ),
+ id="additional key after terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01key\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "key without a value",
+ ),
+ id="key without value",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(
+ name=b"OAUTHBEARER",
+ data=b"y,,\x01auth=Bearer 0\x01auth=Bearer 1\x01\x01",
+ )
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "contains multiple auth values",
+ ),
+ id="multiple auth values",
+ ),
+ ],
+)
+def test_oauth_bad_initial_response(conn, oauth_ctx, type, payload, err):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # The server expects a SASL response; give it something else instead.
+ if not isinstance(payload, dict):
+ payload = dict(payload_data=payload)
+ pq3.send(conn, type, **payload)
+
+ resp = pq3.recv1(conn)
+ err.match(resp)
+
+
+def test_oauth_empty_initial_response(conn, oauth_ctx, bearer_token):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send an initial response without data.
+ initial = pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER"))
+ pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+ # The server should respond with an empty challenge so we can send the data
+ # it wants.
+ pkt = pq3.recv1(conn)
+
+ assert pkt.type == pq3.types.AuthnRequest
+ assert pkt.payload.type == pq3.authn.SASLContinue
+ assert not pkt.payload.body
+
+ # Now send the initial data.
+ data = b"n,,\x01auth=Bearer " + bearer_token() + b"\x01\x01"
+ pq3.send(conn, pq3.types.PasswordMessage, data)
+
+ # Server should now complete the handshake.
+ expect_handshake_success(conn)
+
+
+@pytest.fixture()
+def set_validator():
+ """
+ A per-test fixture that allows a test to override the setting of
+ oauth_validator_command for the cluster. The setting will be reverted during
+ teardown.
+
+ Passing None will perform an ALTER SYSTEM RESET.
+ """
+ conn = psycopg2.connect("")
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+
+ # Save the previous value.
+ c.execute("SHOW oauth_validator_command;")
+ prev_cmd = c.fetchone()[0]
+
+ def setter(cmd):
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+ c.execute("SELECT pg_reload_conf();")
+
+ yield setter
+
+ # Restore the previous value.
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (prev_cmd,))
+ c.execute("SELECT pg_reload_conf();")
+
+
+def test_oauth_no_validator(oauth_ctx, set_validator, connect, bearer_token):
+ # Clear out our validator command, then establish a new connection.
+ set_validator("")
+ conn = connect()
+
+ begin_oauth_handshake(conn, oauth_ctx)
+ send_initial_response(conn, bearer=bearer_token())
+
+ # The server should fail the connection.
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+def test_oauth_validator_role(oauth_ctx, set_validator, connect):
+ # Switch the validator implementation. This validator will reflect the
+ # PGUSER as the authenticated identity.
+ path = pathlib.Path(__file__).parent / "validate_reflect.py"
+ path = str(path.absolute())
+
+ set_validator(f"{shlex.quote(path)} '%r' <&%f")
+ conn = connect()
+
+ # Log in. Note that the reflection validator ignores the bearer token.
+ begin_oauth_handshake(conn, oauth_ctx, user=oauth_ctx.user)
+ send_initial_response(conn, bearer=b"dontcare")
+ expect_handshake_success(conn)
+
+ # Check the user identity.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ expected = oauth_ctx.user.encode("utf-8")
+ assert row.columns == [expected]
+
+
+def test_oauth_role_with_shell_unsafe_characters(oauth_ctx, set_validator, connect):
+ """
+ XXX This test pins undesirable behavior. We should be able to handle any
+ valid Postgres role name.
+ """
+ # Switch the validator implementation. This validator will reflect the
+ # PGUSER as the authenticated identity.
+ path = pathlib.Path(__file__).parent / "validate_reflect.py"
+ path = str(path.absolute())
+
+ set_validator(f"{shlex.quote(path)} '%r' <&%f")
+ conn = connect()
+
+ unsafe_username = "hello'there"
+ begin_oauth_handshake(conn, oauth_ctx, user=unsafe_username)
+
+ # The server should reject the handshake.
+ send_initial_response(conn, bearer=b"dontcare")
+ expect_handshake_failure(conn, oauth_ctx)
diff --git a/src/test/python/server/test_server.py b/src/test/python/server/test_server.py
new file mode 100644
index 0000000000..02126dba79
--- /dev/null
+++ b/src/test/python/server/test_server.py
@@ -0,0 +1,21 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import pq3
+
+
+def test_handshake(connect):
+ """Basic sanity check."""
+ conn = connect()
+
+ pq3.handshake(conn, user=pq3.pguser(), database=pq3.pgdatabase())
+
+ pq3.send(conn, pq3.types.Query, query=b"")
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.EmptyQueryResponse
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.ReadyForQuery
diff --git a/src/test/python/server/validate_bearer.py b/src/test/python/server/validate_bearer.py
new file mode 100755
index 0000000000..2cc73ff154
--- /dev/null
+++ b/src/test/python/server/validate_bearer.py
@@ -0,0 +1,101 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It doesn't actually validate
+# anything, and it logs the bearer token data, which is sensitive.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. Memory is shared and communicated from that test module's
+# bearer_token() fixture.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import base64
+import binascii
+import contextlib
+import struct
+import sys
+from multiprocessing import shared_memory
+
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def remove_shm_from_resource_tracker():
+ """
+ Monkey-patch multiprocessing.resource_tracker so SharedMemory won't be
+ tracked. Pulled from this thread, where there are more details:
+
+ https://bugs.python.org/issue38119
+
+ TL;DR: all clients of shared memory segments automatically destroy them on
+ process exit, which makes shared memory segments much less useful. This
+ monkeypatch removes that behavior so that we can defer to the test to manage
+ the segment lifetime.
+
+ Ideally a future Python patch will pull in this fix and then the entire
+ function can go away.
+ """
+ from multiprocessing import resource_tracker
+
+ def fix_register(name, rtype):
+ if rtype == "shared_memory":
+ return
+ return resource_tracker._resource_tracker.register(self, name, rtype)
+
+ resource_tracker.register = fix_register
+
+ def fix_unregister(name, rtype):
+ if rtype == "shared_memory":
+ return
+ return resource_tracker._resource_tracker.unregister(self, name, rtype)
+
+ resource_tracker.unregister = fix_unregister
+
+ if "shared_memory" in resource_tracker._CLEANUP_FUNCS:
+ del resource_tracker._CLEANUP_FUNCS["shared_memory"]
+
+
+def main(args):
+ remove_shm_from_resource_tracker() # XXX remove some day
+
+ # Get the expected token from the currently running test.
+ shared_mem_name = args[0]
+
+ mem = shared_memory.SharedMemory(shared_mem_name)
+ with contextlib.closing(mem):
+ # First two bytes are the token length.
+ size = struct.unpack("H", mem.buf[:2])[0]
+
+ if size == MAX_UINT16:
+ # Special case: the test wants us to accept any token.
+ sys.stderr.write("accepting token without validation\n")
+ return
+
+ # The remainder of the buffer contains the expected token.
+ assert size <= (mem.size - 2)
+ expected_token = mem.buf[2 : size + 2].tobytes()
+
+ mem.buf[:] = b"\0" * mem.size # scribble over the token
+
+ token = sys.stdin.buffer.read()
+ if token != expected_token:
+ sys.exit(f"failed to match Bearer token ({token!r} != {expected_token!r})")
+
+ # See if the test wants us to print anything. If so, it will have encoded
+ # the desired output in the token with an "output=" prefix.
+ try:
+ # altchars="-_" corresponds to the urlsafe alphabet.
+ data = base64.b64decode(token, altchars="-_", validate=True)
+
+ if data.startswith(b"output="):
+ sys.stdout.buffer.write(data[7:])
+
+ except binascii.Error:
+ pass
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/src/test/python/server/validate_reflect.py b/src/test/python/server/validate_reflect.py
new file mode 100755
index 0000000000..24c3a7e715
--- /dev/null
+++ b/src/test/python/server/validate_reflect.py
@@ -0,0 +1,34 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It ignores the bearer token
+# entirely and automatically logs the user in.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. It expects the user's desired role name as an argument; the
+# actual token will be discarded and the user will be logged in with the role
+# name as the authenticated identity.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import sys
+
+
+def main(args):
+ # We have to read the entire token as our first action to unblock the
+ # server, but we won't actually use it.
+ _ = sys.stdin.buffer.read()
+
+ if len(args) != 1:
+ sys.exit("usage: ./validate_reflect.py ROLE")
+
+ # Log the user in as the provided role.
+ role = args[0]
+ print(role)
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/src/test/python/test_internals.py b/src/test/python/test_internals.py
new file mode 100644
index 0000000000..dee4855fc0
--- /dev/null
+++ b/src/test/python/test_internals.py
@@ -0,0 +1,138 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import io
+
+from pq3 import _DebugStream
+
+
+def test_DebugStream_read():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+
+ stream = _DebugStream(under, out)
+
+ res = stream.read(5)
+ assert res == b"abcde"
+
+ res = stream.read(16)
+ assert res == b"fghijklmnopqrstu"
+
+ stream.flush_debug()
+
+ res = stream.read()
+ assert res == b"vwxyz"
+
+ stream.flush_debug()
+
+ expected = (
+ "< 0000:\t61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70\tabcdefghijklmnop\n"
+ "< 0010:\t71 72 73 74 75 \tqrstu\n"
+ "\n"
+ "< 0000:\t76 77 78 79 7a \tvwxyz\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_write():
+ under = io.BytesIO()
+ out = io.StringIO()
+
+ stream = _DebugStream(under, out)
+
+ stream.write(b"\x00\x01\x02")
+ stream.flush()
+
+ assert under.getvalue() == b"\x00\x01\x02"
+
+ stream.write(b"\xc0\xc1\xc2")
+ stream.flush()
+
+ assert under.getvalue() == b"\x00\x01\x02\xc0\xc1\xc2"
+
+ stream.flush_debug()
+
+ expected = "> 0000:\t00 01 02 c0 c1 c2 \t......\n\n"
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_read_write():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+ stream = _DebugStream(under, out)
+
+ res = stream.read(5)
+ assert res == b"abcde"
+
+ stream.write(b"xxxxx")
+ stream.flush()
+
+ assert under.getvalue() == b"abcdexxxxxklmnopqrstuvwxyz"
+
+ res = stream.read(5)
+ assert res == b"klmno"
+
+ stream.write(b"xxxxx")
+ stream.flush()
+
+ assert under.getvalue() == b"abcdexxxxxklmnoxxxxxuvwxyz"
+
+ stream.flush_debug()
+
+ expected = (
+ "< 0000:\t61 62 63 64 65 6b 6c 6d 6e 6f \tabcdeklmno\n"
+ "\n"
+ "> 0000:\t78 78 78 78 78 78 78 78 78 78 \txxxxxxxxxx\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_end_packet():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+ stream = _DebugStream(under, out)
+
+ stream.read(5)
+ stream.end_packet("read description", read=True, indent=" ")
+
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("write description", indent=" ")
+
+ stream.read(5)
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("read/write combo for read", read=True, indent=" ")
+
+ stream.read(5)
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("read/write combo for write", indent=" ")
+
+ expected = (
+ " < 0000:\t61 62 63 64 65 \tabcde\n"
+ "\n"
+ "< read description\n"
+ "\n"
+ "> write description\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ " < 0000:\t6b 6c 6d 6e 6f \tklmno\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ "< read/write combo for read\n"
+ "\n"
+ "> read/write combo for write\n"
+ "\n"
+ " < 0000:\t75 76 77 78 79 \tuvwxy\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
diff --git a/src/test/python/test_pq3.py b/src/test/python/test_pq3.py
new file mode 100644
index 0000000000..e0c0e0568d
--- /dev/null
+++ b/src/test/python/test_pq3.py
@@ -0,0 +1,558 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import struct
+import sys
+
+import pytest
+from construct import Container, PaddingError, StreamError, TerminatedError
+
+import pq3
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"\x00\x00\x00\x10\x00\x04\x00\x00abcdefgh",
+ Container(len=16, proto=0x40000, payload=b"abcdefgh"),
+ b"",
+ id="8-byte payload",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x08\x00\x04\x00\x00",
+ Container(len=8, proto=0x40000, payload=b""),
+ b"",
+ id="no payload",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x09\x00\x04\x00\x00abcde",
+ Container(len=9, proto=0x40000, payload=b"a"),
+ b"bcde",
+ id="1-byte payload and extra padding",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x0B\x00\x03\x00\x00hi\x00",
+ Container(len=11, proto=pq3.protocol(3, 0), payload=[b"hi"]),
+ b"",
+ id="implied parameter list when using proto version 3.0",
+ ),
+ ],
+)
+def test_Startup_parse(raw, expected, extra):
+ with io.BytesIO(raw) as stream:
+ actual = pq3.Startup.parse_stream(stream)
+
+ assert actual == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "packet,expected_bytes",
+ [
+ pytest.param(
+ dict(),
+ b"\x00\x00\x00\x08\x00\x00\x00\x00",
+ id="nothing set",
+ ),
+ pytest.param(
+ dict(len=10, proto=0x12345678),
+ b"\x00\x00\x00\x0A\x12\x34\x56\x78\x00\x00",
+ id="len and proto set explicitly",
+ ),
+ pytest.param(
+ dict(proto=0x12345678),
+ b"\x00\x00\x00\x08\x12\x34\x56\x78",
+ id="implied len with no payload",
+ ),
+ pytest.param(
+ dict(proto=0x12345678, payload=b"abcd"),
+ b"\x00\x00\x00\x0C\x12\x34\x56\x78abcd",
+ id="implied len with payload",
+ ),
+ pytest.param(
+ dict(payload=[b""]),
+ b"\x00\x00\x00\x09\x00\x03\x00\x00\x00",
+ id="implied proto version 3 when sending parameters",
+ ),
+ pytest.param(
+ dict(payload=[b"hi", b""]),
+ b"\x00\x00\x00\x0C\x00\x03\x00\x00hi\x00\x00",
+ id="implied proto version 3 and len when sending more than one parameter",
+ ),
+ pytest.param(
+ dict(payload=dict(user="jsmith", database="postgres")),
+ b"\x00\x00\x00\x27\x00\x03\x00\x00user\x00jsmith\x00database\x00postgres\x00\x00",
+ id="auto-serialization of dict parameters",
+ ),
+ ],
+)
+def test_Startup_build(packet, expected_bytes):
+ actual = pq3.Startup.build(packet)
+ assert actual == expected_bytes
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"*\x00\x00\x00\x08abcd",
+ dict(type=b"*", len=8, payload=b"abcd"),
+ b"",
+ id="4-byte payload",
+ ),
+ pytest.param(
+ b"*\x00\x00\x00\x04",
+ dict(type=b"*", len=4, payload=b""),
+ b"",
+ id="no payload",
+ ),
+ pytest.param(
+ b"*\x00\x00\x00\x05xabcd",
+ dict(type=b"*", len=5, payload=b"x"),
+ b"abcd",
+ id="1-byte payload with extra padding",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=8,
+ payload=dict(type=pq3.authn.OK, body=None),
+ ),
+ b"",
+ id="AuthenticationOk",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x12\x00\x00\x00\x0AEXTERNAL\x00\x00",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=18,
+ payload=dict(type=pq3.authn.SASL, body=[b"EXTERNAL", b""]),
+ ),
+ b"",
+ id="AuthenticationSASL",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=13,
+ payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+ ),
+ b"",
+ id="AuthenticationSASLContinue",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=13,
+ payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+ ),
+ b"",
+ id="AuthenticationSASLFinal",
+ ),
+ pytest.param(
+ b"p\x00\x00\x00\x0Bhunter2",
+ dict(
+ type=pq3.types.PasswordMessage,
+ len=11,
+ payload=b"hunter2",
+ ),
+ b"",
+ id="PasswordMessage",
+ ),
+ pytest.param(
+ b"K\x00\x00\x00\x0C\x00\x00\x00\x00\x12\x34\x56\x78",
+ dict(
+ type=pq3.types.BackendKeyData,
+ len=12,
+ payload=dict(pid=0, key=0x12345678),
+ ),
+ b"",
+ id="BackendKeyData",
+ ),
+ pytest.param(
+ b"C\x00\x00\x00\x08SET\x00",
+ dict(
+ type=pq3.types.CommandComplete,
+ len=8,
+ payload=dict(tag=b"SET"),
+ ),
+ b"",
+ id="CommandComplete",
+ ),
+ pytest.param(
+ b"E\x00\x00\x00\x11Mbad!\x00Mdog!\x00\x00",
+ dict(type=b"E", len=17, payload=dict(fields=[b"Mbad!", b"Mdog!", b""])),
+ b"",
+ id="ErrorResponse",
+ ),
+ pytest.param(
+ b"S\x00\x00\x00\x08a\x00b\x00",
+ dict(
+ type=pq3.types.ParameterStatus,
+ len=8,
+ payload=dict(name=b"a", value=b"b"),
+ ),
+ b"",
+ id="ParameterStatus",
+ ),
+ pytest.param(
+ b"Z\x00\x00\x00\x05x",
+ dict(type=b"Z", len=5, payload=dict(status=b"x")),
+ b"",
+ id="ReadyForQuery",
+ ),
+ pytest.param(
+ b"Q\x00\x00\x00\x06!\x00",
+ dict(type=pq3.types.Query, len=6, payload=dict(query=b"!")),
+ b"",
+ id="Query",
+ ),
+ pytest.param(
+ b"D\x00\x00\x00\x0B\x00\x01\x00\x00\x00\x01!",
+ dict(type=pq3.types.DataRow, len=11, payload=dict(columns=[b"!"])),
+ b"",
+ id="DataRow",
+ ),
+ pytest.param(
+ b"D\x00\x00\x00\x06\x00\x00extra",
+ dict(type=pq3.types.DataRow, len=6, payload=dict(columns=[])),
+ b"extra",
+ id="DataRow with extra data",
+ ),
+ pytest.param(
+ b"I\x00\x00\x00\x04",
+ dict(type=pq3.types.EmptyQueryResponse, len=4, payload=None),
+ b"",
+ id="EmptyQueryResponse",
+ ),
+ pytest.param(
+ b"I\x00\x00\x00\x04\xFF",
+ dict(type=b"I", len=4, payload=None),
+ b"\xFF",
+ id="EmptyQueryResponse with extra bytes",
+ ),
+ pytest.param(
+ b"X\x00\x00\x00\x04",
+ dict(type=pq3.types.Terminate, len=4, payload=None),
+ b"",
+ id="Terminate",
+ ),
+ ],
+)
+def test_Pq3_parse(raw, expected, extra):
+ with io.BytesIO(raw) as stream:
+ actual = pq3.Pq3.parse_stream(stream)
+
+ assert actual == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(type=b"*", len=5),
+ b"*\x00\x00\x00\x05\x00",
+ id="type and len set explicitly",
+ ),
+ pytest.param(
+ dict(type=b"*"),
+ b"*\x00\x00\x00\x04",
+ id="implied len with no payload",
+ ),
+ pytest.param(
+ dict(type=b"*", payload=b"1234"),
+ b"*\x00\x00\x00\x081234",
+ id="implied len with payload",
+ ),
+ pytest.param(
+ dict(type=pq3.types.AuthnRequest, payload=dict(type=pq3.authn.OK)),
+ b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+ id="implied len/type for AuthenticationOK",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(
+ type=pq3.authn.SASL,
+ body=[b"SCRAM-SHA-256-PLUS", b"SCRAM-SHA-256", b""],
+ ),
+ ),
+ b"R\x00\x00\x00\x2A\x00\x00\x00\x0ASCRAM-SHA-256-PLUS\x00SCRAM-SHA-256\x00\x00",
+ id="implied len/type for AuthenticationSASL",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+ ),
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+ id="implied len/type for AuthenticationSASLContinue",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+ ),
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+ id="implied len/type for AuthenticationSASLFinal",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.PasswordMessage,
+ payload=b"hunter2",
+ ),
+ b"p\x00\x00\x00\x0Bhunter2",
+ id="implied len/type for PasswordMessage",
+ ),
+ pytest.param(
+ dict(type=pq3.types.BackendKeyData, payload=dict(pid=1, key=7)),
+ b"K\x00\x00\x00\x0C\x00\x00\x00\x01\x00\x00\x00\x07",
+ id="implied len/type for BackendKeyData",
+ ),
+ pytest.param(
+ dict(type=pq3.types.CommandComplete, payload=dict(tag=b"SET")),
+ b"C\x00\x00\x00\x08SET\x00",
+ id="implied len/type for CommandComplete",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ErrorResponse, payload=dict(fields=[b"error", b""])),
+ b"E\x00\x00\x00\x0Berror\x00\x00",
+ id="implied len/type for ErrorResponse",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ParameterStatus, payload=dict(name=b"a", value=b"b")),
+ b"S\x00\x00\x00\x08a\x00b\x00",
+ id="implied len/type for ParameterStatus",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ReadyForQuery, payload=dict(status=b"I")),
+ b"Z\x00\x00\x00\x05I",
+ id="implied len/type for ReadyForQuery",
+ ),
+ pytest.param(
+ dict(type=pq3.types.Query, payload=dict(query=b"SELECT 1;")),
+ b"Q\x00\x00\x00\x0eSELECT 1;\x00",
+ id="implied len/type for Query",
+ ),
+ pytest.param(
+ dict(type=pq3.types.DataRow, payload=dict(columns=[b"abcd"])),
+ b"D\x00\x00\x00\x0E\x00\x01\x00\x00\x00\x04abcd",
+ id="implied len/type for DataRow",
+ ),
+ pytest.param(
+ dict(type=pq3.types.EmptyQueryResponse),
+ b"I\x00\x00\x00\x04",
+ id="implied len for EmptyQueryResponse",
+ ),
+ pytest.param(
+ dict(type=pq3.types.Terminate),
+ b"X\x00\x00\x00\x04",
+ id="implied len for Terminate",
+ ),
+ ],
+)
+def test_Pq3_build(fields, expected):
+ actual = pq3.Pq3.build(fields)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"\x00\x00",
+ dict(columns=[]),
+ b"",
+ id="no columns",
+ ),
+ pytest.param(
+ b"\x00\x01\x00\x00\x00\x04abcd",
+ dict(columns=[b"abcd"]),
+ b"",
+ id="one column",
+ ),
+ pytest.param(
+ b"\x00\x02\x00\x00\x00\x04abcd\x00\x00\x00\x01x",
+ dict(columns=[b"abcd", b"x"]),
+ b"",
+ id="multiple columns",
+ ),
+ pytest.param(
+ b"\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01x",
+ dict(columns=[b"", b"x"]),
+ b"",
+ id="empty column value",
+ ),
+ pytest.param(
+ b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+ dict(columns=[None, None]),
+ b"",
+ id="null columns",
+ ),
+ ],
+)
+def test_DataRow_parse(raw, expected, extra):
+ pkt = b"D" + struct.pack("!i", len(raw) + 4) + raw
+ with io.BytesIO(pkt) as stream:
+ actual = pq3.Pq3.parse_stream(stream)
+
+ assert actual.type == pq3.types.DataRow
+ assert actual.payload == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(),
+ b"\x00\x00",
+ id="no columns",
+ ),
+ pytest.param(
+ dict(columns=[None, None]),
+ b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+ id="null columns",
+ ),
+ ],
+)
+def test_DataRow_build(fields, expected):
+ actual = pq3.Pq3.build(dict(type=pq3.types.DataRow, payload=fields))
+
+ expected = b"D" + struct.pack("!i", len(expected) + 4) + expected
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "raw,expected,exception",
+ [
+ pytest.param(
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ dict(name=b"EXTERNAL", len=-1, data=None),
+ None,
+ id="no initial response",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\x02me",
+ dict(name=b"EXTERNAL", len=2, data=b"me"),
+ None,
+ id="initial response",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\x02meextra",
+ None,
+ TerminatedError,
+ id="extra data",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\xFFme",
+ None,
+ StreamError,
+ id="underflow",
+ ),
+ ],
+)
+def test_SASLInitialResponse_parse(raw, expected, exception):
+ ctx = contextlib.nullcontext()
+ if exception:
+ ctx = pytest.raises(exception)
+
+ with ctx:
+ actual = pq3.SASLInitialResponse.parse(raw)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(name=b"EXTERNAL"),
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ id="no initial response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=None),
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ id="no initial response (explicit None)",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=b""),
+ b"EXTERNAL\x00\x00\x00\x00\x00",
+ id="empty response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=b"me@example.com"),
+ b"EXTERNAL\x00\x00\x00\x00\x0Eme@example.com",
+ id="initial response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", len=2, data=b"me@example.com"),
+ b"EXTERNAL\x00\x00\x00\x00\x02me@example.com",
+ id="data overflow",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", len=14, data=b"me"),
+ b"EXTERNAL\x00\x00\x00\x00\x0Eme",
+ id="data underflow",
+ ),
+ ],
+)
+def test_SASLInitialResponse_build(fields, expected):
+ actual = pq3.SASLInitialResponse.build(fields)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "version,expected_bytes",
+ [
+ pytest.param((3, 0), b"\x00\x03\x00\x00", id="version 3"),
+ pytest.param((1234, 5679), b"\x04\xd2\x16\x2f", id="SSLRequest"),
+ ],
+)
+def test_protocol(version, expected_bytes):
+ # Make sure the integer returned by protocol is correctly serialized on the
+ # wire.
+ assert struct.pack("!i", pq3.protocol(*version)) == expected_bytes
+
+
+@pytest.mark.parametrize(
+ "envvar,func,expected",
+ [
+ ("PGHOST", pq3.pghost, "localhost"),
+ ("PGPORT", pq3.pgport, 5432),
+ ("PGUSER", pq3.pguser, getpass.getuser()),
+ ("PGDATABASE", pq3.pgdatabase, "postgres"),
+ ],
+)
+def test_env_defaults(monkeypatch, envvar, func, expected):
+ monkeypatch.delenv(envvar, raising=False)
+
+ actual = func()
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "envvars,func,expected",
+ [
+ (dict(PGHOST="otherhost"), pq3.pghost, "otherhost"),
+ (dict(PGPORT="6789"), pq3.pgport, 6789),
+ (dict(PGUSER="postgres"), pq3.pguser, "postgres"),
+ (dict(PGDATABASE="template1"), pq3.pgdatabase, "template1"),
+ ],
+)
+def test_env(monkeypatch, envvars, func, expected):
+ for k, v in envvars.items():
+ monkeypatch.setenv(k, v)
+
+ actual = func()
+ assert actual == expected
diff --git a/src/test/python/tls.py b/src/test/python/tls.py
new file mode 100644
index 0000000000..075c02c1ca
--- /dev/null
+++ b/src/test/python/tls.py
@@ -0,0 +1,195 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+from construct import *
+
+#
+# TLS 1.3
+#
+# Most of the types below are transcribed from RFC 8446:
+#
+# https://tools.ietf.org/html/rfc8446
+#
+
+
+def _Vector(size_field, element):
+ return Prefixed(size_field, GreedyRange(element))
+
+
+# Alerts
+
+AlertLevel = Enum(
+ Byte,
+ warning=1,
+ fatal=2,
+)
+
+AlertDescription = Enum(
+ Byte,
+ close_notify=0,
+ unexpected_message=10,
+ bad_record_mac=20,
+ decryption_failed_RESERVED=21,
+ record_overflow=22,
+ decompression_failure=30,
+ handshake_failure=40,
+ no_certificate_RESERVED=41,
+ bad_certificate=42,
+ unsupported_certificate=43,
+ certificate_revoked=44,
+ certificate_expired=45,
+ certificate_unknown=46,
+ illegal_parameter=47,
+ unknown_ca=48,
+ access_denied=49,
+ decode_error=50,
+ decrypt_error=51,
+ export_restriction_RESERVED=60,
+ protocol_version=70,
+ insufficient_security=71,
+ internal_error=80,
+ user_canceled=90,
+ no_renegotiation=100,
+ unsupported_extension=110,
+)
+
+Alert = Struct(
+ "level" / AlertLevel,
+ "description" / AlertDescription,
+)
+
+
+# Extensions
+
+ExtensionType = Enum(
+ Int16ub,
+ server_name=0,
+ max_fragment_length=1,
+ status_request=5,
+ supported_groups=10,
+ signature_algorithms=13,
+ use_srtp=14,
+ heartbeat=15,
+ application_layer_protocol_negotiation=16,
+ signed_certificate_timestamp=18,
+ client_certificate_type=19,
+ server_certificate_type=20,
+ padding=21,
+ pre_shared_key=41,
+ early_data=42,
+ supported_versions=43,
+ cookie=44,
+ psk_key_exchange_modes=45,
+ certificate_authorities=47,
+ oid_filters=48,
+ post_handshake_auth=49,
+ signature_algorithms_cert=50,
+ key_share=51,
+)
+
+Extension = Struct(
+ "extension_type" / ExtensionType,
+ "extension_data" / Prefixed(Int16ub, GreedyBytes),
+)
+
+
+# ClientHello
+
+
+class _CipherSuiteAdapter(Adapter):
+ class _hextuple(tuple):
+ def __repr__(self):
+ return f"(0x{self[0]:02X}, 0x{self[1]:02X})"
+
+ def _encode(self, obj, context, path):
+ return bytes(obj)
+
+ def _decode(self, obj, context, path):
+ assert len(obj) == 2
+ return self._hextuple(obj)
+
+
+ProtocolVersion = Hex(Int16ub)
+
+Random = Hex(Bytes(32))
+
+CipherSuite = _CipherSuiteAdapter(Byte[2])
+
+ClientHello = Struct(
+ "legacy_version" / ProtocolVersion,
+ "random" / Random,
+ "legacy_session_id" / Prefixed(Byte, Hex(GreedyBytes)),
+ "cipher_suites" / _Vector(Int16ub, CipherSuite),
+ "legacy_compression_methods" / Prefixed(Byte, GreedyBytes),
+ "extensions" / _Vector(Int16ub, Extension),
+)
+
+# ServerHello
+
+ServerHello = Struct(
+ "legacy_version" / ProtocolVersion,
+ "random" / Random,
+ "legacy_session_id_echo" / Prefixed(Byte, Hex(GreedyBytes)),
+ "cipher_suite" / CipherSuite,
+ "legacy_compression_method" / Hex(Byte),
+ "extensions" / _Vector(Int16ub, Extension),
+)
+
+# Handshake
+
+HandshakeType = Enum(
+ Byte,
+ client_hello=1,
+ server_hello=2,
+ new_session_ticket=4,
+ end_of_early_data=5,
+ encrypted_extensions=8,
+ certificate=11,
+ certificate_request=13,
+ certificate_verify=15,
+ finished=20,
+ key_update=24,
+ message_hash=254,
+)
+
+Handshake = Struct(
+ "msg_type" / HandshakeType,
+ "length" / Int24ub,
+ "payload"
+ / Switch(
+ this.msg_type,
+ {
+ HandshakeType.client_hello: ClientHello,
+ HandshakeType.server_hello: ServerHello,
+ # HandshakeType.end_of_early_data: EndOfEarlyData,
+ # HandshakeType.encrypted_extensions: EncryptedExtensions,
+ # HandshakeType.certificate_request: CertificateRequest,
+ # HandshakeType.certificate: Certificate,
+ # HandshakeType.certificate_verify: CertificateVerify,
+ # HandshakeType.finished: Finished,
+ # HandshakeType.new_session_ticket: NewSessionTicket,
+ # HandshakeType.key_update: KeyUpdate,
+ },
+ default=FixedSized(this.length, GreedyBytes),
+ ),
+)
+
+# Records
+
+ContentType = Enum(
+ Byte,
+ invalid=0,
+ change_cipher_spec=20,
+ alert=21,
+ handshake=22,
+ application_data=23,
+)
+
+Plaintext = Struct(
+ "type" / ContentType,
+ "legacy_record_version" / ProtocolVersion,
+ "length" / Int16ub,
+ "fragment" / FixedSized(this.length, GreedyBytes),
+)
--
2.25.1
On Tue, Jun 08, 2021 at 04:37:46PM +0000, Jacob Champion wrote:
1. Prep
0001 decouples the SASL code from the SCRAM implementation.
0002 makes it possible to use common/jsonapi from the frontend.
0003 lets the json_errdetail() result be freed, to avoid leaks.The first three patches are, hopefully, generally useful outside of
this implementation, and I'll plan to register them in the next
commitfest. The middle two patches are the "interesting" pieces, and
I've split them into client and server for ease of understanding,
though neither is particularly useful without the other.
Beginning with the beginning, could you spawn two threads for the
jsonapi rework and the SASL/SCRAM business? I agree that these look
independently useful. Glad to see someone improving the code with
SASL and SCRAM which are too inter-dependent now. I saw in the RFCs
dedicated to OAUTH the need for the JSON part as well.
+# define check_stack_depth()
+# ifdef JSONAPI_NO_LOG
+# define json_log_and_abort(...) \
+ do { fprintf(stderr, __VA_ARGS__); exit(1); } while(0)
+# else
In patch 0002, this is the wrong approach. libpq will not be able to
feed on such reports, and you cannot use any of the APIs from the
palloc() family either as these just fail on OOM. libpq should be
able to know about the error, and would fill in the error back to the
application. This abstraction is not necessary on HEAD as
pg_verifybackup is fine with this level of reporting. My rough guess
is that we will need to split the existing jsonapi.c into two files,
one that can be used in shared libraries and a second that handles the
errors.
+ /* TODO: SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
if (result == SASL_EXCHANGE_SUCCESS)
sendAuthRequest(port,
AUTH_REQ_SASL_FIN,
output,
outputlen);
Perhaps that's an issue we need to worry on its own? I didn't recall
this part..
--
Michael
On 08/06/2021 19:37, Jacob Champion wrote:
We've been working on ways to expand the list of third-party auth
methods that Postgres provides. Some example use cases might be "I want
to let anyone with a Google account read this table" or "let anyone who
belongs to this GitHub organization connect as a superuser".
Cool!
The iddawc dependency for client-side OAuth was extremely helpful to
develop this proof of concept quickly, but I don't think it would be an
appropriate component to build a real feature on. It's extremely
heavyweight -- it incorporates a huge stack of dependencies, including
a logging framework and a web server, to implement features we would
probably never use -- and it's fairly difficult to debug in practice.
If a device authorization flow were the only thing that libpq needed to
support natively, I think we should just depend on a widely used HTTP
client, like libcurl or neon, and implement the minimum spec directly
against the existing test suite.
You could punt and let the application implement that stuff. I'm
imagining that the application code would look something like this:
conn = PQconnectStartParams(...);
for (;;)
{
status = PQconnectPoll(conn)
switch (status)
{
case CONNECTION_SASL_TOKEN_REQUIRED:
/* open a browser for the user, get token */
token = open_browser()
PQauthResponse(token);
break;
...
}
}
It would be nice to have a simple default implementation, though, for
psql and all the other client applications that come with PostgreSQL itself.
If you've read this far, thank you for your interest, and I hope you
enjoy playing with it!
A few small things caught my eye in the backend oauth_exchange function:
+ /* Handle the client's initial message. */ + p = strdup(input);
this strdup() should be pstrdup().
In the same function, there are a bunch of reports like this:
ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Comma expected, but found character \"%s\".", + sanitize_char(*p))));
I don't think the double quotes are needed here, because sanitize_char
will return quotes if it's a single character. So it would end up
looking like this: ... found character "'x'".
- Heikki
On Fri, 2021-06-18 at 11:31 +0300, Heikki Linnakangas wrote:
On 08/06/2021 19:37, Jacob Champion wrote:
We've been working on ways to expand the list of third-party auth
methods that Postgres provides. Some example use cases might be "I want
to let anyone with a Google account read this table" or "let anyone who
belongs to this GitHub organization connect as a superuser".Cool!
Glad you think so! :D
The iddawc dependency for client-side OAuth was extremely helpful to
develop this proof of concept quickly, but I don't think it would be an
appropriate component to build a real feature on. It's extremely
heavyweight -- it incorporates a huge stack of dependencies, including
a logging framework and a web server, to implement features we would
probably never use -- and it's fairly difficult to debug in practice.
If a device authorization flow were the only thing that libpq needed to
support natively, I think we should just depend on a widely used HTTP
client, like libcurl or neon, and implement the minimum spec directly
against the existing test suite.You could punt and let the application implement that stuff. I'm
imagining that the application code would look something like this:conn = PQconnectStartParams(...);
for (;;)
{
status = PQconnectPoll(conn)
switch (status)
{
case CONNECTION_SASL_TOKEN_REQUIRED:
/* open a browser for the user, get token */
token = open_browser()
PQauthResponse(token);
break;
...
}
}
I was toying with the idea of having a callback for libpq clients,
where they could take full control of the OAuth flow if they wanted to.
Doing it inline with PQconnectPoll seems like it would work too. It has
a couple of drawbacks that I can see:
- If a client isn't currently using a poll loop, they'd have to switch
to one to be able to use OAuth connections. Not a difficult change, but
considering all the other hurdles to making this work, I'm hoping to
minimize the hoop-jumping.
- A client would still have to receive a bunch of OAuth parameters from
some new libpq API in order to construct the correct URL to visit, so
the overall complexity for implementers might be higher than if we just
passed those params directly in a callback.
It would be nice to have a simple default implementation, though, for
psql and all the other client applications that come with PostgreSQL itself.
I agree. I think having a bare-bones implementation in libpq itself
would make initial adoption *much* easier, and then if specific
applications wanted to have richer control over an authorization flow,
then they could implement that themselves with the aforementioned
callback.
The Device Authorization flow was the most minimal working
implementation I could find, since it doesn't require a web browser on
the system, just the ability to print a prompt to the console. But if
anyone knows of a better flow for this use case, I'm all ears.
If you've read this far, thank you for your interest, and I hope you
enjoy playing with it!A few small things caught my eye in the backend oauth_exchange function:
+ /* Handle the client's initial message. */ + p = strdup(input);this strdup() should be pstrdup().
Thanks, I'll fix that in the next re-roll.
In the same function, there are a bunch of reports like this:
ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Comma expected, but found character \"%s\".", + sanitize_char(*p))));I don't think the double quotes are needed here, because sanitize_char
will return quotes if it's a single character. So it would end up
looking like this: ... found character "'x'".
I'll fix this too. Thanks!
--Jacob
On Fri, 2021-06-18 at 13:07 +0900, Michael Paquier wrote:
On Tue, Jun 08, 2021 at 04:37:46PM +0000, Jacob Champion wrote:
1. Prep
0001 decouples the SASL code from the SCRAM implementation.
0002 makes it possible to use common/jsonapi from the frontend.
0003 lets the json_errdetail() result be freed, to avoid leaks.The first three patches are, hopefully, generally useful outside of
this implementation, and I'll plan to register them in the next
commitfest. The middle two patches are the "interesting" pieces, and
I've split them into client and server for ease of understanding,
though neither is particularly useful without the other.Beginning with the beginning, could you spawn two threads for the
jsonapi rework and the SASL/SCRAM business?
Done [1, 2]. I've copied your comments into those threads with my
responses, and I'll have them registered in commitfest shortly.
Thanks!
--Jacob
[1]: /messages/by-id/3d2a6f5d50e741117d6baf83eb67ebf1a8a35a11.camel@vmware.com
[2]: /messages/by-id/a250d475ba1c0cc0efb7dfec8e538fcc77cdcb8e.camel@vmware.com
On Tue, Jun 22, 2021 at 11:26:03PM +0000, Jacob Champion wrote:
Done [1, 2]. I've copied your comments into those threads with my
responses, and I'll have them registered in commitfest shortly.
Thanks!
--
Michael
On Tue, 2021-06-22 at 23:22 +0000, Jacob Champion wrote:
On Fri, 2021-06-18 at 11:31 +0300, Heikki Linnakangas wrote:
A few small things caught my eye in the backend oauth_exchange function:
+ /* Handle the client's initial message. */ + p = strdup(input);this strdup() should be pstrdup().
Thanks, I'll fix that in the next re-roll.
In the same function, there are a bunch of reports like this:
ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Comma expected, but found character \"%s\".", + sanitize_char(*p))));I don't think the double quotes are needed here, because sanitize_char
will return quotes if it's a single character. So it would end up
looking like this: ... found character "'x'".I'll fix this too. Thanks!
v2, attached, incorporates Heikki's suggested fixes and also rebases on
top of latest HEAD, which had the SASL refactoring changes committed
last month.
The biggest change from the last patchset is 0001, an attempt at
enabling jsonapi in the frontend without the use of palloc(), based on
suggestions by Michael and Tom from last commitfest. I've also made
some improvements to the pytest suite. No major changes to the OAuth
implementation yet.
--Jacob
Attachments:
v2-0001-common-jsonapi-support-FRONTEND-clients.patchtext/x-patch; name=v2-0001-common-jsonapi-support-FRONTEND-clients.patchDownload
From 8c4b82940efb7e0f0f33ac915d5f7969a36e3644 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Mon, 3 May 2021 15:38:26 -0700
Subject: [PATCH v2 1/5] common/jsonapi: support FRONTEND clients
Based on a patch by Michael Paquier.
For frontend code, use PQExpBuffer instead of StringInfo. This requires
us to track allocation failures so that we can return JSON_OUT_OF_MEMORY
as needed. json_errdetail() now allocates its error message inside
memory owned by the JsonLexContext, so clients don't need to worry about
freeing it.
For convenience, the backend now has destroyJsonLexContext() to mirror
other create/destroy APIs. The frontend has init/term versions of the
API to handle stack-allocated JsonLexContexts.
We can now partially revert b44669b2ca, now that json_errdetail() works
correctly.
---
src/backend/utils/adt/jsonfuncs.c | 4 +-
src/bin/pg_verifybackup/parse_manifest.c | 13 +-
src/bin/pg_verifybackup/t/005_bad_manifest.pl | 2 +-
src/common/Makefile | 2 +-
src/common/jsonapi.c | 290 +++++++++++++-----
src/include/common/jsonapi.h | 47 ++-
6 files changed, 270 insertions(+), 88 deletions(-)
diff --git a/src/backend/utils/adt/jsonfuncs.c b/src/backend/utils/adt/jsonfuncs.c
index 5fd54b64b5..fa39751188 100644
--- a/src/backend/utils/adt/jsonfuncs.c
+++ b/src/backend/utils/adt/jsonfuncs.c
@@ -723,9 +723,7 @@ json_object_keys(PG_FUNCTION_ARGS)
pg_parse_json_or_ereport(lex, sem);
/* keys are now in state->result */
- pfree(lex->strval->data);
- pfree(lex->strval);
- pfree(lex);
+ destroyJsonLexContext(lex);
pfree(sem);
MemoryContextSwitchTo(oldcontext);
diff --git a/src/bin/pg_verifybackup/parse_manifest.c b/src/bin/pg_verifybackup/parse_manifest.c
index c7ccc78c70..6cedb7435f 100644
--- a/src/bin/pg_verifybackup/parse_manifest.c
+++ b/src/bin/pg_verifybackup/parse_manifest.c
@@ -119,7 +119,7 @@ void
json_parse_manifest(JsonManifestParseContext *context, char *buffer,
size_t size)
{
- JsonLexContext *lex;
+ JsonLexContext lex = {0};
JsonParseErrorType json_error;
JsonSemAction sem;
JsonManifestParseState parse;
@@ -129,8 +129,8 @@ json_parse_manifest(JsonManifestParseContext *context, char *buffer,
parse.state = JM_EXPECT_TOPLEVEL_START;
parse.saw_version_field = false;
- /* Create a JSON lexing context. */
- lex = makeJsonLexContextCstringLen(buffer, size, PG_UTF8, true);
+ /* Initialize a JSON lexing context. */
+ initJsonLexContextCstringLen(&lex, buffer, size, PG_UTF8, true);
/* Set up semantic actions. */
sem.semstate = &parse;
@@ -145,14 +145,17 @@ json_parse_manifest(JsonManifestParseContext *context, char *buffer,
sem.scalar = json_manifest_scalar;
/* Run the actual JSON parser. */
- json_error = pg_parse_json(lex, &sem);
+ json_error = pg_parse_json(&lex, &sem);
if (json_error != JSON_SUCCESS)
- json_manifest_parse_failure(context, "parsing failed");
+ json_manifest_parse_failure(context, json_errdetail(json_error, &lex));
if (parse.state != JM_EXPECT_EOF)
json_manifest_parse_failure(context, "manifest ended unexpectedly");
/* Verify the manifest checksum. */
verify_manifest_checksum(&parse, buffer, size);
+
+ /* Clean up. */
+ termJsonLexContext(&lex);
}
/*
diff --git a/src/bin/pg_verifybackup/t/005_bad_manifest.pl b/src/bin/pg_verifybackup/t/005_bad_manifest.pl
index 4f5b8f5a49..9f8a100a71 100644
--- a/src/bin/pg_verifybackup/t/005_bad_manifest.pl
+++ b/src/bin/pg_verifybackup/t/005_bad_manifest.pl
@@ -16,7 +16,7 @@ my $tempdir = TestLib::tempdir;
test_bad_manifest(
'input string ended unexpectedly',
- qr/could not parse backup manifest: parsing failed/,
+ qr/could not parse backup manifest: The input string ended unexpectedly/,
<<EOM);
{
EOM
diff --git a/src/common/Makefile b/src/common/Makefile
index 880722fcf5..5ecb09a8c4 100644
--- a/src/common/Makefile
+++ b/src/common/Makefile
@@ -40,7 +40,7 @@ override CPPFLAGS += -DVAL_LDFLAGS_EX="\"$(LDFLAGS_EX)\""
override CPPFLAGS += -DVAL_LDFLAGS_SL="\"$(LDFLAGS_SL)\""
override CPPFLAGS += -DVAL_LIBS="\"$(LIBS)\""
-override CPPFLAGS := -DFRONTEND -I. -I$(top_srcdir)/src/common $(CPPFLAGS)
+override CPPFLAGS := -DFRONTEND -I. -I$(top_srcdir)/src/common -I$(libpq_srcdir) $(CPPFLAGS)
LIBS += $(PTHREAD_LIBS)
# If you add objects here, see also src/tools/msvc/Mkvcbuild.pm
diff --git a/src/common/jsonapi.c b/src/common/jsonapi.c
index 5504072b4f..3a9620f739 100644
--- a/src/common/jsonapi.c
+++ b/src/common/jsonapi.c
@@ -20,10 +20,39 @@
#include "common/jsonapi.h"
#include "mb/pg_wchar.h"
-#ifndef FRONTEND
+#ifdef FRONTEND
+#include "pqexpbuffer.h"
+#else
+#include "lib/stringinfo.h"
#include "miscadmin.h"
#endif
+/*
+ * In backend, we will use palloc/pfree along with StringInfo. In frontend, use
+ * malloc and PQExpBuffer, and return JSON_OUT_OF_MEMORY on out-of-memory.
+ */
+#ifdef FRONTEND
+
+#define STRDUP(s) strdup(s)
+#define ALLOC(size) malloc(size)
+
+#define appendStrVal appendPQExpBuffer
+#define appendStrValChar appendPQExpBufferChar
+#define createStrVal createPQExpBuffer
+#define resetStrVal resetPQExpBuffer
+
+#else /* !FRONTEND */
+
+#define STRDUP(s) pstrdup(s)
+#define ALLOC(size) palloc(size)
+
+#define appendStrVal appendStringInfo
+#define appendStrValChar appendStringInfoChar
+#define createStrVal makeStringInfo
+#define resetStrVal resetStringInfo
+
+#endif
+
/*
* The context of the parser is maintained by the recursive descent
* mechanism, but is passed explicitly to the error reporting routine
@@ -132,10 +161,12 @@ IsValidJsonNumber(const char *str, int len)
return (!numeric_error) && (total_len == dummy_lex.input_length);
}
+#ifndef FRONTEND
+
/*
* makeJsonLexContextCstringLen
*
- * lex constructor, with or without StringInfo object for de-escaped lexemes.
+ * lex constructor, with or without a string object for de-escaped lexemes.
*
* Without is better as it makes the processing faster, so only make one
* if really required.
@@ -145,13 +176,66 @@ makeJsonLexContextCstringLen(char *json, int len, int encoding, bool need_escape
{
JsonLexContext *lex = palloc0(sizeof(JsonLexContext));
+ initJsonLexContextCstringLen(lex, json, len, encoding, need_escapes);
+
+ return lex;
+}
+
+void
+destroyJsonLexContext(JsonLexContext *lex)
+{
+ termJsonLexContext(lex);
+ pfree(lex);
+}
+
+#endif /* !FRONTEND */
+
+void
+initJsonLexContextCstringLen(JsonLexContext *lex, char *json, int len, int encoding, bool need_escapes)
+{
lex->input = lex->token_terminator = lex->line_start = json;
lex->line_number = 1;
lex->input_length = len;
lex->input_encoding = encoding;
- if (need_escapes)
- lex->strval = makeStringInfo();
- return lex;
+ lex->parse_strval = need_escapes;
+ if (lex->parse_strval)
+ {
+ /*
+ * This call can fail in FRONTEND code. We defer error handling to time
+ * of use (json_lex_string()) since there's no way to signal failure
+ * here, and we might not need to parse any strings anyway.
+ */
+ lex->strval = createStrVal();
+ }
+ lex->errormsg = NULL;
+}
+
+void
+termJsonLexContext(JsonLexContext *lex)
+{
+ static const JsonLexContext empty = {0};
+
+ if (lex->strval)
+ {
+#ifdef FRONTEND
+ destroyPQExpBuffer(lex->strval);
+#else
+ pfree(lex->strval->data);
+ pfree(lex->strval);
+#endif
+ }
+
+ if (lex->errormsg)
+ {
+#ifdef FRONTEND
+ destroyPQExpBuffer(lex->errormsg);
+#else
+ pfree(lex->errormsg->data);
+ pfree(lex->errormsg);
+#endif
+ }
+
+ *lex = empty;
}
/*
@@ -217,7 +301,7 @@ json_count_array_elements(JsonLexContext *lex, int *elements)
* etc, so doing this with a copy makes that safe.
*/
memcpy(©lex, lex, sizeof(JsonLexContext));
- copylex.strval = NULL; /* not interested in values here */
+ copylex.parse_strval = false; /* not interested in values here */
copylex.lex_level++;
count = 0;
@@ -279,14 +363,21 @@ parse_scalar(JsonLexContext *lex, JsonSemAction *sem)
/* extract the de-escaped string value, or the raw lexeme */
if (lex_peek(lex) == JSON_TOKEN_STRING)
{
- if (lex->strval != NULL)
- val = pstrdup(lex->strval->data);
+ if (lex->parse_strval)
+ {
+ val = STRDUP(lex->strval->data);
+ if (val == NULL)
+ return JSON_OUT_OF_MEMORY;
+ }
}
else
{
int len = (lex->token_terminator - lex->token_start);
- val = palloc(len + 1);
+ val = ALLOC(len + 1);
+ if (val == NULL)
+ return JSON_OUT_OF_MEMORY;
+
memcpy(val, lex->token_start, len);
val[len] = '\0';
}
@@ -320,8 +411,12 @@ parse_object_field(JsonLexContext *lex, JsonSemAction *sem)
if (lex_peek(lex) != JSON_TOKEN_STRING)
return report_parse_error(JSON_PARSE_STRING, lex);
- if ((ostart != NULL || oend != NULL) && lex->strval != NULL)
- fname = pstrdup(lex->strval->data);
+ if ((ostart != NULL || oend != NULL) && lex->parse_strval)
+ {
+ fname = STRDUP(lex->strval->data);
+ if (fname == NULL)
+ return JSON_OUT_OF_MEMORY;
+ }
result = json_lex(lex);
if (result != JSON_SUCCESS)
return result;
@@ -368,6 +463,10 @@ parse_object(JsonLexContext *lex, JsonSemAction *sem)
JsonParseErrorType result;
#ifndef FRONTEND
+ /*
+ * TODO: clients need some way to put a bound on stack growth. Parse level
+ * limits maybe?
+ */
check_stack_depth();
#endif
@@ -676,8 +775,15 @@ json_lex_string(JsonLexContext *lex)
int len;
int hi_surrogate = -1;
- if (lex->strval != NULL)
- resetStringInfo(lex->strval);
+ if (lex->parse_strval)
+ {
+#ifdef FRONTEND
+ /* make sure initialization succeeded */
+ if (lex->strval == NULL)
+ return JSON_OUT_OF_MEMORY;
+#endif
+ resetStrVal(lex->strval);
+ }
Assert(lex->input_length > 0);
s = lex->token_start;
@@ -737,7 +843,7 @@ json_lex_string(JsonLexContext *lex)
return JSON_UNICODE_ESCAPE_FORMAT;
}
}
- if (lex->strval != NULL)
+ if (lex->parse_strval)
{
/*
* Combine surrogate pairs.
@@ -797,19 +903,19 @@ json_lex_string(JsonLexContext *lex)
unicode_to_utf8(ch, (unsigned char *) utf8str);
utf8len = pg_utf_mblen((unsigned char *) utf8str);
- appendBinaryStringInfo(lex->strval, utf8str, utf8len);
+ appendBinaryPQExpBuffer(lex->strval, utf8str, utf8len);
}
else if (ch <= 0x007f)
{
/* The ASCII range is the same in all encodings */
- appendStringInfoChar(lex->strval, (char) ch);
+ appendPQExpBufferChar(lex->strval, (char) ch);
}
else
return JSON_UNICODE_HIGH_ESCAPE;
#endif /* FRONTEND */
}
}
- else if (lex->strval != NULL)
+ else if (lex->parse_strval)
{
if (hi_surrogate != -1)
return JSON_UNICODE_LOW_SURROGATE;
@@ -819,22 +925,22 @@ json_lex_string(JsonLexContext *lex)
case '"':
case '\\':
case '/':
- appendStringInfoChar(lex->strval, *s);
+ appendStrValChar(lex->strval, *s);
break;
case 'b':
- appendStringInfoChar(lex->strval, '\b');
+ appendStrValChar(lex->strval, '\b');
break;
case 'f':
- appendStringInfoChar(lex->strval, '\f');
+ appendStrValChar(lex->strval, '\f');
break;
case 'n':
- appendStringInfoChar(lex->strval, '\n');
+ appendStrValChar(lex->strval, '\n');
break;
case 'r':
- appendStringInfoChar(lex->strval, '\r');
+ appendStrValChar(lex->strval, '\r');
break;
case 't':
- appendStringInfoChar(lex->strval, '\t');
+ appendStrValChar(lex->strval, '\t');
break;
default:
/* Not a valid string escape, so signal error. */
@@ -858,12 +964,12 @@ json_lex_string(JsonLexContext *lex)
}
}
- else if (lex->strval != NULL)
+ else if (lex->parse_strval)
{
if (hi_surrogate != -1)
return JSON_UNICODE_LOW_SURROGATE;
- appendStringInfoChar(lex->strval, *s);
+ appendStrValChar(lex->strval, *s);
}
}
@@ -871,6 +977,11 @@ json_lex_string(JsonLexContext *lex)
if (hi_surrogate != -1)
return JSON_UNICODE_LOW_SURROGATE;
+#ifdef FRONTEND
+ if (lex->parse_strval && PQExpBufferBroken(lex->strval))
+ return JSON_OUT_OF_MEMORY;
+#endif
+
/* Hooray, we found the end of the string! */
lex->prev_token_terminator = lex->token_terminator;
lex->token_terminator = s + 1;
@@ -1043,72 +1154,93 @@ report_parse_error(JsonParseContext ctx, JsonLexContext *lex)
return JSON_SUCCESS; /* silence stupider compilers */
}
-
-#ifndef FRONTEND
-/*
- * Extract the current token from a lexing context, for error reporting.
- */
-static char *
-extract_token(JsonLexContext *lex)
-{
- int toklen = lex->token_terminator - lex->token_start;
- char *token = palloc(toklen + 1);
-
- memcpy(token, lex->token_start, toklen);
- token[toklen] = '\0';
- return token;
-}
-
/*
* Construct a detail message for a JSON error.
*
- * Note that the error message generated by this routine may not be
- * palloc'd, making it unsafe for frontend code as there is no way to
- * know if this can be safery pfree'd or not.
+ * The returned allocation is either static or owned by the JsonLexContext and
+ * should not be freed.
*/
char *
json_errdetail(JsonParseErrorType error, JsonLexContext *lex)
{
+ int toklen = lex->token_terminator - lex->token_start;
+
+ if (error == JSON_OUT_OF_MEMORY)
+ {
+ /* Short circuit. Allocating anything for this case is unhelpful. */
+ return _("out of memory");
+ }
+
+ if (lex->errormsg)
+ resetStrVal(lex->errormsg);
+ else
+ lex->errormsg = createStrVal();
+
switch (error)
{
case JSON_SUCCESS:
/* fall through to the error code after switch */
break;
case JSON_ESCAPING_INVALID:
- return psprintf(_("Escape sequence \"\\%s\" is invalid."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Escape sequence \"\\%.*s\" is invalid."),
+ toklen, lex->token_start);
+ break;
case JSON_ESCAPING_REQUIRED:
- return psprintf(_("Character with value 0x%02x must be escaped."),
- (unsigned char) *(lex->token_terminator));
+ appendStrVal(lex->errormsg,
+ _("Character with value 0x%02x must be escaped."),
+ (unsigned char) *(lex->token_terminator));
+ break;
case JSON_EXPECTED_END:
- return psprintf(_("Expected end of input, but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected end of input, but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_ARRAY_FIRST:
- return psprintf(_("Expected array element or \"]\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected array element or \"]\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_ARRAY_NEXT:
- return psprintf(_("Expected \",\" or \"]\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected \",\" or \"]\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_COLON:
- return psprintf(_("Expected \":\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected \":\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_JSON:
- return psprintf(_("Expected JSON value, but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected JSON value, but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_MORE:
return _("The input string ended unexpectedly.");
case JSON_EXPECTED_OBJECT_FIRST:
- return psprintf(_("Expected string or \"}\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected string or \"}\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_OBJECT_NEXT:
- return psprintf(_("Expected \",\" or \"}\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected \",\" or \"}\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_STRING:
- return psprintf(_("Expected string, but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected string, but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_INVALID_TOKEN:
- return psprintf(_("Token \"%s\" is invalid."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Token \"%.*s\" is invalid."),
+ toklen, lex->token_start);
+ break;
+ case JSON_OUT_OF_MEMORY:
+ /* should have been handled above; use the error path */
+ break;
case JSON_UNICODE_CODE_POINT_ZERO:
return _("\\u0000 cannot be converted to text.");
case JSON_UNICODE_ESCAPE_FORMAT:
@@ -1122,12 +1254,22 @@ json_errdetail(JsonParseErrorType error, JsonLexContext *lex)
return _("Unicode low surrogate must follow a high surrogate.");
}
- /*
- * We don't use a default: case, so that the compiler will warn about
- * unhandled enum values. But this needs to be here anyway to cover the
- * possibility of an incorrect input.
- */
- elog(ERROR, "unexpected json parse error type: %d", (int) error);
- return NULL;
-}
+ /* Note that lex->errormsg can be NULL in FRONTEND code. */
+ if (lex->errormsg && !lex->errormsg->data[0])
+ {
+ /*
+ * We don't use a default: case, so that the compiler will warn about
+ * unhandled enum values. But this needs to be here anyway to cover the
+ * possibility of an incorrect input.
+ */
+ appendStrVal(lex->errormsg,
+ "unexpected json parse error type: %d", (int) error);
+ }
+
+#ifdef FRONTEND
+ if (PQExpBufferBroken(lex->errormsg))
+ return _("out of memory while constructing error description");
#endif
+
+ return lex->errormsg->data;
+}
diff --git a/src/include/common/jsonapi.h b/src/include/common/jsonapi.h
index ec3dfce9c3..dc71ab2cd3 100644
--- a/src/include/common/jsonapi.h
+++ b/src/include/common/jsonapi.h
@@ -14,8 +14,6 @@
#ifndef JSONAPI_H
#define JSONAPI_H
-#include "lib/stringinfo.h"
-
typedef enum
{
JSON_TOKEN_INVALID,
@@ -48,6 +46,7 @@ typedef enum
JSON_EXPECTED_OBJECT_NEXT,
JSON_EXPECTED_STRING,
JSON_INVALID_TOKEN,
+ JSON_OUT_OF_MEMORY,
JSON_UNICODE_CODE_POINT_ZERO,
JSON_UNICODE_ESCAPE_FORMAT,
JSON_UNICODE_HIGH_ESCAPE,
@@ -55,6 +54,17 @@ typedef enum
JSON_UNICODE_LOW_SURROGATE
} JsonParseErrorType;
+/*
+ * Don't depend on the internal type header for strval; if callers need access
+ * then they can include the appropriate header themselves.
+ */
+#ifdef FRONTEND
+#define StrValType PQExpBufferData
+#else
+#define StrValType StringInfoData
+#endif
+
+typedef struct StrValType StrValType;
/*
* All the fields in this structure should be treated as read-only.
@@ -81,7 +91,9 @@ typedef struct JsonLexContext
int lex_level;
int line_number; /* line number, starting from 1 */
char *line_start; /* where that line starts within input */
- StringInfo strval;
+ bool parse_strval;
+ StrValType *strval; /* only used if parse_strval == true */
+ StrValType *errormsg;
} JsonLexContext;
typedef void (*json_struct_action) (void *state);
@@ -141,9 +153,10 @@ extern JsonSemAction nullSemAction;
*/
extern JsonParseErrorType json_count_array_elements(JsonLexContext *lex,
int *elements);
+#ifndef FRONTEND
/*
- * constructor for JsonLexContext, with or without strval element.
+ * allocating constructor for JsonLexContext, with or without strval element.
* If supplied, the strval element will contain a de-escaped version of
* the lexeme. However, doing this imposes a performance penalty, so
* it should be avoided if the de-escaped lexeme is not required.
@@ -153,6 +166,32 @@ extern JsonLexContext *makeJsonLexContextCstringLen(char *json,
int encoding,
bool need_escapes);
+/*
+ * Counterpart to makeJsonLexContextCstringLen(): clears and deallocates lex.
+ * The context pointer should not be used after this call.
+ */
+extern void destroyJsonLexContext(JsonLexContext *lex);
+
+#endif /* !FRONTEND */
+
+/*
+ * stack constructor for JsonLexContext, with or without strval element.
+ * If supplied, the strval element will contain a de-escaped version of
+ * the lexeme. However, doing this imposes a performance penalty, so
+ * it should be avoided if the de-escaped lexeme is not required.
+ */
+extern void initJsonLexContextCstringLen(JsonLexContext *lex,
+ char *json,
+ int len,
+ int encoding,
+ bool need_escapes);
+
+/*
+ * Counterpart to initJsonLexContextCstringLen(): clears the contents of lex,
+ * but does not deallocate lex itself.
+ */
+extern void termJsonLexContext(JsonLexContext *lex);
+
/* lex one token */
extern JsonParseErrorType json_lex(JsonLexContext *lex);
--
2.25.1
v2-0002-libpq-add-OAUTHBEARER-SASL-mechanism.patchtext/x-patch; name=v2-0002-libpq-add-OAUTHBEARER-SASL-mechanism.patchDownload
From 52ac4bd25ca19735eb2bd863e8b1549ccbe6560a Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 13 Apr 2021 10:27:27 -0700
Subject: [PATCH v2 2/5] libpq: add OAUTHBEARER SASL mechanism
DO NOT USE THIS PROOF OF CONCEPT IN PRODUCTION.
Implement OAUTHBEARER (RFC 7628) and OAuth 2.0 Device Authorization
Grants (RFC 8628) on the client side. When speaking to a OAuth-enabled
server, it looks a bit like this:
$ psql 'host=example.org oauth_client_id=f02c6361-0635-...'
Visit https://oauth.example.org/login and enter the code: FPQ2-M4BG
The OAuth issuer must support device authorization. No other OAuth flows
are currently implemented.
The client implementation requires libiddawc and its development
headers. Configure --with-oauth (and --with-includes/--with-libraries to
point at the iddawc installation, if it's in a custom location).
Several TODOs:
- don't retry forever if the server won't accept our token
- perform several sanity checks on the OAuth issuer's responses
- handle cases where the client has been set up with an issuer and
scope, but the Postgres server wants to use something different
- improve error debuggability during the OAuth handshake
- ...and more.
---
configure | 100 ++++
configure.ac | 19 +
src/Makefile.global.in | 1 +
src/include/common/oauth-common.h | 19 +
src/include/pg_config.h.in | 6 +
src/interfaces/libpq/Makefile | 7 +-
src/interfaces/libpq/fe-auth-oauth.c | 745 +++++++++++++++++++++++++++
src/interfaces/libpq/fe-auth-sasl.h | 5 +-
src/interfaces/libpq/fe-auth-scram.c | 6 +-
src/interfaces/libpq/fe-auth.c | 42 +-
src/interfaces/libpq/fe-auth.h | 3 +
src/interfaces/libpq/fe-connect.c | 38 ++
src/interfaces/libpq/libpq-int.h | 8 +
13 files changed, 980 insertions(+), 19 deletions(-)
create mode 100644 src/include/common/oauth-common.h
create mode 100644 src/interfaces/libpq/fe-auth-oauth.c
diff --git a/configure b/configure
index 7542fe30a1..2ddbe9a1d9 100755
--- a/configure
+++ b/configure
@@ -713,6 +713,7 @@ with_uuid
with_readline
with_systemd
with_selinux
+with_oauth
with_ldap
with_krb_srvnam
krb_srvtab
@@ -856,6 +857,7 @@ with_krb_srvnam
with_pam
with_bsd_auth
with_ldap
+with_oauth
with_bonjour
with_selinux
with_systemd
@@ -1562,6 +1564,7 @@ Optional Packages:
--with-pam build with PAM support
--with-bsd-auth build with BSD Authentication support
--with-ldap build with LDAP support
+ --with-oauth build with OAuth 2.0 support
--with-bonjour build with Bonjour support
--with-selinux build with SELinux support
--with-systemd build with systemd support
@@ -8144,6 +8147,42 @@ $as_echo "$with_ldap" >&6; }
+#
+# OAuth 2.0
+#
+{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether to build with OAuth support" >&5
+$as_echo_n "checking whether to build with OAuth support... " >&6; }
+
+
+
+# Check whether --with-oauth was given.
+if test "${with_oauth+set}" = set; then :
+ withval=$with_oauth;
+ case $withval in
+ yes)
+
+$as_echo "#define USE_OAUTH 1" >>confdefs.h
+
+ ;;
+ no)
+ :
+ ;;
+ *)
+ as_fn_error $? "no argument expected for --with-oauth option" "$LINENO" 5
+ ;;
+ esac
+
+else
+ with_oauth=no
+
+fi
+
+
+{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $with_oauth" >&5
+$as_echo "$with_oauth" >&6; }
+
+
+
#
# Bonjour
#
@@ -13084,6 +13123,56 @@ fi
+if test "$with_oauth" = yes ; then
+ { $as_echo "$as_me:${as_lineno-$LINENO}: checking for i_init_session in -liddawc" >&5
+$as_echo_n "checking for i_init_session in -liddawc... " >&6; }
+if ${ac_cv_lib_iddawc_i_init_session+:} false; then :
+ $as_echo_n "(cached) " >&6
+else
+ ac_check_lib_save_LIBS=$LIBS
+LIBS="-liddawc $LIBS"
+cat confdefs.h - <<_ACEOF >conftest.$ac_ext
+/* end confdefs.h. */
+
+/* Override any GCC internal prototype to avoid an error.
+ Use char because int might match the return type of a GCC
+ builtin and then its argument prototype would still apply. */
+#ifdef __cplusplus
+extern "C"
+#endif
+char i_init_session ();
+int
+main ()
+{
+return i_init_session ();
+ ;
+ return 0;
+}
+_ACEOF
+if ac_fn_c_try_link "$LINENO"; then :
+ ac_cv_lib_iddawc_i_init_session=yes
+else
+ ac_cv_lib_iddawc_i_init_session=no
+fi
+rm -f core conftest.err conftest.$ac_objext \
+ conftest$ac_exeext conftest.$ac_ext
+LIBS=$ac_check_lib_save_LIBS
+fi
+{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_iddawc_i_init_session" >&5
+$as_echo "$ac_cv_lib_iddawc_i_init_session" >&6; }
+if test "x$ac_cv_lib_iddawc_i_init_session" = xyes; then :
+ cat >>confdefs.h <<_ACEOF
+#define HAVE_LIBIDDAWC 1
+_ACEOF
+
+ LIBS="-liddawc $LIBS"
+
+else
+ as_fn_error $? "library 'iddawc' is required for OAuth support" "$LINENO" 5
+fi
+
+fi
+
# for contrib/sepgsql
if test "$with_selinux" = yes; then
{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for security_compute_create_name in -lselinux" >&5
@@ -13978,6 +14067,17 @@ fi
done
+fi
+
+if test "$with_oauth" != no; then
+ ac_fn_c_check_header_mongrel "$LINENO" "iddawc.h" "ac_cv_header_iddawc_h" "$ac_includes_default"
+if test "x$ac_cv_header_iddawc_h" = xyes; then :
+
+else
+ as_fn_error $? "header file <iddawc.h> is required for OAuth" "$LINENO" 5
+fi
+
+
fi
if test "$PORTNAME" = "win32" ; then
diff --git a/configure.ac b/configure.ac
index ed3cdb9a8e..22026476d9 100644
--- a/configure.ac
+++ b/configure.ac
@@ -851,6 +851,17 @@ AC_MSG_RESULT([$with_ldap])
AC_SUBST(with_ldap)
+#
+# OAuth 2.0
+#
+AC_MSG_CHECKING([whether to build with OAuth support])
+PGAC_ARG_BOOL(with, oauth, no,
+ [build with OAuth 2.0 support],
+ [AC_DEFINE([USE_OAUTH], 1, [Define to 1 to build with OAuth 2.0 support. (--with-oauth)])])
+AC_MSG_RESULT([$with_oauth])
+AC_SUBST(with_oauth)
+
+
#
# Bonjour
#
@@ -1321,6 +1332,10 @@ fi
AC_SUBST(LDAP_LIBS_FE)
AC_SUBST(LDAP_LIBS_BE)
+if test "$with_oauth" = yes ; then
+ AC_CHECK_LIB(iddawc, i_init_session, [], [AC_MSG_ERROR([library 'iddawc' is required for OAuth support])])
+fi
+
# for contrib/sepgsql
if test "$with_selinux" = yes; then
AC_CHECK_LIB(selinux, security_compute_create_name, [],
@@ -1531,6 +1546,10 @@ elif test "$with_uuid" = ossp ; then
[AC_MSG_ERROR([header file <ossp/uuid.h> or <uuid.h> is required for OSSP UUID])])])
fi
+if test "$with_oauth" != no; then
+ AC_CHECK_HEADER(iddawc.h, [], [AC_MSG_ERROR([header file <iddawc.h> is required for OAuth])])
+fi
+
if test "$PORTNAME" = "win32" ; then
AC_CHECK_HEADERS(crtdefs.h)
fi
diff --git a/src/Makefile.global.in b/src/Makefile.global.in
index 6e2f224cc4..d67912711e 100644
--- a/src/Makefile.global.in
+++ b/src/Makefile.global.in
@@ -193,6 +193,7 @@ with_ldap = @with_ldap@
with_libxml = @with_libxml@
with_libxslt = @with_libxslt@
with_llvm = @with_llvm@
+with_oauth = @with_oauth@
with_system_tzdata = @with_system_tzdata@
with_uuid = @with_uuid@
with_zlib = @with_zlib@
diff --git a/src/include/common/oauth-common.h b/src/include/common/oauth-common.h
new file mode 100644
index 0000000000..3fa95ac7e8
--- /dev/null
+++ b/src/include/common/oauth-common.h
@@ -0,0 +1,19 @@
+/*-------------------------------------------------------------------------
+ *
+ * oauth-common.h
+ * Declarations for helper functions used for OAuth/OIDC authentication
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/common/oauth-common.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef OAUTH_COMMON_H
+#define OAUTH_COMMON_H
+
+/* Name of SASL mechanism per IANA */
+#define OAUTHBEARER_NAME "OAUTHBEARER"
+
+#endif /* OAUTH_COMMON_H */
diff --git a/src/include/pg_config.h.in b/src/include/pg_config.h.in
index 15ffdd895a..f82ab38536 100644
--- a/src/include/pg_config.h.in
+++ b/src/include/pg_config.h.in
@@ -331,6 +331,9 @@
/* Define to 1 if you have the `crypto' library (-lcrypto). */
#undef HAVE_LIBCRYPTO
+/* Define to 1 if you have the `iddawc' library (-liddawc). */
+#undef HAVE_LIBIDDAWC
+
/* Define to 1 if you have the `ldap' library (-lldap). */
#undef HAVE_LIBLDAP
@@ -926,6 +929,9 @@
/* Define to select named POSIX semaphores. */
#undef USE_NAMED_POSIX_SEMAPHORES
+/* Define to 1 to build with OAuth 2.0 support. (--with-oauth) */
+#undef USE_OAUTH
+
/* Define to 1 to build with OpenSSL support. (--with-ssl=openssl) */
#undef USE_OPENSSL
diff --git a/src/interfaces/libpq/Makefile b/src/interfaces/libpq/Makefile
index 7cbdeb589b..3cdf19294b 100644
--- a/src/interfaces/libpq/Makefile
+++ b/src/interfaces/libpq/Makefile
@@ -62,6 +62,11 @@ OBJS += \
fe-secure-gssapi.o
endif
+ifeq ($(with_oauth),yes)
+OBJS += \
+ fe-auth-oauth.o
+endif
+
ifeq ($(PORTNAME), cygwin)
override shlib = cyg$(NAME)$(DLSUFFIX)
endif
@@ -83,7 +88,7 @@ endif
# that are built correctly for use in a shlib.
SHLIB_LINK_INTERNAL = -lpgcommon_shlib -lpgport_shlib
ifneq ($(PORTNAME), win32)
-SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi_krb5 -lgss -lgssapi -lssl -lsocket -lnsl -lresolv -lintl -lm, $(LIBS)) $(LDAP_LIBS_FE) $(PTHREAD_LIBS)
+SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi_krb5 -lgss -lgssapi -lssl -liddawc -lsocket -lnsl -lresolv -lintl -lm, $(LIBS)) $(LDAP_LIBS_FE) $(PTHREAD_LIBS)
else
SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi32 -lssl -lsocket -lnsl -lresolv -lintl -lm $(PTHREAD_LIBS), $(LIBS)) $(LDAP_LIBS_FE)
endif
diff --git a/src/interfaces/libpq/fe-auth-oauth.c b/src/interfaces/libpq/fe-auth-oauth.c
new file mode 100644
index 0000000000..91d2c69f16
--- /dev/null
+++ b/src/interfaces/libpq/fe-auth-oauth.c
@@ -0,0 +1,745 @@
+/*-------------------------------------------------------------------------
+ *
+ * fe-auth-oauth.c
+ * The front-end (client) implementation of OAuth/OIDC authentication.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * IDENTIFICATION
+ * src/interfaces/libpq/fe-auth-oauth.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include <iddawc.h>
+
+#include "postgres_fe.h"
+
+#include "common/base64.h"
+#include "common/hmac.h"
+#include "common/jsonapi.h"
+#include "common/oauth-common.h"
+#include "fe-auth.h"
+#include "mb/pg_wchar.h"
+
+/* The exported OAuth callback mechanism. */
+static void *oauth_init(PGconn *conn, const char *password,
+ const char *sasl_mechanism);
+static void oauth_exchange(void *opaq, bool final,
+ char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success);
+static bool oauth_channel_bound(void *opaq);
+static void oauth_free(void *opaq);
+
+const pg_fe_sasl_mech pg_oauth_mech = {
+ oauth_init,
+ oauth_exchange,
+ oauth_channel_bound,
+ oauth_free,
+};
+
+typedef enum
+{
+ FE_OAUTH_INIT,
+ FE_OAUTH_BEARER_SENT,
+ FE_OAUTH_SERVER_ERROR,
+} fe_oauth_state_enum;
+
+typedef struct
+{
+ fe_oauth_state_enum state;
+
+ PGconn *conn;
+} fe_oauth_state;
+
+static void *
+oauth_init(PGconn *conn, const char *password,
+ const char *sasl_mechanism)
+{
+ fe_oauth_state *state;
+
+ /*
+ * We only support one SASL mechanism here; anything else is programmer
+ * error.
+ */
+ Assert(sasl_mechanism != NULL);
+ Assert(!strcmp(sasl_mechanism, OAUTHBEARER_NAME));
+
+ state = malloc(sizeof(*state));
+ if (!state)
+ return NULL;
+
+ state->state = FE_OAUTH_INIT;
+ state->conn = conn;
+
+ return state;
+}
+
+static const char *
+iddawc_error_string(int errcode)
+{
+ switch (errcode)
+ {
+ case I_OK:
+ return "I_OK";
+
+ case I_ERROR:
+ return "I_ERROR";
+
+ case I_ERROR_PARAM:
+ return "I_ERROR_PARAM";
+
+ case I_ERROR_MEMORY:
+ return "I_ERROR_MEMORY";
+
+ case I_ERROR_UNAUTHORIZED:
+ return "I_ERROR_UNAUTHORIZED";
+
+ case I_ERROR_SERVER:
+ return "I_ERROR_SERVER";
+ }
+
+ return "<unknown>";
+}
+
+static void
+iddawc_error(PGconn *conn, int errcode, const char *msg)
+{
+ appendPQExpBufferStr(&conn->errorMessage, libpq_gettext(msg));
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext(" (iddawc error %s)\n"),
+ iddawc_error_string(errcode));
+}
+
+static void
+iddawc_request_error(PGconn *conn, struct _i_session *i, int err, const char *msg)
+{
+ const char *error_code;
+ const char *desc;
+
+ appendPQExpBuffer(&conn->errorMessage, "%s: ", libpq_gettext(msg));
+
+ error_code = i_get_str_parameter(i, I_OPT_ERROR);
+ if (!error_code)
+ {
+ /*
+ * The server didn't give us any useful information, so just print the
+ * error code.
+ */
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("(iddawc error %s)\n"),
+ iddawc_error_string(err));
+ return;
+ }
+
+ /* If the server gave a string description, print that too. */
+ desc = i_get_str_parameter(i, I_OPT_ERROR_DESCRIPTION);
+ if (desc)
+ appendPQExpBuffer(&conn->errorMessage, "%s ", desc);
+
+ appendPQExpBuffer(&conn->errorMessage, "(%s)\n", error_code);
+}
+
+static char *
+get_auth_token(PGconn *conn)
+{
+ PQExpBuffer token_buf = NULL;
+ struct _i_session session;
+ int err;
+ int auth_method;
+ bool user_prompted = false;
+ const char *verification_uri;
+ const char *user_code;
+ const char *access_token;
+ const char *token_type;
+ char *token = NULL;
+
+ if (!conn->oauth_discovery_uri)
+ return strdup(""); /* ask the server for one */
+
+ i_init_session(&session);
+
+ if (!conn->oauth_client_id)
+ {
+ /* We can't talk to a server without a client identifier. */
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("no oauth_client_id is set for the connection"));
+ goto cleanup;
+ }
+
+ token_buf = createPQExpBuffer();
+
+ if (!token_buf)
+ goto cleanup;
+
+ err = i_set_str_parameter(&session, I_OPT_OPENID_CONFIG_ENDPOINT, conn->oauth_discovery_uri);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set OpenID config endpoint");
+ goto cleanup;
+ }
+
+ err = i_get_openid_config(&session);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to fetch OpenID discovery document");
+ goto cleanup;
+ }
+
+ if (!i_get_str_parameter(&session, I_OPT_TOKEN_ENDPOINT))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer has no token endpoint"));
+ goto cleanup;
+ }
+
+ if (!i_get_str_parameter(&session, I_OPT_DEVICE_AUTHORIZATION_ENDPOINT))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer does not support device authorization"));
+ goto cleanup;
+ }
+
+ err = i_set_response_type(&session, I_RESPONSE_TYPE_DEVICE_CODE);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set device code response type");
+ goto cleanup;
+ }
+
+ auth_method = I_TOKEN_AUTH_METHOD_NONE;
+ if (conn->oauth_client_secret && *conn->oauth_client_secret)
+ auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
+
+ err = i_set_parameter_list(&session,
+ I_OPT_CLIENT_ID, conn->oauth_client_id,
+ I_OPT_CLIENT_SECRET, conn->oauth_client_secret,
+ I_OPT_TOKEN_METHOD, auth_method,
+ I_OPT_SCOPE, conn->oauth_scope,
+ I_OPT_NONE
+ );
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set client identifier");
+ goto cleanup;
+ }
+
+ err = i_run_device_auth_request(&session);
+ if (err)
+ {
+ iddawc_request_error(conn, &session, err,
+ "failed to obtain device authorization");
+ goto cleanup;
+ }
+
+ verification_uri = i_get_str_parameter(&session, I_OPT_DEVICE_AUTH_VERIFICATION_URI);
+ if (!verification_uri)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a verification URI"));
+ goto cleanup;
+ }
+
+ user_code = i_get_str_parameter(&session, I_OPT_DEVICE_AUTH_USER_CODE);
+ if (!user_code)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a user code"));
+ goto cleanup;
+ }
+
+ /*
+ * Poll the token endpoint until either the user logs in and authorizes the
+ * use of a token, or a hard failure occurs. We perform one ping _before_
+ * prompting the user, so that we don't make them do the work of logging in
+ * only to find that the token endpoint is completely unreachable.
+ */
+ err = i_run_token_request(&session);
+ while (err)
+ {
+ const char *error_code;
+ uint interval;
+
+ error_code = i_get_str_parameter(&session, I_OPT_ERROR);
+
+ /*
+ * authorization_pending and slow_down are the only acceptable errors;
+ * anything else and we bail.
+ */
+ if (!error_code || (strcmp(error_code, "authorization_pending")
+ && strcmp(error_code, "slow_down")))
+ {
+ iddawc_request_error(conn, &session, err,
+ "OAuth token retrieval failed");
+ goto cleanup;
+ }
+
+ if (!user_prompted)
+ {
+ /*
+ * Now that we know the token endpoint isn't broken, give the user
+ * the login instructions.
+ */
+ pqInternalNotice(&conn->noticeHooks,
+ "Visit %s and enter the code: %s",
+ verification_uri, user_code);
+
+ user_prompted = true;
+ }
+
+ /*
+ * We are required to wait between polls; the server tells us how long.
+ * TODO: if interval's not set, we need to default to five seconds
+ * TODO: sanity check the interval
+ */
+ interval = i_get_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL);
+
+ /*
+ * A slow_down error requires us to permanently increase our retry
+ * interval by five seconds. RFC 8628, Sec. 3.5.
+ */
+ if (!strcmp(error_code, "slow_down"))
+ {
+ interval += 5;
+ i_set_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL, interval);
+ }
+
+ sleep(interval);
+
+ /*
+ * XXX Reset the error code before every call, because iddawc won't do
+ * that for us. This matters if the server first sends a "pending" error
+ * code, then later hard-fails without sending an error code to
+ * overwrite the first one.
+ *
+ * That we have to do this at all seems like a bug in iddawc.
+ */
+ i_set_str_parameter(&session, I_OPT_ERROR, NULL);
+
+ err = i_run_token_request(&session);
+ }
+
+ access_token = i_get_str_parameter(&session, I_OPT_ACCESS_TOKEN);
+ token_type = i_get_str_parameter(&session, I_OPT_TOKEN_TYPE);
+
+ if (!access_token || !token_type || strcasecmp(token_type, "Bearer"))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a bearer token"));
+ goto cleanup;
+ }
+
+ appendPQExpBufferStr(token_buf, "Bearer ");
+ appendPQExpBufferStr(token_buf, access_token);
+
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+
+ token = strdup(token_buf->data);
+
+cleanup:
+ if (token_buf)
+ destroyPQExpBuffer(token_buf);
+ i_clean_session(&session);
+
+ return token;
+}
+
+#define kvsep "\x01"
+
+static char *
+client_initial_response(PGconn *conn)
+{
+ static const char * const resp_format = "n,," kvsep "auth=%s" kvsep kvsep;
+
+ PQExpBuffer token_buf;
+ PQExpBuffer discovery_buf = NULL;
+ char *token = NULL;
+ char *response = NULL;
+
+ token_buf = createPQExpBuffer();
+ if (!token_buf)
+ goto cleanup;
+
+ /*
+ * If we don't yet have a discovery URI, but the user gave us an explicit
+ * issuer, use the .well-known discovery URI for that issuer.
+ */
+ if (!conn->oauth_discovery_uri && conn->oauth_issuer)
+ {
+ discovery_buf = createPQExpBuffer();
+ if (!discovery_buf)
+ goto cleanup;
+
+ appendPQExpBufferStr(discovery_buf, conn->oauth_issuer);
+ appendPQExpBufferStr(discovery_buf, "/.well-known/openid-configuration");
+
+ if (PQExpBufferBroken(discovery_buf))
+ goto cleanup;
+
+ conn->oauth_discovery_uri = strdup(discovery_buf->data);
+ }
+
+ token = get_auth_token(conn);
+ if (!token)
+ goto cleanup;
+
+ appendPQExpBuffer(token_buf, resp_format, token);
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+
+ response = strdup(token_buf->data);
+
+cleanup:
+ if (token)
+ free(token);
+ if (discovery_buf)
+ destroyPQExpBuffer(discovery_buf);
+ if (token_buf)
+ destroyPQExpBuffer(token_buf);
+
+ return response;
+}
+
+#define ERROR_STATUS_FIELD "status"
+#define ERROR_SCOPE_FIELD "scope"
+#define ERROR_OPENID_CONFIGURATION_FIELD "openid-configuration"
+
+struct json_ctx
+{
+ char *errmsg; /* any non-NULL value stops all processing */
+ PQExpBufferData errbuf; /* backing memory for errmsg */
+ int nested; /* nesting level (zero is the top) */
+
+ const char *target_field_name; /* points to a static allocation */
+ char **target_field; /* see below */
+
+ /* target_field, if set, points to one of the following: */
+ char *status;
+ char *scope;
+ char *discovery_uri;
+};
+
+#define oauth_json_has_error(ctx) \
+ (PQExpBufferDataBroken((ctx)->errbuf) || (ctx)->errmsg)
+
+#define oauth_json_set_error(ctx, ...) \
+ do { \
+ appendPQExpBuffer(&(ctx)->errbuf, __VA_ARGS__); \
+ (ctx)->errmsg = (ctx)->errbuf.data; \
+ } while (0)
+
+static void
+oauth_json_object_start(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ return; /* short-circuit */
+
+ if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ oauth_json_set_error(ctx,
+ libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+
+ ++ctx->nested;
+}
+
+static void
+oauth_json_object_end(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ return; /* short-circuit */
+
+ --ctx->nested;
+}
+
+static void
+oauth_json_object_field_start(void *state, char *name, bool isnull)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ {
+ /* short-circuit */
+ free(name);
+ return;
+ }
+
+ if (ctx->nested == 1)
+ {
+ if (!strcmp(name, ERROR_STATUS_FIELD))
+ {
+ ctx->target_field_name = ERROR_STATUS_FIELD;
+ ctx->target_field = &ctx->status;
+ }
+ else if (!strcmp(name, ERROR_SCOPE_FIELD))
+ {
+ ctx->target_field_name = ERROR_SCOPE_FIELD;
+ ctx->target_field = &ctx->scope;
+ }
+ else if (!strcmp(name, ERROR_OPENID_CONFIGURATION_FIELD))
+ {
+ ctx->target_field_name = ERROR_OPENID_CONFIGURATION_FIELD;
+ ctx->target_field = &ctx->discovery_uri;
+ }
+ }
+
+ free(name);
+}
+
+static void
+oauth_json_array_start(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ return; /* short-circuit */
+
+ if (!ctx->nested)
+ {
+ ctx->errmsg = libpq_gettext("top-level element must be an object");
+ }
+ else if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ oauth_json_set_error(ctx,
+ libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+}
+
+static void
+oauth_json_scalar(void *state, char *token, JsonTokenType type)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ {
+ /* short-circuit */
+ free(token);
+ return;
+ }
+
+ if (!ctx->nested)
+ {
+ ctx->errmsg = libpq_gettext("top-level element must be an object");
+ }
+ else if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ if (type == JSON_TOKEN_STRING)
+ {
+ *ctx->target_field = token;
+
+ ctx->target_field = NULL;
+ ctx->target_field_name = NULL;
+
+ return; /* don't free the token we're using */
+ }
+
+ oauth_json_set_error(ctx,
+ libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+
+ free(token);
+}
+
+static bool
+handle_oauth_sasl_error(PGconn *conn, char *msg, int msglen)
+{
+ JsonLexContext lex = {0};
+ JsonSemAction sem = {0};
+ JsonParseErrorType err;
+ struct json_ctx ctx = {0};
+ char *errmsg = NULL;
+
+ /* Sanity check. */
+ if (strlen(msg) != msglen)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("server's error message contained an embedded NULL"));
+ return false;
+ }
+
+ initJsonLexContextCstringLen(&lex, msg, msglen, PG_UTF8, true);
+
+ initPQExpBuffer(&ctx.errbuf);
+ sem.semstate = &ctx;
+
+ sem.object_start = oauth_json_object_start;
+ sem.object_end = oauth_json_object_end;
+ sem.object_field_start = oauth_json_object_field_start;
+ sem.array_start = oauth_json_array_start;
+ sem.scalar = oauth_json_scalar;
+
+ err = pg_parse_json(&lex, &sem);
+
+ if (err != JSON_SUCCESS)
+ {
+ errmsg = json_errdetail(err, &lex);
+ }
+ else if (PQExpBufferDataBroken(ctx.errbuf))
+ {
+ errmsg = libpq_gettext("out of memory");
+ }
+ else if (ctx.errmsg)
+ {
+ errmsg = ctx.errmsg;
+ }
+
+ if (errmsg)
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("failed to parse server's error response: %s"),
+ errmsg);
+
+ /* Don't need the error buffer or the JSON lexer anymore. */
+ termPQExpBuffer(&ctx.errbuf);
+ termJsonLexContext(&lex);
+
+ if (errmsg)
+ return false;
+
+ /* TODO: what if these override what the user already specified? */
+ if (ctx.discovery_uri)
+ {
+ if (conn->oauth_discovery_uri)
+ free(conn->oauth_discovery_uri);
+
+ conn->oauth_discovery_uri = ctx.discovery_uri;
+ }
+
+ if (ctx.scope)
+ {
+ if (conn->oauth_scope)
+ free(conn->oauth_scope);
+
+ conn->oauth_scope = ctx.scope;
+ }
+ /* TODO: missing error scope should clear any existing connection scope */
+
+ if (!ctx.status)
+ {
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("server sent error response without a status"));
+ return false;
+ }
+
+ if (!strcmp(ctx.status, "invalid_token"))
+ {
+ /*
+ * invalid_token is the only error code we'll automatically retry for,
+ * but only if we have enough information to do so.
+ */
+ if (conn->oauth_discovery_uri)
+ conn->oauth_want_retry = true;
+ }
+ /* TODO: include status in hard failure message */
+
+ return true;
+}
+
+static void
+oauth_exchange(void *opaq, bool final,
+ char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success)
+{
+ fe_oauth_state *state = opaq;
+ PGconn *conn = state->conn;
+
+ *done = false;
+ *success = false;
+ *output = NULL;
+ *outputlen = 0;
+
+ switch (state->state)
+ {
+ case FE_OAUTH_INIT:
+ Assert(inputlen == -1);
+
+ *output = client_initial_response(conn);
+ if (!*output)
+ goto error;
+
+ *outputlen = strlen(*output);
+ state->state = FE_OAUTH_BEARER_SENT;
+
+ break;
+
+ case FE_OAUTH_BEARER_SENT:
+ if (final)
+ {
+ /* TODO: ensure there is no message content here. */
+ *done = true;
+ *success = true;
+
+ break;
+ }
+
+ /*
+ * Error message sent by the server.
+ */
+ if (!handle_oauth_sasl_error(conn, input, inputlen))
+ goto error;
+
+ /*
+ * Respond with the required dummy message (RFC 7628, sec. 3.2.3).
+ */
+ *output = strdup(kvsep);
+ *outputlen = strlen(*output); /* == 1 */
+
+ state->state = FE_OAUTH_SERVER_ERROR;
+ break;
+
+ case FE_OAUTH_SERVER_ERROR:
+ /*
+ * After an error, the server should send an error response to fail
+ * the SASL handshake, which is handled in higher layers.
+ *
+ * If we get here, the server either sent *another* challenge which
+ * isn't defined in the RFC, or completed the handshake successfully
+ * after telling us it was going to fail. Neither is acceptable.
+ */
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("server sent additional OAuth data after error\n"));
+ goto error;
+
+ default:
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("invalid OAuth exchange state\n"));
+ goto error;
+ }
+
+ return;
+
+error:
+ *done = true;
+ *success = false;
+}
+
+static bool
+oauth_channel_bound(void *opaq)
+{
+ /* This mechanism does not support channel binding. */
+ return false;
+}
+
+static void
+oauth_free(void *opaq)
+{
+ fe_oauth_state *state = opaq;
+
+ free(state);
+}
diff --git a/src/interfaces/libpq/fe-auth-sasl.h b/src/interfaces/libpq/fe-auth-sasl.h
index 3d7ee576f2..0920102908 100644
--- a/src/interfaces/libpq/fe-auth-sasl.h
+++ b/src/interfaces/libpq/fe-auth-sasl.h
@@ -65,6 +65,8 @@ typedef struct pg_fe_sasl_mech
*
* state: The opaque mechanism state returned by init()
*
+ * final: true if the server has sent a final exchange outcome
+ *
* input: The challenge data sent by the server, or NULL when
* generating a client-first initial response (that is, when
* the server expects the client to send a message to start
@@ -92,7 +94,8 @@ typedef struct pg_fe_sasl_mech
* Ignored if *done is false.
*--------
*/
- void (*exchange) (void *state, char *input, int inputlen,
+ void (*exchange) (void *state, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success);
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index 4337e89ce9..489cbeda50 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -24,7 +24,8 @@
/* The exported SCRAM callback mechanism. */
static void *scram_init(PGconn *conn, const char *password,
const char *sasl_mechanism);
-static void scram_exchange(void *opaq, char *input, int inputlen,
+static void scram_exchange(void *opaq, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success);
static bool scram_channel_bound(void *opaq);
@@ -205,7 +206,8 @@ scram_free(void *opaq)
* Exchange a SCRAM message with backend.
*/
static void
-scram_exchange(void *opaq, char *input, int inputlen,
+scram_exchange(void *opaq, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success)
{
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 3421ed4685..0b5b91962a 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -39,6 +39,7 @@
#endif
#include "common/md5.h"
+#include "common/oauth-common.h"
#include "common/scram-common.h"
#include "fe-auth.h"
#include "fe-auth-sasl.h"
@@ -423,7 +424,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
bool success;
const char *selected_mechanism;
PQExpBufferData mechanism_buf;
- char *password;
+ char *password = NULL;
initPQExpBuffer(&mechanism_buf);
@@ -445,8 +446,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
/*
* Parse the list of SASL authentication mechanisms in the
* AuthenticationSASL message, and select the best mechanism that we
- * support. SCRAM-SHA-256-PLUS and SCRAM-SHA-256 are the only ones
- * supported at the moment, listed by order of decreasing importance.
+ * support. Mechanisms are listed by order of decreasing importance.
*/
selected_mechanism = NULL;
for (;;)
@@ -486,6 +486,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
{
selected_mechanism = SCRAM_SHA_256_PLUS_NAME;
conn->sasl = &pg_scram_mech;
+ conn->password_needed = true;
}
#else
/*
@@ -523,7 +524,17 @@ pg_SASL_init(PGconn *conn, int payloadlen)
{
selected_mechanism = SCRAM_SHA_256_NAME;
conn->sasl = &pg_scram_mech;
+ conn->password_needed = true;
}
+#ifdef USE_OAUTH
+ else if (strcmp(mechanism_buf.data, OAUTHBEARER_NAME) == 0 &&
+ !selected_mechanism)
+ {
+ selected_mechanism = OAUTHBEARER_NAME;
+ conn->sasl = &pg_oauth_mech;
+ conn->password_needed = false;
+ }
+#endif
}
if (!selected_mechanism)
@@ -548,18 +559,19 @@ pg_SASL_init(PGconn *conn, int payloadlen)
/*
* First, select the password to use for the exchange, complaining if
- * there isn't one. Currently, all supported SASL mechanisms require a
- * password, so we can just go ahead here without further distinction.
+ * there isn't one and the SASL mechanism needs it.
*/
- conn->password_needed = true;
- password = conn->connhost[conn->whichhost].password;
- if (password == NULL)
- password = conn->pgpass;
- if (password == NULL || password[0] == '\0')
+ if (conn->password_needed)
{
- appendPQExpBufferStr(&conn->errorMessage,
- PQnoPasswordSupplied);
- goto error;
+ password = conn->connhost[conn->whichhost].password;
+ if (password == NULL)
+ password = conn->pgpass;
+ if (password == NULL || password[0] == '\0')
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ PQnoPasswordSupplied);
+ goto error;
+ }
}
Assert(conn->sasl);
@@ -577,7 +589,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
goto oom_error;
/* Get the mechanism-specific Initial Client Response, if any */
- conn->sasl->exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state, false,
NULL, -1,
&initialresponse, &initialresponselen,
&done, &success);
@@ -658,7 +670,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
/* For safety and convenience, ensure the buffer is NULL-terminated. */
challenge[payloadlen] = '\0';
- conn->sasl->exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state, final,
challenge, payloadlen,
&output, &outputlen,
&done, &success);
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 63927480ee..03bea124a6 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -26,4 +26,7 @@ extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
extern const pg_fe_sasl_mech pg_scram_mech;
extern char *pg_fe_scram_build_secret(const char *password);
+/* Mechanisms in fe-auth-oauth.c */
+extern const pg_fe_sasl_mech pg_oauth_mech;
+
#endif /* FE_AUTH_H */
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 49eec3e835..ba9c097060 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -344,6 +344,23 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
"Target-Session-Attrs", "", 15, /* sizeof("prefer-standby") = 15 */
offsetof(struct pg_conn, target_session_attrs)},
+ /* OAuth v2 */
+ {"oauth_issuer", NULL, NULL, NULL,
+ "OAuth-Issuer", "", 40,
+ offsetof(struct pg_conn, oauth_issuer)},
+
+ {"oauth_client_id", NULL, NULL, NULL,
+ "OAuth-Client-ID", "", 40,
+ offsetof(struct pg_conn, oauth_client_id)},
+
+ {"oauth_client_secret", NULL, NULL, NULL,
+ "OAuth-Client-Secret", "", 40,
+ offsetof(struct pg_conn, oauth_client_secret)},
+
+ {"oauth_scope", NULL, NULL, NULL,
+ "OAuth-Scope", "", 15,
+ offsetof(struct pg_conn, oauth_scope)},
+
/* Terminating entry --- MUST BE LAST */
{NULL, NULL, NULL, NULL,
NULL, NULL, 0}
@@ -606,6 +623,7 @@ pqDropServerData(PGconn *conn)
conn->write_err_msg = NULL;
conn->be_pid = 0;
conn->be_key = 0;
+ /* conn->oauth_want_retry = false; TODO */
}
@@ -3355,6 +3373,16 @@ keep_going: /* We will come back to here until there is
/* Check to see if we should mention pgpassfile */
pgpassfileWarning(conn);
+#ifdef USE_OAUTH
+ if (conn->sasl == &pg_oauth_mech
+ && conn->oauth_want_retry)
+ {
+ /* TODO: only allow retry once */
+ need_new_connection = true;
+ goto keep_going;
+ }
+#endif
+
#ifdef ENABLE_GSS
/*
@@ -4129,6 +4157,16 @@ freePGconn(PGconn *conn)
free(conn->rowBuf);
if (conn->target_session_attrs)
free(conn->target_session_attrs);
+ if (conn->oauth_issuer)
+ free(conn->oauth_issuer);
+ if (conn->oauth_discovery_uri)
+ free(conn->oauth_discovery_uri);
+ if (conn->oauth_client_id)
+ free(conn->oauth_client_id);
+ if (conn->oauth_client_secret)
+ free(conn->oauth_client_secret);
+ if (conn->oauth_scope)
+ free(conn->oauth_scope);
termPQExpBuffer(&conn->errorMessage);
termPQExpBuffer(&conn->workBuffer);
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 490458adef..3d20482550 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -394,6 +394,14 @@ struct pg_conn
char *ssl_max_protocol_version; /* maximum TLS protocol version */
char *target_session_attrs; /* desired session properties */
+ /* OAuth v2 */
+ char *oauth_issuer; /* token issuer URL */
+ char *oauth_discovery_uri; /* URI of the issuer's discovery document */
+ char *oauth_client_id; /* client identifier */
+ char *oauth_client_secret; /* client secret */
+ char *oauth_scope; /* access token scope */
+ bool oauth_want_retry; /* should we retry on failure? */
+
/* Optional file to write trace info to */
FILE *Pfdebug;
int traceFlags;
--
2.25.1
v2-0003-backend-add-OAUTHBEARER-SASL-mechanism.patchtext/x-patch; name=v2-0003-backend-add-OAUTHBEARER-SASL-mechanism.patchDownload
From d09a00b4d52a5ed578ee5cd7623108ebdd12f202 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 4 May 2021 16:21:11 -0700
Subject: [PATCH v2 3/5] backend: add OAUTHBEARER SASL mechanism
DO NOT USE THIS PROOF OF CONCEPT IN PRODUCTION.
Implement OAUTHBEARER (RFC 7628) on the server side. This adds a new
auth method, oauth, to pg_hba.
Because OAuth implementations vary so wildly, and bearer token
validation is heavily dependent on the issuing party, authn/z is done by
communicating with an external program: the oauth_validator_command.
This command must do the following:
1. Receive the bearer token by reading its contents from a file
descriptor passed from the server. (The numeric value of this
descriptor may be inserted into the oauth_validator_command using the
%f specifier.)
This MUST be the first action the command performs. The server will
not begin reading stdout from the command until the token has been
read in full, so if the command tries to print anything and hits a
buffer limit, the backend will deadlock and time out.
2. Validate the bearer token. The correct way to do this depends on the
issuer, but it generally involves either cryptographic operations to
prove that the token was issued by a trusted party, or the
presentation of the bearer token to some other party so that _it_ can
perform validation.
The command MUST maintain confidentiality of the bearer token, since
in most cases it can be used just like a password. (There are ways to
cryptographically bind tokens to client certificates, but they are
way beyond the scope of this commit message.)
If the token cannot be validated, the command must exit with a
non-zero status. Further authentication/authorization is pointless if
the bearer token wasn't issued by someone you trust.
3. Authenticate the user, authorize the user, or both:
a. To authenticate the user, use the bearer token to retrieve some
trusted identifier string for the end user. The exact process for
this is, again, issuer-dependent. The command should print the
authenticated identity string to stdout, followed by a newline.
If the user cannot be authenticated, the validator should not
print anything to stdout. It should also exit with a non-zero
status, unless the token may be used to authorize the connection
through some other means (see below).
On a success, the command may then exit with a zero success code.
By default, the server will then check to make sure the identity
string matches the role that is being used (or matches a usermap
entry, if one is in use).
b. To optionally authorize the user, in combination with the HBA
option trust_validator_authz=1 (see below), the validator simply
returns a zero exit code if the client should be allowed to
connect with its presented role (which can be passed to the
command using the %r specifier), or a non-zero code otherwise.
The hard part is in determining whether the given token truly
authorizes the client to use the given role, which must
unfortunately be left as an exercise to the reader.
This obviously requires some care, as a poorly implemented token
validator may silently open the entire database to anyone with a
bearer token. But it may be a more portable approach, since OAuth
is designed as an authorization framework, not an authentication
framework. For example, the user's bearer token could carry an
"allow_superuser_access" claim, which would authorize pseudonymous
database access as any role. It's then up to the OAuth system
administrators to ensure that allow_superuser_access is doled out
only to the proper users.
c. It's possible that the user can be successfully authenticated but
isn't authorized to connect. In this case, the command may print
the authenticated ID and then fail with a non-zero exit code.
(This makes it easier to see what's going on in the Postgres
logs.)
4. Token validators may optionally log to stderr. This will be printed
verbatim into the Postgres server logs.
The oauth method supports the following HBA options (but note that two
of them are not optional, since we have no way of choosing sensible
defaults):
issuer: Required. The URL of the OAuth issuing party, which the client
must contact to receive a bearer token.
Some real-world examples as of time of writing:
- https://accounts.google.com
- https://login.microsoft.com/[tenant-id]/v2.0
scope: Required. The OAuth scope(s) required for the server to
authenticate and/or authorize the user. This is heavily
deployment-specific, but a simple example is "openid email".
map: Optional. Specify a standard PostgreSQL user map; this works
the same as with other auth methods such as peer. If a map is
not specified, the user ID returned by the token validator
must exactly match the role that's being requested (but see
trust_validator_authz, below).
trust_validator_authz:
Optional. When set to 1, this allows the token validator to
take full control of the authorization process. Standard user
mapping is skipped: if the validator command succeeds, the
client is allowed to connect under its desired role and no
further checks are done.
Unlike the client, servers support OAuth without needing to be built
against libiddawc (since the responsibility for "speaking" OAuth/OIDC
correctly is delegated entirely to the oauth_validator_command).
Several TODOs:
- port to platforms other than "modern Linux"
- overhaul the communication with oauth_validator_command, which is
currently a bad hack on OpenPipeStream()
- implement more sanity checks on the OAUTHBEARER message format and
tokens sent by the client
- implement more helpful handling of HBA misconfigurations
- properly interpolate JSON when generating error responses
- use logdetail during auth failures
- deal with role names that can't be safely passed to system() without
shell-escaping
- allow passing the configured issuer to the oauth_validator_command, to
deal with multi-issuer setups
- ...and more.
---
src/backend/libpq/Makefile | 1 +
src/backend/libpq/auth-oauth.c | 797 +++++++++++++++++++++++++++++++++
src/backend/libpq/auth-sasl.c | 10 +-
src/backend/libpq/auth-scram.c | 4 +-
src/backend/libpq/auth.c | 26 +-
src/backend/libpq/hba.c | 29 +-
src/backend/utils/misc/guc.c | 12 +
src/include/libpq/auth.h | 17 +
src/include/libpq/hba.h | 8 +-
src/include/libpq/oauth.h | 24 +
src/include/libpq/sasl.h | 11 +
11 files changed, 907 insertions(+), 32 deletions(-)
create mode 100644 src/backend/libpq/auth-oauth.c
create mode 100644 src/include/libpq/oauth.h
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 6d385fd6a4..98eb2a8242 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,6 +15,7 @@ include $(top_builddir)/src/Makefile.global
# be-fsstubs is here for historical reasons, probably belongs elsewhere
OBJS = \
+ auth-oauth.o \
auth-sasl.o \
auth-scram.o \
auth.o \
diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c
new file mode 100644
index 0000000000..c47211132c
--- /dev/null
+++ b/src/backend/libpq/auth-oauth.c
@@ -0,0 +1,797 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth-oauth.c
+ * Server-side implementation of the SASL OAUTHBEARER mechanism.
+ *
+ * See the following RFC for more details:
+ * - RFC 7628: https://tools.ietf.org/html/rfc7628
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/backend/libpq/auth-oauth.c
+ *
+ *-------------------------------------------------------------------------
+ */
+#include "postgres.h"
+
+#include <unistd.h>
+#include <fcntl.h>
+
+#include "common/oauth-common.h"
+#include "lib/stringinfo.h"
+#include "libpq/auth.h"
+#include "libpq/hba.h"
+#include "libpq/oauth.h"
+#include "libpq/sasl.h"
+#include "storage/fd.h"
+
+/* GUC */
+char *oauth_validator_command;
+
+static void oauth_get_mechanisms(Port *port, StringInfo buf);
+static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
+static int oauth_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, char **logdetail);
+
+/* Mechanism declaration */
+const pg_be_sasl_mech pg_be_oauth_mech = {
+ oauth_get_mechanisms,
+ oauth_init,
+ oauth_exchange,
+
+ PG_MAX_AUTH_TOKEN_LENGTH,
+};
+
+
+typedef enum
+{
+ OAUTH_STATE_INIT = 0,
+ OAUTH_STATE_ERROR,
+ OAUTH_STATE_FINISHED,
+} oauth_state;
+
+struct oauth_ctx
+{
+ oauth_state state;
+ Port *port;
+ const char *issuer;
+ const char *scope;
+};
+
+static char *sanitize_char(char c);
+static char *parse_kvpairs_for_auth(char **input);
+static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
+static bool validate(Port *port, const char *auth, char **logdetail);
+static bool run_validator_command(Port *port, const char *token);
+static bool check_exit(FILE **fh, const char *command);
+static bool unset_cloexec(int fd);
+static bool username_ok_for_shell(const char *username);
+
+#define KVSEP 0x01
+#define AUTH_KEY "auth"
+#define BEARER_SCHEME "Bearer "
+
+static void
+oauth_get_mechanisms(Port *port, StringInfo buf)
+{
+ /* Only OAUTHBEARER is supported. */
+ appendStringInfoString(buf, OAUTHBEARER_NAME);
+ appendStringInfoChar(buf, '\0');
+}
+
+static void *
+oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
+{
+ struct oauth_ctx *ctx;
+
+ if (strcmp(selected_mech, OAUTHBEARER_NAME))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("client selected an invalid SASL authentication mechanism")));
+
+ ctx = palloc0(sizeof(*ctx));
+
+ ctx->state = OAUTH_STATE_INIT;
+ ctx->port = port;
+
+ Assert(port->hba);
+ ctx->issuer = port->hba->oauth_issuer;
+ ctx->scope = port->hba->oauth_scope;
+
+ return ctx;
+}
+
+static int
+oauth_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, char **logdetail)
+{
+ char *p;
+ char cbind_flag;
+ char *auth;
+
+ struct oauth_ctx *ctx = opaq;
+
+ *output = NULL;
+ *outputlen = -1;
+
+ /*
+ * If the client didn't include an "Initial Client Response" in the
+ * SASLInitialResponse message, send an empty challenge, to which the
+ * client will respond with the same data that usually comes in the
+ * Initial Client Response.
+ */
+ if (input == NULL)
+ {
+ Assert(ctx->state == OAUTH_STATE_INIT);
+
+ *output = pstrdup("");
+ *outputlen = 0;
+ return PG_SASL_EXCHANGE_CONTINUE;
+ }
+
+ /*
+ * Check that the input length agrees with the string length of the input.
+ */
+ if (inputlen == 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("The message is empty.")));
+ if (inputlen != strlen(input))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message length does not match input length.")));
+
+ switch (ctx->state)
+ {
+ case OAUTH_STATE_INIT:
+ /* Handle this case below. */
+ break;
+
+ case OAUTH_STATE_ERROR:
+ /*
+ * Only one response is valid for the client during authentication
+ * failure: a single kvsep.
+ */
+ if (inputlen != 1 || *input != KVSEP)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Client did not send a kvsep response.")));
+
+ /* The (failed) handshake is now complete. */
+ ctx->state = OAUTH_STATE_FINISHED;
+ return PG_SASL_EXCHANGE_FAILURE;
+
+ default:
+ elog(ERROR, "invalid OAUTHBEARER exchange state");
+ return PG_SASL_EXCHANGE_FAILURE;
+ }
+
+ /* Handle the client's initial message. */
+ p = pstrdup(input);
+
+ /*
+ * OAUTHBEARER does not currently define a channel binding (so there is no
+ * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a 'y'
+ * specifier purely for the remote chance that a future specification could
+ * define one; then future clients can still interoperate with this server
+ * implementation. 'n' is the expected case.
+ */
+ cbind_flag = *p;
+ switch (cbind_flag)
+ {
+ case 'p':
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data.")));
+ break;
+
+ case 'y': /* fall through */
+ case 'n':
+ p++;
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Comma expected, but found character %s.",
+ sanitize_char(*p))));
+ p++;
+ break;
+
+ default:
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Unexpected channel-binding flag %s.",
+ sanitize_char(cbind_flag))));
+ }
+
+ /*
+ * Forbid optional authzid (authorization identity). We don't support it.
+ */
+ if (*p == 'a')
+ ereport(ERROR,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("client uses authorization identity, but it is not supported")));
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Unexpected attribute %s in client-first-message.",
+ sanitize_char(*p))));
+ p++;
+
+ /* All remaining fields are separated by the RFC's kvsep (\x01). */
+ if (*p != KVSEP)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Key-value separator expected, but found character %s.",
+ sanitize_char(*p))));
+ p++;
+
+ auth = parse_kvpairs_for_auth(&p);
+ if (!auth)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message does not contain an auth value.")));
+
+ /* We should be at the end of our message. */
+ if (*p)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains additional data after the final terminator.")));
+
+ if (!validate(ctx->port, auth, logdetail))
+ {
+ generate_error_response(ctx, output, outputlen);
+
+ ctx->state = OAUTH_STATE_ERROR;
+ return PG_SASL_EXCHANGE_CONTINUE;
+ }
+
+ ctx->state = OAUTH_STATE_FINISHED;
+ return PG_SASL_EXCHANGE_SUCCESS;
+}
+
+/*
+ * Convert an arbitrary byte to printable form. For error messages.
+ *
+ * If it's a printable ASCII character, print it as a single character.
+ * otherwise, print it in hex.
+ *
+ * The returned pointer points to a static buffer.
+ */
+static char *
+sanitize_char(char c)
+{
+ static char buf[5];
+
+ if (c >= 0x21 && c <= 0x7E)
+ snprintf(buf, sizeof(buf), "'%c'", c);
+ else
+ snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
+ return buf;
+}
+
+/*
+ * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is
+ * found, its value is returned.
+ */
+static char *
+parse_kvpairs_for_auth(char **input)
+{
+ char *pos = *input;
+ char *auth = NULL;
+
+ /*
+ * The relevant ABNF, from Sec. 3.1:
+ *
+ * kvsep = %x01
+ * key = 1*(ALPHA)
+ * value = *(VCHAR / SP / HTAB / CR / LF )
+ * kvpair = key "=" value kvsep
+ * ;;gs2-header = See RFC 5801
+ * client-resp = (gs2-header kvsep *kvpair kvsep) / kvsep
+ *
+ * By the time we reach this code, the gs2-header and initial kvsep have
+ * already been validated. We start at the beginning of the first kvpair.
+ */
+
+ while (*pos)
+ {
+ char *end;
+ char *sep;
+ char *key;
+ char *value;
+
+ /*
+ * Find the end of this kvpair. Note that input is null-terminated by
+ * the SASL code, so the strchr() is bounded.
+ */
+ end = strchr(pos, KVSEP);
+ if (!end)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains an unterminated key/value pair.")));
+ *end = '\0';
+
+ if (pos == end)
+ {
+ /* Empty kvpair, signifying the end of the list. */
+ *input = pos + 1;
+ return auth;
+ }
+
+ /*
+ * Find the end of the key name.
+ *
+ * TODO further validate the key/value grammar? empty keys, bad chars...
+ */
+ sep = strchr(pos, '=');
+ if (!sep)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains a key without a value.")));
+ *sep = '\0';
+
+ /* Both key and value are now safely terminated. */
+ key = pos;
+ value = sep + 1;
+
+ if (!strcmp(key, AUTH_KEY))
+ {
+ if (auth)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains multiple auth values.")));
+
+ auth = value;
+ }
+ else
+ {
+ /*
+ * The RFC also defines the host and port keys, but they are not
+ * required for OAUTHBEARER and we do not use them. Also, per
+ * Sec. 3.1, any key/value pairs we don't recognize must be ignored.
+ */
+ }
+
+ /* Move to the next pair. */
+ pos = end + 1;
+ }
+
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message did not contain a final terminator.")));
+
+ return NULL; /* unreachable */
+}
+
+static void
+generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
+{
+ StringInfoData buf;
+
+ /*
+ * The admin needs to set an issuer and scope for OAuth to work. There's not
+ * really a way to hide this from the user, either, because we can't choose
+ * a "default" issuer, so be honest in the failure message.
+ *
+ * TODO: see if there's a better place to fail, earlier than this.
+ */
+ if (!ctx->issuer || !ctx->scope)
+ ereport(FATAL,
+ (errcode(ERRCODE_INTERNAL_ERROR),
+ errmsg("OAuth is not properly configured for this user"),
+ errdetail_log("The issuer and scope parameters must be set in pg_hba.conf.")));
+
+
+ initStringInfo(&buf);
+
+ /*
+ * TODO: JSON escaping
+ */
+ appendStringInfo(&buf,
+ "{ "
+ "\"status\": \"invalid_token\", "
+ "\"openid-configuration\": \"%s/.well-known/openid-configuration\","
+ "\"scope\": \"%s\" "
+ "}",
+ ctx->issuer, ctx->scope);
+
+ *output = buf.data;
+ *outputlen = buf.len;
+}
+
+static bool
+validate(Port *port, const char *auth, char **logdetail)
+{
+ static const char * const b64_set = "abcdefghijklmnopqrstuvwxyz"
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "0123456789-._~+/";
+
+ const char *token;
+ size_t span;
+ int ret;
+
+ /* TODO: handle logdetail when the test framework can check it */
+
+ /*
+ * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec.
+ * 2.1:
+ *
+ * b64token = 1*( ALPHA / DIGIT /
+ * "-" / "." / "_" / "~" / "+" / "/" ) *"="
+ * credentials = "Bearer" 1*SP b64token
+ *
+ * The "credentials" construction is what we receive in our auth value.
+ *
+ * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization
+ * header format; RFC 7235 Sec. 2), the "Bearer" scheme string must be
+ * compared case-insensitively. (This is not mentioned in RFC 6750, but it's
+ * pointed out in RFC 7628 Sec. 4.)
+ *
+ * TODO: handle the Authorization spec, RFC 7235 Sec. 2.1.
+ */
+ if (strncasecmp(auth, BEARER_SCHEME, strlen(BEARER_SCHEME)))
+ return false;
+
+ /* Pull the bearer token out of the auth value. */
+ token = auth + strlen(BEARER_SCHEME);
+
+ /* Swallow any additional spaces. */
+ while (*token == ' ')
+ token++;
+
+ /*
+ * Before invoking the validator command, sanity-check the token format to
+ * avoid any injection attacks later in the chain. Invalid formats are
+ * technically a protocol violation, but don't reflect any information about
+ * the sensitive Bearer token back to the client; log at COMMERROR instead.
+ */
+
+ /* Tokens must not be empty. */
+ if (!*token)
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Bearer token is empty.")));
+ return false;
+ }
+
+ /*
+ * Make sure the token contains only allowed characters. Tokens may end with
+ * any number of '=' characters.
+ */
+ span = strspn(token, b64_set);
+ while (token[span] == '=')
+ span++;
+
+ if (token[span] != '\0')
+ {
+ /*
+ * This error message could be more helpful by printing the problematic
+ * character(s), but that'd be a bit like printing a piece of someone's
+ * password into the logs.
+ */
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Bearer token is not in the correct format.")));
+ return false;
+ }
+
+ /* Have the validator check the token. */
+ if (!run_validator_command(port, token))
+ return false;
+
+ if (port->hba->oauth_skip_usermap)
+ {
+ /*
+ * If the validator is our authorization authority, we're done.
+ * Authentication may or may not have been performed depending on the
+ * validator implementation; all that matters is that the validator says
+ * the user can log in with the target role.
+ */
+ return true;
+ }
+
+ /* Make sure the validator authenticated the user. */
+ if (!port->authn_id)
+ {
+ /* TODO: use logdetail; reduce message duplication */
+ ereport(LOG,
+ (errmsg("OAuth bearer authentication failed for user \"%s\": validator provided no identity",
+ port->user_name)));
+ return false;
+ }
+
+ /* Finally, check the user map. */
+ ret = check_usermap(port->hba->usermap, port->user_name, port->authn_id,
+ false);
+ return (ret == STATUS_OK);
+}
+
+static bool
+run_validator_command(Port *port, const char *token)
+{
+ bool success = false;
+ int rc;
+ int pipefd[2];
+ int rfd = -1;
+ int wfd = -1;
+
+ StringInfoData command = { 0 };
+ char *p;
+ FILE *fh = NULL;
+
+ ssize_t written;
+ char *line = NULL;
+ size_t size = 0;
+ ssize_t len;
+
+ Assert(oauth_validator_command);
+
+ if (!oauth_validator_command[0])
+ {
+ ereport(COMMERROR,
+ (errmsg("oauth_validator_command is not set"),
+ errhint("To allow OAuth authenticated connections, set "
+ "oauth_validator_command in postgresql.conf.")));
+ return false;
+ }
+
+ /*
+ * Since popen() is unidirectional, open up a pipe for the other direction.
+ * Use CLOEXEC to ensure that our write end doesn't accidentally get copied
+ * into child processes, which would prevent us from closing it cleanly.
+ *
+ * XXX this is ugly. We should just read from the child process's stdout,
+ * but that's a lot more code.
+ * XXX by bypassing the popen API, we open the potential of process
+ * deadlock. Clearly document child process requirements (i.e. the child
+ * MUST read all data off of the pipe before writing anything).
+ * TODO: port to Windows using _pipe().
+ */
+ rc = pipe2(pipefd, O_CLOEXEC);
+ if (rc < 0)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not create child pipe: %m")));
+ return false;
+ }
+
+ rfd = pipefd[0];
+ wfd = pipefd[1];
+
+ /* Allow the read pipe be passed to the child. */
+ if (!unset_cloexec(rfd))
+ {
+ /* error message was already logged */
+ goto cleanup;
+ }
+
+ /*
+ * Construct the command, substituting any recognized %-specifiers:
+ *
+ * %f: the file descriptor of the input pipe
+ * %r: the role that the client wants to assume (port->user_name)
+ * %%: a literal '%'
+ */
+ initStringInfo(&command);
+
+ for (p = oauth_validator_command; *p; p++)
+ {
+ if (p[0] == '%')
+ {
+ switch (p[1])
+ {
+ case 'f':
+ appendStringInfo(&command, "%d", rfd);
+ p++;
+ break;
+ case 'r':
+ /*
+ * TODO: decide how this string should be escaped. The role
+ * is controlled by the client, so if we don't escape it,
+ * command injections are inevitable.
+ *
+ * This is probably an indication that the role name needs
+ * to be communicated to the validator process in some other
+ * way. For this proof of concept, just be incredibly strict
+ * about the characters that are allowed in user names.
+ */
+ if (!username_ok_for_shell(port->user_name))
+ goto cleanup;
+
+ appendStringInfoString(&command, port->user_name);
+ p++;
+ break;
+ case '%':
+ appendStringInfoChar(&command, '%');
+ p++;
+ break;
+ default:
+ appendStringInfoChar(&command, p[0]);
+ }
+ }
+ else
+ appendStringInfoChar(&command, p[0]);
+ }
+
+ /* Execute the command. */
+ fh = OpenPipeStream(command.data, "re");
+ /* TODO: handle failures */
+
+ /* We don't need the read end of the pipe anymore. */
+ close(rfd);
+ rfd = -1;
+
+ /* Give the command the token to validate. */
+ written = write(wfd, token, strlen(token));
+ if (written != strlen(token))
+ {
+ /* TODO must loop for short writes, EINTR et al */
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not write token to child pipe: %m")));
+ goto cleanup;
+ }
+
+ close(wfd);
+ wfd = -1;
+
+ /*
+ * Read the command's response.
+ *
+ * TODO: getline() is probably too new to use, unfortunately.
+ * TODO: loop over all lines
+ */
+ if ((len = getline(&line, &size, fh)) >= 0)
+ {
+ /* TODO: fail if the authn_id doesn't end with a newline */
+ if (len > 0)
+ line[len - 1] = '\0';
+
+ set_authn_id(port, line);
+ }
+ else if (ferror(fh))
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not read from command \"%s\": %m",
+ command.data)));
+ goto cleanup;
+ }
+
+ /* Make sure the command exits cleanly. */
+ if (!check_exit(&fh, command.data))
+ {
+ /* error message already logged */
+ goto cleanup;
+ }
+
+ /* Done. */
+ success = true;
+
+cleanup:
+ if (line)
+ free(line);
+
+ /*
+ * In the successful case, the pipe fds are already closed. For the error
+ * case, always close out the pipe before waiting for the command, to
+ * prevent deadlock.
+ */
+ if (rfd >= 0)
+ close(rfd);
+ if (wfd >= 0)
+ close(wfd);
+
+ if (fh)
+ {
+ Assert(!success);
+ check_exit(&fh, command.data);
+ }
+
+ if (command.data)
+ pfree(command.data);
+
+ return success;
+}
+
+static bool
+check_exit(FILE **fh, const char *command)
+{
+ int rc;
+
+ rc = ClosePipeStream(*fh);
+ *fh = NULL;
+
+ if (rc == -1)
+ {
+ /* pclose() itself failed. */
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not close pipe to command \"%s\": %m",
+ command)));
+ }
+ else if (rc != 0)
+ {
+ char *reason = wait_result_to_str(rc);
+
+ ereport(COMMERROR,
+ (errmsg("failed to execute command \"%s\": %s",
+ command, reason)));
+
+ pfree(reason);
+ }
+
+ return (rc == 0);
+}
+
+static bool
+unset_cloexec(int fd)
+{
+ int flags;
+ int rc;
+
+ flags = fcntl(fd, F_GETFD);
+ if (flags == -1)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not get fd flags for child pipe: %m")));
+ return false;
+ }
+
+ rc = fcntl(fd, F_SETFD, flags & ~FD_CLOEXEC);
+ if (rc < 0)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not unset FD_CLOEXEC for child pipe: %m")));
+ return false;
+ }
+
+ return true;
+}
+
+/*
+ * XXX This should go away eventually and be replaced with either a proper
+ * escape or a different strategy for communication with the validator command.
+ */
+static bool
+username_ok_for_shell(const char *username)
+{
+ /* This set is borrowed from fe_utils' appendShellStringNoError(). */
+ static const char * const allowed = "abcdefghijklmnopqrstuvwxyz"
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "0123456789-_./:";
+ size_t span;
+
+ Assert(username && username[0]); /* should have already been checked */
+
+ span = strspn(username, allowed);
+ if (username[span] != '\0')
+ {
+ ereport(COMMERROR,
+ (errmsg("PostgreSQL user name contains unsafe characters and cannot be passed to the OAuth validator")));
+ return false;
+ }
+
+ return true;
+}
diff --git a/src/backend/libpq/auth-sasl.c b/src/backend/libpq/auth-sasl.c
index 6cfd90fa21..f6c49a4de5 100644
--- a/src/backend/libpq/auth-sasl.c
+++ b/src/backend/libpq/auth-sasl.c
@@ -20,14 +20,6 @@
#include "libpq/pqformat.h"
#include "libpq/sasl.h"
-/*
- * Maximum accepted size of SASL messages.
- *
- * The messages that the server or libpq generate are much smaller than this,
- * but have some headroom.
- */
-#define PG_MAX_SASL_MESSAGE_LENGTH 1024
-
/*
* Perform a SASL exchange with a libpq client, using a specific mechanism
* implementation.
@@ -103,7 +95,7 @@ CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
/* Get the actual SASL message */
initStringInfo(&buf);
- if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
+ if (pq_getmessage(&buf, mech->max_message_length))
{
/* EOF - pq_getmessage already logged error */
pfree(buf.data);
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 9df8f17837..5bb0388c01 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -117,7 +117,9 @@ static int scram_exchange(void *opaq, const char *input, int inputlen,
const pg_be_sasl_mech pg_be_scram_mech = {
scram_get_mechanisms,
scram_init,
- scram_exchange
+ scram_exchange,
+
+ PG_MAX_SASL_MESSAGE_LENGTH
};
/*
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 8cc23ef7fb..fbcc2c55b4 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -29,6 +29,7 @@
#include "libpq/auth.h"
#include "libpq/crypt.h"
#include "libpq/libpq.h"
+#include "libpq/oauth.h"
#include "libpq/pqformat.h"
#include "libpq/sasl.h"
#include "libpq/scram.h"
@@ -47,7 +48,6 @@
*/
static void auth_failed(Port *port, int status, char *logdetail);
static char *recv_password_packet(Port *port);
-static void set_authn_id(Port *port, const char *id);
/*----------------------------------------------------------------
@@ -205,22 +205,6 @@ static int CheckRADIUSAuth(Port *port);
static int PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd);
-/*
- * Maximum accepted size of GSS and SSPI authentication tokens.
- * We also use this as a limit on ordinary password packet lengths.
- *
- * Kerberos tickets are usually quite small, but the TGTs issued by Windows
- * domain controllers include an authorization field known as the Privilege
- * Attribute Certificate (PAC), which contains the user's Windows permissions
- * (group memberships etc.). The PAC is copied into all tickets obtained on
- * the basis of this TGT (even those issued by Unix realms which the Windows
- * realm trusts), and can be several kB in size. The maximum token size
- * accepted by Windows systems is determined by the MaxAuthToken Windows
- * registry setting. Microsoft recommends that it is not set higher than
- * 65535 bytes, so that seems like a reasonable limit for us as well.
- */
-#define PG_MAX_AUTH_TOKEN_LENGTH 65535
-
/*----------------------------------------------------------------
* Global authentication functions
*----------------------------------------------------------------
@@ -309,6 +293,9 @@ auth_failed(Port *port, int status, char *logdetail)
case uaRADIUS:
errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
break;
+ case uaOAuth:
+ errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
+ break;
default:
errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method");
break;
@@ -343,7 +330,7 @@ auth_failed(Port *port, int status, char *logdetail)
* lifetime of the Port, so it is safe to pass a string that is managed by an
* external library.
*/
-static void
+void
set_authn_id(Port *port, const char *id)
{
Assert(id);
@@ -628,6 +615,9 @@ ClientAuthentication(Port *port)
case uaTrust:
status = STATUS_OK;
break;
+ case uaOAuth:
+ status = CheckSASLAuth(&pg_be_oauth_mech, port, NULL, NULL);
+ break;
}
if ((status == STATUS_OK && port->hba->clientcert == clientCertFull)
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 3be8778d21..98147700dd 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -134,7 +134,8 @@ static const char *const UserAuthName[] =
"ldap",
"cert",
"radius",
- "peer"
+ "peer",
+ "oauth",
};
@@ -1399,6 +1400,8 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
#endif
else if (strcmp(token->string, "radius") == 0)
parsedline->auth_method = uaRADIUS;
+ else if (strcmp(token->string, "oauth") == 0)
+ parsedline->auth_method = uaOAuth;
else
{
ereport(elevel,
@@ -1713,8 +1716,9 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->auth_method != uaPeer &&
hbaline->auth_method != uaGSS &&
hbaline->auth_method != uaSSPI &&
+ hbaline->auth_method != uaOAuth &&
hbaline->auth_method != uaCert)
- INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, and cert"));
+ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, oauth, and cert"));
hbaline->usermap = pstrdup(val);
}
else if (strcmp(name, "clientcert") == 0)
@@ -2098,6 +2102,27 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->radiusidentifiers = parsed_identifiers;
hbaline->radiusidentifiers_s = pstrdup(val);
}
+ else if (strcmp(name, "issuer") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("issuer", gettext_noop("oauth"));
+ hbaline->oauth_issuer = pstrdup(val);
+ }
+ else if (strcmp(name, "scope") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("scope", gettext_noop("oauth"));
+ hbaline->oauth_scope = pstrdup(val);
+ }
+ else if (strcmp(name, "trust_validator_authz") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("trust_validator_authz", gettext_noop("oauth"));
+ if (strcmp(val, "1") == 0)
+ hbaline->oauth_skip_usermap = true;
+ else
+ hbaline->oauth_skip_usermap = false;
+ }
else
{
ereport(elevel,
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index 467b0fd6fe..2b42862f71 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -56,6 +56,7 @@
#include "libpq/auth.h"
#include "libpq/libpq.h"
#include "libpq/pqformat.h"
+#include "libpq/oauth.h"
#include "miscadmin.h"
#include "optimizer/cost.h"
#include "optimizer/geqo.h"
@@ -4594,6 +4595,17 @@ static struct config_string ConfigureNamesString[] =
check_backtrace_functions, assign_backtrace_functions, NULL
},
+ {
+ {"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH,
+ gettext_noop("Command to validate OAuth v2 bearer tokens."),
+ NULL,
+ GUC_SUPERUSER_ONLY
+ },
+ &oauth_validator_command,
+ "",
+ NULL, NULL, NULL
+ },
+
/* End-of-list marker */
{
{NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 3d6734f253..1c77dcb0c1 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -16,6 +16,22 @@
#include "libpq/libpq-be.h"
+/*
+ * Maximum accepted size of GSS and SSPI authentication tokens.
+ * We also use this as a limit on ordinary password packet lengths.
+ *
+ * Kerberos tickets are usually quite small, but the TGTs issued by Windows
+ * domain controllers include an authorization field known as the Privilege
+ * Attribute Certificate (PAC), which contains the user's Windows permissions
+ * (group memberships etc.). The PAC is copied into all tickets obtained on
+ * the basis of this TGT (even those issued by Unix realms which the Windows
+ * realm trusts), and can be several kB in size. The maximum token size
+ * accepted by Windows systems is determined by the MaxAuthToken Windows
+ * registry setting. Microsoft recommends that it is not set higher than
+ * 65535 bytes, so that seems like a reasonable limit for us as well.
+ */
+#define PG_MAX_AUTH_TOKEN_LENGTH 65535
+
extern char *pg_krb_server_keyfile;
extern bool pg_krb_caseins_users;
extern char *pg_krb_realm;
@@ -23,6 +39,7 @@ extern char *pg_krb_realm;
extern void ClientAuthentication(Port *port);
extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
int extralen);
+extern void set_authn_id(Port *port, const char *id);
/* Hook for plugins to get control in ClientAuthentication() */
typedef void (*ClientAuthentication_hook_type) (Port *, int);
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 8d9f3821b1..441dd5623e 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -38,8 +38,9 @@ typedef enum UserAuth
uaLDAP,
uaCert,
uaRADIUS,
- uaPeer
-#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */
+ uaPeer,
+ uaOAuth
+#define USER_AUTH_LAST uaOAuth /* Must be last value of this enum */
} UserAuth;
/*
@@ -120,6 +121,9 @@ typedef struct HbaLine
char *radiusidentifiers_s;
List *radiusports;
char *radiusports_s;
+ char *oauth_issuer;
+ char *oauth_scope;
+ bool oauth_skip_usermap;
} HbaLine;
typedef struct IdentLine
diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h
new file mode 100644
index 0000000000..870e426af1
--- /dev/null
+++ b/src/include/libpq/oauth.h
@@ -0,0 +1,24 @@
+/*-------------------------------------------------------------------------
+ *
+ * oauth.h
+ * Interface to libpq/auth-oauth.c
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/libpq/oauth.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef PG_OAUTH_H
+#define PG_OAUTH_H
+
+#include "libpq/libpq-be.h"
+#include "libpq/sasl.h"
+
+extern char *oauth_validator_command;
+
+/* Implementation */
+extern const pg_be_sasl_mech pg_be_oauth_mech;
+
+#endif /* PG_OAUTH_H */
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
index 4c611bab6b..c0a88430d5 100644
--- a/src/include/libpq/sasl.h
+++ b/src/include/libpq/sasl.h
@@ -26,6 +26,14 @@
#define PG_SASL_EXCHANGE_SUCCESS 1
#define PG_SASL_EXCHANGE_FAILURE 2
+/*
+ * Maximum accepted size of SASL messages.
+ *
+ * The messages that the server or libpq generate are much smaller than this,
+ * but have some headroom.
+ */
+#define PG_MAX_SASL_MESSAGE_LENGTH 1024
+
/*
* Backend SASL mechanism callbacks.
*
@@ -127,6 +135,9 @@ typedef struct pg_be_sasl_mech
const char *input, int inputlen,
char **output, int *outputlen,
char **logdetail);
+
+ /* The maximum size allowed for client SASLResponses. */
+ int max_message_length;
} pg_be_sasl_mech;
/* Common implementation for auth.c */
--
2.25.1
v2-0004-Add-a-very-simple-authn_id-extension.patchtext/x-patch; name=v2-0004-Add-a-very-simple-authn_id-extension.patchDownload
From 7c4175f9ad87141d40dd44d6c9fe9312ce8e5b88 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 18 May 2021 15:01:29 -0700
Subject: [PATCH v2 4/5] Add a very simple authn_id extension
...for retrieving the authn_id from the server in tests.
---
contrib/authn_id/Makefile | 19 +++++++++++++++++++
contrib/authn_id/authn_id--1.0.sql | 8 ++++++++
contrib/authn_id/authn_id.c | 28 ++++++++++++++++++++++++++++
contrib/authn_id/authn_id.control | 5 +++++
4 files changed, 60 insertions(+)
create mode 100644 contrib/authn_id/Makefile
create mode 100644 contrib/authn_id/authn_id--1.0.sql
create mode 100644 contrib/authn_id/authn_id.c
create mode 100644 contrib/authn_id/authn_id.control
diff --git a/contrib/authn_id/Makefile b/contrib/authn_id/Makefile
new file mode 100644
index 0000000000..46026358e0
--- /dev/null
+++ b/contrib/authn_id/Makefile
@@ -0,0 +1,19 @@
+# contrib/authn_id/Makefile
+
+MODULE_big = authn_id
+OBJS = authn_id.o
+
+EXTENSION = authn_id
+DATA = authn_id--1.0.sql
+PGFILEDESC = "authn_id - information about the authenticated user"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = contrib/authn_id
+top_builddir = ../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/contrib/authn_id/authn_id--1.0.sql b/contrib/authn_id/authn_id--1.0.sql
new file mode 100644
index 0000000000..af2a4d3991
--- /dev/null
+++ b/contrib/authn_id/authn_id--1.0.sql
@@ -0,0 +1,8 @@
+/* contrib/authn_id/authn_id--1.0.sql */
+
+-- complain if script is sourced in psql, rather than via CREATE EXTENSION
+\echo Use "CREATE EXTENSION authn_id" to load this file. \quit
+
+CREATE FUNCTION authn_id() RETURNS text
+AS 'MODULE_PATHNAME', 'authn_id'
+LANGUAGE C IMMUTABLE;
diff --git a/contrib/authn_id/authn_id.c b/contrib/authn_id/authn_id.c
new file mode 100644
index 0000000000..0fecac36a8
--- /dev/null
+++ b/contrib/authn_id/authn_id.c
@@ -0,0 +1,28 @@
+/*
+ * Extension to expose the current user's authn_id.
+ *
+ * contrib/authn_id/authn_id.c
+ */
+
+#include "postgres.h"
+
+#include "fmgr.h"
+#include "libpq/libpq-be.h"
+#include "miscadmin.h"
+#include "utils/builtins.h"
+
+PG_MODULE_MAGIC;
+
+PG_FUNCTION_INFO_V1(authn_id);
+
+/*
+ * Returns the current user's authenticated identity.
+ */
+Datum
+authn_id(PG_FUNCTION_ARGS)
+{
+ if (!MyProcPort->authn_id)
+ PG_RETURN_NULL();
+
+ PG_RETURN_TEXT_P(cstring_to_text(MyProcPort->authn_id));
+}
diff --git a/contrib/authn_id/authn_id.control b/contrib/authn_id/authn_id.control
new file mode 100644
index 0000000000..e0f9e06bed
--- /dev/null
+++ b/contrib/authn_id/authn_id.control
@@ -0,0 +1,5 @@
+# authn_id extension
+comment = 'current user identity'
+default_version = '1.0'
+module_pathname = '$libdir/authn_id'
+relocatable = true
--
2.25.1
v2-0005-Add-pytest-suite-for-OAuth.patchtext/x-patch; name=v2-0005-Add-pytest-suite-for-OAuth.patchDownload
From 0281635f35a44e0fdfd4369423f98ebe5b467ce3 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Fri, 4 Jun 2021 09:06:38 -0700
Subject: [PATCH v2 5/5] Add pytest suite for OAuth
Requires Python 3; on the first run of `make installcheck` the
dependencies will be installed into ./venv for you. See the README for
more details.
---
src/test/python/.gitignore | 2 +
src/test/python/Makefile | 38 +
src/test/python/README | 54 ++
src/test/python/client/__init__.py | 0
src/test/python/client/conftest.py | 126 +++
src/test/python/client/test_client.py | 180 ++++
src/test/python/client/test_oauth.py | 936 ++++++++++++++++++
src/test/python/pq3.py | 727 ++++++++++++++
src/test/python/pytest.ini | 4 +
src/test/python/requirements.txt | 7 +
src/test/python/server/__init__.py | 0
src/test/python/server/conftest.py | 45 +
src/test/python/server/test_oauth.py | 1012 ++++++++++++++++++++
src/test/python/server/test_server.py | 21 +
src/test/python/server/validate_bearer.py | 101 ++
src/test/python/server/validate_reflect.py | 34 +
src/test/python/test_internals.py | 138 +++
src/test/python/test_pq3.py | 558 +++++++++++
src/test/python/tls.py | 195 ++++
19 files changed, 4178 insertions(+)
create mode 100644 src/test/python/.gitignore
create mode 100644 src/test/python/Makefile
create mode 100644 src/test/python/README
create mode 100644 src/test/python/client/__init__.py
create mode 100644 src/test/python/client/conftest.py
create mode 100644 src/test/python/client/test_client.py
create mode 100644 src/test/python/client/test_oauth.py
create mode 100644 src/test/python/pq3.py
create mode 100644 src/test/python/pytest.ini
create mode 100644 src/test/python/requirements.txt
create mode 100644 src/test/python/server/__init__.py
create mode 100644 src/test/python/server/conftest.py
create mode 100644 src/test/python/server/test_oauth.py
create mode 100644 src/test/python/server/test_server.py
create mode 100755 src/test/python/server/validate_bearer.py
create mode 100755 src/test/python/server/validate_reflect.py
create mode 100644 src/test/python/test_internals.py
create mode 100644 src/test/python/test_pq3.py
create mode 100644 src/test/python/tls.py
diff --git a/src/test/python/.gitignore b/src/test/python/.gitignore
new file mode 100644
index 0000000000..0e8f027b2e
--- /dev/null
+++ b/src/test/python/.gitignore
@@ -0,0 +1,2 @@
+__pycache__/
+/venv/
diff --git a/src/test/python/Makefile b/src/test/python/Makefile
new file mode 100644
index 0000000000..b0695b6287
--- /dev/null
+++ b/src/test/python/Makefile
@@ -0,0 +1,38 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+# Only Python 3 is supported, but if it's named something different on your
+# system you can override it with the PYTHON3 variable.
+PYTHON3 := python3
+
+# All dependencies are placed into this directory. The default is .gitignored
+# for you, but you can override it if you'd like.
+VENV := ./venv
+
+override VBIN := $(VENV)/bin
+override PIP := $(VBIN)/pip
+override PYTEST := $(VBIN)/py.test
+override ISORT := $(VBIN)/isort
+override BLACK := $(VBIN)/black
+
+.PHONY: installcheck indent
+
+installcheck: $(PYTEST)
+ $(PYTEST) -v -rs
+
+indent: $(ISORT) $(BLACK)
+ $(ISORT) --profile black *.py client/*.py server/*.py
+ $(BLACK) *.py client/*.py server/*.py
+
+$(PYTEST) $(ISORT) $(BLACK) &: requirements.txt | $(PIP)
+ $(PIP) install --force-reinstall -r $<
+
+$(PIP):
+ $(PYTHON3) -m venv $(VENV)
+
+# A convenience recipe to rebuild psycopg2 against the local libpq.
+.PHONY: rebuild-psycopg2
+rebuild-psycopg2: | $(PIP)
+ $(PIP) install --force-reinstall --no-binary :all: $(shell grep psycopg2 requirements.txt)
diff --git a/src/test/python/README b/src/test/python/README
new file mode 100644
index 0000000000..0bda582c4b
--- /dev/null
+++ b/src/test/python/README
@@ -0,0 +1,54 @@
+A test suite for exercising both the libpq client and the server backend at the
+protocol level, based on pytest and Construct.
+
+The test suite currently assumes that the standard PG* environment variables
+point to the database under test and are sufficient to log in a superuser on
+that system. In other words, a bare `psql` needs to Just Work before the test
+suite can do its thing. For a newly built dev cluster, typically all that I need
+to do is a
+
+ export PGDATABASE=postgres
+
+but you can adjust as needed for your setup.
+
+## Requirements
+
+A supported version (3.6+) of Python.
+
+The first run of
+
+ make installcheck
+
+will install a local virtual environment and all needed dependencies. During
+development, if libpq changes incompatibly, you can issue
+
+ $ make rebuild-psycopg2
+
+to force a rebuild of the client library.
+
+## Hacking
+
+The code style is enforced by a _very_ opinionated autoformatter. Running the
+
+ make indent
+
+recipe will invoke it for you automatically. Don't fight the tool; part of the
+zen is in knowing that if the formatter makes your code ugly, there's probably a
+cleaner way to write your code.
+
+## Advanced Usage
+
+The Makefile is there for convenience, but you don't have to use it. Activate
+the virtualenv to be able to use pytest directly:
+
+ $ source venv/bin/activate
+ $ py.test -k oauth
+ ...
+ $ py.test ./server/test_server.py
+ ...
+ $ deactivate # puts the PATH et al back the way it was before
+
+To make quick smoke tests possible, slow tests have been marked explicitly. You
+can skip them by saying e.g.
+
+ $ py.test -m 'not slow'
diff --git a/src/test/python/client/__init__.py b/src/test/python/client/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/client/conftest.py b/src/test/python/client/conftest.py
new file mode 100644
index 0000000000..f38da7a138
--- /dev/null
+++ b/src/test/python/client/conftest.py
@@ -0,0 +1,126 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import socket
+import sys
+import threading
+
+import psycopg2
+import pytest
+
+import pq3
+
+BLOCKING_TIMEOUT = 2 # the number of seconds to wait for blocking calls
+
+
+@pytest.fixture
+def server_socket(unused_tcp_port_factory):
+ """
+ Returns a listening socket bound to an ephemeral port.
+ """
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("127.0.0.1", unused_tcp_port_factory()))
+ s.listen(1)
+ s.settimeout(BLOCKING_TIMEOUT)
+ yield s
+
+
+class ClientHandshake(threading.Thread):
+ """
+ A thread that connects to a local Postgres server using psycopg2. Once the
+ opening handshake completes, the connection will be immediately closed.
+ """
+
+ def __init__(self, *, port, **kwargs):
+ super().__init__()
+
+ kwargs["port"] = port
+ self._kwargs = kwargs
+
+ self.exception = None
+
+ def run(self):
+ try:
+ conn = psycopg2.connect(host="127.0.0.1", **self._kwargs)
+ conn.close()
+ except Exception as e:
+ self.exception = e
+
+ def check_completed(self, timeout=BLOCKING_TIMEOUT):
+ """
+ Joins the client thread. Raises an exception if the thread could not be
+ joined, or if it threw an exception itself. (The exception will be
+ cleared, so future calls to check_completed will succeed.)
+ """
+ self.join(timeout)
+
+ if self.is_alive():
+ raise TimeoutError("client thread did not handshake within the timeout")
+ elif self.exception:
+ e = self.exception
+ self.exception = None
+ raise e
+
+
+@pytest.fixture
+def accept(server_socket):
+ """
+ Returns a factory function that, when called, returns a pair (sock, client)
+ where sock is a server socket that has accepted a connection from client,
+ and client is an instance of ClientHandshake. Clients will complete their
+ handshakes and cleanly disconnect.
+
+ The default connstring options may be extended or overridden by passing
+ arbitrary keyword arguments. Keep in mind that you generally should not
+ override the host or port, since they point to the local test server.
+
+ For situations where a client needs to connect more than once to complete a
+ handshake, the accept function may be called more than once. (The client
+ returned for subsequent calls will always be the same client that was
+ returned for the first call.)
+
+ Tests must either complete the handshake so that the client thread can be
+ automatically joined during teardown, or else call client.check_completed()
+ and manually handle any expected errors.
+ """
+ _, port = server_socket.getsockname()
+
+ client = None
+ default_opts = dict(
+ port=port,
+ user=pq3.pguser(),
+ sslmode="disable",
+ )
+
+ def factory(**kwargs):
+ nonlocal client
+
+ if client is None:
+ opts = dict(default_opts)
+ opts.update(kwargs)
+
+ # The server_socket is already listening, so the client thread can
+ # be safely started; it'll block on the connection until we accept.
+ client = ClientHandshake(**opts)
+ client.start()
+
+ sock, _ = server_socket.accept()
+ return sock, client
+
+ yield factory
+ client.check_completed()
+
+
+@pytest.fixture
+def conn(accept):
+ """
+ Returns an accepted, wrapped pq3 connection to a psycopg2 client. The socket
+ will be closed when the test finishes, and the client will be checked for a
+ cleanly completed handshake.
+ """
+ sock, client = accept()
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ yield conn
diff --git a/src/test/python/client/test_client.py b/src/test/python/client/test_client.py
new file mode 100644
index 0000000000..c4c946fda4
--- /dev/null
+++ b/src/test/python/client/test_client.py
@@ -0,0 +1,180 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import sys
+
+import psycopg2
+import pytest
+from cryptography.hazmat.primitives import hashes, hmac
+
+import pq3
+
+
+def finish_handshake(conn):
+ """
+ Sends the AuthenticationOK message and the standard opening salvo of server
+ messages, then asserts that the client immediately sends a Terminate message
+ to close the connection cleanly.
+ """
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+ pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.Terminate
+
+
+def test_handshake(conn):
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ finish_handshake(conn)
+
+
+def test_aborted_connection(accept):
+ """
+ Make sure the client correctly reports an early close during handshakes.
+ """
+ sock, client = accept()
+ sock.close()
+
+ expected = "server closed the connection unexpectedly"
+ with pytest.raises(psycopg2.OperationalError, match=expected):
+ client.check_completed()
+
+
+#
+# SCRAM-SHA-256 (see RFC 5802: https://tools.ietf.org/html/rfc5802)
+#
+
+
+@pytest.fixture
+def password():
+ """
+ Returns a password for use by both client and server.
+ """
+ # TODO: parameterize this with passwords that require SASLprep.
+ return "secret"
+
+
+@pytest.fixture
+def pwconn(accept, password):
+ """
+ Like the conn fixture, but uses a password in the connection.
+ """
+ sock, client = accept(password=password)
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ yield conn
+
+
+def sha256(data):
+ """The H(str) function from Section 2.2."""
+ digest = hashes.Hash(hashes.SHA256())
+ digest.update(data)
+ return digest.finalize()
+
+
+def hmac_256(key, data):
+ """The HMAC(key, str) function from Section 2.2."""
+ h = hmac.HMAC(key, hashes.SHA256())
+ h.update(data)
+ return h.finalize()
+
+
+def xor(a, b):
+ """The XOR operation from Section 2.2."""
+ res = bytearray(a)
+ for i, byte in enumerate(b):
+ res[i] ^= byte
+ return bytes(res)
+
+
+def h_i(data, salt, i):
+ """The Hi(str, salt, i) function from Section 2.2."""
+ assert i > 0
+
+ acc = hmac_256(data, salt + b"\x00\x00\x00\x01")
+ last = acc
+ i -= 1
+
+ while i:
+ u = hmac_256(data, last)
+ acc = xor(acc, u)
+
+ last = u
+ i -= 1
+
+ return acc
+
+
+def test_scram(pwconn, password):
+ startup = pq3.recv1(pwconn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ pwconn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"SCRAM-SHA-256", b""],
+ )
+
+ # Get the client-first-message.
+ pkt = pq3.recv1(pwconn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ initial = pq3.SASLInitialResponse.parse(pkt.payload)
+ assert initial.name == b"SCRAM-SHA-256"
+
+ c_bind, authzid, c_name, c_nonce = initial.data.split(b",")
+ assert c_bind == b"n" # no channel bindings on a plaintext connection
+ assert authzid == b"" # we don't support authzid currently
+ assert c_name == b"n=" # libpq doesn't honor the GS2 username
+ assert c_nonce.startswith(b"r=")
+
+ # Send the server-first-message.
+ salt = b"12345"
+ iterations = 2
+
+ s_nonce = c_nonce + b"somenonce"
+ s_salt = b"s=" + base64.b64encode(salt)
+ s_iterations = b"i=%d" % iterations
+
+ msg = b",".join([s_nonce, s_salt, s_iterations])
+ pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=msg)
+
+ # Get the client-final-message.
+ pkt = pq3.recv1(pwconn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ c_bind_final, c_nonce_final, c_proof = pkt.payload.split(b",")
+ assert c_bind_final == b"c=" + base64.b64encode(c_bind + b"," + authzid + b",")
+ assert c_nonce_final == s_nonce
+
+ # Calculate what the client proof should be.
+ salted_password = h_i(password.encode("ascii"), salt, iterations)
+ client_key = hmac_256(salted_password, b"Client Key")
+ stored_key = sha256(client_key)
+
+ auth_message = b",".join(
+ [c_name, c_nonce, s_nonce, s_salt, s_iterations, c_bind_final, c_nonce_final]
+ )
+ client_signature = hmac_256(stored_key, auth_message)
+ client_proof = xor(client_key, client_signature)
+
+ expected = b"p=" + base64.b64encode(client_proof)
+ assert c_proof == expected
+
+ # Send the correct server signature.
+ server_key = hmac_256(salted_password, b"Server Key")
+ server_signature = hmac_256(server_key, auth_message)
+
+ s_verify = b"v=" + base64.b64encode(server_signature)
+ pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal, body=s_verify)
+
+ # Done!
+ finish_handshake(pwconn)
diff --git a/src/test/python/client/test_oauth.py b/src/test/python/client/test_oauth.py
new file mode 100644
index 0000000000..a754a9c0b6
--- /dev/null
+++ b/src/test/python/client/test_oauth.py
@@ -0,0 +1,936 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import http.server
+import json
+import secrets
+import sys
+import threading
+import time
+import urllib.parse
+
+import psycopg2
+import pytest
+
+import pq3
+
+from .conftest import BLOCKING_TIMEOUT
+
+
+def finish_handshake(conn):
+ """
+ Sends the AuthenticationOK message and the standard opening salvo of server
+ messages, then asserts that the client immediately sends a Terminate message
+ to close the connection cleanly.
+ """
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+ pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.Terminate
+
+
+#
+# OAUTHBEARER (see RFC 7628: https://tools.ietf.org/html/rfc7628)
+#
+
+
+def start_oauth_handshake(conn):
+ """
+ Negotiates an OAUTHBEARER SASL challenge. Returns the client's initial
+ response data.
+ """
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn, pq3.types.AuthnRequest, type=pq3.authn.SASL, body=[b"OAUTHBEARER", b""]
+ )
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ initial = pq3.SASLInitialResponse.parse(pkt.payload)
+ assert initial.name == b"OAUTHBEARER"
+
+ return initial.data
+
+
+def get_auth_value(initial):
+ """
+ Finds the auth value (e.g. "Bearer somedata..." in the client's initial SASL
+ response.
+ """
+ kvpairs = initial.split(b"\x01")
+ assert kvpairs[0] == b"n,," # no channel binding or authzid
+ assert kvpairs[2] == b"" # ends with an empty kvpair
+ assert kvpairs[3] == b"" # ...and there's nothing after it
+ assert len(kvpairs) == 4
+
+ key, value = kvpairs[1].split(b"=", 2)
+ assert key == b"auth"
+
+ return value
+
+
+def xtest_oauth_success(conn): # TODO
+ initial = start_oauth_handshake(conn)
+
+ auth = get_auth_value(initial)
+ assert auth.startswith(b"Bearer ")
+
+ # Accept the token. TODO actually validate
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+class OpenIDProvider(threading.Thread):
+ """
+ A thread that runs a mock OpenID provider server.
+ """
+
+ def __init__(self, *, port):
+ super().__init__()
+
+ self.exception = None
+
+ addr = ("", port)
+ self.server = self._Server(addr, self._Handler)
+
+ # TODO: allow HTTPS only, somehow
+ oauth = self._OAuthState()
+ oauth.host = f"localhost:{port}"
+ oauth.issuer = f"http://localhost:{port}"
+
+ # The following endpoints are required to be advertised by providers,
+ # even though our chosen client implementation does not actually make
+ # use of them.
+ oauth.register_endpoint(
+ "authorization_endpoint", "POST", "/authorize", self._authorization_handler
+ )
+ oauth.register_endpoint("jwks_uri", "GET", "/keys", self._jwks_handler)
+
+ self.server.oauth = oauth
+
+ def run(self):
+ try:
+ self.server.serve_forever()
+ except Exception as e:
+ self.exception = e
+
+ def stop(self, timeout=BLOCKING_TIMEOUT):
+ """
+ Shuts down the server and joins its thread. Raises an exception if the
+ thread could not be joined, or if it threw an exception itself. Must
+ only be called once, after start().
+ """
+ self.server.shutdown()
+ self.join(timeout)
+
+ if self.is_alive():
+ raise TimeoutError("client thread did not handshake within the timeout")
+ elif self.exception:
+ e = self.exception
+ raise e
+
+ class _OAuthState(object):
+ def __init__(self):
+ self.endpoint_paths = {}
+ self._endpoints = {}
+
+ def register_endpoint(self, name, method, path, func):
+ if method not in self._endpoints:
+ self._endpoints[method] = {}
+
+ self._endpoints[method][path] = func
+ self.endpoint_paths[name] = path
+
+ def endpoint(self, method, path):
+ if method not in self._endpoints:
+ return None
+
+ return self._endpoints[method].get(path)
+
+ class _Server(http.server.HTTPServer):
+ def handle_error(self, request, addr):
+ self.shutdown_request(request)
+ raise
+
+ @staticmethod
+ def _jwks_handler(headers, params):
+ return 200, {"keys": []}
+
+ @staticmethod
+ def _authorization_handler(headers, params):
+ # We don't actually want this to be called during these tests -- we
+ # should be using the device authorization endpoint instead.
+ assert (
+ False
+ ), "authorization handler called instead of device authorization handler"
+
+ class _Handler(http.server.BaseHTTPRequestHandler):
+ timeout = BLOCKING_TIMEOUT
+
+ def _discovery_handler(self, headers, params):
+ oauth = self.server.oauth
+
+ doc = {
+ "issuer": oauth.issuer,
+ "response_types_supported": ["token"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+
+ for name, path in oauth.endpoint_paths.items():
+ doc[name] = oauth.issuer + path
+
+ return 200, doc
+
+ def _handle(self, *, params=None, handler=None):
+ oauth = self.server.oauth
+ assert self.headers["Host"] == oauth.host
+
+ if handler is None:
+ handler = oauth.endpoint(self.command, self.path)
+ assert (
+ handler is not None
+ ), f"no registered endpoint for {self.command} {self.path}"
+
+ code, resp = handler(self.headers, params)
+
+ self.send_response(code)
+ self.send_header("Content-Type", "application/json")
+ self.end_headers()
+
+ resp = json.dumps(resp)
+ resp = resp.encode("utf-8")
+ self.wfile.write(resp)
+
+ self.close_connection = True
+
+ def do_GET(self):
+ if self.path == "/.well-known/openid-configuration":
+ self._handle(handler=self._discovery_handler)
+ return
+
+ self._handle()
+
+ def _request_body(self):
+ length = self.headers["Content-Length"]
+
+ # Handle only an explicit content-length.
+ assert length is not None
+ length = int(length)
+
+ return self.rfile.read(length).decode("utf-8")
+
+ def do_POST(self):
+ assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
+
+ body = self._request_body()
+ params = urllib.parse.parse_qs(body)
+
+ self._handle(params=params)
+
+
+@pytest.fixture
+def openid_provider(unused_tcp_port_factory):
+ """
+ A fixture that returns the OAuth state of a running OpenID provider server. The
+ server will be stopped when the fixture is torn down.
+ """
+ thread = OpenIDProvider(port=unused_tcp_port_factory())
+ thread.start()
+
+ try:
+ yield thread.server.oauth
+ finally:
+ thread.stop()
+
+
+@pytest.mark.parametrize("secret", [None, "", "hunter2"])
+@pytest.mark.parametrize("scope", [None, "", "openid email"])
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_with_explicit_issuer(
+ capfd, accept, openid_provider, retries, scope, secret
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ oauth_client_secret=secret,
+ oauth_scope=scope,
+ )
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+ verification_url = "https://example.com/device"
+
+ access_token = secrets.token_urlsafe()
+
+ def check_client_authn(headers, params):
+ if not secret:
+ assert params["client_id"] == [client_id]
+ return
+
+ # Require the client to use Basic authn; request-body credentials are
+ # NOT RECOMMENDED (RFC 6749, Sec. 2.3.1).
+ assert "Authorization" in headers
+
+ method, creds = headers["Authorization"].split()
+ assert method == "Basic"
+
+ expected = f"{client_id}:{secret}"
+ assert base64.b64decode(creds) == expected.encode("ascii")
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ check_client_authn(headers, params)
+
+ if scope:
+ assert params["scope"] == [scope]
+ else:
+ assert "scope" not in params
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": verification_url,
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ attempts = 0
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ check_client_authn(headers, params)
+
+ assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+ assert params["device_code"] == [device_code]
+
+ now = time.monotonic()
+
+ with retry_lock:
+ nonlocal attempts
+
+ # If the test wants to force the client to retry, return an
+ # authorization_pending response and decrement the retry count.
+ if attempts < retries:
+ attempts += 1
+ return 400, {"error": "authorization_pending"}
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+ if retries:
+ # Finally, make sure that the client prompted the user with the expected
+ # authorization URL and user code.
+ expected = f"Visit {verification_url} and enter the code: {user_code}"
+ _, stderr = capfd.readouterr()
+ assert expected in stderr
+
+
+def test_oauth_requires_client_id(accept, openid_provider):
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ # Do not set a client ID; this should cause a client error after the
+ # server asks for OAUTHBEARER and the client tries to contact the
+ # issuer.
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should disconnect at this point.
+ assert not conn.read()
+
+ expected_error = "no oauth_client_id is set"
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("error_code", ["authorization_pending", "slow_down"])
+@pytest.mark.parametrize("retries", [1, 2])
+def test_oauth_retry_interval(accept, openid_provider, retries, error_code):
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id="some-id",
+ )
+
+ expected_retry_interval = 1
+ access_token = secrets.token_urlsafe()
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ resp = {
+ "device_code": "my-device-code",
+ "user_code": "my-user-code",
+ "interval": expected_retry_interval,
+ "verification_uri": "https://example.com",
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ attempts = 0
+ last_retry = None
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ now = time.monotonic()
+
+ with retry_lock:
+ nonlocal attempts, last_retry, expected_retry_interval
+
+ # Make sure the retry interval is being respected by the client.
+ if last_retry is not None:
+ interval = now - last_retry
+ assert interval >= expected_retry_interval
+
+ last_retry = now
+
+ # If the test wants to force the client to retry, return the desired
+ # error response and decrement the retry count.
+ if attempts < retries:
+ attempts += 1
+
+ # A slow_down code requires the client to additionally increase
+ # its interval by five seconds.
+ if error_code == "slow_down":
+ expected_retry_interval += 5
+
+ return 400, {"error": error_code}
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+ "failure_mode, error_pattern",
+ [
+ pytest.param(
+ {
+ "error": "invalid_client",
+ "error_description": "client authentication failed",
+ },
+ r"client authentication failed \(invalid_client\)",
+ id="authentication failure with description",
+ ),
+ pytest.param(
+ {"error": "invalid_request"},
+ r"\(invalid_request\)",
+ id="invalid request without description",
+ ),
+ pytest.param(
+ {},
+ r"failed to obtain device authorization",
+ id="broken error response",
+ ),
+ ],
+)
+def test_oauth_device_authorization_failures(
+ accept, openid_provider, failure_mode, error_pattern
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ )
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ return 400, failure_mode
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ def token_endpoint(headers, params):
+ assert False, "token endpoint was invoked unexpectedly"
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should not continue the connection due to the hardcoded
+ # provider failure; we disconnect here.
+
+ # Now make sure the client correctly failed.
+ with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+ client.check_completed()
+
+
+@pytest.mark.parametrize(
+ "failure_mode, error_pattern",
+ [
+ pytest.param(
+ {
+ "error": "expired_token",
+ "error_description": "the device code has expired",
+ },
+ r"the device code has expired \(expired_token\)",
+ id="expired token with description",
+ ),
+ pytest.param(
+ {"error": "access_denied"},
+ r"\(access_denied\)",
+ id="access denied without description",
+ ),
+ pytest.param(
+ {},
+ r"OAuth token retrieval failed",
+ id="broken error response",
+ ),
+ ],
+)
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_token_failures(
+ accept, openid_provider, retries, failure_mode, error_pattern
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ )
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ assert params["client_id"] == [client_id]
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": "https://example.com/device",
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ with retry_lock:
+ nonlocal retries
+
+ # If the test wants to force the client to retry, return an
+ # authorization_pending response and decrement the retry count.
+ if retries > 0:
+ retries -= 1
+ return 400, {"error": "authorization_pending"}
+
+ return 400, failure_mode
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should not continue the connection due to the hardcoded
+ # provider failure; we disconnect here.
+
+ # Now make sure the client correctly failed.
+ with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+ client.check_completed()
+
+
+@pytest.mark.parametrize("scope", [None, "openid email"])
+@pytest.mark.parametrize(
+ "base_response",
+ [
+ {"status": "invalid_token"},
+ {"extra_object": {"key": "value"}, "status": "invalid_token"},
+ {"extra_object": {"status": 1}, "status": "invalid_token"},
+ ],
+)
+def test_oauth_discovery(accept, openid_provider, base_response, scope):
+ sock, client = accept(oauth_client_id=secrets.token_hex())
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+ verification_url = "https://example.com/device"
+
+ access_token = secrets.token_urlsafe()
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ if scope:
+ assert params["scope"] == [scope]
+ else:
+ assert "scope" not in params
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": verification_url,
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ def token_endpoint(headers, params):
+ assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+ assert params["device_code"] == [device_code]
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # For discovery, the client should send an empty auth header. See
+ # RFC 7628, Sec. 4.3.
+ auth = get_auth_value(initial)
+ assert auth == b""
+
+ # We will fail the first SASL exchange. First return a link to the
+ # discovery document, pointing to the test provider server.
+ resp = dict(base_response)
+
+ discovery_uri = f"{openid_provider.issuer}/.well-known/openid-configuration"
+ resp["openid-configuration"] = discovery_uri
+
+ if scope:
+ resp["scope"] = scope
+
+ resp = json.dumps(resp)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASLContinue,
+ body=resp.encode("ascii"),
+ )
+
+ # Per RFC, the client is required to send a dummy ^A response.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+ assert pkt.payload == b"\x01"
+
+ # Now fail the SASL exchange.
+ pq3.send(
+ conn,
+ pq3.types.ErrorResponse,
+ fields=[
+ b"SFATAL",
+ b"C28000",
+ b"Mdoesn't matter",
+ b"",
+ ],
+ )
+
+ # The client will connect to us a second time, using the parameters we sent
+ # it.
+ sock, _ = accept()
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+ "response,expected_error",
+ [
+ pytest.param(
+ "abcde",
+ 'Token "abcde" is invalid',
+ id="bad JSON: invalid syntax",
+ ),
+ pytest.param(
+ '"abcde"',
+ "top-level element must be an object",
+ id="bad JSON: top-level element is a string",
+ ),
+ pytest.param(
+ "[]",
+ "top-level element must be an object",
+ id="bad JSON: top-level element is an array",
+ ),
+ pytest.param(
+ "{}",
+ "server sent error response without a status",
+ id="bad JSON: no status member",
+ ),
+ pytest.param(
+ '{ "status": null }',
+ 'field "status" must be a string',
+ id="bad JSON: null status member",
+ ),
+ pytest.param(
+ '{ "status": 0 }',
+ 'field "status" must be a string',
+ id="bad JSON: int status member",
+ ),
+ pytest.param(
+ '{ "status": [ "bad" ] }',
+ 'field "status" must be a string',
+ id="bad JSON: array status member",
+ ),
+ pytest.param(
+ '{ "status": { "bad": "bad" } }',
+ 'field "status" must be a string',
+ id="bad JSON: object status member",
+ ),
+ pytest.param(
+ '{ "nested": { "status": "bad" } }',
+ "server sent error response without a status",
+ id="bad JSON: nested status",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token" ',
+ "The input string ended unexpectedly",
+ id="bad JSON: unterminated object",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token" } { }',
+ 'Expected end of input, but found "{"',
+ id="bad JSON: trailing data",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "openid-configuration": 1 }',
+ 'field "openid-configuration" must be a string',
+ id="bad JSON: int openid-configuration member",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "openid-configuration": 1 }',
+ 'field "openid-configuration" must be a string',
+ id="bad JSON: int openid-configuration member",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "scope": 1 }',
+ 'field "scope" must be a string',
+ id="bad JSON: int scope member",
+ ),
+ ],
+)
+def test_oauth_discovery_server_error(accept, response, expected_error):
+ sock, client = accept(oauth_client_id=secrets.token_hex())
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # Fail the SASL exchange with an invalid JSON response.
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASLContinue,
+ body=response.encode("utf-8"),
+ )
+
+ # The client should disconnect, so the socket is closed here. (If
+ # the client doesn't disconnect, it will report a different error
+ # below and the test will fail.)
+
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
+
+
+@pytest.mark.parametrize(
+ "sasl_err,resp_type,resp_payload,expected_error",
+ [
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.ErrorResponse,
+ dict(
+ fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+ ),
+ "expected error message",
+ id="standard server error: invalid_request",
+ ),
+ pytest.param(
+ {"status": "invalid_token"},
+ pq3.types.ErrorResponse,
+ dict(
+ fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+ ),
+ "expected error message",
+ id="standard server error: invalid_token without discovery URI",
+ ),
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.AuthnRequest,
+ dict(type=pq3.authn.SASLContinue, body=b""),
+ "server sent additional OAuth data",
+ id="broken server: additional challenge after error",
+ ),
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.AuthnRequest,
+ dict(type=pq3.authn.SASLFinal),
+ "server sent additional OAuth data",
+ id="broken server: SASL success after error",
+ ),
+ ],
+)
+def test_oauth_server_error(accept, sasl_err, resp_type, resp_payload, expected_error):
+ sock, client = accept()
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ start_oauth_handshake(conn)
+
+ # Ignore the client data. Return an error "challenge".
+ resp = json.dumps(sasl_err)
+ resp = resp.encode("utf-8")
+
+ pq3.send(
+ conn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=resp
+ )
+
+ # Per RFC, the client is required to send a dummy ^A response.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+ assert pkt.payload == b"\x01"
+
+ # Now fail the SASL exchange (in either a valid way, or an invalid
+ # one, depending on the test).
+ pq3.send(conn, resp_type, **resp_payload)
+
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
diff --git a/src/test/python/pq3.py b/src/test/python/pq3.py
new file mode 100644
index 0000000000..3a22dad0b6
--- /dev/null
+++ b/src/test/python/pq3.py
@@ -0,0 +1,727 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import os
+import ssl
+import sys
+import textwrap
+
+from construct import *
+
+import tls
+
+
+def protocol(major, minor):
+ """
+ Returns the protocol version, in integer format, corresponding to the given
+ major and minor version numbers.
+ """
+ return (major << 16) | minor
+
+
+# Startup
+
+StringList = GreedyRange(NullTerminated(GreedyBytes))
+
+
+class KeyValueAdapter(Adapter):
+ """
+ Turns a key-value store into a null-terminated list of null-terminated
+ strings, as presented on the wire in the startup packet.
+ """
+
+ def _encode(self, obj, context, path):
+ if isinstance(obj, list):
+ return obj
+
+ l = []
+
+ for k, v in obj.items():
+ if isinstance(k, str):
+ k = k.encode("utf-8")
+ l.append(k)
+
+ if isinstance(v, str):
+ v = v.encode("utf-8")
+ l.append(v)
+
+ l.append(b"")
+ return l
+
+ def _decode(self, obj, context, path):
+ # TODO: turn a list back into a dict
+ return obj
+
+
+KeyValues = KeyValueAdapter(StringList)
+
+_startup_payload = Switch(
+ this.proto,
+ {
+ protocol(3, 0): KeyValues,
+ },
+ default=GreedyBytes,
+)
+
+
+def _default_protocol(this):
+ try:
+ if isinstance(this.payload, (list, dict)):
+ return protocol(3, 0)
+ except AttributeError:
+ pass # no payload passed during build
+
+ return 0
+
+
+def _startup_payload_len(this):
+ """
+ The payload field has a fixed size based on the length of the packet. But
+ if the caller hasn't supplied an explicit length at build time, we have to
+ build the payload to figure out how long it is, which requires us to know
+ the length first... This function exists solely to break the cycle.
+ """
+ assert this._building, "_startup_payload_len() cannot be called during parsing"
+
+ try:
+ payload = this.payload
+ except AttributeError:
+ return 0 # no payload
+
+ if isinstance(payload, bytes):
+ # already serialized; just use the given length
+ return len(payload)
+
+ try:
+ proto = this.proto
+ except AttributeError:
+ proto = _default_protocol(this)
+
+ data = _startup_payload.build(payload, proto=proto)
+ return len(data)
+
+
+Startup = Struct(
+ "len" / Default(Int32sb, lambda this: _startup_payload_len(this) + 8),
+ "proto" / Default(Hex(Int32sb), _default_protocol),
+ "payload" / FixedSized(this.len - 8, Default(_startup_payload, b"")),
+)
+
+# Pq3
+
+# Adapted from construct.core.EnumIntegerString
+class EnumNamedByte:
+ def __init__(self, val, name):
+ self._val = val
+ self._name = name
+
+ def __int__(self):
+ return ord(self._val)
+
+ def __str__(self):
+ return "(enum) %s %r" % (self._name, self._val)
+
+ def __repr__(self):
+ return "EnumNamedByte(%r)" % self._val
+
+ def __eq__(self, other):
+ if isinstance(other, EnumNamedByte):
+ other = other._val
+ if not isinstance(other, bytes):
+ return NotImplemented
+
+ return self._val == other
+
+ def __hash__(self):
+ return hash(self._val)
+
+
+# Adapted from construct.core.Enum
+class ByteEnum(Adapter):
+ def __init__(self, **mapping):
+ super(ByteEnum, self).__init__(Byte)
+ self.namemapping = {k: EnumNamedByte(v, k) for k, v in mapping.items()}
+ self.decmapping = {v: EnumNamedByte(v, k) for k, v in mapping.items()}
+
+ def __getattr__(self, name):
+ if name in self.namemapping:
+ return self.decmapping[self.namemapping[name]]
+ raise AttributeError
+
+ def _decode(self, obj, context, path):
+ b = bytes([obj])
+ try:
+ return self.decmapping[b]
+ except KeyError:
+ return EnumNamedByte(b, "(unknown)")
+
+ def _encode(self, obj, context, path):
+ if isinstance(obj, int):
+ return obj
+ elif isinstance(obj, bytes):
+ return ord(obj)
+ return int(obj)
+
+
+types = ByteEnum(
+ ErrorResponse=b"E",
+ ReadyForQuery=b"Z",
+ Query=b"Q",
+ EmptyQueryResponse=b"I",
+ AuthnRequest=b"R",
+ PasswordMessage=b"p",
+ BackendKeyData=b"K",
+ CommandComplete=b"C",
+ ParameterStatus=b"S",
+ DataRow=b"D",
+ Terminate=b"X",
+)
+
+
+authn = Enum(
+ Int32ub,
+ OK=0,
+ SASL=10,
+ SASLContinue=11,
+ SASLFinal=12,
+)
+
+
+_authn_body = Switch(
+ this.type,
+ {
+ authn.OK: Terminated,
+ authn.SASL: StringList,
+ },
+ default=GreedyBytes,
+)
+
+
+def _data_len(this):
+ assert this._building, "_data_len() cannot be called during parsing"
+
+ if not hasattr(this, "data") or this.data is None:
+ return -1
+
+ return len(this.data)
+
+
+# The protocol reuses the PasswordMessage for several authentication response
+# types, and there's no good way to figure out which is which without keeping
+# state for the entire stream. So this is a separate Construct that can be
+# explicitly parsed/built by code that knows it's needed.
+SASLInitialResponse = Struct(
+ "name" / NullTerminated(GreedyBytes),
+ "len" / Default(Int32sb, lambda this: _data_len(this)),
+ "data"
+ / IfThenElse(
+ # Allow tests to explicitly pass an incorrect length during testing, by
+ # not enforcing a FixedSized during build. (The len calculation above
+ # defaults to the correct size.)
+ this._building,
+ Optional(GreedyBytes),
+ If(this.len != -1, Default(FixedSized(this.len, GreedyBytes), b"")),
+ ),
+ Terminated, # make sure the entire response is consumed
+)
+
+
+_column = FocusedSeq(
+ "data",
+ "len" / Default(Int32sb, lambda this: _data_len(this)),
+ "data" / If(this.len != -1, FixedSized(this.len, GreedyBytes)),
+)
+
+
+_payload_map = {
+ types.ErrorResponse: Struct("fields" / StringList),
+ types.ReadyForQuery: Struct("status" / Bytes(1)),
+ types.Query: Struct("query" / NullTerminated(GreedyBytes)),
+ types.EmptyQueryResponse: Terminated,
+ types.AuthnRequest: Struct("type" / authn, "body" / Default(_authn_body, b"")),
+ types.BackendKeyData: Struct("pid" / Int32ub, "key" / Hex(Int32ub)),
+ types.CommandComplete: Struct("tag" / NullTerminated(GreedyBytes)),
+ types.ParameterStatus: Struct(
+ "name" / NullTerminated(GreedyBytes), "value" / NullTerminated(GreedyBytes)
+ ),
+ types.DataRow: Struct("columns" / Default(PrefixedArray(Int16sb, _column), b"")),
+ types.Terminate: Terminated,
+}
+
+
+_payload = FocusedSeq(
+ "_payload",
+ "_payload"
+ / Switch(
+ this._.type,
+ _payload_map,
+ default=GreedyBytes,
+ ),
+ Terminated, # make sure every payload consumes the entire packet
+)
+
+
+def _payload_len(this):
+ """
+ See _startup_payload_len() for an explanation.
+ """
+ assert this._building, "_payload_len() cannot be called during parsing"
+
+ try:
+ payload = this.payload
+ except AttributeError:
+ return 0 # no payload
+
+ if isinstance(payload, bytes):
+ # already serialized; just use the given length
+ return len(payload)
+
+ data = _payload.build(payload, type=this.type)
+ return len(data)
+
+
+Pq3 = Struct(
+ "type" / types,
+ "len" / Default(Int32ub, lambda this: _payload_len(this) + 4),
+ "payload" / FixedSized(this.len - 4, Default(_payload, b"")),
+)
+
+
+# Environment
+
+
+def pghost():
+ return os.environ.get("PGHOST", default="localhost")
+
+
+def pgport():
+ return int(os.environ.get("PGPORT", default=5432))
+
+
+def pguser():
+ try:
+ return os.environ["PGUSER"]
+ except KeyError:
+ return getpass.getuser()
+
+
+def pgdatabase():
+ return os.environ.get("PGDATABASE", default="postgres")
+
+
+# Connections
+
+
+def _hexdump_translation_map():
+ """
+ For hexdumps. Translates any unprintable or non-ASCII bytes into '.'.
+ """
+ input = bytearray()
+
+ for i in range(128):
+ c = chr(i)
+
+ if not c.isprintable():
+ input += bytes([i])
+
+ input += bytes(range(128, 256))
+
+ return bytes.maketrans(input, b"." * len(input))
+
+
+class _DebugStream(object):
+ """
+ Wraps a file-like object and adds hexdumps of the read and write data. Call
+ end_packet() on a _DebugStream to write the accumulated hexdumps to the
+ output stream, along with the packet that was sent.
+ """
+
+ _translation_map = _hexdump_translation_map()
+
+ def __init__(self, stream, out=sys.stdout):
+ """
+ Creates a new _DebugStream wrapping the given stream (which must have
+ been created by wrap()). All attributes not provided by the _DebugStream
+ are delegated to the wrapped stream. out is the text stream to which
+ hexdumps are written.
+ """
+ self.raw = stream
+ self._out = out
+ self._rbuf = io.BytesIO()
+ self._wbuf = io.BytesIO()
+
+ def __getattr__(self, name):
+ return getattr(self.raw, name)
+
+ def __setattr__(self, name, value):
+ if name in ("raw", "_out", "_rbuf", "_wbuf"):
+ return object.__setattr__(self, name, value)
+
+ setattr(self.raw, name, value)
+
+ def read(self, *args, **kwargs):
+ buf = self.raw.read(*args, **kwargs)
+
+ self._rbuf.write(buf)
+ return buf
+
+ def write(self, b):
+ self._wbuf.write(b)
+ return self.raw.write(b)
+
+ def recv(self, *args):
+ buf = self.raw.recv(*args)
+
+ self._rbuf.write(buf)
+ return buf
+
+ def _flush(self, buf, prefix):
+ width = 16
+ hexwidth = width * 3 - 1
+
+ count = 0
+ buf.seek(0)
+
+ while True:
+ line = buf.read(16)
+
+ if not line:
+ if count:
+ self._out.write("\n") # separate the output block with a newline
+ return
+
+ self._out.write("%s %04X:\t" % (prefix, count))
+ self._out.write("%*s\t" % (-hexwidth, line.hex(" ")))
+ self._out.write(line.translate(self._translation_map).decode("ascii"))
+ self._out.write("\n")
+
+ count += 16
+
+ def print_debug(self, obj, *, prefix=""):
+ contents = ""
+ if obj is not None:
+ contents = str(obj)
+
+ for line in contents.splitlines():
+ self._out.write("%s%s\n" % (prefix, line))
+
+ self._out.write("\n")
+
+ def flush_debug(self, *, prefix=""):
+ self._flush(self._rbuf, prefix + "<")
+ self._rbuf = io.BytesIO()
+
+ self._flush(self._wbuf, prefix + ">")
+ self._wbuf = io.BytesIO()
+
+ def end_packet(self, pkt, *, read=False, prefix="", indent=" "):
+ """
+ Marks the end of a logical "packet" of data. A string representation of
+ pkt will be printed, and the debug buffers will be flushed with an
+ indent. All lines can be optionally prefixed.
+
+ If read is True, the packet representation is written after the debug
+ buffers; otherwise the default of False (meaning write) causes the
+ packet representation to be dumped first. This is meant to capture the
+ logical flow of layer translation.
+ """
+ write = not read
+
+ if write:
+ self.print_debug(pkt, prefix=prefix + "> ")
+
+ self.flush_debug(prefix=prefix + indent)
+
+ if read:
+ self.print_debug(pkt, prefix=prefix + "< ")
+
+
+@contextlib.contextmanager
+def wrap(socket, *, debug_stream=None):
+ """
+ Transforms a raw socket into a connection that can be used for Construct
+ building and parsing. The return value is a context manager and can be used
+ in a with statement.
+ """
+ # It is critical that buffering be disabled here, so that we can still
+ # manipulate the raw socket without desyncing the stream.
+ with socket.makefile("rwb", buffering=0) as sfile:
+ # Expose the original socket's recv() on the SocketIO object we return.
+ def recv(self, *args):
+ return socket.recv(*args)
+
+ sfile.recv = recv.__get__(sfile)
+
+ conn = sfile
+ if debug_stream:
+ conn = _DebugStream(conn, debug_stream)
+
+ try:
+ yield conn
+ finally:
+ if debug_stream:
+ conn.flush_debug(prefix="? ")
+
+
+def _send(stream, cls, obj):
+ debugging = hasattr(stream, "flush_debug")
+ out = io.BytesIO()
+
+ # Ideally we would build directly to the passed stream, but because we need
+ # to reparse the generated output for the debugging case, build to an
+ # intermediate BytesIO and send it instead.
+ cls.build_stream(obj, out)
+ buf = out.getvalue()
+
+ stream.write(buf)
+ if debugging:
+ pkt = cls.parse(buf)
+ stream.end_packet(pkt)
+
+ stream.flush()
+
+
+def send(stream, packet_type, payload_data=None, **payloadkw):
+ """
+ Sends a packet on the given pq3 connection. type is the pq3.types member
+ that should be assigned to the packet. If payload_data is given, it will be
+ used as the packet payload; otherwise the key/value pairs in payloadkw will
+ be the payload contents.
+ """
+ data = payloadkw
+
+ if payload_data is not None:
+ if payloadkw:
+ raise ValueError(
+ "payload_data and payload keywords may not be used simultaneously"
+ )
+
+ data = payload_data
+
+ _send(stream, Pq3, dict(type=packet_type, payload=data))
+
+
+def send_startup(stream, proto=None, **kwargs):
+ """
+ Sends a startup packet on the given pq3 connection. In most cases you should
+ use the handshake functions instead, which will do this for you.
+
+ By default, a protocol version 3 packet will be sent. This can be overridden
+ with the proto parameter.
+ """
+ pkt = {}
+
+ if proto is not None:
+ pkt["proto"] = proto
+ if kwargs:
+ pkt["payload"] = kwargs
+
+ _send(stream, Startup, pkt)
+
+
+def recv1(stream, *, cls=Pq3):
+ """
+ Receives a single pq3 packet from the given stream and returns it.
+ """
+ resp = cls.parse_stream(stream)
+
+ debugging = hasattr(stream, "flush_debug")
+ if debugging:
+ stream.end_packet(resp, read=True)
+
+ return resp
+
+
+def handshake(stream, **kwargs):
+ """
+ Performs a libpq v3 startup handshake. kwargs should contain the key/value
+ parameters to send to the server in the startup packet.
+ """
+ # Send our startup parameters.
+ send_startup(stream, **kwargs)
+
+ # Receive and dump packets until the server indicates it's ready for our
+ # first query.
+ while True:
+ resp = recv1(stream)
+ if resp is None:
+ raise RuntimeError("server closed connection during handshake")
+
+ if resp.type == types.ReadyForQuery:
+ return
+ elif resp.type == types.ErrorResponse:
+ raise RuntimeError(
+ f"received error response from peer: {resp.payload.fields!r}"
+ )
+
+
+# TLS
+
+
+class _TLSStream(object):
+ """
+ A file-like object that performs TLS encryption/decryption on a wrapped
+ stream. Differs from ssl.SSLSocket in that we have full visibility and
+ control over the TLS layer.
+ """
+
+ def __init__(self, stream, context):
+ self._stream = stream
+ self._debugging = hasattr(stream, "flush_debug")
+
+ self._in = ssl.MemoryBIO()
+ self._out = ssl.MemoryBIO()
+ self._ssl = context.wrap_bio(self._in, self._out)
+
+ def handshake(self):
+ try:
+ self._pump(lambda: self._ssl.do_handshake())
+ finally:
+ self._flush_debug(prefix="? ")
+
+ def read(self, *args):
+ return self._pump(lambda: self._ssl.read(*args))
+
+ def write(self, *args):
+ return self._pump(lambda: self._ssl.write(*args))
+
+ def _decode(self, buf):
+ """
+ Attempts to decode a buffer of TLS data into a packet representation
+ that can be printed.
+
+ TODO: handle buffers (and record fragments) that don't align with packet
+ boundaries.
+ """
+ end = len(buf)
+ bio = io.BytesIO(buf)
+
+ ret = io.StringIO()
+
+ while bio.tell() < end:
+ record = tls.Plaintext.parse_stream(bio)
+
+ if ret.tell() > 0:
+ ret.write("\n")
+ ret.write("[Record] ")
+ ret.write(str(record))
+ ret.write("\n")
+
+ if record.type == tls.ContentType.handshake:
+ record_cls = tls.Handshake
+ else:
+ continue
+
+ innerlen = len(record.fragment)
+ inner = io.BytesIO(record.fragment)
+
+ while inner.tell() < innerlen:
+ msg = record_cls.parse_stream(inner)
+
+ indented = "[Message] " + str(msg)
+ indented = textwrap.indent(indented, " ")
+
+ ret.write("\n")
+ ret.write(indented)
+ ret.write("\n")
+
+ return ret.getvalue()
+
+ def flush(self):
+ if not self._out.pending:
+ self._stream.flush()
+ return
+
+ buf = self._out.read()
+ self._stream.write(buf)
+
+ if self._debugging:
+ pkt = self._decode(buf)
+ self._stream.end_packet(pkt, prefix=" ")
+
+ self._stream.flush()
+
+ def _pump(self, operation):
+ while True:
+ try:
+ return operation()
+ except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
+ want = e
+ self._read_write(want)
+
+ def _recv(self, maxsize):
+ buf = self._stream.recv(4096)
+ if not buf:
+ self._in.write_eof()
+ return
+
+ self._in.write(buf)
+
+ if not self._debugging:
+ return
+
+ pkt = self._decode(buf)
+ self._stream.end_packet(pkt, read=True, prefix=" ")
+
+ def _read_write(self, want):
+ # XXX This needs work. So many corner cases yet to handle. For one,
+ # doing blocking writes in flush may lead to distributed deadlock if the
+ # peer is already blocking on its writes.
+
+ if isinstance(want, ssl.SSLWantWriteError):
+ assert self._out.pending, "SSL backend wants write without data"
+
+ self.flush()
+
+ if isinstance(want, ssl.SSLWantReadError):
+ self._recv(4096)
+
+ def _flush_debug(self, prefix):
+ if not self._debugging:
+ return
+
+ self._stream.flush_debug(prefix=prefix)
+
+
+@contextlib.contextmanager
+def tls_handshake(stream, context):
+ """
+ Performs a TLS handshake over the given stream (which must have been created
+ via a call to wrap()), and returns a new stream which transparently tunnels
+ data over the TLS connection.
+
+ If the passed stream has debugging enabled, the returned stream will also
+ have debugging, using the same output IO.
+ """
+ debugging = hasattr(stream, "flush_debug")
+
+ # Send our startup parameters.
+ send_startup(stream, proto=protocol(1234, 5679))
+
+ # Look at the SSL response.
+ resp = stream.read(1)
+ if debugging:
+ stream.flush_debug(prefix=" ")
+
+ if resp == b"N":
+ raise RuntimeError("server does not support SSLRequest")
+ if resp != b"S":
+ raise RuntimeError(f"unexpected response of type {resp!r} during TLS startup")
+
+ tls = _TLSStream(stream, context)
+ tls.handshake()
+
+ if debugging:
+ tls = _DebugStream(tls, stream._out)
+
+ try:
+ yield tls
+ # TODO: teardown/unwrap the connection?
+ finally:
+ if debugging:
+ tls.flush_debug(prefix="? ")
diff --git a/src/test/python/pytest.ini b/src/test/python/pytest.ini
new file mode 100644
index 0000000000..ab7a6e7fb9
--- /dev/null
+++ b/src/test/python/pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+
+markers =
+ slow: mark test as slow
diff --git a/src/test/python/requirements.txt b/src/test/python/requirements.txt
new file mode 100644
index 0000000000..32f105ea84
--- /dev/null
+++ b/src/test/python/requirements.txt
@@ -0,0 +1,7 @@
+black
+cryptography~=3.4.6
+construct~=2.10.61
+isort~=5.6
+psycopg2~=2.8.6
+pytest~=6.1
+pytest-asyncio~=0.14.0
diff --git a/src/test/python/server/__init__.py b/src/test/python/server/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/server/conftest.py b/src/test/python/server/conftest.py
new file mode 100644
index 0000000000..ba7342a453
--- /dev/null
+++ b/src/test/python/server/conftest.py
@@ -0,0 +1,45 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import socket
+import sys
+
+import pytest
+
+import pq3
+
+
+@pytest.fixture
+def connect():
+ """
+ A factory fixture that, when called, returns a socket connected to a
+ Postgres server, wrapped in a pq3 connection. The calling test will be
+ skipped automatically if a server is not running at PGHOST:PGPORT, so it's
+ best to connect as soon as possible after the test case begins, to avoid
+ doing unnecessary work.
+ """
+ # Set up an ExitStack to handle safe cleanup of all of the moving pieces.
+ with contextlib.ExitStack() as stack:
+
+ def conn_factory():
+ addr = (pq3.pghost(), pq3.pgport())
+
+ try:
+ sock = socket.create_connection(addr, timeout=2)
+ except ConnectionError as e:
+ pytest.skip(f"unable to connect to {addr}: {e}")
+
+ # Have ExitStack close our socket.
+ stack.enter_context(sock)
+
+ # Wrap the connection in a pq3 layer and have ExitStack clean it up
+ # too.
+ wrap_ctx = pq3.wrap(sock, debug_stream=sys.stdout)
+ conn = stack.enter_context(wrap_ctx)
+
+ return conn
+
+ yield conn_factory
diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py
new file mode 100644
index 0000000000..cb5ca7fa23
--- /dev/null
+++ b/src/test/python/server/test_oauth.py
@@ -0,0 +1,1012 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import contextlib
+import json
+import os
+import pathlib
+import secrets
+import shlex
+import shutil
+import socket
+import struct
+from multiprocessing import shared_memory
+
+import psycopg2
+import pytest
+from psycopg2 import sql
+
+import pq3
+
+MAX_SASL_MESSAGE_LENGTH = 65535
+
+INVALID_AUTHORIZATION_ERRCODE = b"28000"
+PROTOCOL_VIOLATION_ERRCODE = b"08P01"
+FEATURE_NOT_SUPPORTED_ERRCODE = b"0A000"
+
+SHARED_MEM_NAME = "oauth-pytest"
+MAX_TOKEN_SIZE = 4096
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def skip_if_no_postgres():
+ """
+ Used by the oauth_ctx fixture to skip this test module if no Postgres server
+ is running.
+
+ This logic is nearly duplicated with the conn fixture. Ideally oauth_ctx
+ would depend on that, but a module-scope fixture can't depend on a
+ test-scope fixture, and we haven't reached the rule of three yet.
+ """
+ addr = (pq3.pghost(), pq3.pgport())
+
+ try:
+ with socket.create_connection(addr, timeout=2):
+ pass
+ except ConnectionError as e:
+ pytest.skip(f"unable to connect to {addr}: {e}")
+
+
+@contextlib.contextmanager
+def prepend_file(path, lines):
+ """
+ A context manager that prepends a file on disk with the desired lines of
+ text. When the context manager is exited, the file will be restored to its
+ original contents.
+ """
+ # First make a backup of the original file.
+ bak = path + ".bak"
+ shutil.copy2(path, bak)
+
+ try:
+ # Write the new lines, followed by the original file content.
+ with open(path, "w") as new, open(bak, "r") as orig:
+ new.writelines(lines)
+ shutil.copyfileobj(orig, new)
+
+ # Return control to the calling code.
+ yield
+
+ finally:
+ # Put the backup back into place.
+ os.replace(bak, path)
+
+
+@pytest.fixture(scope="module")
+def oauth_ctx():
+ """
+ Creates a database and user that use the oauth auth method. The context
+ object contains the dbname and user attributes as strings to be used during
+ connection, as well as the issuer and scope that have been set in the HBA
+ configuration.
+
+ This fixture assumes that the standard PG* environment variables point to a
+ server running on a local machine, and that the PGUSER has rights to create
+ databases and roles.
+ """
+ skip_if_no_postgres() # don't bother running these tests without a server
+
+ id = secrets.token_hex(4)
+
+ class Context:
+ dbname = "oauth_test_" + id
+
+ user = "oauth_user_" + id
+ map_user = "oauth_map_user_" + id
+ authz_user = "oauth_authz_user_" + id
+
+ issuer = "https://example.com/" + id
+ scope = "openid " + id
+
+ ctx = Context()
+ hba_lines = (
+ f'host {ctx.dbname} {ctx.map_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
+ f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
+ f'host {ctx.dbname} all samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
+ )
+ ident_lines = (r"oauth /^(.*)@example\.com$ \1",)
+
+ conn = psycopg2.connect("")
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+
+ # Create our roles and database.
+ user = sql.Identifier(ctx.user)
+ map_user = sql.Identifier(ctx.map_user)
+ authz_user = sql.Identifier(ctx.authz_user)
+ dbname = sql.Identifier(ctx.dbname)
+
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(user))
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(map_user))
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user))
+ c.execute(sql.SQL("CREATE DATABASE {};").format(dbname))
+
+ # Make this test script the server's oauth_validator.
+ path = pathlib.Path(__file__).parent / "validate_bearer.py"
+ path = str(path.absolute())
+
+ cmd = f"{shlex.quote(path)} {SHARED_MEM_NAME} <&%f"
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+
+ # Replace pg_hba and pg_ident.
+ c.execute("SHOW hba_file;")
+ hba = c.fetchone()[0]
+
+ c.execute("SHOW ident_file;")
+ ident = c.fetchone()[0]
+
+ with prepend_file(hba, hba_lines), prepend_file(ident, ident_lines):
+ c.execute("SELECT pg_reload_conf();")
+
+ # Use the new database and user.
+ yield ctx
+
+ # Put things back the way they were.
+ c.execute("SELECT pg_reload_conf();")
+
+ c.execute("ALTER SYSTEM RESET oauth_validator_command;")
+ c.execute(sql.SQL("DROP DATABASE {};").format(dbname))
+ c.execute(sql.SQL("DROP ROLE {};").format(authz_user))
+ c.execute(sql.SQL("DROP ROLE {};").format(map_user))
+ c.execute(sql.SQL("DROP ROLE {};").format(user))
+
+
+@pytest.fixture()
+def conn(oauth_ctx, connect):
+ """
+ A convenience wrapper for connect(). The main purpose of this fixture is to
+ make sure oauth_ctx runs its setup code before the connection is made.
+ """
+ return connect()
+
+
+@pytest.fixture(scope="module", autouse=True)
+def authn_id_extension(oauth_ctx):
+ """
+ Performs a `CREATE EXTENSION authn_id` in the test database. This fixture is
+ autoused, so tests don't need to rely on it.
+ """
+ conn = psycopg2.connect(database=oauth_ctx.dbname)
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+ c.execute("CREATE EXTENSION authn_id;")
+
+
+@pytest.fixture(scope="session")
+def shared_mem():
+ """
+ Yields a shared memory segment that can be used for communication between
+ the bearer_token fixture and ./validate_bearer.py.
+ """
+ size = MAX_TOKEN_SIZE + 2 # two byte length prefix
+ mem = shared_memory.SharedMemory(SHARED_MEM_NAME, create=True, size=size)
+
+ try:
+ with contextlib.closing(mem):
+ yield mem
+ finally:
+ mem.unlink()
+
+
+@pytest.fixture()
+def bearer_token(shared_mem):
+ """
+ Returns a factory function that, when called, will store a Bearer token in
+ shared_mem. If token is None (the default), a new token will be generated
+ using secrets.token_urlsafe() and returned; otherwise the passed token will
+ be used as-is.
+
+ When token is None, the generated token size in bytes may be specified as an
+ argument; if unset, a small 16-byte token will be generated. The token size
+ may not exceed MAX_TOKEN_SIZE in any case.
+
+ The return value is the token, converted to a bytes object.
+
+ As a special case for testing failure modes, accept_any may be set to True.
+ This signals to the validator command that any bearer token should be
+ accepted. The returned token in this case may be used or discarded as needed
+ by the test.
+ """
+
+ def set_token(token=None, *, size=16, accept_any=False):
+ if token is not None:
+ size = len(token)
+
+ if size > MAX_TOKEN_SIZE:
+ raise ValueError(f"token size {size} exceeds maximum size {MAX_TOKEN_SIZE}")
+
+ if token is None:
+ if size % 4:
+ raise ValueError(f"requested token size {size} is not a multiple of 4")
+
+ token = secrets.token_urlsafe(size // 4 * 3)
+ assert len(token) == size
+
+ try:
+ token = token.encode("ascii")
+ except AttributeError:
+ pass # already encoded
+
+ if accept_any:
+ # Two-byte magic value.
+ shared_mem.buf[:2] = struct.pack("H", MAX_UINT16)
+ else:
+ # Two-byte length prefix, then the token data.
+ shared_mem.buf[:2] = struct.pack("H", len(token))
+ shared_mem.buf[2 : size + 2] = token
+
+ return token
+
+ return set_token
+
+
+def begin_oauth_handshake(conn, oauth_ctx, *, user=None):
+ if user is None:
+ user = oauth_ctx.authz_user
+
+ pq3.send_startup(conn, user=user, database=oauth_ctx.dbname)
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.AuthnRequest
+
+ # The server should advertise exactly one mechanism.
+ assert resp.payload.type == pq3.authn.SASL
+ assert resp.payload.body == [b"OAUTHBEARER", b""]
+
+
+def send_initial_response(conn, *, auth=None, bearer=None):
+ """
+ Sends the OAUTHBEARER initial response on the connection, using the given
+ bearer token. Alternatively to a bearer token, the initial response's auth
+ field may be explicitly specified to test corner cases.
+ """
+ if bearer is not None and auth is not None:
+ raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+ if bearer is not None:
+ auth = b"Bearer " + bearer
+
+ if auth is None:
+ raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+ initial = pq3.SASLInitialResponse.build(
+ dict(
+ name=b"OAUTHBEARER",
+ data=b"n,,\x01auth=" + auth + b"\x01\x01",
+ )
+ )
+ pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+
+def expect_handshake_success(conn):
+ """
+ Validates that the server responds with an AuthnOK message, and then drains
+ the connection until a ReadyForQuery message is received.
+ """
+ resp = pq3.recv1(conn)
+
+ assert resp.type == pq3.types.AuthnRequest
+ assert resp.payload.type == pq3.authn.OK
+ assert not resp.payload.body
+
+ receive_until(conn, pq3.types.ReadyForQuery)
+
+
+def expect_handshake_failure(conn, oauth_ctx):
+ """
+ Performs the OAUTHBEARER SASL failure "handshake" and validates the server's
+ side of the conversation, including the final ErrorResponse.
+ """
+
+ # We expect a discovery "challenge" back from the server before the authn
+ # failure message.
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.AuthnRequest
+
+ req = resp.payload
+ assert req.type == pq3.authn.SASLContinue
+
+ body = json.loads(req.body)
+ assert body["status"] == "invalid_token"
+ assert body["scope"] == oauth_ctx.scope
+
+ expected_config = oauth_ctx.issuer + "/.well-known/openid-configuration"
+ assert body["openid-configuration"] == expected_config
+
+ # Send the dummy response to complete the failed handshake.
+ pq3.send(conn, pq3.types.PasswordMessage, b"\x01")
+ resp = pq3.recv1(conn)
+
+ err = ExpectedError(INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed")
+ err.match(resp)
+
+
+def receive_until(conn, type):
+ """
+ receive_until pulls packets off the pq3 connection until a packet with the
+ desired type is found, or an error response is received.
+ """
+ while True:
+ pkt = pq3.recv1(conn)
+
+ if pkt.type == type:
+ return pkt
+ elif pkt.type == pq3.types.ErrorResponse:
+ raise RuntimeError(
+ f"received error response from peer: {pkt.payload.fields!r}"
+ )
+
+
+@pytest.mark.parametrize("token_len", [16, 1024, 4096])
+@pytest.mark.parametrize(
+ "auth_prefix",
+ [
+ b"Bearer ",
+ b"bearer ",
+ b"Bearer ",
+ ],
+)
+def test_oauth(conn, oauth_ctx, bearer_token, auth_prefix, token_len):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Generate our bearer token with the desired length.
+ token = bearer_token(size=token_len)
+ auth = auth_prefix + token
+
+ send_initial_response(conn, auth=auth)
+ expect_handshake_success(conn)
+
+ # Make sure that the server has not set an authenticated ID.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ assert row.columns == [None]
+
+
+@pytest.mark.parametrize(
+ "token_value",
+ [
+ "abcdzA==",
+ "123456M=",
+ "x-._~+/x",
+ ],
+)
+def test_oauth_bearer_corner_cases(conn, oauth_ctx, bearer_token, token_value):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ send_initial_response(conn, bearer=bearer_token(token_value))
+
+ expect_handshake_success(conn)
+
+
+@pytest.mark.parametrize(
+ "user,authn_id,should_succeed",
+ [
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: ctx.user,
+ True,
+ id="validator authn: succeeds when authn_id == username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: None,
+ False,
+ id="validator authn: fails when authn_id is not set",
+ ),
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: ctx.authz_user,
+ False,
+ id="validator authn: fails when authn_id != username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: ctx.map_user + "@example.com",
+ True,
+ id="validator with map: succeeds when authn_id matches map",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: None,
+ False,
+ id="validator with map: fails when authn_id is not set",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: ctx.map_user + "@example.net",
+ False,
+ id="validator with map: fails when authn_id doesn't match map",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: None,
+ True,
+ id="validator authz: succeeds with no authn_id",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "",
+ True,
+ id="validator authz: succeeds with empty authn_id",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "postgres",
+ True,
+ id="validator authz: succeeds with basic username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "me@example.com",
+ True,
+ id="validator authz: succeeds with email address",
+ ),
+ ],
+)
+def test_oauth_authn_id(conn, oauth_ctx, bearer_token, user, authn_id, should_succeed):
+ token = None
+
+ authn_id = authn_id(oauth_ctx)
+ if authn_id is not None:
+ authn_id = authn_id.encode("ascii")
+
+ # As a hack to get the validator to reflect arbitrary output from this
+ # test, encode the desired output as a base64 token. The validator will
+ # key on the leading "output=" to differentiate this from the random
+ # tokens generated by secrets.token_urlsafe().
+ output = b"output=" + authn_id + b"\n"
+ token = base64.urlsafe_b64encode(output)
+
+ token = bearer_token(token)
+ username = user(oauth_ctx)
+
+ begin_oauth_handshake(conn, oauth_ctx, user=username)
+ send_initial_response(conn, bearer=token)
+
+ if not should_succeed:
+ expect_handshake_failure(conn, oauth_ctx)
+ return
+
+ expect_handshake_success(conn)
+
+ # Check the reported authn_id.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ assert row.columns == [authn_id]
+
+
+class ExpectedError(object):
+ def __init__(self, code, msg=None, detail=None):
+ self.code = code
+ self.msg = msg
+ self.detail = detail
+
+ # Protect against the footgun of an accidental empty string, which will
+ # "match" anything. If you don't want to match message or detail, just
+ # don't pass them.
+ if self.msg == "":
+ raise ValueError("msg must be non-empty or None")
+ if self.detail == "":
+ raise ValueError("detail must be non-empty or None")
+
+ def _getfield(self, resp, type):
+ """
+ Searches an ErrorResponse for a single field of the given type (e.g.
+ "M", "C", "D") and returns its value. Asserts if it doesn't find exactly
+ one field.
+ """
+ prefix = type.encode("ascii")
+ fields = [f for f in resp.payload.fields if f.startswith(prefix)]
+
+ assert len(fields) == 1
+ return fields[0][1:] # strip off the type byte
+
+ def match(self, resp):
+ """
+ Checks that the given response matches the expected code, message, and
+ detail (if given). The error code must match exactly. The expected
+ message and detail must be contained within the actual strings.
+ """
+ assert resp.type == pq3.types.ErrorResponse
+
+ code = self._getfield(resp, "C")
+ assert code == self.code
+
+ if self.msg:
+ msg = self._getfield(resp, "M")
+ expected = self.msg.encode("utf-8")
+ assert expected in msg
+
+ if self.detail:
+ detail = self._getfield(resp, "D")
+ expected = self.detail.encode("utf-8")
+ assert expected in detail
+
+
+def test_oauth_rejected_bearer(conn, oauth_ctx, bearer_token):
+ # Generate a new bearer token, which we will proceed not to use.
+ _ = bearer_token()
+
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send a bearer token that doesn't match what the validator expects. It
+ # should fail the connection.
+ send_initial_response(conn, bearer=b"xxxxxx")
+
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.parametrize(
+ "bad_bearer",
+ [
+ b"Bearer ",
+ b"Bearer a===b",
+ b"Bearer hello!",
+ b"Bearer me@example.com",
+ b'OAuth realm="Example"',
+ b"",
+ ],
+)
+def test_oauth_invalid_bearer(conn, oauth_ctx, bearer_token, bad_bearer):
+ # Tell the validator to accept any token. This ensures that the invalid
+ # bearer tokens are rejected before the validation step.
+ _ = bearer_token(accept_any=True)
+
+ begin_oauth_handshake(conn, oauth_ctx)
+ send_initial_response(conn, auth=bad_bearer)
+
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "resp_type,resp,err",
+ [
+ pytest.param(
+ None,
+ None,
+ None,
+ marks=pytest.mark.slow,
+ id="no response (expect timeout)",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"hello",
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not send a kvsep response",
+ ),
+ id="bad dummy response",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"\x01\x01",
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not send a kvsep response",
+ ),
+ id="multiple kvseps",
+ ),
+ pytest.param(
+ pq3.types.Query,
+ dict(query=b""),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+ id="bad response message type",
+ ),
+ ],
+)
+def test_oauth_bad_response_to_error_challenge(conn, oauth_ctx, resp_type, resp, err):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send an empty auth initial response, which will force an authn failure.
+ send_initial_response(conn, auth=b"")
+
+ # We expect a discovery "challenge" back from the server before the authn
+ # failure message.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.AuthnRequest
+
+ req = pkt.payload
+ assert req.type == pq3.authn.SASLContinue
+
+ body = json.loads(req.body)
+ assert body["status"] == "invalid_token"
+
+ if resp_type is None:
+ # Do not send the dummy response. We should time out and not get a
+ # response from the server.
+ with pytest.raises(socket.timeout):
+ conn.read(1)
+
+ # Done with the test.
+ return
+
+ # Send the bad response.
+ pq3.send(conn, resp_type, resp)
+
+ # Make sure the server fails the connection correctly.
+ pkt = pq3.recv1(conn)
+ err.match(pkt)
+
+
+@pytest.mark.parametrize(
+ "type,payload,err",
+ [
+ pytest.param(
+ pq3.types.ErrorResponse,
+ dict(fields=[b""]),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+ id="error response in initial message",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"x" * (MAX_SASL_MESSAGE_LENGTH + 1),
+ ExpectedError(
+ INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed"
+ ),
+ id="overlong initial response data",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"SCRAM-SHA-256")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE, "invalid SASL authentication mechanism"
+ ),
+ id="bad SASL mechanism selection",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=2, data=b"x")),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "insufficient data"),
+ id="SASL data underflow",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=0, data=b"x")),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "invalid message format"),
+ id="SASL data overflow",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "message is empty",
+ ),
+ id="empty",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"n,,\x01auth=\x01\x01\0")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "length does not match input length",
+ ),
+ id="contains null byte",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected channel-binding flag", # XXX this is a bit strange
+ ),
+ id="initial error response",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"p=tls-server-end-point,,\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "server does not support channel binding",
+ ),
+ id="uses channel binding",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"x,,\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected channel-binding flag",
+ ),
+ id="invalid channel binding specifier",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Comma expected",
+ ),
+ id="bad GS2 header: missing channel binding terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,a")),
+ ExpectedError(
+ FEATURE_NOT_SUPPORTED_ERRCODE,
+ "client uses authorization identity",
+ ),
+ id="bad GS2 header: authzid in use",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,b,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected attribute",
+ ),
+ id="bad GS2 header: extra attribute",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected attribute 0x00", # XXX this is a bit strange
+ ),
+ id="bad GS2 header: missing authzid terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Key-value separator expected",
+ ),
+ id="missing initial kvsep",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Key-value separator expected",
+ ),
+ id="missing initial kvsep",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "does not contain an auth value",
+ ),
+ id="missing auth value: empty key-value list",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "does not contain an auth value",
+ ),
+ id="missing auth value: other keys present",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "unterminated key/value pair",
+ ),
+ id="missing value terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not contain a final terminator",
+ ),
+ id="missing list terminator: empty list",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not contain a final terminator",
+ ),
+ id="missing list terminator: with auth value",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01\x01blah")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "additional data after the final terminator",
+ ),
+ id="additional key after terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01key\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "key without a value",
+ ),
+ id="key without value",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(
+ name=b"OAUTHBEARER",
+ data=b"y,,\x01auth=Bearer 0\x01auth=Bearer 1\x01\x01",
+ )
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "contains multiple auth values",
+ ),
+ id="multiple auth values",
+ ),
+ ],
+)
+def test_oauth_bad_initial_response(conn, oauth_ctx, type, payload, err):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # The server expects a SASL response; give it something else instead.
+ if not isinstance(payload, dict):
+ payload = dict(payload_data=payload)
+ pq3.send(conn, type, **payload)
+
+ resp = pq3.recv1(conn)
+ err.match(resp)
+
+
+def test_oauth_empty_initial_response(conn, oauth_ctx, bearer_token):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send an initial response without data.
+ initial = pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER"))
+ pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+ # The server should respond with an empty challenge so we can send the data
+ # it wants.
+ pkt = pq3.recv1(conn)
+
+ assert pkt.type == pq3.types.AuthnRequest
+ assert pkt.payload.type == pq3.authn.SASLContinue
+ assert not pkt.payload.body
+
+ # Now send the initial data.
+ data = b"n,,\x01auth=Bearer " + bearer_token() + b"\x01\x01"
+ pq3.send(conn, pq3.types.PasswordMessage, data)
+
+ # Server should now complete the handshake.
+ expect_handshake_success(conn)
+
+
+@pytest.fixture()
+def set_validator():
+ """
+ A per-test fixture that allows a test to override the setting of
+ oauth_validator_command for the cluster. The setting will be reverted during
+ teardown.
+
+ Passing None will perform an ALTER SYSTEM RESET.
+ """
+ conn = psycopg2.connect("")
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+
+ # Save the previous value.
+ c.execute("SHOW oauth_validator_command;")
+ prev_cmd = c.fetchone()[0]
+
+ def setter(cmd):
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+ c.execute("SELECT pg_reload_conf();")
+
+ yield setter
+
+ # Restore the previous value.
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (prev_cmd,))
+ c.execute("SELECT pg_reload_conf();")
+
+
+def test_oauth_no_validator(oauth_ctx, set_validator, connect, bearer_token):
+ # Clear out our validator command, then establish a new connection.
+ set_validator("")
+ conn = connect()
+
+ begin_oauth_handshake(conn, oauth_ctx)
+ send_initial_response(conn, bearer=bearer_token())
+
+ # The server should fail the connection.
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+def test_oauth_validator_role(oauth_ctx, set_validator, connect):
+ # Switch the validator implementation. This validator will reflect the
+ # PGUSER as the authenticated identity.
+ path = pathlib.Path(__file__).parent / "validate_reflect.py"
+ path = str(path.absolute())
+
+ set_validator(f"{shlex.quote(path)} '%r' <&%f")
+ conn = connect()
+
+ # Log in. Note that the reflection validator ignores the bearer token.
+ begin_oauth_handshake(conn, oauth_ctx, user=oauth_ctx.user)
+ send_initial_response(conn, bearer=b"dontcare")
+ expect_handshake_success(conn)
+
+ # Check the user identity.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ expected = oauth_ctx.user.encode("utf-8")
+ assert row.columns == [expected]
+
+
+def test_oauth_role_with_shell_unsafe_characters(oauth_ctx, set_validator, connect):
+ """
+ XXX This test pins undesirable behavior. We should be able to handle any
+ valid Postgres role name.
+ """
+ # Switch the validator implementation. This validator will reflect the
+ # PGUSER as the authenticated identity.
+ path = pathlib.Path(__file__).parent / "validate_reflect.py"
+ path = str(path.absolute())
+
+ set_validator(f"{shlex.quote(path)} '%r' <&%f")
+ conn = connect()
+
+ unsafe_username = "hello'there"
+ begin_oauth_handshake(conn, oauth_ctx, user=unsafe_username)
+
+ # The server should reject the handshake.
+ send_initial_response(conn, bearer=b"dontcare")
+ expect_handshake_failure(conn, oauth_ctx)
diff --git a/src/test/python/server/test_server.py b/src/test/python/server/test_server.py
new file mode 100644
index 0000000000..02126dba79
--- /dev/null
+++ b/src/test/python/server/test_server.py
@@ -0,0 +1,21 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import pq3
+
+
+def test_handshake(connect):
+ """Basic sanity check."""
+ conn = connect()
+
+ pq3.handshake(conn, user=pq3.pguser(), database=pq3.pgdatabase())
+
+ pq3.send(conn, pq3.types.Query, query=b"")
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.EmptyQueryResponse
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.ReadyForQuery
diff --git a/src/test/python/server/validate_bearer.py b/src/test/python/server/validate_bearer.py
new file mode 100755
index 0000000000..2cc73ff154
--- /dev/null
+++ b/src/test/python/server/validate_bearer.py
@@ -0,0 +1,101 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It doesn't actually validate
+# anything, and it logs the bearer token data, which is sensitive.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. Memory is shared and communicated from that test module's
+# bearer_token() fixture.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import base64
+import binascii
+import contextlib
+import struct
+import sys
+from multiprocessing import shared_memory
+
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def remove_shm_from_resource_tracker():
+ """
+ Monkey-patch multiprocessing.resource_tracker so SharedMemory won't be
+ tracked. Pulled from this thread, where there are more details:
+
+ https://bugs.python.org/issue38119
+
+ TL;DR: all clients of shared memory segments automatically destroy them on
+ process exit, which makes shared memory segments much less useful. This
+ monkeypatch removes that behavior so that we can defer to the test to manage
+ the segment lifetime.
+
+ Ideally a future Python patch will pull in this fix and then the entire
+ function can go away.
+ """
+ from multiprocessing import resource_tracker
+
+ def fix_register(name, rtype):
+ if rtype == "shared_memory":
+ return
+ return resource_tracker._resource_tracker.register(self, name, rtype)
+
+ resource_tracker.register = fix_register
+
+ def fix_unregister(name, rtype):
+ if rtype == "shared_memory":
+ return
+ return resource_tracker._resource_tracker.unregister(self, name, rtype)
+
+ resource_tracker.unregister = fix_unregister
+
+ if "shared_memory" in resource_tracker._CLEANUP_FUNCS:
+ del resource_tracker._CLEANUP_FUNCS["shared_memory"]
+
+
+def main(args):
+ remove_shm_from_resource_tracker() # XXX remove some day
+
+ # Get the expected token from the currently running test.
+ shared_mem_name = args[0]
+
+ mem = shared_memory.SharedMemory(shared_mem_name)
+ with contextlib.closing(mem):
+ # First two bytes are the token length.
+ size = struct.unpack("H", mem.buf[:2])[0]
+
+ if size == MAX_UINT16:
+ # Special case: the test wants us to accept any token.
+ sys.stderr.write("accepting token without validation\n")
+ return
+
+ # The remainder of the buffer contains the expected token.
+ assert size <= (mem.size - 2)
+ expected_token = mem.buf[2 : size + 2].tobytes()
+
+ mem.buf[:] = b"\0" * mem.size # scribble over the token
+
+ token = sys.stdin.buffer.read()
+ if token != expected_token:
+ sys.exit(f"failed to match Bearer token ({token!r} != {expected_token!r})")
+
+ # See if the test wants us to print anything. If so, it will have encoded
+ # the desired output in the token with an "output=" prefix.
+ try:
+ # altchars="-_" corresponds to the urlsafe alphabet.
+ data = base64.b64decode(token, altchars="-_", validate=True)
+
+ if data.startswith(b"output="):
+ sys.stdout.buffer.write(data[7:])
+
+ except binascii.Error:
+ pass
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/src/test/python/server/validate_reflect.py b/src/test/python/server/validate_reflect.py
new file mode 100755
index 0000000000..24c3a7e715
--- /dev/null
+++ b/src/test/python/server/validate_reflect.py
@@ -0,0 +1,34 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It ignores the bearer token
+# entirely and automatically logs the user in.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. It expects the user's desired role name as an argument; the
+# actual token will be discarded and the user will be logged in with the role
+# name as the authenticated identity.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import sys
+
+
+def main(args):
+ # We have to read the entire token as our first action to unblock the
+ # server, but we won't actually use it.
+ _ = sys.stdin.buffer.read()
+
+ if len(args) != 1:
+ sys.exit("usage: ./validate_reflect.py ROLE")
+
+ # Log the user in as the provided role.
+ role = args[0]
+ print(role)
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/src/test/python/test_internals.py b/src/test/python/test_internals.py
new file mode 100644
index 0000000000..dee4855fc0
--- /dev/null
+++ b/src/test/python/test_internals.py
@@ -0,0 +1,138 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import io
+
+from pq3 import _DebugStream
+
+
+def test_DebugStream_read():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+
+ stream = _DebugStream(under, out)
+
+ res = stream.read(5)
+ assert res == b"abcde"
+
+ res = stream.read(16)
+ assert res == b"fghijklmnopqrstu"
+
+ stream.flush_debug()
+
+ res = stream.read()
+ assert res == b"vwxyz"
+
+ stream.flush_debug()
+
+ expected = (
+ "< 0000:\t61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70\tabcdefghijklmnop\n"
+ "< 0010:\t71 72 73 74 75 \tqrstu\n"
+ "\n"
+ "< 0000:\t76 77 78 79 7a \tvwxyz\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_write():
+ under = io.BytesIO()
+ out = io.StringIO()
+
+ stream = _DebugStream(under, out)
+
+ stream.write(b"\x00\x01\x02")
+ stream.flush()
+
+ assert under.getvalue() == b"\x00\x01\x02"
+
+ stream.write(b"\xc0\xc1\xc2")
+ stream.flush()
+
+ assert under.getvalue() == b"\x00\x01\x02\xc0\xc1\xc2"
+
+ stream.flush_debug()
+
+ expected = "> 0000:\t00 01 02 c0 c1 c2 \t......\n\n"
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_read_write():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+ stream = _DebugStream(under, out)
+
+ res = stream.read(5)
+ assert res == b"abcde"
+
+ stream.write(b"xxxxx")
+ stream.flush()
+
+ assert under.getvalue() == b"abcdexxxxxklmnopqrstuvwxyz"
+
+ res = stream.read(5)
+ assert res == b"klmno"
+
+ stream.write(b"xxxxx")
+ stream.flush()
+
+ assert under.getvalue() == b"abcdexxxxxklmnoxxxxxuvwxyz"
+
+ stream.flush_debug()
+
+ expected = (
+ "< 0000:\t61 62 63 64 65 6b 6c 6d 6e 6f \tabcdeklmno\n"
+ "\n"
+ "> 0000:\t78 78 78 78 78 78 78 78 78 78 \txxxxxxxxxx\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_end_packet():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+ stream = _DebugStream(under, out)
+
+ stream.read(5)
+ stream.end_packet("read description", read=True, indent=" ")
+
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("write description", indent=" ")
+
+ stream.read(5)
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("read/write combo for read", read=True, indent=" ")
+
+ stream.read(5)
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("read/write combo for write", indent=" ")
+
+ expected = (
+ " < 0000:\t61 62 63 64 65 \tabcde\n"
+ "\n"
+ "< read description\n"
+ "\n"
+ "> write description\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ " < 0000:\t6b 6c 6d 6e 6f \tklmno\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ "< read/write combo for read\n"
+ "\n"
+ "> read/write combo for write\n"
+ "\n"
+ " < 0000:\t75 76 77 78 79 \tuvwxy\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
diff --git a/src/test/python/test_pq3.py b/src/test/python/test_pq3.py
new file mode 100644
index 0000000000..e0c0e0568d
--- /dev/null
+++ b/src/test/python/test_pq3.py
@@ -0,0 +1,558 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import struct
+import sys
+
+import pytest
+from construct import Container, PaddingError, StreamError, TerminatedError
+
+import pq3
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"\x00\x00\x00\x10\x00\x04\x00\x00abcdefgh",
+ Container(len=16, proto=0x40000, payload=b"abcdefgh"),
+ b"",
+ id="8-byte payload",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x08\x00\x04\x00\x00",
+ Container(len=8, proto=0x40000, payload=b""),
+ b"",
+ id="no payload",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x09\x00\x04\x00\x00abcde",
+ Container(len=9, proto=0x40000, payload=b"a"),
+ b"bcde",
+ id="1-byte payload and extra padding",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x0B\x00\x03\x00\x00hi\x00",
+ Container(len=11, proto=pq3.protocol(3, 0), payload=[b"hi"]),
+ b"",
+ id="implied parameter list when using proto version 3.0",
+ ),
+ ],
+)
+def test_Startup_parse(raw, expected, extra):
+ with io.BytesIO(raw) as stream:
+ actual = pq3.Startup.parse_stream(stream)
+
+ assert actual == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "packet,expected_bytes",
+ [
+ pytest.param(
+ dict(),
+ b"\x00\x00\x00\x08\x00\x00\x00\x00",
+ id="nothing set",
+ ),
+ pytest.param(
+ dict(len=10, proto=0x12345678),
+ b"\x00\x00\x00\x0A\x12\x34\x56\x78\x00\x00",
+ id="len and proto set explicitly",
+ ),
+ pytest.param(
+ dict(proto=0x12345678),
+ b"\x00\x00\x00\x08\x12\x34\x56\x78",
+ id="implied len with no payload",
+ ),
+ pytest.param(
+ dict(proto=0x12345678, payload=b"abcd"),
+ b"\x00\x00\x00\x0C\x12\x34\x56\x78abcd",
+ id="implied len with payload",
+ ),
+ pytest.param(
+ dict(payload=[b""]),
+ b"\x00\x00\x00\x09\x00\x03\x00\x00\x00",
+ id="implied proto version 3 when sending parameters",
+ ),
+ pytest.param(
+ dict(payload=[b"hi", b""]),
+ b"\x00\x00\x00\x0C\x00\x03\x00\x00hi\x00\x00",
+ id="implied proto version 3 and len when sending more than one parameter",
+ ),
+ pytest.param(
+ dict(payload=dict(user="jsmith", database="postgres")),
+ b"\x00\x00\x00\x27\x00\x03\x00\x00user\x00jsmith\x00database\x00postgres\x00\x00",
+ id="auto-serialization of dict parameters",
+ ),
+ ],
+)
+def test_Startup_build(packet, expected_bytes):
+ actual = pq3.Startup.build(packet)
+ assert actual == expected_bytes
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"*\x00\x00\x00\x08abcd",
+ dict(type=b"*", len=8, payload=b"abcd"),
+ b"",
+ id="4-byte payload",
+ ),
+ pytest.param(
+ b"*\x00\x00\x00\x04",
+ dict(type=b"*", len=4, payload=b""),
+ b"",
+ id="no payload",
+ ),
+ pytest.param(
+ b"*\x00\x00\x00\x05xabcd",
+ dict(type=b"*", len=5, payload=b"x"),
+ b"abcd",
+ id="1-byte payload with extra padding",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=8,
+ payload=dict(type=pq3.authn.OK, body=None),
+ ),
+ b"",
+ id="AuthenticationOk",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x12\x00\x00\x00\x0AEXTERNAL\x00\x00",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=18,
+ payload=dict(type=pq3.authn.SASL, body=[b"EXTERNAL", b""]),
+ ),
+ b"",
+ id="AuthenticationSASL",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=13,
+ payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+ ),
+ b"",
+ id="AuthenticationSASLContinue",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=13,
+ payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+ ),
+ b"",
+ id="AuthenticationSASLFinal",
+ ),
+ pytest.param(
+ b"p\x00\x00\x00\x0Bhunter2",
+ dict(
+ type=pq3.types.PasswordMessage,
+ len=11,
+ payload=b"hunter2",
+ ),
+ b"",
+ id="PasswordMessage",
+ ),
+ pytest.param(
+ b"K\x00\x00\x00\x0C\x00\x00\x00\x00\x12\x34\x56\x78",
+ dict(
+ type=pq3.types.BackendKeyData,
+ len=12,
+ payload=dict(pid=0, key=0x12345678),
+ ),
+ b"",
+ id="BackendKeyData",
+ ),
+ pytest.param(
+ b"C\x00\x00\x00\x08SET\x00",
+ dict(
+ type=pq3.types.CommandComplete,
+ len=8,
+ payload=dict(tag=b"SET"),
+ ),
+ b"",
+ id="CommandComplete",
+ ),
+ pytest.param(
+ b"E\x00\x00\x00\x11Mbad!\x00Mdog!\x00\x00",
+ dict(type=b"E", len=17, payload=dict(fields=[b"Mbad!", b"Mdog!", b""])),
+ b"",
+ id="ErrorResponse",
+ ),
+ pytest.param(
+ b"S\x00\x00\x00\x08a\x00b\x00",
+ dict(
+ type=pq3.types.ParameterStatus,
+ len=8,
+ payload=dict(name=b"a", value=b"b"),
+ ),
+ b"",
+ id="ParameterStatus",
+ ),
+ pytest.param(
+ b"Z\x00\x00\x00\x05x",
+ dict(type=b"Z", len=5, payload=dict(status=b"x")),
+ b"",
+ id="ReadyForQuery",
+ ),
+ pytest.param(
+ b"Q\x00\x00\x00\x06!\x00",
+ dict(type=pq3.types.Query, len=6, payload=dict(query=b"!")),
+ b"",
+ id="Query",
+ ),
+ pytest.param(
+ b"D\x00\x00\x00\x0B\x00\x01\x00\x00\x00\x01!",
+ dict(type=pq3.types.DataRow, len=11, payload=dict(columns=[b"!"])),
+ b"",
+ id="DataRow",
+ ),
+ pytest.param(
+ b"D\x00\x00\x00\x06\x00\x00extra",
+ dict(type=pq3.types.DataRow, len=6, payload=dict(columns=[])),
+ b"extra",
+ id="DataRow with extra data",
+ ),
+ pytest.param(
+ b"I\x00\x00\x00\x04",
+ dict(type=pq3.types.EmptyQueryResponse, len=4, payload=None),
+ b"",
+ id="EmptyQueryResponse",
+ ),
+ pytest.param(
+ b"I\x00\x00\x00\x04\xFF",
+ dict(type=b"I", len=4, payload=None),
+ b"\xFF",
+ id="EmptyQueryResponse with extra bytes",
+ ),
+ pytest.param(
+ b"X\x00\x00\x00\x04",
+ dict(type=pq3.types.Terminate, len=4, payload=None),
+ b"",
+ id="Terminate",
+ ),
+ ],
+)
+def test_Pq3_parse(raw, expected, extra):
+ with io.BytesIO(raw) as stream:
+ actual = pq3.Pq3.parse_stream(stream)
+
+ assert actual == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(type=b"*", len=5),
+ b"*\x00\x00\x00\x05\x00",
+ id="type and len set explicitly",
+ ),
+ pytest.param(
+ dict(type=b"*"),
+ b"*\x00\x00\x00\x04",
+ id="implied len with no payload",
+ ),
+ pytest.param(
+ dict(type=b"*", payload=b"1234"),
+ b"*\x00\x00\x00\x081234",
+ id="implied len with payload",
+ ),
+ pytest.param(
+ dict(type=pq3.types.AuthnRequest, payload=dict(type=pq3.authn.OK)),
+ b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+ id="implied len/type for AuthenticationOK",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(
+ type=pq3.authn.SASL,
+ body=[b"SCRAM-SHA-256-PLUS", b"SCRAM-SHA-256", b""],
+ ),
+ ),
+ b"R\x00\x00\x00\x2A\x00\x00\x00\x0ASCRAM-SHA-256-PLUS\x00SCRAM-SHA-256\x00\x00",
+ id="implied len/type for AuthenticationSASL",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+ ),
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+ id="implied len/type for AuthenticationSASLContinue",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+ ),
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+ id="implied len/type for AuthenticationSASLFinal",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.PasswordMessage,
+ payload=b"hunter2",
+ ),
+ b"p\x00\x00\x00\x0Bhunter2",
+ id="implied len/type for PasswordMessage",
+ ),
+ pytest.param(
+ dict(type=pq3.types.BackendKeyData, payload=dict(pid=1, key=7)),
+ b"K\x00\x00\x00\x0C\x00\x00\x00\x01\x00\x00\x00\x07",
+ id="implied len/type for BackendKeyData",
+ ),
+ pytest.param(
+ dict(type=pq3.types.CommandComplete, payload=dict(tag=b"SET")),
+ b"C\x00\x00\x00\x08SET\x00",
+ id="implied len/type for CommandComplete",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ErrorResponse, payload=dict(fields=[b"error", b""])),
+ b"E\x00\x00\x00\x0Berror\x00\x00",
+ id="implied len/type for ErrorResponse",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ParameterStatus, payload=dict(name=b"a", value=b"b")),
+ b"S\x00\x00\x00\x08a\x00b\x00",
+ id="implied len/type for ParameterStatus",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ReadyForQuery, payload=dict(status=b"I")),
+ b"Z\x00\x00\x00\x05I",
+ id="implied len/type for ReadyForQuery",
+ ),
+ pytest.param(
+ dict(type=pq3.types.Query, payload=dict(query=b"SELECT 1;")),
+ b"Q\x00\x00\x00\x0eSELECT 1;\x00",
+ id="implied len/type for Query",
+ ),
+ pytest.param(
+ dict(type=pq3.types.DataRow, payload=dict(columns=[b"abcd"])),
+ b"D\x00\x00\x00\x0E\x00\x01\x00\x00\x00\x04abcd",
+ id="implied len/type for DataRow",
+ ),
+ pytest.param(
+ dict(type=pq3.types.EmptyQueryResponse),
+ b"I\x00\x00\x00\x04",
+ id="implied len for EmptyQueryResponse",
+ ),
+ pytest.param(
+ dict(type=pq3.types.Terminate),
+ b"X\x00\x00\x00\x04",
+ id="implied len for Terminate",
+ ),
+ ],
+)
+def test_Pq3_build(fields, expected):
+ actual = pq3.Pq3.build(fields)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"\x00\x00",
+ dict(columns=[]),
+ b"",
+ id="no columns",
+ ),
+ pytest.param(
+ b"\x00\x01\x00\x00\x00\x04abcd",
+ dict(columns=[b"abcd"]),
+ b"",
+ id="one column",
+ ),
+ pytest.param(
+ b"\x00\x02\x00\x00\x00\x04abcd\x00\x00\x00\x01x",
+ dict(columns=[b"abcd", b"x"]),
+ b"",
+ id="multiple columns",
+ ),
+ pytest.param(
+ b"\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01x",
+ dict(columns=[b"", b"x"]),
+ b"",
+ id="empty column value",
+ ),
+ pytest.param(
+ b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+ dict(columns=[None, None]),
+ b"",
+ id="null columns",
+ ),
+ ],
+)
+def test_DataRow_parse(raw, expected, extra):
+ pkt = b"D" + struct.pack("!i", len(raw) + 4) + raw
+ with io.BytesIO(pkt) as stream:
+ actual = pq3.Pq3.parse_stream(stream)
+
+ assert actual.type == pq3.types.DataRow
+ assert actual.payload == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(),
+ b"\x00\x00",
+ id="no columns",
+ ),
+ pytest.param(
+ dict(columns=[None, None]),
+ b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+ id="null columns",
+ ),
+ ],
+)
+def test_DataRow_build(fields, expected):
+ actual = pq3.Pq3.build(dict(type=pq3.types.DataRow, payload=fields))
+
+ expected = b"D" + struct.pack("!i", len(expected) + 4) + expected
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "raw,expected,exception",
+ [
+ pytest.param(
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ dict(name=b"EXTERNAL", len=-1, data=None),
+ None,
+ id="no initial response",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\x02me",
+ dict(name=b"EXTERNAL", len=2, data=b"me"),
+ None,
+ id="initial response",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\x02meextra",
+ None,
+ TerminatedError,
+ id="extra data",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\xFFme",
+ None,
+ StreamError,
+ id="underflow",
+ ),
+ ],
+)
+def test_SASLInitialResponse_parse(raw, expected, exception):
+ ctx = contextlib.nullcontext()
+ if exception:
+ ctx = pytest.raises(exception)
+
+ with ctx:
+ actual = pq3.SASLInitialResponse.parse(raw)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(name=b"EXTERNAL"),
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ id="no initial response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=None),
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ id="no initial response (explicit None)",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=b""),
+ b"EXTERNAL\x00\x00\x00\x00\x00",
+ id="empty response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=b"me@example.com"),
+ b"EXTERNAL\x00\x00\x00\x00\x0Eme@example.com",
+ id="initial response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", len=2, data=b"me@example.com"),
+ b"EXTERNAL\x00\x00\x00\x00\x02me@example.com",
+ id="data overflow",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", len=14, data=b"me"),
+ b"EXTERNAL\x00\x00\x00\x00\x0Eme",
+ id="data underflow",
+ ),
+ ],
+)
+def test_SASLInitialResponse_build(fields, expected):
+ actual = pq3.SASLInitialResponse.build(fields)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "version,expected_bytes",
+ [
+ pytest.param((3, 0), b"\x00\x03\x00\x00", id="version 3"),
+ pytest.param((1234, 5679), b"\x04\xd2\x16\x2f", id="SSLRequest"),
+ ],
+)
+def test_protocol(version, expected_bytes):
+ # Make sure the integer returned by protocol is correctly serialized on the
+ # wire.
+ assert struct.pack("!i", pq3.protocol(*version)) == expected_bytes
+
+
+@pytest.mark.parametrize(
+ "envvar,func,expected",
+ [
+ ("PGHOST", pq3.pghost, "localhost"),
+ ("PGPORT", pq3.pgport, 5432),
+ ("PGUSER", pq3.pguser, getpass.getuser()),
+ ("PGDATABASE", pq3.pgdatabase, "postgres"),
+ ],
+)
+def test_env_defaults(monkeypatch, envvar, func, expected):
+ monkeypatch.delenv(envvar, raising=False)
+
+ actual = func()
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "envvars,func,expected",
+ [
+ (dict(PGHOST="otherhost"), pq3.pghost, "otherhost"),
+ (dict(PGPORT="6789"), pq3.pgport, 6789),
+ (dict(PGUSER="postgres"), pq3.pguser, "postgres"),
+ (dict(PGDATABASE="template1"), pq3.pgdatabase, "template1"),
+ ],
+)
+def test_env(monkeypatch, envvars, func, expected):
+ for k, v in envvars.items():
+ monkeypatch.setenv(k, v)
+
+ actual = func()
+ assert actual == expected
diff --git a/src/test/python/tls.py b/src/test/python/tls.py
new file mode 100644
index 0000000000..075c02c1ca
--- /dev/null
+++ b/src/test/python/tls.py
@@ -0,0 +1,195 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+from construct import *
+
+#
+# TLS 1.3
+#
+# Most of the types below are transcribed from RFC 8446:
+#
+# https://tools.ietf.org/html/rfc8446
+#
+
+
+def _Vector(size_field, element):
+ return Prefixed(size_field, GreedyRange(element))
+
+
+# Alerts
+
+AlertLevel = Enum(
+ Byte,
+ warning=1,
+ fatal=2,
+)
+
+AlertDescription = Enum(
+ Byte,
+ close_notify=0,
+ unexpected_message=10,
+ bad_record_mac=20,
+ decryption_failed_RESERVED=21,
+ record_overflow=22,
+ decompression_failure=30,
+ handshake_failure=40,
+ no_certificate_RESERVED=41,
+ bad_certificate=42,
+ unsupported_certificate=43,
+ certificate_revoked=44,
+ certificate_expired=45,
+ certificate_unknown=46,
+ illegal_parameter=47,
+ unknown_ca=48,
+ access_denied=49,
+ decode_error=50,
+ decrypt_error=51,
+ export_restriction_RESERVED=60,
+ protocol_version=70,
+ insufficient_security=71,
+ internal_error=80,
+ user_canceled=90,
+ no_renegotiation=100,
+ unsupported_extension=110,
+)
+
+Alert = Struct(
+ "level" / AlertLevel,
+ "description" / AlertDescription,
+)
+
+
+# Extensions
+
+ExtensionType = Enum(
+ Int16ub,
+ server_name=0,
+ max_fragment_length=1,
+ status_request=5,
+ supported_groups=10,
+ signature_algorithms=13,
+ use_srtp=14,
+ heartbeat=15,
+ application_layer_protocol_negotiation=16,
+ signed_certificate_timestamp=18,
+ client_certificate_type=19,
+ server_certificate_type=20,
+ padding=21,
+ pre_shared_key=41,
+ early_data=42,
+ supported_versions=43,
+ cookie=44,
+ psk_key_exchange_modes=45,
+ certificate_authorities=47,
+ oid_filters=48,
+ post_handshake_auth=49,
+ signature_algorithms_cert=50,
+ key_share=51,
+)
+
+Extension = Struct(
+ "extension_type" / ExtensionType,
+ "extension_data" / Prefixed(Int16ub, GreedyBytes),
+)
+
+
+# ClientHello
+
+
+class _CipherSuiteAdapter(Adapter):
+ class _hextuple(tuple):
+ def __repr__(self):
+ return f"(0x{self[0]:02X}, 0x{self[1]:02X})"
+
+ def _encode(self, obj, context, path):
+ return bytes(obj)
+
+ def _decode(self, obj, context, path):
+ assert len(obj) == 2
+ return self._hextuple(obj)
+
+
+ProtocolVersion = Hex(Int16ub)
+
+Random = Hex(Bytes(32))
+
+CipherSuite = _CipherSuiteAdapter(Byte[2])
+
+ClientHello = Struct(
+ "legacy_version" / ProtocolVersion,
+ "random" / Random,
+ "legacy_session_id" / Prefixed(Byte, Hex(GreedyBytes)),
+ "cipher_suites" / _Vector(Int16ub, CipherSuite),
+ "legacy_compression_methods" / Prefixed(Byte, GreedyBytes),
+ "extensions" / _Vector(Int16ub, Extension),
+)
+
+# ServerHello
+
+ServerHello = Struct(
+ "legacy_version" / ProtocolVersion,
+ "random" / Random,
+ "legacy_session_id_echo" / Prefixed(Byte, Hex(GreedyBytes)),
+ "cipher_suite" / CipherSuite,
+ "legacy_compression_method" / Hex(Byte),
+ "extensions" / _Vector(Int16ub, Extension),
+)
+
+# Handshake
+
+HandshakeType = Enum(
+ Byte,
+ client_hello=1,
+ server_hello=2,
+ new_session_ticket=4,
+ end_of_early_data=5,
+ encrypted_extensions=8,
+ certificate=11,
+ certificate_request=13,
+ certificate_verify=15,
+ finished=20,
+ key_update=24,
+ message_hash=254,
+)
+
+Handshake = Struct(
+ "msg_type" / HandshakeType,
+ "length" / Int24ub,
+ "payload"
+ / Switch(
+ this.msg_type,
+ {
+ HandshakeType.client_hello: ClientHello,
+ HandshakeType.server_hello: ServerHello,
+ # HandshakeType.end_of_early_data: EndOfEarlyData,
+ # HandshakeType.encrypted_extensions: EncryptedExtensions,
+ # HandshakeType.certificate_request: CertificateRequest,
+ # HandshakeType.certificate: Certificate,
+ # HandshakeType.certificate_verify: CertificateVerify,
+ # HandshakeType.finished: Finished,
+ # HandshakeType.new_session_ticket: NewSessionTicket,
+ # HandshakeType.key_update: KeyUpdate,
+ },
+ default=FixedSized(this.length, GreedyBytes),
+ ),
+)
+
+# Records
+
+ContentType = Enum(
+ Byte,
+ invalid=0,
+ change_cipher_spec=20,
+ alert=21,
+ handshake=22,
+ application_data=23,
+)
+
+Plaintext = Struct(
+ "type" / ContentType,
+ "legacy_record_version" / ProtocolVersion,
+ "length" / Int16ub,
+ "fragment" / FixedSized(this.length, GreedyBytes),
+)
--
2.25.1
On Wed, Aug 25, 2021 at 11:42 AM Jacob Champion <pchampion@vmware.com>
wrote:
On Tue, 2021-06-22 at 23:22 +0000, Jacob Champion wrote:
On Fri, 2021-06-18 at 11:31 +0300, Heikki Linnakangas wrote:
A few small things caught my eye in the backend oauth_exchange
function:
+ /* Handle the client's initial message. */ + p = strdup(input);this strdup() should be pstrdup().
Thanks, I'll fix that in the next re-roll.
In the same function, there are a bunch of reports like this:
ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Comma expected, but foundcharacter \"%s\".",
+ sanitize_char(*p))));
I don't think the double quotes are needed here, because sanitize_char
will return quotes if it's a single character. So it would end up
looking like this: ... found character "'x'".I'll fix this too. Thanks!
v2, attached, incorporates Heikki's suggested fixes and also rebases on
top of latest HEAD, which had the SASL refactoring changes committed
last month.The biggest change from the last patchset is 0001, an attempt at
enabling jsonapi in the frontend without the use of palloc(), based on
suggestions by Michael and Tom from last commitfest. I've also made
some improvements to the pytest suite. No major changes to the OAuth
implementation yet.--Jacob
Hi,
For v2-0001-common-jsonapi-support-FRONTEND-clients.patch :
+ /* Clean up. */
+ termJsonLexContext(&lex);
At the end of termJsonLexContext(), empty is copied to lex. For stack
based JsonLexContext, the copy seems unnecessary.
Maybe introduce a boolean parameter for termJsonLexContext() to signal that
the copy can be omitted ?
+#ifdef FRONTEND
+ /* make sure initialization succeeded */
+ if (lex->strval == NULL)
+ return JSON_OUT_OF_MEMORY;
Should PQExpBufferBroken(lex->strval) be used for the check ?
Thanks
On Wed, Aug 25, 2021 at 3:25 PM Zhihong Yu <zyu@yugabyte.com> wrote:
On Wed, Aug 25, 2021 at 11:42 AM Jacob Champion <pchampion@vmware.com>
wrote:On Tue, 2021-06-22 at 23:22 +0000, Jacob Champion wrote:
On Fri, 2021-06-18 at 11:31 +0300, Heikki Linnakangas wrote:
A few small things caught my eye in the backend oauth_exchange
function:
+ /* Handle the client's initial message. */ + p = strdup(input);this strdup() should be pstrdup().
Thanks, I'll fix that in the next re-roll.
In the same function, there are a bunch of reports like this:
ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Comma expected, but foundcharacter \"%s\".",
+ sanitize_char(*p))));
I don't think the double quotes are needed here, because
sanitize_char
will return quotes if it's a single character. So it would end up
looking like this: ... found character "'x'".I'll fix this too. Thanks!
v2, attached, incorporates Heikki's suggested fixes and also rebases on
top of latest HEAD, which had the SASL refactoring changes committed
last month.The biggest change from the last patchset is 0001, an attempt at
enabling jsonapi in the frontend without the use of palloc(), based on
suggestions by Michael and Tom from last commitfest. I've also made
some improvements to the pytest suite. No major changes to the OAuth
implementation yet.--Jacob
Hi,
For v2-0001-common-jsonapi-support-FRONTEND-clients.patch :+ /* Clean up. */
+ termJsonLexContext(&lex);At the end of termJsonLexContext(), empty is copied to lex. For stack
based JsonLexContext, the copy seems unnecessary.
Maybe introduce a boolean parameter for termJsonLexContext() to signal
that the copy can be omitted ?+#ifdef FRONTEND + /* make sure initialization succeeded */ + if (lex->strval == NULL) + return JSON_OUT_OF_MEMORY;Should PQExpBufferBroken(lex->strval) be used for the check ?
Thanks
Hi,
For v2-0002-libpq-add-OAUTHBEARER-SASL-mechanism.patch :
+ i_init_session(&session);
+
+ if (!conn->oauth_client_id)
+ {
+ /* We can't talk to a server without a client identifier. */
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("no oauth_client_id is set for
the connection"));
+ goto cleanup;
Can conn->oauth_client_id check be performed ahead of i_init_session() ?
That way, ```goto cleanup``` can be replaced with return.
+ if (!error_code || (strcmp(error_code, "authorization_pending")
+ && strcmp(error_code, "slow_down")))
What if, in the future, there is error code different from the above two
which doesn't represent "OAuth token retrieval failed" scenario ?
For client_initial_response(),
+ token_buf = createPQExpBuffer();
+ if (!token_buf)
+ goto cleanup;
If token_buf is NULL, there doesn't seem to be anything to free. We can
return directly.
Cheers
On Wed, 2021-08-25 at 15:25 -0700, Zhihong Yu wrote:
Hi,
For v2-0001-common-jsonapi-support-FRONTEND-clients.patch :+ /* Clean up. */
+ termJsonLexContext(&lex);At the end of termJsonLexContext(), empty is copied to lex. For stack
based JsonLexContext, the copy seems unnecessary.
Maybe introduce a boolean parameter for termJsonLexContext() to
signal that the copy can be omitted ?
Do you mean heap-based? i.e. destroyJsonLexContext() does an
unnecessary copy before free? Yeah, in that case it's not super useful,
but I think I'd want some evidence that the performance hit matters
before optimizing it.
Are there any other internal APIs that take a boolean parameter like
that? If not, I think we'd probably just want to remove the copy
entirely if it's a problem.
+#ifdef FRONTEND + /* make sure initialization succeeded */ + if (lex->strval == NULL) + return JSON_OUT_OF_MEMORY;Should PQExpBufferBroken(lex->strval) be used for the check ?
It should be okay to continue if the strval is broken but non-NULL,
since it's about to be reset. That has the fringe benefit of allowing
the function to go as far as possible without failing, though that's
probably a pretty weak justification.
In practice, do you think that the probability of success is low enough
that we should just short-circuit and be done with it?
On Wed, 2021-08-25 at 16:24 -0700, Zhihong Yu wrote:
For v2-0002-libpq-add-OAUTHBEARER-SASL-mechanism.patch :
+ i_init_session(&session); + + if (!conn->oauth_client_id) + { + /* We can't talk to a server without a client identifier. */ + appendPQExpBufferStr(&conn->errorMessage, + libpq_gettext("no oauth_client_id is set for the connection")); + goto cleanup;Can conn->oauth_client_id check be performed ahead
of i_init_session() ? That way, ```goto cleanup``` can be replaced
with return.
Yeah, I think that makes sense. FYI, this is probably one of the
functions that will be rewritten completely once iddawc is removed.
+ if (!error_code || (strcmp(error_code, "authorization_pending") + && strcmp(error_code, "slow_down")))What if, in the future, there is error code different from the above
two which doesn't represent "OAuth token retrieval failed" scenario ?
We'd have to update our code; that would be a breaking change to the
Device Authorization spec. Here's what it says today [1]https://datatracker.ietf.org/doc/html/rfc8628#section-3.5:
The "authorization_pending" and "slow_down" error codes define
particularly unique behavior, as they indicate that the OAuth client
should continue to poll the token endpoint by repeating the token
request (implementing the precise behavior defined above). If the
client receives an error response with any other error code, it MUST
stop polling and SHOULD react accordingly, for example, by displaying
an error to the user.
For client_initial_response(),
+ token_buf = createPQExpBuffer(); + if (!token_buf) + goto cleanup;If token_buf is NULL, there doesn't seem to be anything to free. We
can return directly.
That's true today, but implementations have a habit of changing. I
personally prefer not to introduce too many exit points from a function
that's already using goto. In my experience, that makes future
maintenance harder.
Thanks for the reviews! Have you been able to give the patchset a try
with an OAuth deployment?
--Jacob
[1]: https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
On Thu, Aug 26, 2021 at 9:13 AM Jacob Champion <pchampion@vmware.com> wrote:
On Wed, 2021-08-25 at 15:25 -0700, Zhihong Yu wrote:
Hi,
For v2-0001-common-jsonapi-support-FRONTEND-clients.patch :+ /* Clean up. */
+ termJsonLexContext(&lex);At the end of termJsonLexContext(), empty is copied to lex. For stack
based JsonLexContext, the copy seems unnecessary.
Maybe introduce a boolean parameter for termJsonLexContext() to
signal that the copy can be omitted ?Do you mean heap-based? i.e. destroyJsonLexContext() does an
unnecessary copy before free? Yeah, in that case it's not super useful,
but I think I'd want some evidence that the performance hit matters
before optimizing it.Are there any other internal APIs that take a boolean parameter like
that? If not, I think we'd probably just want to remove the copy
entirely if it's a problem.+#ifdef FRONTEND + /* make sure initialization succeeded */ + if (lex->strval == NULL) + return JSON_OUT_OF_MEMORY;Should PQExpBufferBroken(lex->strval) be used for the check ?
It should be okay to continue if the strval is broken but non-NULL,
since it's about to be reset. That has the fringe benefit of allowing
the function to go as far as possible without failing, though that's
probably a pretty weak justification.In practice, do you think that the probability of success is low enough
that we should just short-circuit and be done with it?On Wed, 2021-08-25 at 16:24 -0700, Zhihong Yu wrote:
For v2-0002-libpq-add-OAUTHBEARER-SASL-mechanism.patch :
+ i_init_session(&session); + + if (!conn->oauth_client_id) + { + /* We can't talk to a server without a client identifier. */ + appendPQExpBufferStr(&conn->errorMessage, + libpq_gettext("no oauth_client_id is setfor the connection"));
+ goto cleanup;
Can conn->oauth_client_id check be performed ahead
of i_init_session() ? That way, ```goto cleanup``` can be replaced
with return.Yeah, I think that makes sense. FYI, this is probably one of the
functions that will be rewritten completely once iddawc is removed.+ if (!error_code || (strcmp(error_code, "authorization_pending") + && strcmp(error_code, "slow_down")))What if, in the future, there is error code different from the above
two which doesn't represent "OAuth token retrieval failed" scenario ?We'd have to update our code; that would be a breaking change to the
Device Authorization spec. Here's what it says today [1]:The "authorization_pending" and "slow_down" error codes define
particularly unique behavior, as they indicate that the OAuth client
should continue to poll the token endpoint by repeating the token
request (implementing the precise behavior defined above). If the
client receives an error response with any other error code, it MUST
stop polling and SHOULD react accordingly, for example, by displaying
an error to the user.For client_initial_response(),
+ token_buf = createPQExpBuffer(); + if (!token_buf) + goto cleanup;If token_buf is NULL, there doesn't seem to be anything to free. We
can return directly.That's true today, but implementations have a habit of changing. I
personally prefer not to introduce too many exit points from a function
that's already using goto. In my experience, that makes future
maintenance harder.Thanks for the reviews! Have you been able to give the patchset a try
with an OAuth deployment?--Jacob
[1] https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
Hi,
bq. destroyJsonLexContext() does an unnecessary copy before free? Yeah, in
that case it's not super useful,
but I think I'd want some evidence that the performance hit matters before
optimizing it.
Yes I agree.
bq. In practice, do you think that the probability of success is low enough
that we should just short-circuit and be done with it?
Haven't had a chance to try your patches out yet.
I will leave this to people who are more familiar with OAuth
implementation(s).
bq. I personally prefer not to introduce too many exit points from a
function that's already using goto.
Fair enough.
Cheers
On Thu, Aug 26, 2021 at 04:13:08PM +0000, Jacob Champion wrote:
Do you mean heap-based? i.e. destroyJsonLexContext() does an
unnecessary copy before free? Yeah, in that case it's not super useful,
but I think I'd want some evidence that the performance hit matters
before optimizing it.
As an authentication code path, the impact is minimal and my take on
that would be to keep the code simple. Now if you'd really wish to
stress that without relying on the backend, one simple way is to use
pgbench -C -n with a mostly-empty script (one meta-command) coupled
with some profiling.
--
Michael
On Fri, 2021-08-27 at 11:32 +0900, Michael Paquier wrote:
Now if you'd really wish to
stress that without relying on the backend, one simple way is to use
pgbench -C -n with a mostly-empty script (one meta-command) coupled
with some profiling.
Ah, thanks! I'll add that to the toolbox.
--Jacob
Hi all,
v3 rebases this patchset over the top of Samay's pluggable auth
provider API [1]/messages/by-id/CAJxrbyxTRn5P8J-p+wHLwFahK5y56PhK28VOb55jqMO05Y-DJw@mail.gmail.com, included here as patches 0001-3. The final patch in
the set ports the server implementation from a core feature to a
contrib module; to switch between the two approaches, simply leave out
that final patch.
There are still some backend changes that must be made to get this
working, as pointed out in 0009, and obviously libpq support still
requires code changes.
--Jacob
[1]: /messages/by-id/CAJxrbyxTRn5P8J-p+wHLwFahK5y56PhK28VOb55jqMO05Y-DJw@mail.gmail.com
Attachments:
v3-0001-Add-support-for-custom-authentication-methods.patchtext/x-patch; name=v3-0001-Add-support-for-custom-authentication-methods.patchDownload
From 206060ed1b31fcec48fb6ee05d61b135ec98cecf Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Tue, 15 Feb 2022 22:23:29 -0800
Subject: [PATCH v3 1/9] Add support for custom authentication methods
Currently, PostgreSQL supports only a set of pre-defined authentication
methods. This patch adds support for 2 hooks which allow users to add
their custom authentication methods by defining a check function and an
error function. Users can then use these methods by using a new "custom"
keyword in pg_hba.conf and specifying the authentication provider they
want to use.
---
src/backend/libpq/auth.c | 89 ++++++++++++++++++++++++++++++----------
src/backend/libpq/hba.c | 44 ++++++++++++++++++++
src/include/libpq/auth.h | 27 ++++++++++++
src/include/libpq/hba.h | 4 ++
4 files changed, 143 insertions(+), 21 deletions(-)
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index efc53f3135..3533b0bc50 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -47,8 +47,6 @@
*----------------------------------------------------------------
*/
static void auth_failed(Port *port, int status, const char *logdetail);
-static char *recv_password_packet(Port *port);
-static void set_authn_id(Port *port, const char *id);
/*----------------------------------------------------------------
@@ -206,23 +204,6 @@ static int pg_SSPI_make_upn(char *accountname,
static int CheckRADIUSAuth(Port *port);
static int PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd);
-
-/*
- * Maximum accepted size of GSS and SSPI authentication tokens.
- * We also use this as a limit on ordinary password packet lengths.
- *
- * Kerberos tickets are usually quite small, but the TGTs issued by Windows
- * domain controllers include an authorization field known as the Privilege
- * Attribute Certificate (PAC), which contains the user's Windows permissions
- * (group memberships etc.). The PAC is copied into all tickets obtained on
- * the basis of this TGT (even those issued by Unix realms which the Windows
- * realm trusts), and can be several kB in size. The maximum token size
- * accepted by Windows systems is determined by the MaxAuthToken Windows
- * registry setting. Microsoft recommends that it is not set higher than
- * 65535 bytes, so that seems like a reasonable limit for us as well.
- */
-#define PG_MAX_AUTH_TOKEN_LENGTH 65535
-
/*----------------------------------------------------------------
* Global authentication functions
*----------------------------------------------------------------
@@ -235,6 +216,16 @@ static int PerformRadiusTransaction(const char *server, const char *secret, cons
*/
ClientAuthentication_hook_type ClientAuthentication_hook = NULL;
+/*
+ * These hooks allow plugins to get control of the client authentication check
+ * and error reporting logic. This allows users to write extensions to
+ * implement authentication using any protocol of their choice. To acquire these
+ * hooks, plugins need to call the RegisterAuthProvider() function.
+ */
+static CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook = NULL;
+static CustomAuthenticationError_hook_type CustomAuthenticationError_hook = NULL;
+char *custom_provider_name = NULL;
+
/*
* Tell the user the authentication failed, but not (much about) why.
*
@@ -311,6 +302,12 @@ auth_failed(Port *port, int status, const char *logdetail)
case uaRADIUS:
errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
break;
+ case uaCustom:
+ if (CustomAuthenticationError_hook)
+ errstr = CustomAuthenticationError_hook(port);
+ else
+ errstr = gettext_noop("Custom authentication failed for user \"%s\"");
+ break;
default:
errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method");
break;
@@ -345,7 +342,7 @@ auth_failed(Port *port, int status, const char *logdetail)
* lifetime of the Port, so it is safe to pass a string that is managed by an
* external library.
*/
-static void
+void
set_authn_id(Port *port, const char *id)
{
Assert(id);
@@ -630,6 +627,10 @@ ClientAuthentication(Port *port)
case uaTrust:
status = STATUS_OK;
break;
+ case uaCustom:
+ if (CustomAuthenticationCheck_hook)
+ status = CustomAuthenticationCheck_hook(port);
+ break;
}
if ((status == STATUS_OK && port->hba->clientcert == clientCertFull)
@@ -689,7 +690,7 @@ sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extrale
*
* Returns NULL if couldn't get password, else palloc'd string.
*/
-static char *
+char *
recv_password_packet(Port *port)
{
StringInfoData buf;
@@ -3343,3 +3344,49 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
}
} /* while (true) */
}
+
+/*----------------------------------------------------------------
+ * Custom authentication
+ *----------------------------------------------------------------
+ */
+
+/*
+ * RegisterAuthProvider registers a custom authentication provider to be
+ * used for authentication. Currently, we allow only one authentication
+ * provider to be registered for use at a time.
+ *
+ * This function should be called in _PG_init() by any extension looking to
+ * add a custom authentication method.
+ */
+void RegisterAuthProvider(const char *provider_name,
+ CustomAuthenticationCheck_hook_type AuthenticationCheckFunction,
+ CustomAuthenticationError_hook_type AuthenticationErrorFunction)
+{
+ if (provider_name == NULL)
+ {
+ ereport(ERROR,
+ (errmsg("cannot register authentication provider without name")));
+ }
+
+ if (AuthenticationCheckFunction == NULL)
+ {
+ ereport(ERROR,
+ (errmsg("cannot register authentication provider without a check function")));
+ }
+
+ if (custom_provider_name)
+ {
+ ereport(ERROR,
+ (errmsg("cannot register authentication provider %s", provider_name),
+ errdetail("Only one authentication provider allowed. Provider %s is already registered.",
+ custom_provider_name)));
+ }
+
+ /*
+ * Allocate in top memory context as we need to read this whenever
+ * we parse pg_hba.conf
+ */
+ custom_provider_name = MemoryContextStrdup(TopMemoryContext,provider_name);
+ CustomAuthenticationCheck_hook = AuthenticationCheckFunction;
+ CustomAuthenticationError_hook = AuthenticationErrorFunction;
+}
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index d84a40b726..ebae992964 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -134,6 +134,7 @@ static const char *const UserAuthName[] =
"ldap",
"cert",
"radius",
+ "custom",
"peer"
};
@@ -1399,6 +1400,8 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
#endif
else if (strcmp(token->string, "radius") == 0)
parsedline->auth_method = uaRADIUS;
+ else if (strcmp(token->string, "custom") == 0)
+ parsedline->auth_method = uaCustom;
else
{
ereport(elevel,
@@ -1691,6 +1694,14 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
parsedline->clientcert = clientCertFull;
}
+ /*
+ * Ensure that the provider name is specified for custom authentication method.
+ */
+ if (parsedline->auth_method == uaCustom)
+ {
+ MANDATORY_AUTH_ARG(parsedline->custom_provider, "provider", "custom");
+ }
+
return parsedline;
}
@@ -2102,6 +2113,32 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->radiusidentifiers = parsed_identifiers;
hbaline->radiusidentifiers_s = pstrdup(val);
}
+ else if (strcmp(name, "provider") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaCustom, "provider", "custom");
+
+ /*
+ * Verify that the provider mentioned is same as the one loaded
+ * via shared_preload_libraries.
+ */
+
+ if (custom_provider_name == NULL || strcmp(val,custom_provider_name) != 0)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("cannot use authentication provider %s",val),
+ errhint("Load authentication provider via shared_preload_libraries."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("cannot use authentication provider %s", val);
+
+ return false;
+ }
+ else
+ {
+ hbaline->custom_provider = pstrdup(val);
+ }
+ }
else
{
ereport(elevel,
@@ -2442,6 +2479,13 @@ gethba_options(HbaLine *hba)
CStringGetTextDatum(psprintf("radiusports=%s", hba->radiusports_s));
}
+ if (hba->auth_method == uaCustom)
+ {
+ if (hba->custom_provider)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("provider=%s",hba->custom_provider));
+ }
+
/* If you add more options, consider increasing MAX_HBA_OPTIONS. */
Assert(noptions <= MAX_HBA_OPTIONS);
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 6d7ee1acb9..1d10cccc1b 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -23,9 +23,36 @@ extern char *pg_krb_realm;
extern void ClientAuthentication(Port *port);
extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
int extralen);
+extern void set_authn_id(Port *port, const char *id);
+extern char *recv_password_packet(Port *port);
/* Hook for plugins to get control in ClientAuthentication() */
+typedef int (*CustomAuthenticationCheck_hook_type) (Port *);
typedef void (*ClientAuthentication_hook_type) (Port *, int);
extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook;
+/* Hook for plugins to report error messages in auth_failed() */
+typedef const char * (*CustomAuthenticationError_hook_type) (Port *);
+
+extern void RegisterAuthProvider
+ (const char *provider_name,
+ CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook,
+ CustomAuthenticationError_hook_type CustomAuthenticationError_hook);
+
+/*
+ * Maximum accepted size of GSS and SSPI authentication tokens.
+ * We also use this as a limit on ordinary password packet lengths.
+ *
+ * Kerberos tickets are usually quite small, but the TGTs issued by Windows
+ * domain controllers include an authorization field known as the Privilege
+ * Attribute Certificate (PAC), which contains the user's Windows permissions
+ * (group memberships etc.). The PAC is copied into all tickets obtained on
+ * the basis of this TGT (even those issued by Unix realms which the Windows
+ * realm trusts), and can be several kB in size. The maximum token size
+ * accepted by Windows systems is determined by the MaxAuthToken Windows
+ * registry setting. Microsoft recommends that it is not set higher than
+ * 65535 bytes, so that seems like a reasonable limit for us as well.
+ */
+#define PG_MAX_AUTH_TOKEN_LENGTH 65535
+
#endif /* AUTH_H */
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 8d9f3821b1..c5aef6994c 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -38,6 +38,7 @@ typedef enum UserAuth
uaLDAP,
uaCert,
uaRADIUS,
+ uaCustom,
uaPeer
#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */
} UserAuth;
@@ -120,6 +121,7 @@ typedef struct HbaLine
char *radiusidentifiers_s;
List *radiusports;
char *radiusports_s;
+ char *custom_provider;
} HbaLine;
typedef struct IdentLine
@@ -144,4 +146,6 @@ extern int check_usermap(const char *usermap_name,
bool case_sensitive);
extern bool pg_isblank(const char c);
+extern char *custom_provider_name;
+
#endif /* HBA_H */
--
2.25.1
v3-0002-Add-sample-extension-to-test-custom-auth-provider.patchtext/x-patch; name=v3-0002-Add-sample-extension-to-test-custom-auth-provider.patchDownload
From 93b108334fa9fcc2d5f68d75fc320cd2714eb9da Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Tue, 15 Feb 2022 22:28:40 -0800
Subject: [PATCH v3 2/9] Add sample extension to test custom auth provider
hooks
This change adds a new extension to src/test/modules to
test the custom authentication provider hooks. In this
extension, we use an array to define which users to
authenticate and what passwords to use.
---
src/test/modules/test_auth_provider/Makefile | 16 ++++
.../test_auth_provider/test_auth_provider.c | 90 +++++++++++++++++++
2 files changed, 106 insertions(+)
create mode 100644 src/test/modules/test_auth_provider/Makefile
create mode 100644 src/test/modules/test_auth_provider/test_auth_provider.c
diff --git a/src/test/modules/test_auth_provider/Makefile b/src/test/modules/test_auth_provider/Makefile
new file mode 100644
index 0000000000..17971a5c7a
--- /dev/null
+++ b/src/test/modules/test_auth_provider/Makefile
@@ -0,0 +1,16 @@
+# src/test/modules/test_auth_provider/Makefile
+
+MODULE_big = test_auth_provider
+OBJS = test_auth_provider.o
+PGFILEDESC = "test_auth_provider - provider to test auth hooks"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = src/test/modules/test_auth_provider
+top_builddir = ../../../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/src/test/modules/test_auth_provider/test_auth_provider.c b/src/test/modules/test_auth_provider/test_auth_provider.c
new file mode 100644
index 0000000000..477ef8b2c3
--- /dev/null
+++ b/src/test/modules/test_auth_provider/test_auth_provider.c
@@ -0,0 +1,90 @@
+/* -------------------------------------------------------------------------
+ *
+ * test_auth_provider.c
+ * example authentication provider plugin
+ *
+ * Copyright (c) 2022, PostgreSQL Global Development Group
+ *
+ * IDENTIFICATION
+ * contrib/test_auth_provider/test_auth_provider.c
+ *
+ * -------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+#include "fmgr.h"
+#include "libpq/auth.h"
+#include "libpq/libpq.h"
+
+PG_MODULE_MAGIC;
+
+void _PG_init(void);
+
+static char *get_password_for_user(char *user_name);
+
+/*
+ * List of usernames / passwords to approve. Here we are not
+ * getting passwords from Postgres but from this list. In a more real-life
+ * extension, you can fetch valid credentials and authentication tokens /
+ * passwords from an external authentication provider.
+ */
+char credentials[3][3][50] = {
+ {"bob","alice","carol"},
+ {"bob123","alice123","carol123"}
+};
+
+static int TestAuthenticationCheck(Port *port)
+{
+ char *passwd;
+ int result = STATUS_ERROR;
+ char *real_pass;
+
+ sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
+
+ passwd = recv_password_packet(port);
+ if (passwd == NULL)
+ return STATUS_EOF;
+
+ real_pass = get_password_for_user(port->user_name);
+ if (real_pass)
+ {
+ if(strcmp(passwd, real_pass) == 0)
+ {
+ result = STATUS_OK;
+ }
+ pfree(real_pass);
+ }
+
+ pfree(passwd);
+
+ return result;
+}
+
+static char *
+get_password_for_user(char *user_name)
+{
+ char *password = NULL;
+ int i;
+ for (i=0; i<3; i++)
+ {
+ if (strcmp(user_name, credentials[0][i]) == 0)
+ {
+ password = pstrdup(credentials[1][i]);
+ }
+ }
+
+ return password;
+}
+
+static const char *TestAuthenticationError(Port *port)
+{
+ char *error_message = (char *)palloc (100);
+ sprintf(error_message, "Test authentication failed for user %s", port->user_name);
+ return error_message;
+}
+
+void
+_PG_init(void)
+{
+ RegisterAuthProvider("test", TestAuthenticationCheck, TestAuthenticationError);
+}
--
2.25.1
v3-0003-Add-tests-for-test_auth_provider-extension.patchtext/x-patch; name=v3-0003-Add-tests-for-test_auth_provider-extension.patchDownload
From 3aa2535fa42b142beabdcef234d1939738f36b9a Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Wed, 16 Feb 2022 12:28:36 -0800
Subject: [PATCH v3 3/9] Add tests for test_auth_provider extension
Add tap tests for test_auth_provider extension allow make check in
src/test/modules to run them.
---
src/test/modules/Makefile | 1 +
src/test/modules/test_auth_provider/Makefile | 2 +
.../test_auth_provider/t/001_custom_auth.pl | 125 ++++++++++++++++++
3 files changed, 128 insertions(+)
create mode 100644 src/test/modules/test_auth_provider/t/001_custom_auth.pl
diff --git a/src/test/modules/Makefile b/src/test/modules/Makefile
index dffc79b2d9..f56533ea13 100644
--- a/src/test/modules/Makefile
+++ b/src/test/modules/Makefile
@@ -14,6 +14,7 @@ SUBDIRS = \
plsample \
snapshot_too_old \
spgist_name_ops \
+ test_auth_provider \
test_bloomfilter \
test_ddl_deparse \
test_extensions \
diff --git a/src/test/modules/test_auth_provider/Makefile b/src/test/modules/test_auth_provider/Makefile
index 17971a5c7a..7d601cf7d5 100644
--- a/src/test/modules/test_auth_provider/Makefile
+++ b/src/test/modules/test_auth_provider/Makefile
@@ -4,6 +4,8 @@ MODULE_big = test_auth_provider
OBJS = test_auth_provider.o
PGFILEDESC = "test_auth_provider - provider to test auth hooks"
+TAP_TESTS = 1
+
ifdef USE_PGXS
PG_CONFIG = pg_config
PGXS := $(shell $(PG_CONFIG) --pgxs)
diff --git a/src/test/modules/test_auth_provider/t/001_custom_auth.pl b/src/test/modules/test_auth_provider/t/001_custom_auth.pl
new file mode 100644
index 0000000000..3b7472dc7f
--- /dev/null
+++ b/src/test/modules/test_auth_provider/t/001_custom_auth.pl
@@ -0,0 +1,125 @@
+
+# Copyright (c) 2021-2022, PostgreSQL Global Development Group
+
+# Set of tests for testing custom authentication hooks.
+
+use strict;
+use warnings;
+use PostgreSQL::Test::Cluster;
+use PostgreSQL::Test::Utils;
+use Test::More;
+
+# Delete pg_hba.conf from the given node, add a new entry to it
+# and then execute a reload to refresh it.
+sub reset_pg_hba
+{
+ my $node = shift;
+ my $hba_method = shift;
+
+ unlink($node->data_dir . '/pg_hba.conf');
+ # just for testing purposes, use a continuation line
+ $node->append_conf('pg_hba.conf', "local all all\\\n $hba_method");
+ $node->reload;
+ return;
+}
+
+# Test if you get expected results in pg_hba_file_rules error column after
+# changing pg_hba.conf and reloading it.
+sub test_hba_reload
+{
+ my ($node, $method, $expected_res) = @_;
+ my $status_string = 'failed';
+ $status_string = 'success' if ($expected_res eq 0);
+ my $testname = "pg_hba.conf reload $status_string for method $method";
+
+ reset_pg_hba($node, $method);
+
+ my ($cmdret, $stdout, $stderr) = $node->psql("postgres",
+ "select count(*) from pg_hba_file_rules where error is not null",extra_params => ['-U','bob']);
+
+ is($stdout, $expected_res, $testname);
+}
+
+# Test access for a single role, useful to wrap all tests into one. Extra
+# named parameters are passed to connect_ok/fails as-is.
+sub test_role
+{
+ local $Test::Builder::Level = $Test::Builder::Level + 1;
+
+ my ($node, $role, $method, $expected_res, %params) = @_;
+ my $status_string = 'failed';
+ $status_string = 'success' if ($expected_res eq 0);
+
+ my $connstr = "user=$role";
+ my $testname =
+ "authentication $status_string for method $method, role $role";
+
+ if ($expected_res eq 0)
+ {
+ $node->connect_ok($connstr, $testname, %params);
+ }
+ else
+ {
+ # No checks of the error message, only the status code.
+ $node->connect_fails($connstr, $testname, %params);
+ }
+}
+
+# Initialize server node
+my $node = PostgreSQL::Test::Cluster->new('server');
+$node->init;
+$node->append_conf('postgresql.conf', "log_connections = on\n");
+$node->append_conf('postgresql.conf', "shared_preload_libraries = 'test_auth_provider.so'\n");
+$node->start;
+
+$node->safe_psql('postgres', "CREATE ROLE bob SUPERUSER LOGIN;");
+$node->safe_psql('postgres', "CREATE ROLE alice LOGIN;");
+$node->safe_psql('postgres', "CREATE ROLE test LOGIN;");
+
+# Add custom auth method to pg_hba.conf
+reset_pg_hba($node, 'custom provider=test');
+
+# Test that users are able to login with correct passwords.
+$ENV{"PGPASSWORD"} = 'bob123';
+test_role($node, 'bob', 'custom', 0, log_like => [qr/connection authorized: user=bob/]);
+$ENV{"PGPASSWORD"} = 'alice123';
+test_role($node, 'alice', 'custom', 0, log_like => [qr/connection authorized: user=alice/]);
+
+# Test that bad passwords are rejected.
+$ENV{"PGPASSWORD"} = 'badpassword';
+test_role($node, 'bob', 'custom', 2, log_unlike => [qr/connection authorized:/]);
+test_role($node, 'alice', 'custom', 2, log_unlike => [qr/connection authorized:/]);
+
+# Test that users not in authentication list are rejected.
+test_role($node, 'test', 'custom', 2, log_unlike => [qr/connection authorized:/]);
+
+$ENV{"PGPASSWORD"} = 'bob123';
+
+# Tests for invalid auth options
+
+# Test that an incorrect provider name is not accepted.
+test_hba_reload($node, 'custom provider=wrong', 1);
+
+# Test that specifying provider option with different auth method is not allowed.
+test_hba_reload($node, 'trust provider=test', 1);
+
+# Test that provider name is a mandatory option for custom auth.
+test_hba_reload($node, 'custom', 1);
+
+# Test that correct provider name allows reload to succeed.
+test_hba_reload($node, 'custom provider=test', 0);
+
+# Custom auth modules require mentioning extension in shared_preload_libraries.
+
+# Remove extension from shared_preload_libraries and try to restart.
+$node->adjust_conf('postgresql.conf', 'shared_preload_libraries', "''");
+command_fails(['pg_ctl', '-w', '-D', $node->data_dir, '-l', $node->logfile, 'restart'],'restart with empty shared_preload_libraries failed');
+
+# Fix shared_preload_libraries and confirm that you can now restart.
+$node->adjust_conf('postgresql.conf', 'shared_preload_libraries', "'test_auth_provider.so'");
+command_ok(['pg_ctl', '-w', '-D', $node->data_dir, '-l', $node->logfile,'start'],'restart with correct shared_preload_libraries succeeded');
+
+# Test that we can connect again
+test_role($node, 'bob', 'custom', 0, log_like => [qr/connection authorized: user=bob/]);
+
+done_testing();
--
2.25.1
v3-0004-common-jsonapi-support-FRONTEND-clients.patchtext/x-patch; name=v3-0004-common-jsonapi-support-FRONTEND-clients.patchDownload
From 13bee49d8c674e921804b4e6edc363dcf33211ce Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Mon, 3 May 2021 15:38:26 -0700
Subject: [PATCH v3 4/9] common/jsonapi: support FRONTEND clients
Based on a patch by Michael Paquier.
For frontend code, use PQExpBuffer instead of StringInfo. This requires
us to track allocation failures so that we can return JSON_OUT_OF_MEMORY
as needed. json_errdetail() now allocates its error message inside
memory owned by the JsonLexContext, so clients don't need to worry about
freeing it.
For convenience, the backend now has destroyJsonLexContext() to mirror
other create/destroy APIs. The frontend has init/term versions of the
API to handle stack-allocated JsonLexContexts.
We can now partially revert b44669b2ca, now that json_errdetail() works
correctly.
---
src/backend/utils/adt/jsonfuncs.c | 4 +-
src/bin/pg_verifybackup/parse_manifest.c | 13 +-
src/bin/pg_verifybackup/t/005_bad_manifest.pl | 2 +-
src/common/Makefile | 2 +-
src/common/jsonapi.c | 290 +++++++++++++-----
src/include/common/jsonapi.h | 47 ++-
6 files changed, 270 insertions(+), 88 deletions(-)
diff --git a/src/backend/utils/adt/jsonfuncs.c b/src/backend/utils/adt/jsonfuncs.c
index 2457061f97..f58233cda9 100644
--- a/src/backend/utils/adt/jsonfuncs.c
+++ b/src/backend/utils/adt/jsonfuncs.c
@@ -723,9 +723,7 @@ json_object_keys(PG_FUNCTION_ARGS)
pg_parse_json_or_ereport(lex, sem);
/* keys are now in state->result */
- pfree(lex->strval->data);
- pfree(lex->strval);
- pfree(lex);
+ destroyJsonLexContext(lex);
pfree(sem);
MemoryContextSwitchTo(oldcontext);
diff --git a/src/bin/pg_verifybackup/parse_manifest.c b/src/bin/pg_verifybackup/parse_manifest.c
index 6364b01282..4b38fd3963 100644
--- a/src/bin/pg_verifybackup/parse_manifest.c
+++ b/src/bin/pg_verifybackup/parse_manifest.c
@@ -119,7 +119,7 @@ void
json_parse_manifest(JsonManifestParseContext *context, char *buffer,
size_t size)
{
- JsonLexContext *lex;
+ JsonLexContext lex = {0};
JsonParseErrorType json_error;
JsonSemAction sem;
JsonManifestParseState parse;
@@ -129,8 +129,8 @@ json_parse_manifest(JsonManifestParseContext *context, char *buffer,
parse.state = JM_EXPECT_TOPLEVEL_START;
parse.saw_version_field = false;
- /* Create a JSON lexing context. */
- lex = makeJsonLexContextCstringLen(buffer, size, PG_UTF8, true);
+ /* Initialize a JSON lexing context. */
+ initJsonLexContextCstringLen(&lex, buffer, size, PG_UTF8, true);
/* Set up semantic actions. */
sem.semstate = &parse;
@@ -145,14 +145,17 @@ json_parse_manifest(JsonManifestParseContext *context, char *buffer,
sem.scalar = json_manifest_scalar;
/* Run the actual JSON parser. */
- json_error = pg_parse_json(lex, &sem);
+ json_error = pg_parse_json(&lex, &sem);
if (json_error != JSON_SUCCESS)
- json_manifest_parse_failure(context, "parsing failed");
+ json_manifest_parse_failure(context, json_errdetail(json_error, &lex));
if (parse.state != JM_EXPECT_EOF)
json_manifest_parse_failure(context, "manifest ended unexpectedly");
/* Verify the manifest checksum. */
verify_manifest_checksum(&parse, buffer, size);
+
+ /* Clean up. */
+ termJsonLexContext(&lex);
}
/*
diff --git a/src/bin/pg_verifybackup/t/005_bad_manifest.pl b/src/bin/pg_verifybackup/t/005_bad_manifest.pl
index 118beb53d7..f2692972fe 100644
--- a/src/bin/pg_verifybackup/t/005_bad_manifest.pl
+++ b/src/bin/pg_verifybackup/t/005_bad_manifest.pl
@@ -16,7 +16,7 @@ my $tempdir = PostgreSQL::Test::Utils::tempdir;
test_bad_manifest(
'input string ended unexpectedly',
- qr/could not parse backup manifest: parsing failed/,
+ qr/could not parse backup manifest: The input string ended unexpectedly/,
<<EOM);
{
EOM
diff --git a/src/common/Makefile b/src/common/Makefile
index 31c0dd366d..8e8b27546e 100644
--- a/src/common/Makefile
+++ b/src/common/Makefile
@@ -40,7 +40,7 @@ override CPPFLAGS += -DVAL_LDFLAGS_EX="\"$(LDFLAGS_EX)\""
override CPPFLAGS += -DVAL_LDFLAGS_SL="\"$(LDFLAGS_SL)\""
override CPPFLAGS += -DVAL_LIBS="\"$(LIBS)\""
-override CPPFLAGS := -DFRONTEND -I. -I$(top_srcdir)/src/common $(CPPFLAGS)
+override CPPFLAGS := -DFRONTEND -I. -I$(top_srcdir)/src/common -I$(libpq_srcdir) $(CPPFLAGS)
LIBS += $(PTHREAD_LIBS)
# If you add objects here, see also src/tools/msvc/Mkvcbuild.pm
diff --git a/src/common/jsonapi.c b/src/common/jsonapi.c
index 6666077a93..7fc5eaf460 100644
--- a/src/common/jsonapi.c
+++ b/src/common/jsonapi.c
@@ -20,10 +20,39 @@
#include "common/jsonapi.h"
#include "mb/pg_wchar.h"
-#ifndef FRONTEND
+#ifdef FRONTEND
+#include "pqexpbuffer.h"
+#else
+#include "lib/stringinfo.h"
#include "miscadmin.h"
#endif
+/*
+ * In backend, we will use palloc/pfree along with StringInfo. In frontend, use
+ * malloc and PQExpBuffer, and return JSON_OUT_OF_MEMORY on out-of-memory.
+ */
+#ifdef FRONTEND
+
+#define STRDUP(s) strdup(s)
+#define ALLOC(size) malloc(size)
+
+#define appendStrVal appendPQExpBuffer
+#define appendStrValChar appendPQExpBufferChar
+#define createStrVal createPQExpBuffer
+#define resetStrVal resetPQExpBuffer
+
+#else /* !FRONTEND */
+
+#define STRDUP(s) pstrdup(s)
+#define ALLOC(size) palloc(size)
+
+#define appendStrVal appendStringInfo
+#define appendStrValChar appendStringInfoChar
+#define createStrVal makeStringInfo
+#define resetStrVal resetStringInfo
+
+#endif
+
/*
* The context of the parser is maintained by the recursive descent
* mechanism, but is passed explicitly to the error reporting routine
@@ -132,10 +161,12 @@ IsValidJsonNumber(const char *str, int len)
return (!numeric_error) && (total_len == dummy_lex.input_length);
}
+#ifndef FRONTEND
+
/*
* makeJsonLexContextCstringLen
*
- * lex constructor, with or without StringInfo object for de-escaped lexemes.
+ * lex constructor, with or without a string object for de-escaped lexemes.
*
* Without is better as it makes the processing faster, so only make one
* if really required.
@@ -145,13 +176,66 @@ makeJsonLexContextCstringLen(char *json, int len, int encoding, bool need_escape
{
JsonLexContext *lex = palloc0(sizeof(JsonLexContext));
+ initJsonLexContextCstringLen(lex, json, len, encoding, need_escapes);
+
+ return lex;
+}
+
+void
+destroyJsonLexContext(JsonLexContext *lex)
+{
+ termJsonLexContext(lex);
+ pfree(lex);
+}
+
+#endif /* !FRONTEND */
+
+void
+initJsonLexContextCstringLen(JsonLexContext *lex, char *json, int len, int encoding, bool need_escapes)
+{
lex->input = lex->token_terminator = lex->line_start = json;
lex->line_number = 1;
lex->input_length = len;
lex->input_encoding = encoding;
- if (need_escapes)
- lex->strval = makeStringInfo();
- return lex;
+ lex->parse_strval = need_escapes;
+ if (lex->parse_strval)
+ {
+ /*
+ * This call can fail in FRONTEND code. We defer error handling to time
+ * of use (json_lex_string()) since there's no way to signal failure
+ * here, and we might not need to parse any strings anyway.
+ */
+ lex->strval = createStrVal();
+ }
+ lex->errormsg = NULL;
+}
+
+void
+termJsonLexContext(JsonLexContext *lex)
+{
+ static const JsonLexContext empty = {0};
+
+ if (lex->strval)
+ {
+#ifdef FRONTEND
+ destroyPQExpBuffer(lex->strval);
+#else
+ pfree(lex->strval->data);
+ pfree(lex->strval);
+#endif
+ }
+
+ if (lex->errormsg)
+ {
+#ifdef FRONTEND
+ destroyPQExpBuffer(lex->errormsg);
+#else
+ pfree(lex->errormsg->data);
+ pfree(lex->errormsg);
+#endif
+ }
+
+ *lex = empty;
}
/*
@@ -217,7 +301,7 @@ json_count_array_elements(JsonLexContext *lex, int *elements)
* etc, so doing this with a copy makes that safe.
*/
memcpy(©lex, lex, sizeof(JsonLexContext));
- copylex.strval = NULL; /* not interested in values here */
+ copylex.parse_strval = false; /* not interested in values here */
copylex.lex_level++;
count = 0;
@@ -279,14 +363,21 @@ parse_scalar(JsonLexContext *lex, JsonSemAction *sem)
/* extract the de-escaped string value, or the raw lexeme */
if (lex_peek(lex) == JSON_TOKEN_STRING)
{
- if (lex->strval != NULL)
- val = pstrdup(lex->strval->data);
+ if (lex->parse_strval)
+ {
+ val = STRDUP(lex->strval->data);
+ if (val == NULL)
+ return JSON_OUT_OF_MEMORY;
+ }
}
else
{
int len = (lex->token_terminator - lex->token_start);
- val = palloc(len + 1);
+ val = ALLOC(len + 1);
+ if (val == NULL)
+ return JSON_OUT_OF_MEMORY;
+
memcpy(val, lex->token_start, len);
val[len] = '\0';
}
@@ -320,8 +411,12 @@ parse_object_field(JsonLexContext *lex, JsonSemAction *sem)
if (lex_peek(lex) != JSON_TOKEN_STRING)
return report_parse_error(JSON_PARSE_STRING, lex);
- if ((ostart != NULL || oend != NULL) && lex->strval != NULL)
- fname = pstrdup(lex->strval->data);
+ if ((ostart != NULL || oend != NULL) && lex->parse_strval)
+ {
+ fname = STRDUP(lex->strval->data);
+ if (fname == NULL)
+ return JSON_OUT_OF_MEMORY;
+ }
result = json_lex(lex);
if (result != JSON_SUCCESS)
return result;
@@ -368,6 +463,10 @@ parse_object(JsonLexContext *lex, JsonSemAction *sem)
JsonParseErrorType result;
#ifndef FRONTEND
+ /*
+ * TODO: clients need some way to put a bound on stack growth. Parse level
+ * limits maybe?
+ */
check_stack_depth();
#endif
@@ -676,8 +775,15 @@ json_lex_string(JsonLexContext *lex)
int len;
int hi_surrogate = -1;
- if (lex->strval != NULL)
- resetStringInfo(lex->strval);
+ if (lex->parse_strval)
+ {
+#ifdef FRONTEND
+ /* make sure initialization succeeded */
+ if (lex->strval == NULL)
+ return JSON_OUT_OF_MEMORY;
+#endif
+ resetStrVal(lex->strval);
+ }
Assert(lex->input_length > 0);
s = lex->token_start;
@@ -737,7 +843,7 @@ json_lex_string(JsonLexContext *lex)
return JSON_UNICODE_ESCAPE_FORMAT;
}
}
- if (lex->strval != NULL)
+ if (lex->parse_strval)
{
/*
* Combine surrogate pairs.
@@ -797,19 +903,19 @@ json_lex_string(JsonLexContext *lex)
unicode_to_utf8(ch, (unsigned char *) utf8str);
utf8len = pg_utf_mblen((unsigned char *) utf8str);
- appendBinaryStringInfo(lex->strval, utf8str, utf8len);
+ appendBinaryPQExpBuffer(lex->strval, utf8str, utf8len);
}
else if (ch <= 0x007f)
{
/* The ASCII range is the same in all encodings */
- appendStringInfoChar(lex->strval, (char) ch);
+ appendPQExpBufferChar(lex->strval, (char) ch);
}
else
return JSON_UNICODE_HIGH_ESCAPE;
#endif /* FRONTEND */
}
}
- else if (lex->strval != NULL)
+ else if (lex->parse_strval)
{
if (hi_surrogate != -1)
return JSON_UNICODE_LOW_SURROGATE;
@@ -819,22 +925,22 @@ json_lex_string(JsonLexContext *lex)
case '"':
case '\\':
case '/':
- appendStringInfoChar(lex->strval, *s);
+ appendStrValChar(lex->strval, *s);
break;
case 'b':
- appendStringInfoChar(lex->strval, '\b');
+ appendStrValChar(lex->strval, '\b');
break;
case 'f':
- appendStringInfoChar(lex->strval, '\f');
+ appendStrValChar(lex->strval, '\f');
break;
case 'n':
- appendStringInfoChar(lex->strval, '\n');
+ appendStrValChar(lex->strval, '\n');
break;
case 'r':
- appendStringInfoChar(lex->strval, '\r');
+ appendStrValChar(lex->strval, '\r');
break;
case 't':
- appendStringInfoChar(lex->strval, '\t');
+ appendStrValChar(lex->strval, '\t');
break;
default:
/* Not a valid string escape, so signal error. */
@@ -858,12 +964,12 @@ json_lex_string(JsonLexContext *lex)
}
}
- else if (lex->strval != NULL)
+ else if (lex->parse_strval)
{
if (hi_surrogate != -1)
return JSON_UNICODE_LOW_SURROGATE;
- appendStringInfoChar(lex->strval, *s);
+ appendStrValChar(lex->strval, *s);
}
}
@@ -871,6 +977,11 @@ json_lex_string(JsonLexContext *lex)
if (hi_surrogate != -1)
return JSON_UNICODE_LOW_SURROGATE;
+#ifdef FRONTEND
+ if (lex->parse_strval && PQExpBufferBroken(lex->strval))
+ return JSON_OUT_OF_MEMORY;
+#endif
+
/* Hooray, we found the end of the string! */
lex->prev_token_terminator = lex->token_terminator;
lex->token_terminator = s + 1;
@@ -1043,72 +1154,93 @@ report_parse_error(JsonParseContext ctx, JsonLexContext *lex)
return JSON_SUCCESS; /* silence stupider compilers */
}
-
-#ifndef FRONTEND
-/*
- * Extract the current token from a lexing context, for error reporting.
- */
-static char *
-extract_token(JsonLexContext *lex)
-{
- int toklen = lex->token_terminator - lex->token_start;
- char *token = palloc(toklen + 1);
-
- memcpy(token, lex->token_start, toklen);
- token[toklen] = '\0';
- return token;
-}
-
/*
* Construct a detail message for a JSON error.
*
- * Note that the error message generated by this routine may not be
- * palloc'd, making it unsafe for frontend code as there is no way to
- * know if this can be safery pfree'd or not.
+ * The returned allocation is either static or owned by the JsonLexContext and
+ * should not be freed.
*/
char *
json_errdetail(JsonParseErrorType error, JsonLexContext *lex)
{
+ int toklen = lex->token_terminator - lex->token_start;
+
+ if (error == JSON_OUT_OF_MEMORY)
+ {
+ /* Short circuit. Allocating anything for this case is unhelpful. */
+ return _("out of memory");
+ }
+
+ if (lex->errormsg)
+ resetStrVal(lex->errormsg);
+ else
+ lex->errormsg = createStrVal();
+
switch (error)
{
case JSON_SUCCESS:
/* fall through to the error code after switch */
break;
case JSON_ESCAPING_INVALID:
- return psprintf(_("Escape sequence \"\\%s\" is invalid."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Escape sequence \"\\%.*s\" is invalid."),
+ toklen, lex->token_start);
+ break;
case JSON_ESCAPING_REQUIRED:
- return psprintf(_("Character with value 0x%02x must be escaped."),
- (unsigned char) *(lex->token_terminator));
+ appendStrVal(lex->errormsg,
+ _("Character with value 0x%02x must be escaped."),
+ (unsigned char) *(lex->token_terminator));
+ break;
case JSON_EXPECTED_END:
- return psprintf(_("Expected end of input, but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected end of input, but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_ARRAY_FIRST:
- return psprintf(_("Expected array element or \"]\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected array element or \"]\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_ARRAY_NEXT:
- return psprintf(_("Expected \",\" or \"]\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected \",\" or \"]\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_COLON:
- return psprintf(_("Expected \":\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected \":\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_JSON:
- return psprintf(_("Expected JSON value, but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected JSON value, but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_MORE:
return _("The input string ended unexpectedly.");
case JSON_EXPECTED_OBJECT_FIRST:
- return psprintf(_("Expected string or \"}\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected string or \"}\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_OBJECT_NEXT:
- return psprintf(_("Expected \",\" or \"}\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected \",\" or \"}\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_STRING:
- return psprintf(_("Expected string, but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected string, but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_INVALID_TOKEN:
- return psprintf(_("Token \"%s\" is invalid."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Token \"%.*s\" is invalid."),
+ toklen, lex->token_start);
+ break;
+ case JSON_OUT_OF_MEMORY:
+ /* should have been handled above; use the error path */
+ break;
case JSON_UNICODE_CODE_POINT_ZERO:
return _("\\u0000 cannot be converted to text.");
case JSON_UNICODE_ESCAPE_FORMAT:
@@ -1122,12 +1254,22 @@ json_errdetail(JsonParseErrorType error, JsonLexContext *lex)
return _("Unicode low surrogate must follow a high surrogate.");
}
- /*
- * We don't use a default: case, so that the compiler will warn about
- * unhandled enum values. But this needs to be here anyway to cover the
- * possibility of an incorrect input.
- */
- elog(ERROR, "unexpected json parse error type: %d", (int) error);
- return NULL;
-}
+ /* Note that lex->errormsg can be NULL in FRONTEND code. */
+ if (lex->errormsg && !lex->errormsg->data[0])
+ {
+ /*
+ * We don't use a default: case, so that the compiler will warn about
+ * unhandled enum values. But this needs to be here anyway to cover the
+ * possibility of an incorrect input.
+ */
+ appendStrVal(lex->errormsg,
+ "unexpected json parse error type: %d", (int) error);
+ }
+
+#ifdef FRONTEND
+ if (PQExpBufferBroken(lex->errormsg))
+ return _("out of memory while constructing error description");
#endif
+
+ return lex->errormsg->data;
+}
diff --git a/src/include/common/jsonapi.h b/src/include/common/jsonapi.h
index 52cb4a9339..d7cafc84fe 100644
--- a/src/include/common/jsonapi.h
+++ b/src/include/common/jsonapi.h
@@ -14,8 +14,6 @@
#ifndef JSONAPI_H
#define JSONAPI_H
-#include "lib/stringinfo.h"
-
typedef enum
{
JSON_TOKEN_INVALID,
@@ -48,6 +46,7 @@ typedef enum
JSON_EXPECTED_OBJECT_NEXT,
JSON_EXPECTED_STRING,
JSON_INVALID_TOKEN,
+ JSON_OUT_OF_MEMORY,
JSON_UNICODE_CODE_POINT_ZERO,
JSON_UNICODE_ESCAPE_FORMAT,
JSON_UNICODE_HIGH_ESCAPE,
@@ -55,6 +54,17 @@ typedef enum
JSON_UNICODE_LOW_SURROGATE
} JsonParseErrorType;
+/*
+ * Don't depend on the internal type header for strval; if callers need access
+ * then they can include the appropriate header themselves.
+ */
+#ifdef FRONTEND
+#define StrValType PQExpBufferData
+#else
+#define StrValType StringInfoData
+#endif
+
+typedef struct StrValType StrValType;
/*
* All the fields in this structure should be treated as read-only.
@@ -81,7 +91,9 @@ typedef struct JsonLexContext
int lex_level;
int line_number; /* line number, starting from 1 */
char *line_start; /* where that line starts within input */
- StringInfo strval;
+ bool parse_strval;
+ StrValType *strval; /* only used if parse_strval == true */
+ StrValType *errormsg;
} JsonLexContext;
typedef void (*json_struct_action) (void *state);
@@ -141,9 +153,10 @@ extern JsonSemAction nullSemAction;
*/
extern JsonParseErrorType json_count_array_elements(JsonLexContext *lex,
int *elements);
+#ifndef FRONTEND
/*
- * constructor for JsonLexContext, with or without strval element.
+ * allocating constructor for JsonLexContext, with or without strval element.
* If supplied, the strval element will contain a de-escaped version of
* the lexeme. However, doing this imposes a performance penalty, so
* it should be avoided if the de-escaped lexeme is not required.
@@ -153,6 +166,32 @@ extern JsonLexContext *makeJsonLexContextCstringLen(char *json,
int encoding,
bool need_escapes);
+/*
+ * Counterpart to makeJsonLexContextCstringLen(): clears and deallocates lex.
+ * The context pointer should not be used after this call.
+ */
+extern void destroyJsonLexContext(JsonLexContext *lex);
+
+#endif /* !FRONTEND */
+
+/*
+ * stack constructor for JsonLexContext, with or without strval element.
+ * If supplied, the strval element will contain a de-escaped version of
+ * the lexeme. However, doing this imposes a performance penalty, so
+ * it should be avoided if the de-escaped lexeme is not required.
+ */
+extern void initJsonLexContextCstringLen(JsonLexContext *lex,
+ char *json,
+ int len,
+ int encoding,
+ bool need_escapes);
+
+/*
+ * Counterpart to initJsonLexContextCstringLen(): clears the contents of lex,
+ * but does not deallocate lex itself.
+ */
+extern void termJsonLexContext(JsonLexContext *lex);
+
/* lex one token */
extern JsonParseErrorType json_lex(JsonLexContext *lex);
--
2.25.1
v3-0005-libpq-add-OAUTHBEARER-SASL-mechanism.patchtext/x-patch; name=v3-0005-libpq-add-OAUTHBEARER-SASL-mechanism.patchDownload
From 7f6b02652cd771a93ce4269607d498f4ac574e7f Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 13 Apr 2021 10:27:27 -0700
Subject: [PATCH v3 5/9] libpq: add OAUTHBEARER SASL mechanism
DO NOT USE THIS PROOF OF CONCEPT IN PRODUCTION.
Implement OAUTHBEARER (RFC 7628) and OAuth 2.0 Device Authorization
Grants (RFC 8628) on the client side. When speaking to a OAuth-enabled
server, it looks a bit like this:
$ psql 'host=example.org oauth_client_id=f02c6361-0635-...'
Visit https://oauth.example.org/login and enter the code: FPQ2-M4BG
The OAuth issuer must support device authorization. No other OAuth flows
are currently implemented.
The client implementation requires libiddawc and its development
headers. Configure --with-oauth (and --with-includes/--with-libraries to
point at the iddawc installation, if it's in a custom location).
Several TODOs:
- don't retry forever if the server won't accept our token
- perform several sanity checks on the OAuth issuer's responses
- handle cases where the client has been set up with an issuer and
scope, but the Postgres server wants to use something different
- improve error debuggability during the OAuth handshake
- ...and more.
---
configure | 100 ++++
configure.ac | 19 +
src/Makefile.global.in | 1 +
src/include/common/oauth-common.h | 19 +
src/include/pg_config.h.in | 6 +
src/interfaces/libpq/Makefile | 7 +-
src/interfaces/libpq/fe-auth-oauth.c | 744 +++++++++++++++++++++++++++
src/interfaces/libpq/fe-auth-sasl.h | 5 +-
src/interfaces/libpq/fe-auth-scram.c | 6 +-
src/interfaces/libpq/fe-auth.c | 42 +-
src/interfaces/libpq/fe-auth.h | 3 +
src/interfaces/libpq/fe-connect.c | 38 ++
src/interfaces/libpq/libpq-int.h | 8 +
13 files changed, 979 insertions(+), 19 deletions(-)
create mode 100644 src/include/common/oauth-common.h
create mode 100644 src/interfaces/libpq/fe-auth-oauth.c
diff --git a/configure b/configure
index f3cb5c2b51..cd0c50a951 100755
--- a/configure
+++ b/configure
@@ -718,6 +718,7 @@ with_uuid
with_readline
with_systemd
with_selinux
+with_oauth
with_ldap
with_krb_srvnam
krb_srvtab
@@ -861,6 +862,7 @@ with_krb_srvnam
with_pam
with_bsd_auth
with_ldap
+with_oauth
with_bonjour
with_selinux
with_systemd
@@ -1570,6 +1572,7 @@ Optional Packages:
--with-pam build with PAM support
--with-bsd-auth build with BSD Authentication support
--with-ldap build with LDAP support
+ --with-oauth build with OAuth 2.0 support
--with-bonjour build with Bonjour support
--with-selinux build with SELinux support
--with-systemd build with systemd support
@@ -8377,6 +8380,42 @@ $as_echo "$with_ldap" >&6; }
+#
+# OAuth 2.0
+#
+{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether to build with OAuth support" >&5
+$as_echo_n "checking whether to build with OAuth support... " >&6; }
+
+
+
+# Check whether --with-oauth was given.
+if test "${with_oauth+set}" = set; then :
+ withval=$with_oauth;
+ case $withval in
+ yes)
+
+$as_echo "#define USE_OAUTH 1" >>confdefs.h
+
+ ;;
+ no)
+ :
+ ;;
+ *)
+ as_fn_error $? "no argument expected for --with-oauth option" "$LINENO" 5
+ ;;
+ esac
+
+else
+ with_oauth=no
+
+fi
+
+
+{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $with_oauth" >&5
+$as_echo "$with_oauth" >&6; }
+
+
+
#
# Bonjour
#
@@ -13500,6 +13539,56 @@ fi
+if test "$with_oauth" = yes ; then
+ { $as_echo "$as_me:${as_lineno-$LINENO}: checking for i_init_session in -liddawc" >&5
+$as_echo_n "checking for i_init_session in -liddawc... " >&6; }
+if ${ac_cv_lib_iddawc_i_init_session+:} false; then :
+ $as_echo_n "(cached) " >&6
+else
+ ac_check_lib_save_LIBS=$LIBS
+LIBS="-liddawc $LIBS"
+cat confdefs.h - <<_ACEOF >conftest.$ac_ext
+/* end confdefs.h. */
+
+/* Override any GCC internal prototype to avoid an error.
+ Use char because int might match the return type of a GCC
+ builtin and then its argument prototype would still apply. */
+#ifdef __cplusplus
+extern "C"
+#endif
+char i_init_session ();
+int
+main ()
+{
+return i_init_session ();
+ ;
+ return 0;
+}
+_ACEOF
+if ac_fn_c_try_link "$LINENO"; then :
+ ac_cv_lib_iddawc_i_init_session=yes
+else
+ ac_cv_lib_iddawc_i_init_session=no
+fi
+rm -f core conftest.err conftest.$ac_objext \
+ conftest$ac_exeext conftest.$ac_ext
+LIBS=$ac_check_lib_save_LIBS
+fi
+{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_iddawc_i_init_session" >&5
+$as_echo "$ac_cv_lib_iddawc_i_init_session" >&6; }
+if test "x$ac_cv_lib_iddawc_i_init_session" = xyes; then :
+ cat >>confdefs.h <<_ACEOF
+#define HAVE_LIBIDDAWC 1
+_ACEOF
+
+ LIBS="-liddawc $LIBS"
+
+else
+ as_fn_error $? "library 'iddawc' is required for OAuth support" "$LINENO" 5
+fi
+
+fi
+
# for contrib/sepgsql
if test "$with_selinux" = yes; then
{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for security_compute_create_name in -lselinux" >&5
@@ -14513,6 +14602,17 @@ fi
done
+fi
+
+if test "$with_oauth" != no; then
+ ac_fn_c_check_header_mongrel "$LINENO" "iddawc.h" "ac_cv_header_iddawc_h" "$ac_includes_default"
+if test "x$ac_cv_header_iddawc_h" = xyes; then :
+
+else
+ as_fn_error $? "header file <iddawc.h> is required for OAuth" "$LINENO" 5
+fi
+
+
fi
if test "$PORTNAME" = "win32" ; then
diff --git a/configure.ac b/configure.ac
index 19d1a80367..922608065f 100644
--- a/configure.ac
+++ b/configure.ac
@@ -887,6 +887,17 @@ AC_MSG_RESULT([$with_ldap])
AC_SUBST(with_ldap)
+#
+# OAuth 2.0
+#
+AC_MSG_CHECKING([whether to build with OAuth support])
+PGAC_ARG_BOOL(with, oauth, no,
+ [build with OAuth 2.0 support],
+ [AC_DEFINE([USE_OAUTH], 1, [Define to 1 to build with OAuth 2.0 support. (--with-oauth)])])
+AC_MSG_RESULT([$with_oauth])
+AC_SUBST(with_oauth)
+
+
#
# Bonjour
#
@@ -1385,6 +1396,10 @@ fi
AC_SUBST(LDAP_LIBS_FE)
AC_SUBST(LDAP_LIBS_BE)
+if test "$with_oauth" = yes ; then
+ AC_CHECK_LIB(iddawc, i_init_session, [], [AC_MSG_ERROR([library 'iddawc' is required for OAuth support])])
+fi
+
# for contrib/sepgsql
if test "$with_selinux" = yes; then
AC_CHECK_LIB(selinux, security_compute_create_name, [],
@@ -1603,6 +1618,10 @@ elif test "$with_uuid" = ossp ; then
[AC_MSG_ERROR([header file <ossp/uuid.h> or <uuid.h> is required for OSSP UUID])])])
fi
+if test "$with_oauth" != no; then
+ AC_CHECK_HEADER(iddawc.h, [], [AC_MSG_ERROR([header file <iddawc.h> is required for OAuth])])
+fi
+
if test "$PORTNAME" = "win32" ; then
AC_CHECK_HEADERS(crtdefs.h)
fi
diff --git a/src/Makefile.global.in b/src/Makefile.global.in
index bbdc1c4bda..c9c61a9c99 100644
--- a/src/Makefile.global.in
+++ b/src/Makefile.global.in
@@ -193,6 +193,7 @@ with_ldap = @with_ldap@
with_libxml = @with_libxml@
with_libxslt = @with_libxslt@
with_llvm = @with_llvm@
+with_oauth = @with_oauth@
with_system_tzdata = @with_system_tzdata@
with_uuid = @with_uuid@
with_zlib = @with_zlib@
diff --git a/src/include/common/oauth-common.h b/src/include/common/oauth-common.h
new file mode 100644
index 0000000000..3fa95ac7e8
--- /dev/null
+++ b/src/include/common/oauth-common.h
@@ -0,0 +1,19 @@
+/*-------------------------------------------------------------------------
+ *
+ * oauth-common.h
+ * Declarations for helper functions used for OAuth/OIDC authentication
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/common/oauth-common.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef OAUTH_COMMON_H
+#define OAUTH_COMMON_H
+
+/* Name of SASL mechanism per IANA */
+#define OAUTHBEARER_NAME "OAUTHBEARER"
+
+#endif /* OAUTH_COMMON_H */
diff --git a/src/include/pg_config.h.in b/src/include/pg_config.h.in
index 635fbb2181..1b3332601e 100644
--- a/src/include/pg_config.h.in
+++ b/src/include/pg_config.h.in
@@ -319,6 +319,9 @@
/* Define to 1 if you have the `crypto' library (-lcrypto). */
#undef HAVE_LIBCRYPTO
+/* Define to 1 if you have the `iddawc' library (-liddawc). */
+#undef HAVE_LIBIDDAWC
+
/* Define to 1 if you have the `ldap' library (-lldap). */
#undef HAVE_LIBLDAP
@@ -922,6 +925,9 @@
/* Define to select named POSIX semaphores. */
#undef USE_NAMED_POSIX_SEMAPHORES
+/* Define to 1 to build with OAuth 2.0 support. (--with-oauth) */
+#undef USE_OAUTH
+
/* Define to 1 to build with OpenSSL support. (--with-ssl=openssl) */
#undef USE_OPENSSL
diff --git a/src/interfaces/libpq/Makefile b/src/interfaces/libpq/Makefile
index 3c53393fa4..727305c578 100644
--- a/src/interfaces/libpq/Makefile
+++ b/src/interfaces/libpq/Makefile
@@ -62,6 +62,11 @@ OBJS += \
fe-secure-gssapi.o
endif
+ifeq ($(with_oauth),yes)
+OBJS += \
+ fe-auth-oauth.o
+endif
+
ifeq ($(PORTNAME), cygwin)
override shlib = cyg$(NAME)$(DLSUFFIX)
endif
@@ -83,7 +88,7 @@ endif
# that are built correctly for use in a shlib.
SHLIB_LINK_INTERNAL = -lpgcommon_shlib -lpgport_shlib
ifneq ($(PORTNAME), win32)
-SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi_krb5 -lgss -lgssapi -lssl -lsocket -lnsl -lresolv -lintl -lm, $(LIBS)) $(LDAP_LIBS_FE) $(PTHREAD_LIBS)
+SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi_krb5 -lgss -lgssapi -lssl -liddawc -lsocket -lnsl -lresolv -lintl -lm, $(LIBS)) $(LDAP_LIBS_FE) $(PTHREAD_LIBS)
else
SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi32 -lssl -lsocket -lnsl -lresolv -lintl -lm $(PTHREAD_LIBS), $(LIBS)) $(LDAP_LIBS_FE)
endif
diff --git a/src/interfaces/libpq/fe-auth-oauth.c b/src/interfaces/libpq/fe-auth-oauth.c
new file mode 100644
index 0000000000..383c9d4bdb
--- /dev/null
+++ b/src/interfaces/libpq/fe-auth-oauth.c
@@ -0,0 +1,744 @@
+/*-------------------------------------------------------------------------
+ *
+ * fe-auth-oauth.c
+ * The front-end (client) implementation of OAuth/OIDC authentication.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * IDENTIFICATION
+ * src/interfaces/libpq/fe-auth-oauth.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include <iddawc.h>
+
+#include "postgres_fe.h"
+
+#include "common/base64.h"
+#include "common/hmac.h"
+#include "common/jsonapi.h"
+#include "common/oauth-common.h"
+#include "fe-auth.h"
+#include "mb/pg_wchar.h"
+
+/* The exported OAuth callback mechanism. */
+static void *oauth_init(PGconn *conn, const char *password,
+ const char *sasl_mechanism);
+static void oauth_exchange(void *opaq, bool final,
+ char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success);
+static bool oauth_channel_bound(void *opaq);
+static void oauth_free(void *opaq);
+
+const pg_fe_sasl_mech pg_oauth_mech = {
+ oauth_init,
+ oauth_exchange,
+ oauth_channel_bound,
+ oauth_free,
+};
+
+typedef enum
+{
+ FE_OAUTH_INIT,
+ FE_OAUTH_BEARER_SENT,
+ FE_OAUTH_SERVER_ERROR,
+} fe_oauth_state_enum;
+
+typedef struct
+{
+ fe_oauth_state_enum state;
+
+ PGconn *conn;
+} fe_oauth_state;
+
+static void *
+oauth_init(PGconn *conn, const char *password,
+ const char *sasl_mechanism)
+{
+ fe_oauth_state *state;
+
+ /*
+ * We only support one SASL mechanism here; anything else is programmer
+ * error.
+ */
+ Assert(sasl_mechanism != NULL);
+ Assert(!strcmp(sasl_mechanism, OAUTHBEARER_NAME));
+
+ state = malloc(sizeof(*state));
+ if (!state)
+ return NULL;
+
+ state->state = FE_OAUTH_INIT;
+ state->conn = conn;
+
+ return state;
+}
+
+static const char *
+iddawc_error_string(int errcode)
+{
+ switch (errcode)
+ {
+ case I_OK:
+ return "I_OK";
+
+ case I_ERROR:
+ return "I_ERROR";
+
+ case I_ERROR_PARAM:
+ return "I_ERROR_PARAM";
+
+ case I_ERROR_MEMORY:
+ return "I_ERROR_MEMORY";
+
+ case I_ERROR_UNAUTHORIZED:
+ return "I_ERROR_UNAUTHORIZED";
+
+ case I_ERROR_SERVER:
+ return "I_ERROR_SERVER";
+ }
+
+ return "<unknown>";
+}
+
+static void
+iddawc_error(PGconn *conn, int errcode, const char *msg)
+{
+ appendPQExpBufferStr(&conn->errorMessage, libpq_gettext(msg));
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext(" (iddawc error %s)\n"),
+ iddawc_error_string(errcode));
+}
+
+static void
+iddawc_request_error(PGconn *conn, struct _i_session *i, int err, const char *msg)
+{
+ const char *error_code;
+ const char *desc;
+
+ appendPQExpBuffer(&conn->errorMessage, "%s: ", libpq_gettext(msg));
+
+ error_code = i_get_str_parameter(i, I_OPT_ERROR);
+ if (!error_code)
+ {
+ /*
+ * The server didn't give us any useful information, so just print the
+ * error code.
+ */
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("(iddawc error %s)\n"),
+ iddawc_error_string(err));
+ return;
+ }
+
+ /* If the server gave a string description, print that too. */
+ desc = i_get_str_parameter(i, I_OPT_ERROR_DESCRIPTION);
+ if (desc)
+ appendPQExpBuffer(&conn->errorMessage, "%s ", desc);
+
+ appendPQExpBuffer(&conn->errorMessage, "(%s)\n", error_code);
+}
+
+static char *
+get_auth_token(PGconn *conn)
+{
+ PQExpBuffer token_buf = NULL;
+ struct _i_session session;
+ int err;
+ int auth_method;
+ bool user_prompted = false;
+ const char *verification_uri;
+ const char *user_code;
+ const char *access_token;
+ const char *token_type;
+ char *token = NULL;
+
+ if (!conn->oauth_discovery_uri)
+ return strdup(""); /* ask the server for one */
+
+ if (!conn->oauth_client_id)
+ {
+ /* We can't talk to a server without a client identifier. */
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("no oauth_client_id is set for the connection"));
+ return NULL;
+ }
+
+ i_init_session(&session);
+
+ token_buf = createPQExpBuffer();
+ if (!token_buf)
+ goto cleanup;
+
+ err = i_set_str_parameter(&session, I_OPT_OPENID_CONFIG_ENDPOINT, conn->oauth_discovery_uri);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set OpenID config endpoint");
+ goto cleanup;
+ }
+
+ err = i_get_openid_config(&session);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to fetch OpenID discovery document");
+ goto cleanup;
+ }
+
+ if (!i_get_str_parameter(&session, I_OPT_TOKEN_ENDPOINT))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer has no token endpoint"));
+ goto cleanup;
+ }
+
+ if (!i_get_str_parameter(&session, I_OPT_DEVICE_AUTHORIZATION_ENDPOINT))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer does not support device authorization"));
+ goto cleanup;
+ }
+
+ err = i_set_response_type(&session, I_RESPONSE_TYPE_DEVICE_CODE);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set device code response type");
+ goto cleanup;
+ }
+
+ auth_method = I_TOKEN_AUTH_METHOD_NONE;
+ if (conn->oauth_client_secret && *conn->oauth_client_secret)
+ auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
+
+ err = i_set_parameter_list(&session,
+ I_OPT_CLIENT_ID, conn->oauth_client_id,
+ I_OPT_CLIENT_SECRET, conn->oauth_client_secret,
+ I_OPT_TOKEN_METHOD, auth_method,
+ I_OPT_SCOPE, conn->oauth_scope,
+ I_OPT_NONE
+ );
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set client identifier");
+ goto cleanup;
+ }
+
+ err = i_run_device_auth_request(&session);
+ if (err)
+ {
+ iddawc_request_error(conn, &session, err,
+ "failed to obtain device authorization");
+ goto cleanup;
+ }
+
+ verification_uri = i_get_str_parameter(&session, I_OPT_DEVICE_AUTH_VERIFICATION_URI);
+ if (!verification_uri)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a verification URI"));
+ goto cleanup;
+ }
+
+ user_code = i_get_str_parameter(&session, I_OPT_DEVICE_AUTH_USER_CODE);
+ if (!user_code)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a user code"));
+ goto cleanup;
+ }
+
+ /*
+ * Poll the token endpoint until either the user logs in and authorizes the
+ * use of a token, or a hard failure occurs. We perform one ping _before_
+ * prompting the user, so that we don't make them do the work of logging in
+ * only to find that the token endpoint is completely unreachable.
+ */
+ err = i_run_token_request(&session);
+ while (err)
+ {
+ const char *error_code;
+ uint interval;
+
+ error_code = i_get_str_parameter(&session, I_OPT_ERROR);
+
+ /*
+ * authorization_pending and slow_down are the only acceptable errors;
+ * anything else and we bail.
+ */
+ if (!error_code || (strcmp(error_code, "authorization_pending")
+ && strcmp(error_code, "slow_down")))
+ {
+ iddawc_request_error(conn, &session, err,
+ "OAuth token retrieval failed");
+ goto cleanup;
+ }
+
+ if (!user_prompted)
+ {
+ /*
+ * Now that we know the token endpoint isn't broken, give the user
+ * the login instructions.
+ */
+ pqInternalNotice(&conn->noticeHooks,
+ "Visit %s and enter the code: %s",
+ verification_uri, user_code);
+
+ user_prompted = true;
+ }
+
+ /*
+ * We are required to wait between polls; the server tells us how long.
+ * TODO: if interval's not set, we need to default to five seconds
+ * TODO: sanity check the interval
+ */
+ interval = i_get_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL);
+
+ /*
+ * A slow_down error requires us to permanently increase our retry
+ * interval by five seconds. RFC 8628, Sec. 3.5.
+ */
+ if (!strcmp(error_code, "slow_down"))
+ {
+ interval += 5;
+ i_set_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL, interval);
+ }
+
+ sleep(interval);
+
+ /*
+ * XXX Reset the error code before every call, because iddawc won't do
+ * that for us. This matters if the server first sends a "pending" error
+ * code, then later hard-fails without sending an error code to
+ * overwrite the first one.
+ *
+ * That we have to do this at all seems like a bug in iddawc.
+ */
+ i_set_str_parameter(&session, I_OPT_ERROR, NULL);
+
+ err = i_run_token_request(&session);
+ }
+
+ access_token = i_get_str_parameter(&session, I_OPT_ACCESS_TOKEN);
+ token_type = i_get_str_parameter(&session, I_OPT_TOKEN_TYPE);
+
+ if (!access_token || !token_type || strcasecmp(token_type, "Bearer"))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a bearer token"));
+ goto cleanup;
+ }
+
+ appendPQExpBufferStr(token_buf, "Bearer ");
+ appendPQExpBufferStr(token_buf, access_token);
+
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+
+ token = strdup(token_buf->data);
+
+cleanup:
+ if (token_buf)
+ destroyPQExpBuffer(token_buf);
+ i_clean_session(&session);
+
+ return token;
+}
+
+#define kvsep "\x01"
+
+static char *
+client_initial_response(PGconn *conn)
+{
+ static const char * const resp_format = "n,," kvsep "auth=%s" kvsep kvsep;
+
+ PQExpBuffer token_buf;
+ PQExpBuffer discovery_buf = NULL;
+ char *token = NULL;
+ char *response = NULL;
+
+ token_buf = createPQExpBuffer();
+ if (!token_buf)
+ goto cleanup;
+
+ /*
+ * If we don't yet have a discovery URI, but the user gave us an explicit
+ * issuer, use the .well-known discovery URI for that issuer.
+ */
+ if (!conn->oauth_discovery_uri && conn->oauth_issuer)
+ {
+ discovery_buf = createPQExpBuffer();
+ if (!discovery_buf)
+ goto cleanup;
+
+ appendPQExpBufferStr(discovery_buf, conn->oauth_issuer);
+ appendPQExpBufferStr(discovery_buf, "/.well-known/openid-configuration");
+
+ if (PQExpBufferBroken(discovery_buf))
+ goto cleanup;
+
+ conn->oauth_discovery_uri = strdup(discovery_buf->data);
+ }
+
+ token = get_auth_token(conn);
+ if (!token)
+ goto cleanup;
+
+ appendPQExpBuffer(token_buf, resp_format, token);
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+
+ response = strdup(token_buf->data);
+
+cleanup:
+ if (token)
+ free(token);
+ if (discovery_buf)
+ destroyPQExpBuffer(discovery_buf);
+ if (token_buf)
+ destroyPQExpBuffer(token_buf);
+
+ return response;
+}
+
+#define ERROR_STATUS_FIELD "status"
+#define ERROR_SCOPE_FIELD "scope"
+#define ERROR_OPENID_CONFIGURATION_FIELD "openid-configuration"
+
+struct json_ctx
+{
+ char *errmsg; /* any non-NULL value stops all processing */
+ PQExpBufferData errbuf; /* backing memory for errmsg */
+ int nested; /* nesting level (zero is the top) */
+
+ const char *target_field_name; /* points to a static allocation */
+ char **target_field; /* see below */
+
+ /* target_field, if set, points to one of the following: */
+ char *status;
+ char *scope;
+ char *discovery_uri;
+};
+
+#define oauth_json_has_error(ctx) \
+ (PQExpBufferDataBroken((ctx)->errbuf) || (ctx)->errmsg)
+
+#define oauth_json_set_error(ctx, ...) \
+ do { \
+ appendPQExpBuffer(&(ctx)->errbuf, __VA_ARGS__); \
+ (ctx)->errmsg = (ctx)->errbuf.data; \
+ } while (0)
+
+static void
+oauth_json_object_start(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ return; /* short-circuit */
+
+ if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ oauth_json_set_error(ctx,
+ libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+
+ ++ctx->nested;
+}
+
+static void
+oauth_json_object_end(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ return; /* short-circuit */
+
+ --ctx->nested;
+}
+
+static void
+oauth_json_object_field_start(void *state, char *name, bool isnull)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ {
+ /* short-circuit */
+ free(name);
+ return;
+ }
+
+ if (ctx->nested == 1)
+ {
+ if (!strcmp(name, ERROR_STATUS_FIELD))
+ {
+ ctx->target_field_name = ERROR_STATUS_FIELD;
+ ctx->target_field = &ctx->status;
+ }
+ else if (!strcmp(name, ERROR_SCOPE_FIELD))
+ {
+ ctx->target_field_name = ERROR_SCOPE_FIELD;
+ ctx->target_field = &ctx->scope;
+ }
+ else if (!strcmp(name, ERROR_OPENID_CONFIGURATION_FIELD))
+ {
+ ctx->target_field_name = ERROR_OPENID_CONFIGURATION_FIELD;
+ ctx->target_field = &ctx->discovery_uri;
+ }
+ }
+
+ free(name);
+}
+
+static void
+oauth_json_array_start(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ return; /* short-circuit */
+
+ if (!ctx->nested)
+ {
+ ctx->errmsg = libpq_gettext("top-level element must be an object");
+ }
+ else if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ oauth_json_set_error(ctx,
+ libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+}
+
+static void
+oauth_json_scalar(void *state, char *token, JsonTokenType type)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ {
+ /* short-circuit */
+ free(token);
+ return;
+ }
+
+ if (!ctx->nested)
+ {
+ ctx->errmsg = libpq_gettext("top-level element must be an object");
+ }
+ else if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ if (type == JSON_TOKEN_STRING)
+ {
+ *ctx->target_field = token;
+
+ ctx->target_field = NULL;
+ ctx->target_field_name = NULL;
+
+ return; /* don't free the token we're using */
+ }
+
+ oauth_json_set_error(ctx,
+ libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+
+ free(token);
+}
+
+static bool
+handle_oauth_sasl_error(PGconn *conn, char *msg, int msglen)
+{
+ JsonLexContext lex = {0};
+ JsonSemAction sem = {0};
+ JsonParseErrorType err;
+ struct json_ctx ctx = {0};
+ char *errmsg = NULL;
+
+ /* Sanity check. */
+ if (strlen(msg) != msglen)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("server's error message contained an embedded NULL"));
+ return false;
+ }
+
+ initJsonLexContextCstringLen(&lex, msg, msglen, PG_UTF8, true);
+
+ initPQExpBuffer(&ctx.errbuf);
+ sem.semstate = &ctx;
+
+ sem.object_start = oauth_json_object_start;
+ sem.object_end = oauth_json_object_end;
+ sem.object_field_start = oauth_json_object_field_start;
+ sem.array_start = oauth_json_array_start;
+ sem.scalar = oauth_json_scalar;
+
+ err = pg_parse_json(&lex, &sem);
+
+ if (err != JSON_SUCCESS)
+ {
+ errmsg = json_errdetail(err, &lex);
+ }
+ else if (PQExpBufferDataBroken(ctx.errbuf))
+ {
+ errmsg = libpq_gettext("out of memory");
+ }
+ else if (ctx.errmsg)
+ {
+ errmsg = ctx.errmsg;
+ }
+
+ if (errmsg)
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("failed to parse server's error response: %s"),
+ errmsg);
+
+ /* Don't need the error buffer or the JSON lexer anymore. */
+ termPQExpBuffer(&ctx.errbuf);
+ termJsonLexContext(&lex);
+
+ if (errmsg)
+ return false;
+
+ /* TODO: what if these override what the user already specified? */
+ if (ctx.discovery_uri)
+ {
+ if (conn->oauth_discovery_uri)
+ free(conn->oauth_discovery_uri);
+
+ conn->oauth_discovery_uri = ctx.discovery_uri;
+ }
+
+ if (ctx.scope)
+ {
+ if (conn->oauth_scope)
+ free(conn->oauth_scope);
+
+ conn->oauth_scope = ctx.scope;
+ }
+ /* TODO: missing error scope should clear any existing connection scope */
+
+ if (!ctx.status)
+ {
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("server sent error response without a status"));
+ return false;
+ }
+
+ if (!strcmp(ctx.status, "invalid_token"))
+ {
+ /*
+ * invalid_token is the only error code we'll automatically retry for,
+ * but only if we have enough information to do so.
+ */
+ if (conn->oauth_discovery_uri)
+ conn->oauth_want_retry = true;
+ }
+ /* TODO: include status in hard failure message */
+
+ return true;
+}
+
+static void
+oauth_exchange(void *opaq, bool final,
+ char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success)
+{
+ fe_oauth_state *state = opaq;
+ PGconn *conn = state->conn;
+
+ *done = false;
+ *success = false;
+ *output = NULL;
+ *outputlen = 0;
+
+ switch (state->state)
+ {
+ case FE_OAUTH_INIT:
+ Assert(inputlen == -1);
+
+ *output = client_initial_response(conn);
+ if (!*output)
+ goto error;
+
+ *outputlen = strlen(*output);
+ state->state = FE_OAUTH_BEARER_SENT;
+
+ break;
+
+ case FE_OAUTH_BEARER_SENT:
+ if (final)
+ {
+ /* TODO: ensure there is no message content here. */
+ *done = true;
+ *success = true;
+
+ break;
+ }
+
+ /*
+ * Error message sent by the server.
+ */
+ if (!handle_oauth_sasl_error(conn, input, inputlen))
+ goto error;
+
+ /*
+ * Respond with the required dummy message (RFC 7628, sec. 3.2.3).
+ */
+ *output = strdup(kvsep);
+ *outputlen = strlen(*output); /* == 1 */
+
+ state->state = FE_OAUTH_SERVER_ERROR;
+ break;
+
+ case FE_OAUTH_SERVER_ERROR:
+ /*
+ * After an error, the server should send an error response to fail
+ * the SASL handshake, which is handled in higher layers.
+ *
+ * If we get here, the server either sent *another* challenge which
+ * isn't defined in the RFC, or completed the handshake successfully
+ * after telling us it was going to fail. Neither is acceptable.
+ */
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("server sent additional OAuth data after error\n"));
+ goto error;
+
+ default:
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("invalid OAuth exchange state\n"));
+ goto error;
+ }
+
+ return;
+
+error:
+ *done = true;
+ *success = false;
+}
+
+static bool
+oauth_channel_bound(void *opaq)
+{
+ /* This mechanism does not support channel binding. */
+ return false;
+}
+
+static void
+oauth_free(void *opaq)
+{
+ fe_oauth_state *state = opaq;
+
+ free(state);
+}
diff --git a/src/interfaces/libpq/fe-auth-sasl.h b/src/interfaces/libpq/fe-auth-sasl.h
index da3c30b87b..b1bb382f70 100644
--- a/src/interfaces/libpq/fe-auth-sasl.h
+++ b/src/interfaces/libpq/fe-auth-sasl.h
@@ -65,6 +65,8 @@ typedef struct pg_fe_sasl_mech
*
* state: The opaque mechanism state returned by init()
*
+ * final: true if the server has sent a final exchange outcome
+ *
* input: The challenge data sent by the server, or NULL when
* generating a client-first initial response (that is, when
* the server expects the client to send a message to start
@@ -92,7 +94,8 @@ typedef struct pg_fe_sasl_mech
* Ignored if *done is false.
*--------
*/
- void (*exchange) (void *state, char *input, int inputlen,
+ void (*exchange) (void *state, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success);
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index e616200704..681b76adbe 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -24,7 +24,8 @@
/* The exported SCRAM callback mechanism. */
static void *scram_init(PGconn *conn, const char *password,
const char *sasl_mechanism);
-static void scram_exchange(void *opaq, char *input, int inputlen,
+static void scram_exchange(void *opaq, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success);
static bool scram_channel_bound(void *opaq);
@@ -206,7 +207,8 @@ scram_free(void *opaq)
* Exchange a SCRAM message with backend.
*/
static void
-scram_exchange(void *opaq, char *input, int inputlen,
+scram_exchange(void *opaq, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success)
{
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 6fceff561b..2567a34023 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -38,6 +38,7 @@
#endif
#include "common/md5.h"
+#include "common/oauth-common.h"
#include "common/scram-common.h"
#include "fe-auth.h"
#include "fe-auth-sasl.h"
@@ -422,7 +423,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
bool success;
const char *selected_mechanism;
PQExpBufferData mechanism_buf;
- char *password;
+ char *password = NULL;
initPQExpBuffer(&mechanism_buf);
@@ -444,8 +445,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
/*
* Parse the list of SASL authentication mechanisms in the
* AuthenticationSASL message, and select the best mechanism that we
- * support. SCRAM-SHA-256-PLUS and SCRAM-SHA-256 are the only ones
- * supported at the moment, listed by order of decreasing importance.
+ * support. Mechanisms are listed by order of decreasing importance.
*/
selected_mechanism = NULL;
for (;;)
@@ -485,6 +485,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
{
selected_mechanism = SCRAM_SHA_256_PLUS_NAME;
conn->sasl = &pg_scram_mech;
+ conn->password_needed = true;
}
#else
/*
@@ -522,7 +523,17 @@ pg_SASL_init(PGconn *conn, int payloadlen)
{
selected_mechanism = SCRAM_SHA_256_NAME;
conn->sasl = &pg_scram_mech;
+ conn->password_needed = true;
}
+#ifdef USE_OAUTH
+ else if (strcmp(mechanism_buf.data, OAUTHBEARER_NAME) == 0 &&
+ !selected_mechanism)
+ {
+ selected_mechanism = OAUTHBEARER_NAME;
+ conn->sasl = &pg_oauth_mech;
+ conn->password_needed = false;
+ }
+#endif
}
if (!selected_mechanism)
@@ -547,18 +558,19 @@ pg_SASL_init(PGconn *conn, int payloadlen)
/*
* First, select the password to use for the exchange, complaining if
- * there isn't one. Currently, all supported SASL mechanisms require a
- * password, so we can just go ahead here without further distinction.
+ * there isn't one and the SASL mechanism needs it.
*/
- conn->password_needed = true;
- password = conn->connhost[conn->whichhost].password;
- if (password == NULL)
- password = conn->pgpass;
- if (password == NULL || password[0] == '\0')
+ if (conn->password_needed)
{
- appendPQExpBufferStr(&conn->errorMessage,
- PQnoPasswordSupplied);
- goto error;
+ password = conn->connhost[conn->whichhost].password;
+ if (password == NULL)
+ password = conn->pgpass;
+ if (password == NULL || password[0] == '\0')
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ PQnoPasswordSupplied);
+ goto error;
+ }
}
Assert(conn->sasl);
@@ -576,7 +588,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
goto oom_error;
/* Get the mechanism-specific Initial Client Response, if any */
- conn->sasl->exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state, false,
NULL, -1,
&initialresponse, &initialresponselen,
&done, &success);
@@ -657,7 +669,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
/* For safety and convenience, ensure the buffer is NULL-terminated. */
challenge[payloadlen] = '\0';
- conn->sasl->exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state, final,
challenge, payloadlen,
&output, &outputlen,
&done, &success);
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 049a8bb1a1..2a56774019 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -28,4 +28,7 @@ extern const pg_fe_sasl_mech pg_scram_mech;
extern char *pg_fe_scram_build_secret(const char *password,
const char **errstr);
+/* Mechanisms in fe-auth-oauth.c */
+extern const pg_fe_sasl_mech pg_oauth_mech;
+
#endif /* FE_AUTH_H */
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 1c5a2b43e9..5f78439586 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -344,6 +344,23 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
"Target-Session-Attrs", "", 15, /* sizeof("prefer-standby") = 15 */
offsetof(struct pg_conn, target_session_attrs)},
+ /* OAuth v2 */
+ {"oauth_issuer", NULL, NULL, NULL,
+ "OAuth-Issuer", "", 40,
+ offsetof(struct pg_conn, oauth_issuer)},
+
+ {"oauth_client_id", NULL, NULL, NULL,
+ "OAuth-Client-ID", "", 40,
+ offsetof(struct pg_conn, oauth_client_id)},
+
+ {"oauth_client_secret", NULL, NULL, NULL,
+ "OAuth-Client-Secret", "", 40,
+ offsetof(struct pg_conn, oauth_client_secret)},
+
+ {"oauth_scope", NULL, NULL, NULL,
+ "OAuth-Scope", "", 15,
+ offsetof(struct pg_conn, oauth_scope)},
+
/* Terminating entry --- MUST BE LAST */
{NULL, NULL, NULL, NULL,
NULL, NULL, 0}
@@ -606,6 +623,7 @@ pqDropServerData(PGconn *conn)
conn->write_err_msg = NULL;
conn->be_pid = 0;
conn->be_key = 0;
+ /* conn->oauth_want_retry = false; TODO */
}
@@ -3381,6 +3399,16 @@ keep_going: /* We will come back to here until there is
/* Check to see if we should mention pgpassfile */
pgpassfileWarning(conn);
+#ifdef USE_OAUTH
+ if (conn->sasl == &pg_oauth_mech
+ && conn->oauth_want_retry)
+ {
+ /* TODO: only allow retry once */
+ need_new_connection = true;
+ goto keep_going;
+ }
+#endif
+
#ifdef ENABLE_GSS
/*
@@ -4161,6 +4189,16 @@ freePGconn(PGconn *conn)
free(conn->rowBuf);
if (conn->target_session_attrs)
free(conn->target_session_attrs);
+ if (conn->oauth_issuer)
+ free(conn->oauth_issuer);
+ if (conn->oauth_discovery_uri)
+ free(conn->oauth_discovery_uri);
+ if (conn->oauth_client_id)
+ free(conn->oauth_client_id);
+ if (conn->oauth_client_secret)
+ free(conn->oauth_client_secret);
+ if (conn->oauth_scope)
+ free(conn->oauth_scope);
termPQExpBuffer(&conn->errorMessage);
termPQExpBuffer(&conn->workBuffer);
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index e0cee4b142..0dff13505a 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -394,6 +394,14 @@ struct pg_conn
char *ssl_max_protocol_version; /* maximum TLS protocol version */
char *target_session_attrs; /* desired session properties */
+ /* OAuth v2 */
+ char *oauth_issuer; /* token issuer URL */
+ char *oauth_discovery_uri; /* URI of the issuer's discovery document */
+ char *oauth_client_id; /* client identifier */
+ char *oauth_client_secret; /* client secret */
+ char *oauth_scope; /* access token scope */
+ bool oauth_want_retry; /* should we retry on failure? */
+
/* Optional file to write trace info to */
FILE *Pfdebug;
int traceFlags;
--
2.25.1
v3-0006-backend-add-OAUTHBEARER-SASL-mechanism.patchtext/x-patch; name=v3-0006-backend-add-OAUTHBEARER-SASL-mechanism.patchDownload
From 43ab0310ce8ee26a469167a2e4eae4c4bc295518 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 4 May 2021 16:21:11 -0700
Subject: [PATCH v3 6/9] backend: add OAUTHBEARER SASL mechanism
DO NOT USE THIS PROOF OF CONCEPT IN PRODUCTION.
Implement OAUTHBEARER (RFC 7628) on the server side. This adds a new
auth method, oauth, to pg_hba.
Because OAuth implementations vary so wildly, and bearer token
validation is heavily dependent on the issuing party, authn/z is done by
communicating with an external program: the oauth_validator_command.
This command must do the following:
1. Receive the bearer token by reading its contents from a file
descriptor passed from the server. (The numeric value of this
descriptor may be inserted into the oauth_validator_command using the
%f specifier.)
This MUST be the first action the command performs. The server will
not begin reading stdout from the command until the token has been
read in full, so if the command tries to print anything and hits a
buffer limit, the backend will deadlock and time out.
2. Validate the bearer token. The correct way to do this depends on the
issuer, but it generally involves either cryptographic operations to
prove that the token was issued by a trusted party, or the
presentation of the bearer token to some other party so that _it_ can
perform validation.
The command MUST maintain confidentiality of the bearer token, since
in most cases it can be used just like a password. (There are ways to
cryptographically bind tokens to client certificates, but they are
way beyond the scope of this commit message.)
If the token cannot be validated, the command must exit with a
non-zero status. Further authentication/authorization is pointless if
the bearer token wasn't issued by someone you trust.
3. Authenticate the user, authorize the user, or both:
a. To authenticate the user, use the bearer token to retrieve some
trusted identifier string for the end user. The exact process for
this is, again, issuer-dependent. The command should print the
authenticated identity string to stdout, followed by a newline.
If the user cannot be authenticated, the validator should not
print anything to stdout. It should also exit with a non-zero
status, unless the token may be used to authorize the connection
through some other means (see below).
On a success, the command may then exit with a zero success code.
By default, the server will then check to make sure the identity
string matches the role that is being used (or matches a usermap
entry, if one is in use).
b. To optionally authorize the user, in combination with the HBA
option trust_validator_authz=1 (see below), the validator simply
returns a zero exit code if the client should be allowed to
connect with its presented role (which can be passed to the
command using the %r specifier), or a non-zero code otherwise.
The hard part is in determining whether the given token truly
authorizes the client to use the given role, which must
unfortunately be left as an exercise to the reader.
This obviously requires some care, as a poorly implemented token
validator may silently open the entire database to anyone with a
bearer token. But it may be a more portable approach, since OAuth
is designed as an authorization framework, not an authentication
framework. For example, the user's bearer token could carry an
"allow_superuser_access" claim, which would authorize pseudonymous
database access as any role. It's then up to the OAuth system
administrators to ensure that allow_superuser_access is doled out
only to the proper users.
c. It's possible that the user can be successfully authenticated but
isn't authorized to connect. In this case, the command may print
the authenticated ID and then fail with a non-zero exit code.
(This makes it easier to see what's going on in the Postgres
logs.)
4. Token validators may optionally log to stderr. This will be printed
verbatim into the Postgres server logs.
The oauth method supports the following HBA options (but note that two
of them are not optional, since we have no way of choosing sensible
defaults):
issuer: Required. The URL of the OAuth issuing party, which the client
must contact to receive a bearer token.
Some real-world examples as of time of writing:
- https://accounts.google.com
- https://login.microsoft.com/[tenant-id]/v2.0
scope: Required. The OAuth scope(s) required for the server to
authenticate and/or authorize the user. This is heavily
deployment-specific, but a simple example is "openid email".
map: Optional. Specify a standard PostgreSQL user map; this works
the same as with other auth methods such as peer. If a map is
not specified, the user ID returned by the token validator
must exactly match the role that's being requested (but see
trust_validator_authz, below).
trust_validator_authz:
Optional. When set to 1, this allows the token validator to
take full control of the authorization process. Standard user
mapping is skipped: if the validator command succeeds, the
client is allowed to connect under its desired role and no
further checks are done.
Unlike the client, servers support OAuth without needing to be built
against libiddawc (since the responsibility for "speaking" OAuth/OIDC
correctly is delegated entirely to the oauth_validator_command).
Several TODOs:
- port to platforms other than "modern Linux"
- overhaul the communication with oauth_validator_command, which is
currently a bad hack on OpenPipeStream()
- implement more sanity checks on the OAUTHBEARER message format and
tokens sent by the client
- implement more helpful handling of HBA misconfigurations
- properly interpolate JSON when generating error responses
- use logdetail during auth failures
- deal with role names that can't be safely passed to system() without
shell-escaping
- allow passing the configured issuer to the oauth_validator_command, to
deal with multi-issuer setups
- ...and more.
---
src/backend/libpq/Makefile | 1 +
src/backend/libpq/auth-oauth.c | 797 +++++++++++++++++++++++++++++++++
src/backend/libpq/auth-sasl.c | 10 +-
src/backend/libpq/auth-scram.c | 4 +-
src/backend/libpq/auth.c | 7 +
src/backend/libpq/hba.c | 29 +-
src/backend/utils/misc/guc.c | 12 +
src/include/libpq/hba.h | 8 +-
src/include/libpq/oauth.h | 24 +
src/include/libpq/sasl.h | 11 +
10 files changed, 889 insertions(+), 14 deletions(-)
create mode 100644 src/backend/libpq/auth-oauth.c
create mode 100644 src/include/libpq/oauth.h
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 6d385fd6a4..98eb2a8242 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,6 +15,7 @@ include $(top_builddir)/src/Makefile.global
# be-fsstubs is here for historical reasons, probably belongs elsewhere
OBJS = \
+ auth-oauth.o \
auth-sasl.o \
auth-scram.o \
auth.o \
diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c
new file mode 100644
index 0000000000..c1232a31a0
--- /dev/null
+++ b/src/backend/libpq/auth-oauth.c
@@ -0,0 +1,797 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth-oauth.c
+ * Server-side implementation of the SASL OAUTHBEARER mechanism.
+ *
+ * See the following RFC for more details:
+ * - RFC 7628: https://tools.ietf.org/html/rfc7628
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/backend/libpq/auth-oauth.c
+ *
+ *-------------------------------------------------------------------------
+ */
+#include "postgres.h"
+
+#include <unistd.h>
+#include <fcntl.h>
+
+#include "common/oauth-common.h"
+#include "lib/stringinfo.h"
+#include "libpq/auth.h"
+#include "libpq/hba.h"
+#include "libpq/oauth.h"
+#include "libpq/sasl.h"
+#include "storage/fd.h"
+
+/* GUC */
+char *oauth_validator_command;
+
+static void oauth_get_mechanisms(Port *port, StringInfo buf);
+static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
+static int oauth_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, const char **logdetail);
+
+/* Mechanism declaration */
+const pg_be_sasl_mech pg_be_oauth_mech = {
+ oauth_get_mechanisms,
+ oauth_init,
+ oauth_exchange,
+
+ PG_MAX_AUTH_TOKEN_LENGTH,
+};
+
+
+typedef enum
+{
+ OAUTH_STATE_INIT = 0,
+ OAUTH_STATE_ERROR,
+ OAUTH_STATE_FINISHED,
+} oauth_state;
+
+struct oauth_ctx
+{
+ oauth_state state;
+ Port *port;
+ const char *issuer;
+ const char *scope;
+};
+
+static char *sanitize_char(char c);
+static char *parse_kvpairs_for_auth(char **input);
+static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
+static bool validate(Port *port, const char *auth, const char **logdetail);
+static bool run_validator_command(Port *port, const char *token);
+static bool check_exit(FILE **fh, const char *command);
+static bool unset_cloexec(int fd);
+static bool username_ok_for_shell(const char *username);
+
+#define KVSEP 0x01
+#define AUTH_KEY "auth"
+#define BEARER_SCHEME "Bearer "
+
+static void
+oauth_get_mechanisms(Port *port, StringInfo buf)
+{
+ /* Only OAUTHBEARER is supported. */
+ appendStringInfoString(buf, OAUTHBEARER_NAME);
+ appendStringInfoChar(buf, '\0');
+}
+
+static void *
+oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
+{
+ struct oauth_ctx *ctx;
+
+ if (strcmp(selected_mech, OAUTHBEARER_NAME))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("client selected an invalid SASL authentication mechanism")));
+
+ ctx = palloc0(sizeof(*ctx));
+
+ ctx->state = OAUTH_STATE_INIT;
+ ctx->port = port;
+
+ Assert(port->hba);
+ ctx->issuer = port->hba->oauth_issuer;
+ ctx->scope = port->hba->oauth_scope;
+
+ return ctx;
+}
+
+static int
+oauth_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, const char **logdetail)
+{
+ char *p;
+ char cbind_flag;
+ char *auth;
+
+ struct oauth_ctx *ctx = opaq;
+
+ *output = NULL;
+ *outputlen = -1;
+
+ /*
+ * If the client didn't include an "Initial Client Response" in the
+ * SASLInitialResponse message, send an empty challenge, to which the
+ * client will respond with the same data that usually comes in the
+ * Initial Client Response.
+ */
+ if (input == NULL)
+ {
+ Assert(ctx->state == OAUTH_STATE_INIT);
+
+ *output = pstrdup("");
+ *outputlen = 0;
+ return PG_SASL_EXCHANGE_CONTINUE;
+ }
+
+ /*
+ * Check that the input length agrees with the string length of the input.
+ */
+ if (inputlen == 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("The message is empty.")));
+ if (inputlen != strlen(input))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message length does not match input length.")));
+
+ switch (ctx->state)
+ {
+ case OAUTH_STATE_INIT:
+ /* Handle this case below. */
+ break;
+
+ case OAUTH_STATE_ERROR:
+ /*
+ * Only one response is valid for the client during authentication
+ * failure: a single kvsep.
+ */
+ if (inputlen != 1 || *input != KVSEP)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Client did not send a kvsep response.")));
+
+ /* The (failed) handshake is now complete. */
+ ctx->state = OAUTH_STATE_FINISHED;
+ return PG_SASL_EXCHANGE_FAILURE;
+
+ default:
+ elog(ERROR, "invalid OAUTHBEARER exchange state");
+ return PG_SASL_EXCHANGE_FAILURE;
+ }
+
+ /* Handle the client's initial message. */
+ p = pstrdup(input);
+
+ /*
+ * OAUTHBEARER does not currently define a channel binding (so there is no
+ * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a 'y'
+ * specifier purely for the remote chance that a future specification could
+ * define one; then future clients can still interoperate with this server
+ * implementation. 'n' is the expected case.
+ */
+ cbind_flag = *p;
+ switch (cbind_flag)
+ {
+ case 'p':
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data.")));
+ break;
+
+ case 'y': /* fall through */
+ case 'n':
+ p++;
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Comma expected, but found character %s.",
+ sanitize_char(*p))));
+ p++;
+ break;
+
+ default:
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Unexpected channel-binding flag %s.",
+ sanitize_char(cbind_flag))));
+ }
+
+ /*
+ * Forbid optional authzid (authorization identity). We don't support it.
+ */
+ if (*p == 'a')
+ ereport(ERROR,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("client uses authorization identity, but it is not supported")));
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Unexpected attribute %s in client-first-message.",
+ sanitize_char(*p))));
+ p++;
+
+ /* All remaining fields are separated by the RFC's kvsep (\x01). */
+ if (*p != KVSEP)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Key-value separator expected, but found character %s.",
+ sanitize_char(*p))));
+ p++;
+
+ auth = parse_kvpairs_for_auth(&p);
+ if (!auth)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message does not contain an auth value.")));
+
+ /* We should be at the end of our message. */
+ if (*p)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains additional data after the final terminator.")));
+
+ if (!validate(ctx->port, auth, logdetail))
+ {
+ generate_error_response(ctx, output, outputlen);
+
+ ctx->state = OAUTH_STATE_ERROR;
+ return PG_SASL_EXCHANGE_CONTINUE;
+ }
+
+ ctx->state = OAUTH_STATE_FINISHED;
+ return PG_SASL_EXCHANGE_SUCCESS;
+}
+
+/*
+ * Convert an arbitrary byte to printable form. For error messages.
+ *
+ * If it's a printable ASCII character, print it as a single character.
+ * otherwise, print it in hex.
+ *
+ * The returned pointer points to a static buffer.
+ */
+static char *
+sanitize_char(char c)
+{
+ static char buf[5];
+
+ if (c >= 0x21 && c <= 0x7E)
+ snprintf(buf, sizeof(buf), "'%c'", c);
+ else
+ snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
+ return buf;
+}
+
+/*
+ * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is
+ * found, its value is returned.
+ */
+static char *
+parse_kvpairs_for_auth(char **input)
+{
+ char *pos = *input;
+ char *auth = NULL;
+
+ /*
+ * The relevant ABNF, from Sec. 3.1:
+ *
+ * kvsep = %x01
+ * key = 1*(ALPHA)
+ * value = *(VCHAR / SP / HTAB / CR / LF )
+ * kvpair = key "=" value kvsep
+ * ;;gs2-header = See RFC 5801
+ * client-resp = (gs2-header kvsep *kvpair kvsep) / kvsep
+ *
+ * By the time we reach this code, the gs2-header and initial kvsep have
+ * already been validated. We start at the beginning of the first kvpair.
+ */
+
+ while (*pos)
+ {
+ char *end;
+ char *sep;
+ char *key;
+ char *value;
+
+ /*
+ * Find the end of this kvpair. Note that input is null-terminated by
+ * the SASL code, so the strchr() is bounded.
+ */
+ end = strchr(pos, KVSEP);
+ if (!end)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains an unterminated key/value pair.")));
+ *end = '\0';
+
+ if (pos == end)
+ {
+ /* Empty kvpair, signifying the end of the list. */
+ *input = pos + 1;
+ return auth;
+ }
+
+ /*
+ * Find the end of the key name.
+ *
+ * TODO further validate the key/value grammar? empty keys, bad chars...
+ */
+ sep = strchr(pos, '=');
+ if (!sep)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains a key without a value.")));
+ *sep = '\0';
+
+ /* Both key and value are now safely terminated. */
+ key = pos;
+ value = sep + 1;
+
+ if (!strcmp(key, AUTH_KEY))
+ {
+ if (auth)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains multiple auth values.")));
+
+ auth = value;
+ }
+ else
+ {
+ /*
+ * The RFC also defines the host and port keys, but they are not
+ * required for OAUTHBEARER and we do not use them. Also, per
+ * Sec. 3.1, any key/value pairs we don't recognize must be ignored.
+ */
+ }
+
+ /* Move to the next pair. */
+ pos = end + 1;
+ }
+
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message did not contain a final terminator.")));
+
+ return NULL; /* unreachable */
+}
+
+static void
+generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
+{
+ StringInfoData buf;
+
+ /*
+ * The admin needs to set an issuer and scope for OAuth to work. There's not
+ * really a way to hide this from the user, either, because we can't choose
+ * a "default" issuer, so be honest in the failure message.
+ *
+ * TODO: see if there's a better place to fail, earlier than this.
+ */
+ if (!ctx->issuer || !ctx->scope)
+ ereport(FATAL,
+ (errcode(ERRCODE_INTERNAL_ERROR),
+ errmsg("OAuth is not properly configured for this user"),
+ errdetail_log("The issuer and scope parameters must be set in pg_hba.conf.")));
+
+
+ initStringInfo(&buf);
+
+ /*
+ * TODO: JSON escaping
+ */
+ appendStringInfo(&buf,
+ "{ "
+ "\"status\": \"invalid_token\", "
+ "\"openid-configuration\": \"%s/.well-known/openid-configuration\","
+ "\"scope\": \"%s\" "
+ "}",
+ ctx->issuer, ctx->scope);
+
+ *output = buf.data;
+ *outputlen = buf.len;
+}
+
+static bool
+validate(Port *port, const char *auth, const char **logdetail)
+{
+ static const char * const b64_set = "abcdefghijklmnopqrstuvwxyz"
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "0123456789-._~+/";
+
+ const char *token;
+ size_t span;
+ int ret;
+
+ /* TODO: handle logdetail when the test framework can check it */
+
+ /*
+ * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec.
+ * 2.1:
+ *
+ * b64token = 1*( ALPHA / DIGIT /
+ * "-" / "." / "_" / "~" / "+" / "/" ) *"="
+ * credentials = "Bearer" 1*SP b64token
+ *
+ * The "credentials" construction is what we receive in our auth value.
+ *
+ * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization
+ * header format; RFC 7235 Sec. 2), the "Bearer" scheme string must be
+ * compared case-insensitively. (This is not mentioned in RFC 6750, but it's
+ * pointed out in RFC 7628 Sec. 4.)
+ *
+ * TODO: handle the Authorization spec, RFC 7235 Sec. 2.1.
+ */
+ if (strncasecmp(auth, BEARER_SCHEME, strlen(BEARER_SCHEME)))
+ return false;
+
+ /* Pull the bearer token out of the auth value. */
+ token = auth + strlen(BEARER_SCHEME);
+
+ /* Swallow any additional spaces. */
+ while (*token == ' ')
+ token++;
+
+ /*
+ * Before invoking the validator command, sanity-check the token format to
+ * avoid any injection attacks later in the chain. Invalid formats are
+ * technically a protocol violation, but don't reflect any information about
+ * the sensitive Bearer token back to the client; log at COMMERROR instead.
+ */
+
+ /* Tokens must not be empty. */
+ if (!*token)
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Bearer token is empty.")));
+ return false;
+ }
+
+ /*
+ * Make sure the token contains only allowed characters. Tokens may end with
+ * any number of '=' characters.
+ */
+ span = strspn(token, b64_set);
+ while (token[span] == '=')
+ span++;
+
+ if (token[span] != '\0')
+ {
+ /*
+ * This error message could be more helpful by printing the problematic
+ * character(s), but that'd be a bit like printing a piece of someone's
+ * password into the logs.
+ */
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Bearer token is not in the correct format.")));
+ return false;
+ }
+
+ /* Have the validator check the token. */
+ if (!run_validator_command(port, token))
+ return false;
+
+ if (port->hba->oauth_skip_usermap)
+ {
+ /*
+ * If the validator is our authorization authority, we're done.
+ * Authentication may or may not have been performed depending on the
+ * validator implementation; all that matters is that the validator says
+ * the user can log in with the target role.
+ */
+ return true;
+ }
+
+ /* Make sure the validator authenticated the user. */
+ if (!port->authn_id)
+ {
+ /* TODO: use logdetail; reduce message duplication */
+ ereport(LOG,
+ (errmsg("OAuth bearer authentication failed for user \"%s\": validator provided no identity",
+ port->user_name)));
+ return false;
+ }
+
+ /* Finally, check the user map. */
+ ret = check_usermap(port->hba->usermap, port->user_name, port->authn_id,
+ false);
+ return (ret == STATUS_OK);
+}
+
+static bool
+run_validator_command(Port *port, const char *token)
+{
+ bool success = false;
+ int rc;
+ int pipefd[2];
+ int rfd = -1;
+ int wfd = -1;
+
+ StringInfoData command = { 0 };
+ char *p;
+ FILE *fh = NULL;
+
+ ssize_t written;
+ char *line = NULL;
+ size_t size = 0;
+ ssize_t len;
+
+ Assert(oauth_validator_command);
+
+ if (!oauth_validator_command[0])
+ {
+ ereport(COMMERROR,
+ (errmsg("oauth_validator_command is not set"),
+ errhint("To allow OAuth authenticated connections, set "
+ "oauth_validator_command in postgresql.conf.")));
+ return false;
+ }
+
+ /*
+ * Since popen() is unidirectional, open up a pipe for the other direction.
+ * Use CLOEXEC to ensure that our write end doesn't accidentally get copied
+ * into child processes, which would prevent us from closing it cleanly.
+ *
+ * XXX this is ugly. We should just read from the child process's stdout,
+ * but that's a lot more code.
+ * XXX by bypassing the popen API, we open the potential of process
+ * deadlock. Clearly document child process requirements (i.e. the child
+ * MUST read all data off of the pipe before writing anything).
+ * TODO: port to Windows using _pipe().
+ */
+ rc = pipe2(pipefd, O_CLOEXEC);
+ if (rc < 0)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not create child pipe: %m")));
+ return false;
+ }
+
+ rfd = pipefd[0];
+ wfd = pipefd[1];
+
+ /* Allow the read pipe be passed to the child. */
+ if (!unset_cloexec(rfd))
+ {
+ /* error message was already logged */
+ goto cleanup;
+ }
+
+ /*
+ * Construct the command, substituting any recognized %-specifiers:
+ *
+ * %f: the file descriptor of the input pipe
+ * %r: the role that the client wants to assume (port->user_name)
+ * %%: a literal '%'
+ */
+ initStringInfo(&command);
+
+ for (p = oauth_validator_command; *p; p++)
+ {
+ if (p[0] == '%')
+ {
+ switch (p[1])
+ {
+ case 'f':
+ appendStringInfo(&command, "%d", rfd);
+ p++;
+ break;
+ case 'r':
+ /*
+ * TODO: decide how this string should be escaped. The role
+ * is controlled by the client, so if we don't escape it,
+ * command injections are inevitable.
+ *
+ * This is probably an indication that the role name needs
+ * to be communicated to the validator process in some other
+ * way. For this proof of concept, just be incredibly strict
+ * about the characters that are allowed in user names.
+ */
+ if (!username_ok_for_shell(port->user_name))
+ goto cleanup;
+
+ appendStringInfoString(&command, port->user_name);
+ p++;
+ break;
+ case '%':
+ appendStringInfoChar(&command, '%');
+ p++;
+ break;
+ default:
+ appendStringInfoChar(&command, p[0]);
+ }
+ }
+ else
+ appendStringInfoChar(&command, p[0]);
+ }
+
+ /* Execute the command. */
+ fh = OpenPipeStream(command.data, "re");
+ /* TODO: handle failures */
+
+ /* We don't need the read end of the pipe anymore. */
+ close(rfd);
+ rfd = -1;
+
+ /* Give the command the token to validate. */
+ written = write(wfd, token, strlen(token));
+ if (written != strlen(token))
+ {
+ /* TODO must loop for short writes, EINTR et al */
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not write token to child pipe: %m")));
+ goto cleanup;
+ }
+
+ close(wfd);
+ wfd = -1;
+
+ /*
+ * Read the command's response.
+ *
+ * TODO: getline() is probably too new to use, unfortunately.
+ * TODO: loop over all lines
+ */
+ if ((len = getline(&line, &size, fh)) >= 0)
+ {
+ /* TODO: fail if the authn_id doesn't end with a newline */
+ if (len > 0)
+ line[len - 1] = '\0';
+
+ set_authn_id(port, line);
+ }
+ else if (ferror(fh))
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not read from command \"%s\": %m",
+ command.data)));
+ goto cleanup;
+ }
+
+ /* Make sure the command exits cleanly. */
+ if (!check_exit(&fh, command.data))
+ {
+ /* error message already logged */
+ goto cleanup;
+ }
+
+ /* Done. */
+ success = true;
+
+cleanup:
+ if (line)
+ free(line);
+
+ /*
+ * In the successful case, the pipe fds are already closed. For the error
+ * case, always close out the pipe before waiting for the command, to
+ * prevent deadlock.
+ */
+ if (rfd >= 0)
+ close(rfd);
+ if (wfd >= 0)
+ close(wfd);
+
+ if (fh)
+ {
+ Assert(!success);
+ check_exit(&fh, command.data);
+ }
+
+ if (command.data)
+ pfree(command.data);
+
+ return success;
+}
+
+static bool
+check_exit(FILE **fh, const char *command)
+{
+ int rc;
+
+ rc = ClosePipeStream(*fh);
+ *fh = NULL;
+
+ if (rc == -1)
+ {
+ /* pclose() itself failed. */
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not close pipe to command \"%s\": %m",
+ command)));
+ }
+ else if (rc != 0)
+ {
+ char *reason = wait_result_to_str(rc);
+
+ ereport(COMMERROR,
+ (errmsg("failed to execute command \"%s\": %s",
+ command, reason)));
+
+ pfree(reason);
+ }
+
+ return (rc == 0);
+}
+
+static bool
+unset_cloexec(int fd)
+{
+ int flags;
+ int rc;
+
+ flags = fcntl(fd, F_GETFD);
+ if (flags == -1)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not get fd flags for child pipe: %m")));
+ return false;
+ }
+
+ rc = fcntl(fd, F_SETFD, flags & ~FD_CLOEXEC);
+ if (rc < 0)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not unset FD_CLOEXEC for child pipe: %m")));
+ return false;
+ }
+
+ return true;
+}
+
+/*
+ * XXX This should go away eventually and be replaced with either a proper
+ * escape or a different strategy for communication with the validator command.
+ */
+static bool
+username_ok_for_shell(const char *username)
+{
+ /* This set is borrowed from fe_utils' appendShellStringNoError(). */
+ static const char * const allowed = "abcdefghijklmnopqrstuvwxyz"
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "0123456789-_./:";
+ size_t span;
+
+ Assert(username && username[0]); /* should have already been checked */
+
+ span = strspn(username, allowed);
+ if (username[span] != '\0')
+ {
+ ereport(COMMERROR,
+ (errmsg("PostgreSQL user name contains unsafe characters and cannot be passed to the OAuth validator")));
+ return false;
+ }
+
+ return true;
+}
diff --git a/src/backend/libpq/auth-sasl.c b/src/backend/libpq/auth-sasl.c
index a1d7dbb6d5..0f461a6696 100644
--- a/src/backend/libpq/auth-sasl.c
+++ b/src/backend/libpq/auth-sasl.c
@@ -20,14 +20,6 @@
#include "libpq/pqformat.h"
#include "libpq/sasl.h"
-/*
- * Maximum accepted size of SASL messages.
- *
- * The messages that the server or libpq generate are much smaller than this,
- * but have some headroom.
- */
-#define PG_MAX_SASL_MESSAGE_LENGTH 1024
-
/*
* Perform a SASL exchange with a libpq client, using a specific mechanism
* implementation.
@@ -103,7 +95,7 @@ CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
/* Get the actual SASL message */
initStringInfo(&buf);
- if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
+ if (pq_getmessage(&buf, mech->max_message_length))
{
/* EOF - pq_getmessage already logged error */
pfree(buf.data);
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index ee7f52218a..4049ace470 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -118,7 +118,9 @@ static int scram_exchange(void *opaq, const char *input, int inputlen,
const pg_be_sasl_mech pg_be_scram_mech = {
scram_get_mechanisms,
scram_init,
- scram_exchange
+ scram_exchange,
+
+ PG_MAX_SASL_MESSAGE_LENGTH
};
/*
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 3533b0bc50..5c30904e2b 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -30,6 +30,7 @@
#include "libpq/auth.h"
#include "libpq/crypt.h"
#include "libpq/libpq.h"
+#include "libpq/oauth.h"
#include "libpq/pqformat.h"
#include "libpq/sasl.h"
#include "libpq/scram.h"
@@ -302,6 +303,9 @@ auth_failed(Port *port, int status, const char *logdetail)
case uaRADIUS:
errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
break;
+ case uaOAuth:
+ errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
+ break;
case uaCustom:
if (CustomAuthenticationError_hook)
errstr = CustomAuthenticationError_hook(port);
@@ -627,6 +631,9 @@ ClientAuthentication(Port *port)
case uaTrust:
status = STATUS_OK;
break;
+ case uaOAuth:
+ status = CheckSASLAuth(&pg_be_oauth_mech, port, NULL, NULL);
+ break;
case uaCustom:
if (CustomAuthenticationCheck_hook)
status = CustomAuthenticationCheck_hook(port);
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index ebae992964..f7f3059927 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -135,7 +135,8 @@ static const char *const UserAuthName[] =
"cert",
"radius",
"custom",
- "peer"
+ "peer",
+ "oauth",
};
@@ -1400,6 +1401,8 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
#endif
else if (strcmp(token->string, "radius") == 0)
parsedline->auth_method = uaRADIUS;
+ else if (strcmp(token->string, "oauth") == 0)
+ parsedline->auth_method = uaOAuth;
else if (strcmp(token->string, "custom") == 0)
parsedline->auth_method = uaCustom;
else
@@ -1728,8 +1731,9 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->auth_method != uaPeer &&
hbaline->auth_method != uaGSS &&
hbaline->auth_method != uaSSPI &&
+ hbaline->auth_method != uaOAuth &&
hbaline->auth_method != uaCert)
- INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, and cert"));
+ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, oauth, and cert"));
hbaline->usermap = pstrdup(val);
}
else if (strcmp(name, "clientcert") == 0)
@@ -2113,6 +2117,27 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->radiusidentifiers = parsed_identifiers;
hbaline->radiusidentifiers_s = pstrdup(val);
}
+ else if (strcmp(name, "issuer") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("issuer", gettext_noop("oauth"));
+ hbaline->oauth_issuer = pstrdup(val);
+ }
+ else if (strcmp(name, "scope") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("scope", gettext_noop("oauth"));
+ hbaline->oauth_scope = pstrdup(val);
+ }
+ else if (strcmp(name, "trust_validator_authz") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("trust_validator_authz", gettext_noop("oauth"));
+ if (strcmp(val, "1") == 0)
+ hbaline->oauth_skip_usermap = true;
+ else
+ hbaline->oauth_skip_usermap = false;
+ }
else if (strcmp(name, "provider") == 0)
{
REQUIRE_AUTH_OPTION(uaCustom, "provider", "custom");
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index 1e3650184b..791c7c83df 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -58,6 +58,7 @@
#include "libpq/auth.h"
#include "libpq/libpq.h"
#include "libpq/pqformat.h"
+#include "libpq/oauth.h"
#include "miscadmin.h"
#include "optimizer/cost.h"
#include "optimizer/geqo.h"
@@ -4662,6 +4663,17 @@ static struct config_string ConfigureNamesString[] =
check_backtrace_functions, assign_backtrace_functions, NULL
},
+ {
+ {"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH,
+ gettext_noop("Command to validate OAuth v2 bearer tokens."),
+ NULL,
+ GUC_SUPERUSER_ONLY
+ },
+ &oauth_validator_command,
+ "",
+ NULL, NULL, NULL
+ },
+
/* End-of-list marker */
{
{NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index c5aef6994c..d46c2108eb 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -39,8 +39,9 @@ typedef enum UserAuth
uaCert,
uaRADIUS,
uaCustom,
- uaPeer
-#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */
+ uaPeer,
+ uaOAuth
+#define USER_AUTH_LAST uaOAuth /* Must be last value of this enum */
} UserAuth;
/*
@@ -121,6 +122,9 @@ typedef struct HbaLine
char *radiusidentifiers_s;
List *radiusports;
char *radiusports_s;
+ char *oauth_issuer;
+ char *oauth_scope;
+ bool oauth_skip_usermap;
char *custom_provider;
} HbaLine;
diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h
new file mode 100644
index 0000000000..870e426af1
--- /dev/null
+++ b/src/include/libpq/oauth.h
@@ -0,0 +1,24 @@
+/*-------------------------------------------------------------------------
+ *
+ * oauth.h
+ * Interface to libpq/auth-oauth.c
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/libpq/oauth.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef PG_OAUTH_H
+#define PG_OAUTH_H
+
+#include "libpq/libpq-be.h"
+#include "libpq/sasl.h"
+
+extern char *oauth_validator_command;
+
+/* Implementation */
+extern const pg_be_sasl_mech pg_be_oauth_mech;
+
+#endif /* PG_OAUTH_H */
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
index 71cc0dc251..3d481cc807 100644
--- a/src/include/libpq/sasl.h
+++ b/src/include/libpq/sasl.h
@@ -26,6 +26,14 @@
#define PG_SASL_EXCHANGE_SUCCESS 1
#define PG_SASL_EXCHANGE_FAILURE 2
+/*
+ * Maximum accepted size of SASL messages.
+ *
+ * The messages that the server or libpq generate are much smaller than this,
+ * but have some headroom.
+ */
+#define PG_MAX_SASL_MESSAGE_LENGTH 1024
+
/*
* Backend SASL mechanism callbacks.
*
@@ -127,6 +135,9 @@ typedef struct pg_be_sasl_mech
const char *input, int inputlen,
char **output, int *outputlen,
const char **logdetail);
+
+ /* The maximum size allowed for client SASLResponses. */
+ int max_message_length;
} pg_be_sasl_mech;
/* Common implementation for auth.c */
--
2.25.1
v3-0007-Add-a-very-simple-authn_id-extension.patchtext/x-patch; name=v3-0007-Add-a-very-simple-authn_id-extension.patchDownload
From 667fbd709f67232155cbaa3e09d1a4b4c02eeb22 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 18 May 2021 15:01:29 -0700
Subject: [PATCH v3 7/9] Add a very simple authn_id extension
...for retrieving the authn_id from the server in tests.
---
contrib/authn_id/Makefile | 19 +++++++++++++++++++
contrib/authn_id/authn_id--1.0.sql | 8 ++++++++
contrib/authn_id/authn_id.c | 28 ++++++++++++++++++++++++++++
contrib/authn_id/authn_id.control | 5 +++++
4 files changed, 60 insertions(+)
create mode 100644 contrib/authn_id/Makefile
create mode 100644 contrib/authn_id/authn_id--1.0.sql
create mode 100644 contrib/authn_id/authn_id.c
create mode 100644 contrib/authn_id/authn_id.control
diff --git a/contrib/authn_id/Makefile b/contrib/authn_id/Makefile
new file mode 100644
index 0000000000..46026358e0
--- /dev/null
+++ b/contrib/authn_id/Makefile
@@ -0,0 +1,19 @@
+# contrib/authn_id/Makefile
+
+MODULE_big = authn_id
+OBJS = authn_id.o
+
+EXTENSION = authn_id
+DATA = authn_id--1.0.sql
+PGFILEDESC = "authn_id - information about the authenticated user"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = contrib/authn_id
+top_builddir = ../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/contrib/authn_id/authn_id--1.0.sql b/contrib/authn_id/authn_id--1.0.sql
new file mode 100644
index 0000000000..af2a4d3991
--- /dev/null
+++ b/contrib/authn_id/authn_id--1.0.sql
@@ -0,0 +1,8 @@
+/* contrib/authn_id/authn_id--1.0.sql */
+
+-- complain if script is sourced in psql, rather than via CREATE EXTENSION
+\echo Use "CREATE EXTENSION authn_id" to load this file. \quit
+
+CREATE FUNCTION authn_id() RETURNS text
+AS 'MODULE_PATHNAME', 'authn_id'
+LANGUAGE C IMMUTABLE;
diff --git a/contrib/authn_id/authn_id.c b/contrib/authn_id/authn_id.c
new file mode 100644
index 0000000000..0fecac36a8
--- /dev/null
+++ b/contrib/authn_id/authn_id.c
@@ -0,0 +1,28 @@
+/*
+ * Extension to expose the current user's authn_id.
+ *
+ * contrib/authn_id/authn_id.c
+ */
+
+#include "postgres.h"
+
+#include "fmgr.h"
+#include "libpq/libpq-be.h"
+#include "miscadmin.h"
+#include "utils/builtins.h"
+
+PG_MODULE_MAGIC;
+
+PG_FUNCTION_INFO_V1(authn_id);
+
+/*
+ * Returns the current user's authenticated identity.
+ */
+Datum
+authn_id(PG_FUNCTION_ARGS)
+{
+ if (!MyProcPort->authn_id)
+ PG_RETURN_NULL();
+
+ PG_RETURN_TEXT_P(cstring_to_text(MyProcPort->authn_id));
+}
diff --git a/contrib/authn_id/authn_id.control b/contrib/authn_id/authn_id.control
new file mode 100644
index 0000000000..e0f9e06bed
--- /dev/null
+++ b/contrib/authn_id/authn_id.control
@@ -0,0 +1,5 @@
+# authn_id extension
+comment = 'current user identity'
+default_version = '1.0'
+module_pathname = '$libdir/authn_id'
+relocatable = true
--
2.25.1
v3-0008-Add-pytest-suite-for-OAuth.patchtext/x-patch; name=v3-0008-Add-pytest-suite-for-OAuth.patchDownload
From 1c24881d4b1e8777ce176d2c276fe8120bd6e648 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Fri, 4 Jun 2021 09:06:38 -0700
Subject: [PATCH v3 8/9] Add pytest suite for OAuth
Requires Python 3; on the first run of `make installcheck` the
dependencies will be installed into ./venv for you. See the README for
more details.
---
src/test/python/.gitignore | 2 +
src/test/python/Makefile | 38 +
src/test/python/README | 54 ++
src/test/python/client/__init__.py | 0
src/test/python/client/conftest.py | 126 +++
src/test/python/client/test_client.py | 180 ++++
src/test/python/client/test_oauth.py | 936 ++++++++++++++++++
src/test/python/pq3.py | 727 ++++++++++++++
src/test/python/pytest.ini | 4 +
src/test/python/requirements.txt | 7 +
src/test/python/server/__init__.py | 0
src/test/python/server/conftest.py | 45 +
src/test/python/server/test_oauth.py | 1012 ++++++++++++++++++++
src/test/python/server/test_server.py | 21 +
src/test/python/server/validate_bearer.py | 101 ++
src/test/python/server/validate_reflect.py | 34 +
src/test/python/test_internals.py | 138 +++
src/test/python/test_pq3.py | 558 +++++++++++
src/test/python/tls.py | 195 ++++
19 files changed, 4178 insertions(+)
create mode 100644 src/test/python/.gitignore
create mode 100644 src/test/python/Makefile
create mode 100644 src/test/python/README
create mode 100644 src/test/python/client/__init__.py
create mode 100644 src/test/python/client/conftest.py
create mode 100644 src/test/python/client/test_client.py
create mode 100644 src/test/python/client/test_oauth.py
create mode 100644 src/test/python/pq3.py
create mode 100644 src/test/python/pytest.ini
create mode 100644 src/test/python/requirements.txt
create mode 100644 src/test/python/server/__init__.py
create mode 100644 src/test/python/server/conftest.py
create mode 100644 src/test/python/server/test_oauth.py
create mode 100644 src/test/python/server/test_server.py
create mode 100755 src/test/python/server/validate_bearer.py
create mode 100755 src/test/python/server/validate_reflect.py
create mode 100644 src/test/python/test_internals.py
create mode 100644 src/test/python/test_pq3.py
create mode 100644 src/test/python/tls.py
diff --git a/src/test/python/.gitignore b/src/test/python/.gitignore
new file mode 100644
index 0000000000..0e8f027b2e
--- /dev/null
+++ b/src/test/python/.gitignore
@@ -0,0 +1,2 @@
+__pycache__/
+/venv/
diff --git a/src/test/python/Makefile b/src/test/python/Makefile
new file mode 100644
index 0000000000..b0695b6287
--- /dev/null
+++ b/src/test/python/Makefile
@@ -0,0 +1,38 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+# Only Python 3 is supported, but if it's named something different on your
+# system you can override it with the PYTHON3 variable.
+PYTHON3 := python3
+
+# All dependencies are placed into this directory. The default is .gitignored
+# for you, but you can override it if you'd like.
+VENV := ./venv
+
+override VBIN := $(VENV)/bin
+override PIP := $(VBIN)/pip
+override PYTEST := $(VBIN)/py.test
+override ISORT := $(VBIN)/isort
+override BLACK := $(VBIN)/black
+
+.PHONY: installcheck indent
+
+installcheck: $(PYTEST)
+ $(PYTEST) -v -rs
+
+indent: $(ISORT) $(BLACK)
+ $(ISORT) --profile black *.py client/*.py server/*.py
+ $(BLACK) *.py client/*.py server/*.py
+
+$(PYTEST) $(ISORT) $(BLACK) &: requirements.txt | $(PIP)
+ $(PIP) install --force-reinstall -r $<
+
+$(PIP):
+ $(PYTHON3) -m venv $(VENV)
+
+# A convenience recipe to rebuild psycopg2 against the local libpq.
+.PHONY: rebuild-psycopg2
+rebuild-psycopg2: | $(PIP)
+ $(PIP) install --force-reinstall --no-binary :all: $(shell grep psycopg2 requirements.txt)
diff --git a/src/test/python/README b/src/test/python/README
new file mode 100644
index 0000000000..0bda582c4b
--- /dev/null
+++ b/src/test/python/README
@@ -0,0 +1,54 @@
+A test suite for exercising both the libpq client and the server backend at the
+protocol level, based on pytest and Construct.
+
+The test suite currently assumes that the standard PG* environment variables
+point to the database under test and are sufficient to log in a superuser on
+that system. In other words, a bare `psql` needs to Just Work before the test
+suite can do its thing. For a newly built dev cluster, typically all that I need
+to do is a
+
+ export PGDATABASE=postgres
+
+but you can adjust as needed for your setup.
+
+## Requirements
+
+A supported version (3.6+) of Python.
+
+The first run of
+
+ make installcheck
+
+will install a local virtual environment and all needed dependencies. During
+development, if libpq changes incompatibly, you can issue
+
+ $ make rebuild-psycopg2
+
+to force a rebuild of the client library.
+
+## Hacking
+
+The code style is enforced by a _very_ opinionated autoformatter. Running the
+
+ make indent
+
+recipe will invoke it for you automatically. Don't fight the tool; part of the
+zen is in knowing that if the formatter makes your code ugly, there's probably a
+cleaner way to write your code.
+
+## Advanced Usage
+
+The Makefile is there for convenience, but you don't have to use it. Activate
+the virtualenv to be able to use pytest directly:
+
+ $ source venv/bin/activate
+ $ py.test -k oauth
+ ...
+ $ py.test ./server/test_server.py
+ ...
+ $ deactivate # puts the PATH et al back the way it was before
+
+To make quick smoke tests possible, slow tests have been marked explicitly. You
+can skip them by saying e.g.
+
+ $ py.test -m 'not slow'
diff --git a/src/test/python/client/__init__.py b/src/test/python/client/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/client/conftest.py b/src/test/python/client/conftest.py
new file mode 100644
index 0000000000..f38da7a138
--- /dev/null
+++ b/src/test/python/client/conftest.py
@@ -0,0 +1,126 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import socket
+import sys
+import threading
+
+import psycopg2
+import pytest
+
+import pq3
+
+BLOCKING_TIMEOUT = 2 # the number of seconds to wait for blocking calls
+
+
+@pytest.fixture
+def server_socket(unused_tcp_port_factory):
+ """
+ Returns a listening socket bound to an ephemeral port.
+ """
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("127.0.0.1", unused_tcp_port_factory()))
+ s.listen(1)
+ s.settimeout(BLOCKING_TIMEOUT)
+ yield s
+
+
+class ClientHandshake(threading.Thread):
+ """
+ A thread that connects to a local Postgres server using psycopg2. Once the
+ opening handshake completes, the connection will be immediately closed.
+ """
+
+ def __init__(self, *, port, **kwargs):
+ super().__init__()
+
+ kwargs["port"] = port
+ self._kwargs = kwargs
+
+ self.exception = None
+
+ def run(self):
+ try:
+ conn = psycopg2.connect(host="127.0.0.1", **self._kwargs)
+ conn.close()
+ except Exception as e:
+ self.exception = e
+
+ def check_completed(self, timeout=BLOCKING_TIMEOUT):
+ """
+ Joins the client thread. Raises an exception if the thread could not be
+ joined, or if it threw an exception itself. (The exception will be
+ cleared, so future calls to check_completed will succeed.)
+ """
+ self.join(timeout)
+
+ if self.is_alive():
+ raise TimeoutError("client thread did not handshake within the timeout")
+ elif self.exception:
+ e = self.exception
+ self.exception = None
+ raise e
+
+
+@pytest.fixture
+def accept(server_socket):
+ """
+ Returns a factory function that, when called, returns a pair (sock, client)
+ where sock is a server socket that has accepted a connection from client,
+ and client is an instance of ClientHandshake. Clients will complete their
+ handshakes and cleanly disconnect.
+
+ The default connstring options may be extended or overridden by passing
+ arbitrary keyword arguments. Keep in mind that you generally should not
+ override the host or port, since they point to the local test server.
+
+ For situations where a client needs to connect more than once to complete a
+ handshake, the accept function may be called more than once. (The client
+ returned for subsequent calls will always be the same client that was
+ returned for the first call.)
+
+ Tests must either complete the handshake so that the client thread can be
+ automatically joined during teardown, or else call client.check_completed()
+ and manually handle any expected errors.
+ """
+ _, port = server_socket.getsockname()
+
+ client = None
+ default_opts = dict(
+ port=port,
+ user=pq3.pguser(),
+ sslmode="disable",
+ )
+
+ def factory(**kwargs):
+ nonlocal client
+
+ if client is None:
+ opts = dict(default_opts)
+ opts.update(kwargs)
+
+ # The server_socket is already listening, so the client thread can
+ # be safely started; it'll block on the connection until we accept.
+ client = ClientHandshake(**opts)
+ client.start()
+
+ sock, _ = server_socket.accept()
+ return sock, client
+
+ yield factory
+ client.check_completed()
+
+
+@pytest.fixture
+def conn(accept):
+ """
+ Returns an accepted, wrapped pq3 connection to a psycopg2 client. The socket
+ will be closed when the test finishes, and the client will be checked for a
+ cleanly completed handshake.
+ """
+ sock, client = accept()
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ yield conn
diff --git a/src/test/python/client/test_client.py b/src/test/python/client/test_client.py
new file mode 100644
index 0000000000..c4c946fda4
--- /dev/null
+++ b/src/test/python/client/test_client.py
@@ -0,0 +1,180 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import sys
+
+import psycopg2
+import pytest
+from cryptography.hazmat.primitives import hashes, hmac
+
+import pq3
+
+
+def finish_handshake(conn):
+ """
+ Sends the AuthenticationOK message and the standard opening salvo of server
+ messages, then asserts that the client immediately sends a Terminate message
+ to close the connection cleanly.
+ """
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+ pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.Terminate
+
+
+def test_handshake(conn):
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ finish_handshake(conn)
+
+
+def test_aborted_connection(accept):
+ """
+ Make sure the client correctly reports an early close during handshakes.
+ """
+ sock, client = accept()
+ sock.close()
+
+ expected = "server closed the connection unexpectedly"
+ with pytest.raises(psycopg2.OperationalError, match=expected):
+ client.check_completed()
+
+
+#
+# SCRAM-SHA-256 (see RFC 5802: https://tools.ietf.org/html/rfc5802)
+#
+
+
+@pytest.fixture
+def password():
+ """
+ Returns a password for use by both client and server.
+ """
+ # TODO: parameterize this with passwords that require SASLprep.
+ return "secret"
+
+
+@pytest.fixture
+def pwconn(accept, password):
+ """
+ Like the conn fixture, but uses a password in the connection.
+ """
+ sock, client = accept(password=password)
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ yield conn
+
+
+def sha256(data):
+ """The H(str) function from Section 2.2."""
+ digest = hashes.Hash(hashes.SHA256())
+ digest.update(data)
+ return digest.finalize()
+
+
+def hmac_256(key, data):
+ """The HMAC(key, str) function from Section 2.2."""
+ h = hmac.HMAC(key, hashes.SHA256())
+ h.update(data)
+ return h.finalize()
+
+
+def xor(a, b):
+ """The XOR operation from Section 2.2."""
+ res = bytearray(a)
+ for i, byte in enumerate(b):
+ res[i] ^= byte
+ return bytes(res)
+
+
+def h_i(data, salt, i):
+ """The Hi(str, salt, i) function from Section 2.2."""
+ assert i > 0
+
+ acc = hmac_256(data, salt + b"\x00\x00\x00\x01")
+ last = acc
+ i -= 1
+
+ while i:
+ u = hmac_256(data, last)
+ acc = xor(acc, u)
+
+ last = u
+ i -= 1
+
+ return acc
+
+
+def test_scram(pwconn, password):
+ startup = pq3.recv1(pwconn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ pwconn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"SCRAM-SHA-256", b""],
+ )
+
+ # Get the client-first-message.
+ pkt = pq3.recv1(pwconn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ initial = pq3.SASLInitialResponse.parse(pkt.payload)
+ assert initial.name == b"SCRAM-SHA-256"
+
+ c_bind, authzid, c_name, c_nonce = initial.data.split(b",")
+ assert c_bind == b"n" # no channel bindings on a plaintext connection
+ assert authzid == b"" # we don't support authzid currently
+ assert c_name == b"n=" # libpq doesn't honor the GS2 username
+ assert c_nonce.startswith(b"r=")
+
+ # Send the server-first-message.
+ salt = b"12345"
+ iterations = 2
+
+ s_nonce = c_nonce + b"somenonce"
+ s_salt = b"s=" + base64.b64encode(salt)
+ s_iterations = b"i=%d" % iterations
+
+ msg = b",".join([s_nonce, s_salt, s_iterations])
+ pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=msg)
+
+ # Get the client-final-message.
+ pkt = pq3.recv1(pwconn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ c_bind_final, c_nonce_final, c_proof = pkt.payload.split(b",")
+ assert c_bind_final == b"c=" + base64.b64encode(c_bind + b"," + authzid + b",")
+ assert c_nonce_final == s_nonce
+
+ # Calculate what the client proof should be.
+ salted_password = h_i(password.encode("ascii"), salt, iterations)
+ client_key = hmac_256(salted_password, b"Client Key")
+ stored_key = sha256(client_key)
+
+ auth_message = b",".join(
+ [c_name, c_nonce, s_nonce, s_salt, s_iterations, c_bind_final, c_nonce_final]
+ )
+ client_signature = hmac_256(stored_key, auth_message)
+ client_proof = xor(client_key, client_signature)
+
+ expected = b"p=" + base64.b64encode(client_proof)
+ assert c_proof == expected
+
+ # Send the correct server signature.
+ server_key = hmac_256(salted_password, b"Server Key")
+ server_signature = hmac_256(server_key, auth_message)
+
+ s_verify = b"v=" + base64.b64encode(server_signature)
+ pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal, body=s_verify)
+
+ # Done!
+ finish_handshake(pwconn)
diff --git a/src/test/python/client/test_oauth.py b/src/test/python/client/test_oauth.py
new file mode 100644
index 0000000000..a754a9c0b6
--- /dev/null
+++ b/src/test/python/client/test_oauth.py
@@ -0,0 +1,936 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import http.server
+import json
+import secrets
+import sys
+import threading
+import time
+import urllib.parse
+
+import psycopg2
+import pytest
+
+import pq3
+
+from .conftest import BLOCKING_TIMEOUT
+
+
+def finish_handshake(conn):
+ """
+ Sends the AuthenticationOK message and the standard opening salvo of server
+ messages, then asserts that the client immediately sends a Terminate message
+ to close the connection cleanly.
+ """
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+ pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.Terminate
+
+
+#
+# OAUTHBEARER (see RFC 7628: https://tools.ietf.org/html/rfc7628)
+#
+
+
+def start_oauth_handshake(conn):
+ """
+ Negotiates an OAUTHBEARER SASL challenge. Returns the client's initial
+ response data.
+ """
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn, pq3.types.AuthnRequest, type=pq3.authn.SASL, body=[b"OAUTHBEARER", b""]
+ )
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ initial = pq3.SASLInitialResponse.parse(pkt.payload)
+ assert initial.name == b"OAUTHBEARER"
+
+ return initial.data
+
+
+def get_auth_value(initial):
+ """
+ Finds the auth value (e.g. "Bearer somedata..." in the client's initial SASL
+ response.
+ """
+ kvpairs = initial.split(b"\x01")
+ assert kvpairs[0] == b"n,," # no channel binding or authzid
+ assert kvpairs[2] == b"" # ends with an empty kvpair
+ assert kvpairs[3] == b"" # ...and there's nothing after it
+ assert len(kvpairs) == 4
+
+ key, value = kvpairs[1].split(b"=", 2)
+ assert key == b"auth"
+
+ return value
+
+
+def xtest_oauth_success(conn): # TODO
+ initial = start_oauth_handshake(conn)
+
+ auth = get_auth_value(initial)
+ assert auth.startswith(b"Bearer ")
+
+ # Accept the token. TODO actually validate
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+class OpenIDProvider(threading.Thread):
+ """
+ A thread that runs a mock OpenID provider server.
+ """
+
+ def __init__(self, *, port):
+ super().__init__()
+
+ self.exception = None
+
+ addr = ("", port)
+ self.server = self._Server(addr, self._Handler)
+
+ # TODO: allow HTTPS only, somehow
+ oauth = self._OAuthState()
+ oauth.host = f"localhost:{port}"
+ oauth.issuer = f"http://localhost:{port}"
+
+ # The following endpoints are required to be advertised by providers,
+ # even though our chosen client implementation does not actually make
+ # use of them.
+ oauth.register_endpoint(
+ "authorization_endpoint", "POST", "/authorize", self._authorization_handler
+ )
+ oauth.register_endpoint("jwks_uri", "GET", "/keys", self._jwks_handler)
+
+ self.server.oauth = oauth
+
+ def run(self):
+ try:
+ self.server.serve_forever()
+ except Exception as e:
+ self.exception = e
+
+ def stop(self, timeout=BLOCKING_TIMEOUT):
+ """
+ Shuts down the server and joins its thread. Raises an exception if the
+ thread could not be joined, or if it threw an exception itself. Must
+ only be called once, after start().
+ """
+ self.server.shutdown()
+ self.join(timeout)
+
+ if self.is_alive():
+ raise TimeoutError("client thread did not handshake within the timeout")
+ elif self.exception:
+ e = self.exception
+ raise e
+
+ class _OAuthState(object):
+ def __init__(self):
+ self.endpoint_paths = {}
+ self._endpoints = {}
+
+ def register_endpoint(self, name, method, path, func):
+ if method not in self._endpoints:
+ self._endpoints[method] = {}
+
+ self._endpoints[method][path] = func
+ self.endpoint_paths[name] = path
+
+ def endpoint(self, method, path):
+ if method not in self._endpoints:
+ return None
+
+ return self._endpoints[method].get(path)
+
+ class _Server(http.server.HTTPServer):
+ def handle_error(self, request, addr):
+ self.shutdown_request(request)
+ raise
+
+ @staticmethod
+ def _jwks_handler(headers, params):
+ return 200, {"keys": []}
+
+ @staticmethod
+ def _authorization_handler(headers, params):
+ # We don't actually want this to be called during these tests -- we
+ # should be using the device authorization endpoint instead.
+ assert (
+ False
+ ), "authorization handler called instead of device authorization handler"
+
+ class _Handler(http.server.BaseHTTPRequestHandler):
+ timeout = BLOCKING_TIMEOUT
+
+ def _discovery_handler(self, headers, params):
+ oauth = self.server.oauth
+
+ doc = {
+ "issuer": oauth.issuer,
+ "response_types_supported": ["token"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+
+ for name, path in oauth.endpoint_paths.items():
+ doc[name] = oauth.issuer + path
+
+ return 200, doc
+
+ def _handle(self, *, params=None, handler=None):
+ oauth = self.server.oauth
+ assert self.headers["Host"] == oauth.host
+
+ if handler is None:
+ handler = oauth.endpoint(self.command, self.path)
+ assert (
+ handler is not None
+ ), f"no registered endpoint for {self.command} {self.path}"
+
+ code, resp = handler(self.headers, params)
+
+ self.send_response(code)
+ self.send_header("Content-Type", "application/json")
+ self.end_headers()
+
+ resp = json.dumps(resp)
+ resp = resp.encode("utf-8")
+ self.wfile.write(resp)
+
+ self.close_connection = True
+
+ def do_GET(self):
+ if self.path == "/.well-known/openid-configuration":
+ self._handle(handler=self._discovery_handler)
+ return
+
+ self._handle()
+
+ def _request_body(self):
+ length = self.headers["Content-Length"]
+
+ # Handle only an explicit content-length.
+ assert length is not None
+ length = int(length)
+
+ return self.rfile.read(length).decode("utf-8")
+
+ def do_POST(self):
+ assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
+
+ body = self._request_body()
+ params = urllib.parse.parse_qs(body)
+
+ self._handle(params=params)
+
+
+@pytest.fixture
+def openid_provider(unused_tcp_port_factory):
+ """
+ A fixture that returns the OAuth state of a running OpenID provider server. The
+ server will be stopped when the fixture is torn down.
+ """
+ thread = OpenIDProvider(port=unused_tcp_port_factory())
+ thread.start()
+
+ try:
+ yield thread.server.oauth
+ finally:
+ thread.stop()
+
+
+@pytest.mark.parametrize("secret", [None, "", "hunter2"])
+@pytest.mark.parametrize("scope", [None, "", "openid email"])
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_with_explicit_issuer(
+ capfd, accept, openid_provider, retries, scope, secret
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ oauth_client_secret=secret,
+ oauth_scope=scope,
+ )
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+ verification_url = "https://example.com/device"
+
+ access_token = secrets.token_urlsafe()
+
+ def check_client_authn(headers, params):
+ if not secret:
+ assert params["client_id"] == [client_id]
+ return
+
+ # Require the client to use Basic authn; request-body credentials are
+ # NOT RECOMMENDED (RFC 6749, Sec. 2.3.1).
+ assert "Authorization" in headers
+
+ method, creds = headers["Authorization"].split()
+ assert method == "Basic"
+
+ expected = f"{client_id}:{secret}"
+ assert base64.b64decode(creds) == expected.encode("ascii")
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ check_client_authn(headers, params)
+
+ if scope:
+ assert params["scope"] == [scope]
+ else:
+ assert "scope" not in params
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": verification_url,
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ attempts = 0
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ check_client_authn(headers, params)
+
+ assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+ assert params["device_code"] == [device_code]
+
+ now = time.monotonic()
+
+ with retry_lock:
+ nonlocal attempts
+
+ # If the test wants to force the client to retry, return an
+ # authorization_pending response and decrement the retry count.
+ if attempts < retries:
+ attempts += 1
+ return 400, {"error": "authorization_pending"}
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+ if retries:
+ # Finally, make sure that the client prompted the user with the expected
+ # authorization URL and user code.
+ expected = f"Visit {verification_url} and enter the code: {user_code}"
+ _, stderr = capfd.readouterr()
+ assert expected in stderr
+
+
+def test_oauth_requires_client_id(accept, openid_provider):
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ # Do not set a client ID; this should cause a client error after the
+ # server asks for OAUTHBEARER and the client tries to contact the
+ # issuer.
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should disconnect at this point.
+ assert not conn.read()
+
+ expected_error = "no oauth_client_id is set"
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("error_code", ["authorization_pending", "slow_down"])
+@pytest.mark.parametrize("retries", [1, 2])
+def test_oauth_retry_interval(accept, openid_provider, retries, error_code):
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id="some-id",
+ )
+
+ expected_retry_interval = 1
+ access_token = secrets.token_urlsafe()
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ resp = {
+ "device_code": "my-device-code",
+ "user_code": "my-user-code",
+ "interval": expected_retry_interval,
+ "verification_uri": "https://example.com",
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ attempts = 0
+ last_retry = None
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ now = time.monotonic()
+
+ with retry_lock:
+ nonlocal attempts, last_retry, expected_retry_interval
+
+ # Make sure the retry interval is being respected by the client.
+ if last_retry is not None:
+ interval = now - last_retry
+ assert interval >= expected_retry_interval
+
+ last_retry = now
+
+ # If the test wants to force the client to retry, return the desired
+ # error response and decrement the retry count.
+ if attempts < retries:
+ attempts += 1
+
+ # A slow_down code requires the client to additionally increase
+ # its interval by five seconds.
+ if error_code == "slow_down":
+ expected_retry_interval += 5
+
+ return 400, {"error": error_code}
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+ "failure_mode, error_pattern",
+ [
+ pytest.param(
+ {
+ "error": "invalid_client",
+ "error_description": "client authentication failed",
+ },
+ r"client authentication failed \(invalid_client\)",
+ id="authentication failure with description",
+ ),
+ pytest.param(
+ {"error": "invalid_request"},
+ r"\(invalid_request\)",
+ id="invalid request without description",
+ ),
+ pytest.param(
+ {},
+ r"failed to obtain device authorization",
+ id="broken error response",
+ ),
+ ],
+)
+def test_oauth_device_authorization_failures(
+ accept, openid_provider, failure_mode, error_pattern
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ )
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ return 400, failure_mode
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ def token_endpoint(headers, params):
+ assert False, "token endpoint was invoked unexpectedly"
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should not continue the connection due to the hardcoded
+ # provider failure; we disconnect here.
+
+ # Now make sure the client correctly failed.
+ with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+ client.check_completed()
+
+
+@pytest.mark.parametrize(
+ "failure_mode, error_pattern",
+ [
+ pytest.param(
+ {
+ "error": "expired_token",
+ "error_description": "the device code has expired",
+ },
+ r"the device code has expired \(expired_token\)",
+ id="expired token with description",
+ ),
+ pytest.param(
+ {"error": "access_denied"},
+ r"\(access_denied\)",
+ id="access denied without description",
+ ),
+ pytest.param(
+ {},
+ r"OAuth token retrieval failed",
+ id="broken error response",
+ ),
+ ],
+)
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_token_failures(
+ accept, openid_provider, retries, failure_mode, error_pattern
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ )
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ assert params["client_id"] == [client_id]
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": "https://example.com/device",
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ with retry_lock:
+ nonlocal retries
+
+ # If the test wants to force the client to retry, return an
+ # authorization_pending response and decrement the retry count.
+ if retries > 0:
+ retries -= 1
+ return 400, {"error": "authorization_pending"}
+
+ return 400, failure_mode
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should not continue the connection due to the hardcoded
+ # provider failure; we disconnect here.
+
+ # Now make sure the client correctly failed.
+ with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+ client.check_completed()
+
+
+@pytest.mark.parametrize("scope", [None, "openid email"])
+@pytest.mark.parametrize(
+ "base_response",
+ [
+ {"status": "invalid_token"},
+ {"extra_object": {"key": "value"}, "status": "invalid_token"},
+ {"extra_object": {"status": 1}, "status": "invalid_token"},
+ ],
+)
+def test_oauth_discovery(accept, openid_provider, base_response, scope):
+ sock, client = accept(oauth_client_id=secrets.token_hex())
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+ verification_url = "https://example.com/device"
+
+ access_token = secrets.token_urlsafe()
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ if scope:
+ assert params["scope"] == [scope]
+ else:
+ assert "scope" not in params
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": verification_url,
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ def token_endpoint(headers, params):
+ assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+ assert params["device_code"] == [device_code]
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # For discovery, the client should send an empty auth header. See
+ # RFC 7628, Sec. 4.3.
+ auth = get_auth_value(initial)
+ assert auth == b""
+
+ # We will fail the first SASL exchange. First return a link to the
+ # discovery document, pointing to the test provider server.
+ resp = dict(base_response)
+
+ discovery_uri = f"{openid_provider.issuer}/.well-known/openid-configuration"
+ resp["openid-configuration"] = discovery_uri
+
+ if scope:
+ resp["scope"] = scope
+
+ resp = json.dumps(resp)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASLContinue,
+ body=resp.encode("ascii"),
+ )
+
+ # Per RFC, the client is required to send a dummy ^A response.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+ assert pkt.payload == b"\x01"
+
+ # Now fail the SASL exchange.
+ pq3.send(
+ conn,
+ pq3.types.ErrorResponse,
+ fields=[
+ b"SFATAL",
+ b"C28000",
+ b"Mdoesn't matter",
+ b"",
+ ],
+ )
+
+ # The client will connect to us a second time, using the parameters we sent
+ # it.
+ sock, _ = accept()
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+ "response,expected_error",
+ [
+ pytest.param(
+ "abcde",
+ 'Token "abcde" is invalid',
+ id="bad JSON: invalid syntax",
+ ),
+ pytest.param(
+ '"abcde"',
+ "top-level element must be an object",
+ id="bad JSON: top-level element is a string",
+ ),
+ pytest.param(
+ "[]",
+ "top-level element must be an object",
+ id="bad JSON: top-level element is an array",
+ ),
+ pytest.param(
+ "{}",
+ "server sent error response without a status",
+ id="bad JSON: no status member",
+ ),
+ pytest.param(
+ '{ "status": null }',
+ 'field "status" must be a string',
+ id="bad JSON: null status member",
+ ),
+ pytest.param(
+ '{ "status": 0 }',
+ 'field "status" must be a string',
+ id="bad JSON: int status member",
+ ),
+ pytest.param(
+ '{ "status": [ "bad" ] }',
+ 'field "status" must be a string',
+ id="bad JSON: array status member",
+ ),
+ pytest.param(
+ '{ "status": { "bad": "bad" } }',
+ 'field "status" must be a string',
+ id="bad JSON: object status member",
+ ),
+ pytest.param(
+ '{ "nested": { "status": "bad" } }',
+ "server sent error response without a status",
+ id="bad JSON: nested status",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token" ',
+ "The input string ended unexpectedly",
+ id="bad JSON: unterminated object",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token" } { }',
+ 'Expected end of input, but found "{"',
+ id="bad JSON: trailing data",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "openid-configuration": 1 }',
+ 'field "openid-configuration" must be a string',
+ id="bad JSON: int openid-configuration member",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "openid-configuration": 1 }',
+ 'field "openid-configuration" must be a string',
+ id="bad JSON: int openid-configuration member",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "scope": 1 }',
+ 'field "scope" must be a string',
+ id="bad JSON: int scope member",
+ ),
+ ],
+)
+def test_oauth_discovery_server_error(accept, response, expected_error):
+ sock, client = accept(oauth_client_id=secrets.token_hex())
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # Fail the SASL exchange with an invalid JSON response.
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASLContinue,
+ body=response.encode("utf-8"),
+ )
+
+ # The client should disconnect, so the socket is closed here. (If
+ # the client doesn't disconnect, it will report a different error
+ # below and the test will fail.)
+
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
+
+
+@pytest.mark.parametrize(
+ "sasl_err,resp_type,resp_payload,expected_error",
+ [
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.ErrorResponse,
+ dict(
+ fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+ ),
+ "expected error message",
+ id="standard server error: invalid_request",
+ ),
+ pytest.param(
+ {"status": "invalid_token"},
+ pq3.types.ErrorResponse,
+ dict(
+ fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+ ),
+ "expected error message",
+ id="standard server error: invalid_token without discovery URI",
+ ),
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.AuthnRequest,
+ dict(type=pq3.authn.SASLContinue, body=b""),
+ "server sent additional OAuth data",
+ id="broken server: additional challenge after error",
+ ),
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.AuthnRequest,
+ dict(type=pq3.authn.SASLFinal),
+ "server sent additional OAuth data",
+ id="broken server: SASL success after error",
+ ),
+ ],
+)
+def test_oauth_server_error(accept, sasl_err, resp_type, resp_payload, expected_error):
+ sock, client = accept()
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ start_oauth_handshake(conn)
+
+ # Ignore the client data. Return an error "challenge".
+ resp = json.dumps(sasl_err)
+ resp = resp.encode("utf-8")
+
+ pq3.send(
+ conn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=resp
+ )
+
+ # Per RFC, the client is required to send a dummy ^A response.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+ assert pkt.payload == b"\x01"
+
+ # Now fail the SASL exchange (in either a valid way, or an invalid
+ # one, depending on the test).
+ pq3.send(conn, resp_type, **resp_payload)
+
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
diff --git a/src/test/python/pq3.py b/src/test/python/pq3.py
new file mode 100644
index 0000000000..3a22dad0b6
--- /dev/null
+++ b/src/test/python/pq3.py
@@ -0,0 +1,727 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import os
+import ssl
+import sys
+import textwrap
+
+from construct import *
+
+import tls
+
+
+def protocol(major, minor):
+ """
+ Returns the protocol version, in integer format, corresponding to the given
+ major and minor version numbers.
+ """
+ return (major << 16) | minor
+
+
+# Startup
+
+StringList = GreedyRange(NullTerminated(GreedyBytes))
+
+
+class KeyValueAdapter(Adapter):
+ """
+ Turns a key-value store into a null-terminated list of null-terminated
+ strings, as presented on the wire in the startup packet.
+ """
+
+ def _encode(self, obj, context, path):
+ if isinstance(obj, list):
+ return obj
+
+ l = []
+
+ for k, v in obj.items():
+ if isinstance(k, str):
+ k = k.encode("utf-8")
+ l.append(k)
+
+ if isinstance(v, str):
+ v = v.encode("utf-8")
+ l.append(v)
+
+ l.append(b"")
+ return l
+
+ def _decode(self, obj, context, path):
+ # TODO: turn a list back into a dict
+ return obj
+
+
+KeyValues = KeyValueAdapter(StringList)
+
+_startup_payload = Switch(
+ this.proto,
+ {
+ protocol(3, 0): KeyValues,
+ },
+ default=GreedyBytes,
+)
+
+
+def _default_protocol(this):
+ try:
+ if isinstance(this.payload, (list, dict)):
+ return protocol(3, 0)
+ except AttributeError:
+ pass # no payload passed during build
+
+ return 0
+
+
+def _startup_payload_len(this):
+ """
+ The payload field has a fixed size based on the length of the packet. But
+ if the caller hasn't supplied an explicit length at build time, we have to
+ build the payload to figure out how long it is, which requires us to know
+ the length first... This function exists solely to break the cycle.
+ """
+ assert this._building, "_startup_payload_len() cannot be called during parsing"
+
+ try:
+ payload = this.payload
+ except AttributeError:
+ return 0 # no payload
+
+ if isinstance(payload, bytes):
+ # already serialized; just use the given length
+ return len(payload)
+
+ try:
+ proto = this.proto
+ except AttributeError:
+ proto = _default_protocol(this)
+
+ data = _startup_payload.build(payload, proto=proto)
+ return len(data)
+
+
+Startup = Struct(
+ "len" / Default(Int32sb, lambda this: _startup_payload_len(this) + 8),
+ "proto" / Default(Hex(Int32sb), _default_protocol),
+ "payload" / FixedSized(this.len - 8, Default(_startup_payload, b"")),
+)
+
+# Pq3
+
+# Adapted from construct.core.EnumIntegerString
+class EnumNamedByte:
+ def __init__(self, val, name):
+ self._val = val
+ self._name = name
+
+ def __int__(self):
+ return ord(self._val)
+
+ def __str__(self):
+ return "(enum) %s %r" % (self._name, self._val)
+
+ def __repr__(self):
+ return "EnumNamedByte(%r)" % self._val
+
+ def __eq__(self, other):
+ if isinstance(other, EnumNamedByte):
+ other = other._val
+ if not isinstance(other, bytes):
+ return NotImplemented
+
+ return self._val == other
+
+ def __hash__(self):
+ return hash(self._val)
+
+
+# Adapted from construct.core.Enum
+class ByteEnum(Adapter):
+ def __init__(self, **mapping):
+ super(ByteEnum, self).__init__(Byte)
+ self.namemapping = {k: EnumNamedByte(v, k) for k, v in mapping.items()}
+ self.decmapping = {v: EnumNamedByte(v, k) for k, v in mapping.items()}
+
+ def __getattr__(self, name):
+ if name in self.namemapping:
+ return self.decmapping[self.namemapping[name]]
+ raise AttributeError
+
+ def _decode(self, obj, context, path):
+ b = bytes([obj])
+ try:
+ return self.decmapping[b]
+ except KeyError:
+ return EnumNamedByte(b, "(unknown)")
+
+ def _encode(self, obj, context, path):
+ if isinstance(obj, int):
+ return obj
+ elif isinstance(obj, bytes):
+ return ord(obj)
+ return int(obj)
+
+
+types = ByteEnum(
+ ErrorResponse=b"E",
+ ReadyForQuery=b"Z",
+ Query=b"Q",
+ EmptyQueryResponse=b"I",
+ AuthnRequest=b"R",
+ PasswordMessage=b"p",
+ BackendKeyData=b"K",
+ CommandComplete=b"C",
+ ParameterStatus=b"S",
+ DataRow=b"D",
+ Terminate=b"X",
+)
+
+
+authn = Enum(
+ Int32ub,
+ OK=0,
+ SASL=10,
+ SASLContinue=11,
+ SASLFinal=12,
+)
+
+
+_authn_body = Switch(
+ this.type,
+ {
+ authn.OK: Terminated,
+ authn.SASL: StringList,
+ },
+ default=GreedyBytes,
+)
+
+
+def _data_len(this):
+ assert this._building, "_data_len() cannot be called during parsing"
+
+ if not hasattr(this, "data") or this.data is None:
+ return -1
+
+ return len(this.data)
+
+
+# The protocol reuses the PasswordMessage for several authentication response
+# types, and there's no good way to figure out which is which without keeping
+# state for the entire stream. So this is a separate Construct that can be
+# explicitly parsed/built by code that knows it's needed.
+SASLInitialResponse = Struct(
+ "name" / NullTerminated(GreedyBytes),
+ "len" / Default(Int32sb, lambda this: _data_len(this)),
+ "data"
+ / IfThenElse(
+ # Allow tests to explicitly pass an incorrect length during testing, by
+ # not enforcing a FixedSized during build. (The len calculation above
+ # defaults to the correct size.)
+ this._building,
+ Optional(GreedyBytes),
+ If(this.len != -1, Default(FixedSized(this.len, GreedyBytes), b"")),
+ ),
+ Terminated, # make sure the entire response is consumed
+)
+
+
+_column = FocusedSeq(
+ "data",
+ "len" / Default(Int32sb, lambda this: _data_len(this)),
+ "data" / If(this.len != -1, FixedSized(this.len, GreedyBytes)),
+)
+
+
+_payload_map = {
+ types.ErrorResponse: Struct("fields" / StringList),
+ types.ReadyForQuery: Struct("status" / Bytes(1)),
+ types.Query: Struct("query" / NullTerminated(GreedyBytes)),
+ types.EmptyQueryResponse: Terminated,
+ types.AuthnRequest: Struct("type" / authn, "body" / Default(_authn_body, b"")),
+ types.BackendKeyData: Struct("pid" / Int32ub, "key" / Hex(Int32ub)),
+ types.CommandComplete: Struct("tag" / NullTerminated(GreedyBytes)),
+ types.ParameterStatus: Struct(
+ "name" / NullTerminated(GreedyBytes), "value" / NullTerminated(GreedyBytes)
+ ),
+ types.DataRow: Struct("columns" / Default(PrefixedArray(Int16sb, _column), b"")),
+ types.Terminate: Terminated,
+}
+
+
+_payload = FocusedSeq(
+ "_payload",
+ "_payload"
+ / Switch(
+ this._.type,
+ _payload_map,
+ default=GreedyBytes,
+ ),
+ Terminated, # make sure every payload consumes the entire packet
+)
+
+
+def _payload_len(this):
+ """
+ See _startup_payload_len() for an explanation.
+ """
+ assert this._building, "_payload_len() cannot be called during parsing"
+
+ try:
+ payload = this.payload
+ except AttributeError:
+ return 0 # no payload
+
+ if isinstance(payload, bytes):
+ # already serialized; just use the given length
+ return len(payload)
+
+ data = _payload.build(payload, type=this.type)
+ return len(data)
+
+
+Pq3 = Struct(
+ "type" / types,
+ "len" / Default(Int32ub, lambda this: _payload_len(this) + 4),
+ "payload" / FixedSized(this.len - 4, Default(_payload, b"")),
+)
+
+
+# Environment
+
+
+def pghost():
+ return os.environ.get("PGHOST", default="localhost")
+
+
+def pgport():
+ return int(os.environ.get("PGPORT", default=5432))
+
+
+def pguser():
+ try:
+ return os.environ["PGUSER"]
+ except KeyError:
+ return getpass.getuser()
+
+
+def pgdatabase():
+ return os.environ.get("PGDATABASE", default="postgres")
+
+
+# Connections
+
+
+def _hexdump_translation_map():
+ """
+ For hexdumps. Translates any unprintable or non-ASCII bytes into '.'.
+ """
+ input = bytearray()
+
+ for i in range(128):
+ c = chr(i)
+
+ if not c.isprintable():
+ input += bytes([i])
+
+ input += bytes(range(128, 256))
+
+ return bytes.maketrans(input, b"." * len(input))
+
+
+class _DebugStream(object):
+ """
+ Wraps a file-like object and adds hexdumps of the read and write data. Call
+ end_packet() on a _DebugStream to write the accumulated hexdumps to the
+ output stream, along with the packet that was sent.
+ """
+
+ _translation_map = _hexdump_translation_map()
+
+ def __init__(self, stream, out=sys.stdout):
+ """
+ Creates a new _DebugStream wrapping the given stream (which must have
+ been created by wrap()). All attributes not provided by the _DebugStream
+ are delegated to the wrapped stream. out is the text stream to which
+ hexdumps are written.
+ """
+ self.raw = stream
+ self._out = out
+ self._rbuf = io.BytesIO()
+ self._wbuf = io.BytesIO()
+
+ def __getattr__(self, name):
+ return getattr(self.raw, name)
+
+ def __setattr__(self, name, value):
+ if name in ("raw", "_out", "_rbuf", "_wbuf"):
+ return object.__setattr__(self, name, value)
+
+ setattr(self.raw, name, value)
+
+ def read(self, *args, **kwargs):
+ buf = self.raw.read(*args, **kwargs)
+
+ self._rbuf.write(buf)
+ return buf
+
+ def write(self, b):
+ self._wbuf.write(b)
+ return self.raw.write(b)
+
+ def recv(self, *args):
+ buf = self.raw.recv(*args)
+
+ self._rbuf.write(buf)
+ return buf
+
+ def _flush(self, buf, prefix):
+ width = 16
+ hexwidth = width * 3 - 1
+
+ count = 0
+ buf.seek(0)
+
+ while True:
+ line = buf.read(16)
+
+ if not line:
+ if count:
+ self._out.write("\n") # separate the output block with a newline
+ return
+
+ self._out.write("%s %04X:\t" % (prefix, count))
+ self._out.write("%*s\t" % (-hexwidth, line.hex(" ")))
+ self._out.write(line.translate(self._translation_map).decode("ascii"))
+ self._out.write("\n")
+
+ count += 16
+
+ def print_debug(self, obj, *, prefix=""):
+ contents = ""
+ if obj is not None:
+ contents = str(obj)
+
+ for line in contents.splitlines():
+ self._out.write("%s%s\n" % (prefix, line))
+
+ self._out.write("\n")
+
+ def flush_debug(self, *, prefix=""):
+ self._flush(self._rbuf, prefix + "<")
+ self._rbuf = io.BytesIO()
+
+ self._flush(self._wbuf, prefix + ">")
+ self._wbuf = io.BytesIO()
+
+ def end_packet(self, pkt, *, read=False, prefix="", indent=" "):
+ """
+ Marks the end of a logical "packet" of data. A string representation of
+ pkt will be printed, and the debug buffers will be flushed with an
+ indent. All lines can be optionally prefixed.
+
+ If read is True, the packet representation is written after the debug
+ buffers; otherwise the default of False (meaning write) causes the
+ packet representation to be dumped first. This is meant to capture the
+ logical flow of layer translation.
+ """
+ write = not read
+
+ if write:
+ self.print_debug(pkt, prefix=prefix + "> ")
+
+ self.flush_debug(prefix=prefix + indent)
+
+ if read:
+ self.print_debug(pkt, prefix=prefix + "< ")
+
+
+@contextlib.contextmanager
+def wrap(socket, *, debug_stream=None):
+ """
+ Transforms a raw socket into a connection that can be used for Construct
+ building and parsing. The return value is a context manager and can be used
+ in a with statement.
+ """
+ # It is critical that buffering be disabled here, so that we can still
+ # manipulate the raw socket without desyncing the stream.
+ with socket.makefile("rwb", buffering=0) as sfile:
+ # Expose the original socket's recv() on the SocketIO object we return.
+ def recv(self, *args):
+ return socket.recv(*args)
+
+ sfile.recv = recv.__get__(sfile)
+
+ conn = sfile
+ if debug_stream:
+ conn = _DebugStream(conn, debug_stream)
+
+ try:
+ yield conn
+ finally:
+ if debug_stream:
+ conn.flush_debug(prefix="? ")
+
+
+def _send(stream, cls, obj):
+ debugging = hasattr(stream, "flush_debug")
+ out = io.BytesIO()
+
+ # Ideally we would build directly to the passed stream, but because we need
+ # to reparse the generated output for the debugging case, build to an
+ # intermediate BytesIO and send it instead.
+ cls.build_stream(obj, out)
+ buf = out.getvalue()
+
+ stream.write(buf)
+ if debugging:
+ pkt = cls.parse(buf)
+ stream.end_packet(pkt)
+
+ stream.flush()
+
+
+def send(stream, packet_type, payload_data=None, **payloadkw):
+ """
+ Sends a packet on the given pq3 connection. type is the pq3.types member
+ that should be assigned to the packet. If payload_data is given, it will be
+ used as the packet payload; otherwise the key/value pairs in payloadkw will
+ be the payload contents.
+ """
+ data = payloadkw
+
+ if payload_data is not None:
+ if payloadkw:
+ raise ValueError(
+ "payload_data and payload keywords may not be used simultaneously"
+ )
+
+ data = payload_data
+
+ _send(stream, Pq3, dict(type=packet_type, payload=data))
+
+
+def send_startup(stream, proto=None, **kwargs):
+ """
+ Sends a startup packet on the given pq3 connection. In most cases you should
+ use the handshake functions instead, which will do this for you.
+
+ By default, a protocol version 3 packet will be sent. This can be overridden
+ with the proto parameter.
+ """
+ pkt = {}
+
+ if proto is not None:
+ pkt["proto"] = proto
+ if kwargs:
+ pkt["payload"] = kwargs
+
+ _send(stream, Startup, pkt)
+
+
+def recv1(stream, *, cls=Pq3):
+ """
+ Receives a single pq3 packet from the given stream and returns it.
+ """
+ resp = cls.parse_stream(stream)
+
+ debugging = hasattr(stream, "flush_debug")
+ if debugging:
+ stream.end_packet(resp, read=True)
+
+ return resp
+
+
+def handshake(stream, **kwargs):
+ """
+ Performs a libpq v3 startup handshake. kwargs should contain the key/value
+ parameters to send to the server in the startup packet.
+ """
+ # Send our startup parameters.
+ send_startup(stream, **kwargs)
+
+ # Receive and dump packets until the server indicates it's ready for our
+ # first query.
+ while True:
+ resp = recv1(stream)
+ if resp is None:
+ raise RuntimeError("server closed connection during handshake")
+
+ if resp.type == types.ReadyForQuery:
+ return
+ elif resp.type == types.ErrorResponse:
+ raise RuntimeError(
+ f"received error response from peer: {resp.payload.fields!r}"
+ )
+
+
+# TLS
+
+
+class _TLSStream(object):
+ """
+ A file-like object that performs TLS encryption/decryption on a wrapped
+ stream. Differs from ssl.SSLSocket in that we have full visibility and
+ control over the TLS layer.
+ """
+
+ def __init__(self, stream, context):
+ self._stream = stream
+ self._debugging = hasattr(stream, "flush_debug")
+
+ self._in = ssl.MemoryBIO()
+ self._out = ssl.MemoryBIO()
+ self._ssl = context.wrap_bio(self._in, self._out)
+
+ def handshake(self):
+ try:
+ self._pump(lambda: self._ssl.do_handshake())
+ finally:
+ self._flush_debug(prefix="? ")
+
+ def read(self, *args):
+ return self._pump(lambda: self._ssl.read(*args))
+
+ def write(self, *args):
+ return self._pump(lambda: self._ssl.write(*args))
+
+ def _decode(self, buf):
+ """
+ Attempts to decode a buffer of TLS data into a packet representation
+ that can be printed.
+
+ TODO: handle buffers (and record fragments) that don't align with packet
+ boundaries.
+ """
+ end = len(buf)
+ bio = io.BytesIO(buf)
+
+ ret = io.StringIO()
+
+ while bio.tell() < end:
+ record = tls.Plaintext.parse_stream(bio)
+
+ if ret.tell() > 0:
+ ret.write("\n")
+ ret.write("[Record] ")
+ ret.write(str(record))
+ ret.write("\n")
+
+ if record.type == tls.ContentType.handshake:
+ record_cls = tls.Handshake
+ else:
+ continue
+
+ innerlen = len(record.fragment)
+ inner = io.BytesIO(record.fragment)
+
+ while inner.tell() < innerlen:
+ msg = record_cls.parse_stream(inner)
+
+ indented = "[Message] " + str(msg)
+ indented = textwrap.indent(indented, " ")
+
+ ret.write("\n")
+ ret.write(indented)
+ ret.write("\n")
+
+ return ret.getvalue()
+
+ def flush(self):
+ if not self._out.pending:
+ self._stream.flush()
+ return
+
+ buf = self._out.read()
+ self._stream.write(buf)
+
+ if self._debugging:
+ pkt = self._decode(buf)
+ self._stream.end_packet(pkt, prefix=" ")
+
+ self._stream.flush()
+
+ def _pump(self, operation):
+ while True:
+ try:
+ return operation()
+ except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
+ want = e
+ self._read_write(want)
+
+ def _recv(self, maxsize):
+ buf = self._stream.recv(4096)
+ if not buf:
+ self._in.write_eof()
+ return
+
+ self._in.write(buf)
+
+ if not self._debugging:
+ return
+
+ pkt = self._decode(buf)
+ self._stream.end_packet(pkt, read=True, prefix=" ")
+
+ def _read_write(self, want):
+ # XXX This needs work. So many corner cases yet to handle. For one,
+ # doing blocking writes in flush may lead to distributed deadlock if the
+ # peer is already blocking on its writes.
+
+ if isinstance(want, ssl.SSLWantWriteError):
+ assert self._out.pending, "SSL backend wants write without data"
+
+ self.flush()
+
+ if isinstance(want, ssl.SSLWantReadError):
+ self._recv(4096)
+
+ def _flush_debug(self, prefix):
+ if not self._debugging:
+ return
+
+ self._stream.flush_debug(prefix=prefix)
+
+
+@contextlib.contextmanager
+def tls_handshake(stream, context):
+ """
+ Performs a TLS handshake over the given stream (which must have been created
+ via a call to wrap()), and returns a new stream which transparently tunnels
+ data over the TLS connection.
+
+ If the passed stream has debugging enabled, the returned stream will also
+ have debugging, using the same output IO.
+ """
+ debugging = hasattr(stream, "flush_debug")
+
+ # Send our startup parameters.
+ send_startup(stream, proto=protocol(1234, 5679))
+
+ # Look at the SSL response.
+ resp = stream.read(1)
+ if debugging:
+ stream.flush_debug(prefix=" ")
+
+ if resp == b"N":
+ raise RuntimeError("server does not support SSLRequest")
+ if resp != b"S":
+ raise RuntimeError(f"unexpected response of type {resp!r} during TLS startup")
+
+ tls = _TLSStream(stream, context)
+ tls.handshake()
+
+ if debugging:
+ tls = _DebugStream(tls, stream._out)
+
+ try:
+ yield tls
+ # TODO: teardown/unwrap the connection?
+ finally:
+ if debugging:
+ tls.flush_debug(prefix="? ")
diff --git a/src/test/python/pytest.ini b/src/test/python/pytest.ini
new file mode 100644
index 0000000000..ab7a6e7fb9
--- /dev/null
+++ b/src/test/python/pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+
+markers =
+ slow: mark test as slow
diff --git a/src/test/python/requirements.txt b/src/test/python/requirements.txt
new file mode 100644
index 0000000000..32f105ea84
--- /dev/null
+++ b/src/test/python/requirements.txt
@@ -0,0 +1,7 @@
+black
+cryptography~=3.4.6
+construct~=2.10.61
+isort~=5.6
+psycopg2~=2.8.6
+pytest~=6.1
+pytest-asyncio~=0.14.0
diff --git a/src/test/python/server/__init__.py b/src/test/python/server/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/server/conftest.py b/src/test/python/server/conftest.py
new file mode 100644
index 0000000000..ba7342a453
--- /dev/null
+++ b/src/test/python/server/conftest.py
@@ -0,0 +1,45 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import socket
+import sys
+
+import pytest
+
+import pq3
+
+
+@pytest.fixture
+def connect():
+ """
+ A factory fixture that, when called, returns a socket connected to a
+ Postgres server, wrapped in a pq3 connection. The calling test will be
+ skipped automatically if a server is not running at PGHOST:PGPORT, so it's
+ best to connect as soon as possible after the test case begins, to avoid
+ doing unnecessary work.
+ """
+ # Set up an ExitStack to handle safe cleanup of all of the moving pieces.
+ with contextlib.ExitStack() as stack:
+
+ def conn_factory():
+ addr = (pq3.pghost(), pq3.pgport())
+
+ try:
+ sock = socket.create_connection(addr, timeout=2)
+ except ConnectionError as e:
+ pytest.skip(f"unable to connect to {addr}: {e}")
+
+ # Have ExitStack close our socket.
+ stack.enter_context(sock)
+
+ # Wrap the connection in a pq3 layer and have ExitStack clean it up
+ # too.
+ wrap_ctx = pq3.wrap(sock, debug_stream=sys.stdout)
+ conn = stack.enter_context(wrap_ctx)
+
+ return conn
+
+ yield conn_factory
diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py
new file mode 100644
index 0000000000..cb5ca7fa23
--- /dev/null
+++ b/src/test/python/server/test_oauth.py
@@ -0,0 +1,1012 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import contextlib
+import json
+import os
+import pathlib
+import secrets
+import shlex
+import shutil
+import socket
+import struct
+from multiprocessing import shared_memory
+
+import psycopg2
+import pytest
+from psycopg2 import sql
+
+import pq3
+
+MAX_SASL_MESSAGE_LENGTH = 65535
+
+INVALID_AUTHORIZATION_ERRCODE = b"28000"
+PROTOCOL_VIOLATION_ERRCODE = b"08P01"
+FEATURE_NOT_SUPPORTED_ERRCODE = b"0A000"
+
+SHARED_MEM_NAME = "oauth-pytest"
+MAX_TOKEN_SIZE = 4096
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def skip_if_no_postgres():
+ """
+ Used by the oauth_ctx fixture to skip this test module if no Postgres server
+ is running.
+
+ This logic is nearly duplicated with the conn fixture. Ideally oauth_ctx
+ would depend on that, but a module-scope fixture can't depend on a
+ test-scope fixture, and we haven't reached the rule of three yet.
+ """
+ addr = (pq3.pghost(), pq3.pgport())
+
+ try:
+ with socket.create_connection(addr, timeout=2):
+ pass
+ except ConnectionError as e:
+ pytest.skip(f"unable to connect to {addr}: {e}")
+
+
+@contextlib.contextmanager
+def prepend_file(path, lines):
+ """
+ A context manager that prepends a file on disk with the desired lines of
+ text. When the context manager is exited, the file will be restored to its
+ original contents.
+ """
+ # First make a backup of the original file.
+ bak = path + ".bak"
+ shutil.copy2(path, bak)
+
+ try:
+ # Write the new lines, followed by the original file content.
+ with open(path, "w") as new, open(bak, "r") as orig:
+ new.writelines(lines)
+ shutil.copyfileobj(orig, new)
+
+ # Return control to the calling code.
+ yield
+
+ finally:
+ # Put the backup back into place.
+ os.replace(bak, path)
+
+
+@pytest.fixture(scope="module")
+def oauth_ctx():
+ """
+ Creates a database and user that use the oauth auth method. The context
+ object contains the dbname and user attributes as strings to be used during
+ connection, as well as the issuer and scope that have been set in the HBA
+ configuration.
+
+ This fixture assumes that the standard PG* environment variables point to a
+ server running on a local machine, and that the PGUSER has rights to create
+ databases and roles.
+ """
+ skip_if_no_postgres() # don't bother running these tests without a server
+
+ id = secrets.token_hex(4)
+
+ class Context:
+ dbname = "oauth_test_" + id
+
+ user = "oauth_user_" + id
+ map_user = "oauth_map_user_" + id
+ authz_user = "oauth_authz_user_" + id
+
+ issuer = "https://example.com/" + id
+ scope = "openid " + id
+
+ ctx = Context()
+ hba_lines = (
+ f'host {ctx.dbname} {ctx.map_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
+ f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
+ f'host {ctx.dbname} all samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
+ )
+ ident_lines = (r"oauth /^(.*)@example\.com$ \1",)
+
+ conn = psycopg2.connect("")
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+
+ # Create our roles and database.
+ user = sql.Identifier(ctx.user)
+ map_user = sql.Identifier(ctx.map_user)
+ authz_user = sql.Identifier(ctx.authz_user)
+ dbname = sql.Identifier(ctx.dbname)
+
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(user))
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(map_user))
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user))
+ c.execute(sql.SQL("CREATE DATABASE {};").format(dbname))
+
+ # Make this test script the server's oauth_validator.
+ path = pathlib.Path(__file__).parent / "validate_bearer.py"
+ path = str(path.absolute())
+
+ cmd = f"{shlex.quote(path)} {SHARED_MEM_NAME} <&%f"
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+
+ # Replace pg_hba and pg_ident.
+ c.execute("SHOW hba_file;")
+ hba = c.fetchone()[0]
+
+ c.execute("SHOW ident_file;")
+ ident = c.fetchone()[0]
+
+ with prepend_file(hba, hba_lines), prepend_file(ident, ident_lines):
+ c.execute("SELECT pg_reload_conf();")
+
+ # Use the new database and user.
+ yield ctx
+
+ # Put things back the way they were.
+ c.execute("SELECT pg_reload_conf();")
+
+ c.execute("ALTER SYSTEM RESET oauth_validator_command;")
+ c.execute(sql.SQL("DROP DATABASE {};").format(dbname))
+ c.execute(sql.SQL("DROP ROLE {};").format(authz_user))
+ c.execute(sql.SQL("DROP ROLE {};").format(map_user))
+ c.execute(sql.SQL("DROP ROLE {};").format(user))
+
+
+@pytest.fixture()
+def conn(oauth_ctx, connect):
+ """
+ A convenience wrapper for connect(). The main purpose of this fixture is to
+ make sure oauth_ctx runs its setup code before the connection is made.
+ """
+ return connect()
+
+
+@pytest.fixture(scope="module", autouse=True)
+def authn_id_extension(oauth_ctx):
+ """
+ Performs a `CREATE EXTENSION authn_id` in the test database. This fixture is
+ autoused, so tests don't need to rely on it.
+ """
+ conn = psycopg2.connect(database=oauth_ctx.dbname)
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+ c.execute("CREATE EXTENSION authn_id;")
+
+
+@pytest.fixture(scope="session")
+def shared_mem():
+ """
+ Yields a shared memory segment that can be used for communication between
+ the bearer_token fixture and ./validate_bearer.py.
+ """
+ size = MAX_TOKEN_SIZE + 2 # two byte length prefix
+ mem = shared_memory.SharedMemory(SHARED_MEM_NAME, create=True, size=size)
+
+ try:
+ with contextlib.closing(mem):
+ yield mem
+ finally:
+ mem.unlink()
+
+
+@pytest.fixture()
+def bearer_token(shared_mem):
+ """
+ Returns a factory function that, when called, will store a Bearer token in
+ shared_mem. If token is None (the default), a new token will be generated
+ using secrets.token_urlsafe() and returned; otherwise the passed token will
+ be used as-is.
+
+ When token is None, the generated token size in bytes may be specified as an
+ argument; if unset, a small 16-byte token will be generated. The token size
+ may not exceed MAX_TOKEN_SIZE in any case.
+
+ The return value is the token, converted to a bytes object.
+
+ As a special case for testing failure modes, accept_any may be set to True.
+ This signals to the validator command that any bearer token should be
+ accepted. The returned token in this case may be used or discarded as needed
+ by the test.
+ """
+
+ def set_token(token=None, *, size=16, accept_any=False):
+ if token is not None:
+ size = len(token)
+
+ if size > MAX_TOKEN_SIZE:
+ raise ValueError(f"token size {size} exceeds maximum size {MAX_TOKEN_SIZE}")
+
+ if token is None:
+ if size % 4:
+ raise ValueError(f"requested token size {size} is not a multiple of 4")
+
+ token = secrets.token_urlsafe(size // 4 * 3)
+ assert len(token) == size
+
+ try:
+ token = token.encode("ascii")
+ except AttributeError:
+ pass # already encoded
+
+ if accept_any:
+ # Two-byte magic value.
+ shared_mem.buf[:2] = struct.pack("H", MAX_UINT16)
+ else:
+ # Two-byte length prefix, then the token data.
+ shared_mem.buf[:2] = struct.pack("H", len(token))
+ shared_mem.buf[2 : size + 2] = token
+
+ return token
+
+ return set_token
+
+
+def begin_oauth_handshake(conn, oauth_ctx, *, user=None):
+ if user is None:
+ user = oauth_ctx.authz_user
+
+ pq3.send_startup(conn, user=user, database=oauth_ctx.dbname)
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.AuthnRequest
+
+ # The server should advertise exactly one mechanism.
+ assert resp.payload.type == pq3.authn.SASL
+ assert resp.payload.body == [b"OAUTHBEARER", b""]
+
+
+def send_initial_response(conn, *, auth=None, bearer=None):
+ """
+ Sends the OAUTHBEARER initial response on the connection, using the given
+ bearer token. Alternatively to a bearer token, the initial response's auth
+ field may be explicitly specified to test corner cases.
+ """
+ if bearer is not None and auth is not None:
+ raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+ if bearer is not None:
+ auth = b"Bearer " + bearer
+
+ if auth is None:
+ raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+ initial = pq3.SASLInitialResponse.build(
+ dict(
+ name=b"OAUTHBEARER",
+ data=b"n,,\x01auth=" + auth + b"\x01\x01",
+ )
+ )
+ pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+
+def expect_handshake_success(conn):
+ """
+ Validates that the server responds with an AuthnOK message, and then drains
+ the connection until a ReadyForQuery message is received.
+ """
+ resp = pq3.recv1(conn)
+
+ assert resp.type == pq3.types.AuthnRequest
+ assert resp.payload.type == pq3.authn.OK
+ assert not resp.payload.body
+
+ receive_until(conn, pq3.types.ReadyForQuery)
+
+
+def expect_handshake_failure(conn, oauth_ctx):
+ """
+ Performs the OAUTHBEARER SASL failure "handshake" and validates the server's
+ side of the conversation, including the final ErrorResponse.
+ """
+
+ # We expect a discovery "challenge" back from the server before the authn
+ # failure message.
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.AuthnRequest
+
+ req = resp.payload
+ assert req.type == pq3.authn.SASLContinue
+
+ body = json.loads(req.body)
+ assert body["status"] == "invalid_token"
+ assert body["scope"] == oauth_ctx.scope
+
+ expected_config = oauth_ctx.issuer + "/.well-known/openid-configuration"
+ assert body["openid-configuration"] == expected_config
+
+ # Send the dummy response to complete the failed handshake.
+ pq3.send(conn, pq3.types.PasswordMessage, b"\x01")
+ resp = pq3.recv1(conn)
+
+ err = ExpectedError(INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed")
+ err.match(resp)
+
+
+def receive_until(conn, type):
+ """
+ receive_until pulls packets off the pq3 connection until a packet with the
+ desired type is found, or an error response is received.
+ """
+ while True:
+ pkt = pq3.recv1(conn)
+
+ if pkt.type == type:
+ return pkt
+ elif pkt.type == pq3.types.ErrorResponse:
+ raise RuntimeError(
+ f"received error response from peer: {pkt.payload.fields!r}"
+ )
+
+
+@pytest.mark.parametrize("token_len", [16, 1024, 4096])
+@pytest.mark.parametrize(
+ "auth_prefix",
+ [
+ b"Bearer ",
+ b"bearer ",
+ b"Bearer ",
+ ],
+)
+def test_oauth(conn, oauth_ctx, bearer_token, auth_prefix, token_len):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Generate our bearer token with the desired length.
+ token = bearer_token(size=token_len)
+ auth = auth_prefix + token
+
+ send_initial_response(conn, auth=auth)
+ expect_handshake_success(conn)
+
+ # Make sure that the server has not set an authenticated ID.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ assert row.columns == [None]
+
+
+@pytest.mark.parametrize(
+ "token_value",
+ [
+ "abcdzA==",
+ "123456M=",
+ "x-._~+/x",
+ ],
+)
+def test_oauth_bearer_corner_cases(conn, oauth_ctx, bearer_token, token_value):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ send_initial_response(conn, bearer=bearer_token(token_value))
+
+ expect_handshake_success(conn)
+
+
+@pytest.mark.parametrize(
+ "user,authn_id,should_succeed",
+ [
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: ctx.user,
+ True,
+ id="validator authn: succeeds when authn_id == username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: None,
+ False,
+ id="validator authn: fails when authn_id is not set",
+ ),
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: ctx.authz_user,
+ False,
+ id="validator authn: fails when authn_id != username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: ctx.map_user + "@example.com",
+ True,
+ id="validator with map: succeeds when authn_id matches map",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: None,
+ False,
+ id="validator with map: fails when authn_id is not set",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: ctx.map_user + "@example.net",
+ False,
+ id="validator with map: fails when authn_id doesn't match map",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: None,
+ True,
+ id="validator authz: succeeds with no authn_id",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "",
+ True,
+ id="validator authz: succeeds with empty authn_id",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "postgres",
+ True,
+ id="validator authz: succeeds with basic username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "me@example.com",
+ True,
+ id="validator authz: succeeds with email address",
+ ),
+ ],
+)
+def test_oauth_authn_id(conn, oauth_ctx, bearer_token, user, authn_id, should_succeed):
+ token = None
+
+ authn_id = authn_id(oauth_ctx)
+ if authn_id is not None:
+ authn_id = authn_id.encode("ascii")
+
+ # As a hack to get the validator to reflect arbitrary output from this
+ # test, encode the desired output as a base64 token. The validator will
+ # key on the leading "output=" to differentiate this from the random
+ # tokens generated by secrets.token_urlsafe().
+ output = b"output=" + authn_id + b"\n"
+ token = base64.urlsafe_b64encode(output)
+
+ token = bearer_token(token)
+ username = user(oauth_ctx)
+
+ begin_oauth_handshake(conn, oauth_ctx, user=username)
+ send_initial_response(conn, bearer=token)
+
+ if not should_succeed:
+ expect_handshake_failure(conn, oauth_ctx)
+ return
+
+ expect_handshake_success(conn)
+
+ # Check the reported authn_id.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ assert row.columns == [authn_id]
+
+
+class ExpectedError(object):
+ def __init__(self, code, msg=None, detail=None):
+ self.code = code
+ self.msg = msg
+ self.detail = detail
+
+ # Protect against the footgun of an accidental empty string, which will
+ # "match" anything. If you don't want to match message or detail, just
+ # don't pass them.
+ if self.msg == "":
+ raise ValueError("msg must be non-empty or None")
+ if self.detail == "":
+ raise ValueError("detail must be non-empty or None")
+
+ def _getfield(self, resp, type):
+ """
+ Searches an ErrorResponse for a single field of the given type (e.g.
+ "M", "C", "D") and returns its value. Asserts if it doesn't find exactly
+ one field.
+ """
+ prefix = type.encode("ascii")
+ fields = [f for f in resp.payload.fields if f.startswith(prefix)]
+
+ assert len(fields) == 1
+ return fields[0][1:] # strip off the type byte
+
+ def match(self, resp):
+ """
+ Checks that the given response matches the expected code, message, and
+ detail (if given). The error code must match exactly. The expected
+ message and detail must be contained within the actual strings.
+ """
+ assert resp.type == pq3.types.ErrorResponse
+
+ code = self._getfield(resp, "C")
+ assert code == self.code
+
+ if self.msg:
+ msg = self._getfield(resp, "M")
+ expected = self.msg.encode("utf-8")
+ assert expected in msg
+
+ if self.detail:
+ detail = self._getfield(resp, "D")
+ expected = self.detail.encode("utf-8")
+ assert expected in detail
+
+
+def test_oauth_rejected_bearer(conn, oauth_ctx, bearer_token):
+ # Generate a new bearer token, which we will proceed not to use.
+ _ = bearer_token()
+
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send a bearer token that doesn't match what the validator expects. It
+ # should fail the connection.
+ send_initial_response(conn, bearer=b"xxxxxx")
+
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.parametrize(
+ "bad_bearer",
+ [
+ b"Bearer ",
+ b"Bearer a===b",
+ b"Bearer hello!",
+ b"Bearer me@example.com",
+ b'OAuth realm="Example"',
+ b"",
+ ],
+)
+def test_oauth_invalid_bearer(conn, oauth_ctx, bearer_token, bad_bearer):
+ # Tell the validator to accept any token. This ensures that the invalid
+ # bearer tokens are rejected before the validation step.
+ _ = bearer_token(accept_any=True)
+
+ begin_oauth_handshake(conn, oauth_ctx)
+ send_initial_response(conn, auth=bad_bearer)
+
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "resp_type,resp,err",
+ [
+ pytest.param(
+ None,
+ None,
+ None,
+ marks=pytest.mark.slow,
+ id="no response (expect timeout)",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"hello",
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not send a kvsep response",
+ ),
+ id="bad dummy response",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"\x01\x01",
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not send a kvsep response",
+ ),
+ id="multiple kvseps",
+ ),
+ pytest.param(
+ pq3.types.Query,
+ dict(query=b""),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+ id="bad response message type",
+ ),
+ ],
+)
+def test_oauth_bad_response_to_error_challenge(conn, oauth_ctx, resp_type, resp, err):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send an empty auth initial response, which will force an authn failure.
+ send_initial_response(conn, auth=b"")
+
+ # We expect a discovery "challenge" back from the server before the authn
+ # failure message.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.AuthnRequest
+
+ req = pkt.payload
+ assert req.type == pq3.authn.SASLContinue
+
+ body = json.loads(req.body)
+ assert body["status"] == "invalid_token"
+
+ if resp_type is None:
+ # Do not send the dummy response. We should time out and not get a
+ # response from the server.
+ with pytest.raises(socket.timeout):
+ conn.read(1)
+
+ # Done with the test.
+ return
+
+ # Send the bad response.
+ pq3.send(conn, resp_type, resp)
+
+ # Make sure the server fails the connection correctly.
+ pkt = pq3.recv1(conn)
+ err.match(pkt)
+
+
+@pytest.mark.parametrize(
+ "type,payload,err",
+ [
+ pytest.param(
+ pq3.types.ErrorResponse,
+ dict(fields=[b""]),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+ id="error response in initial message",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"x" * (MAX_SASL_MESSAGE_LENGTH + 1),
+ ExpectedError(
+ INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed"
+ ),
+ id="overlong initial response data",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"SCRAM-SHA-256")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE, "invalid SASL authentication mechanism"
+ ),
+ id="bad SASL mechanism selection",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=2, data=b"x")),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "insufficient data"),
+ id="SASL data underflow",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=0, data=b"x")),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "invalid message format"),
+ id="SASL data overflow",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "message is empty",
+ ),
+ id="empty",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"n,,\x01auth=\x01\x01\0")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "length does not match input length",
+ ),
+ id="contains null byte",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected channel-binding flag", # XXX this is a bit strange
+ ),
+ id="initial error response",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"p=tls-server-end-point,,\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "server does not support channel binding",
+ ),
+ id="uses channel binding",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"x,,\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected channel-binding flag",
+ ),
+ id="invalid channel binding specifier",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Comma expected",
+ ),
+ id="bad GS2 header: missing channel binding terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,a")),
+ ExpectedError(
+ FEATURE_NOT_SUPPORTED_ERRCODE,
+ "client uses authorization identity",
+ ),
+ id="bad GS2 header: authzid in use",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,b,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected attribute",
+ ),
+ id="bad GS2 header: extra attribute",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected attribute 0x00", # XXX this is a bit strange
+ ),
+ id="bad GS2 header: missing authzid terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Key-value separator expected",
+ ),
+ id="missing initial kvsep",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Key-value separator expected",
+ ),
+ id="missing initial kvsep",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "does not contain an auth value",
+ ),
+ id="missing auth value: empty key-value list",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "does not contain an auth value",
+ ),
+ id="missing auth value: other keys present",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "unterminated key/value pair",
+ ),
+ id="missing value terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not contain a final terminator",
+ ),
+ id="missing list terminator: empty list",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not contain a final terminator",
+ ),
+ id="missing list terminator: with auth value",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01\x01blah")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "additional data after the final terminator",
+ ),
+ id="additional key after terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01key\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "key without a value",
+ ),
+ id="key without value",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(
+ name=b"OAUTHBEARER",
+ data=b"y,,\x01auth=Bearer 0\x01auth=Bearer 1\x01\x01",
+ )
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "contains multiple auth values",
+ ),
+ id="multiple auth values",
+ ),
+ ],
+)
+def test_oauth_bad_initial_response(conn, oauth_ctx, type, payload, err):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # The server expects a SASL response; give it something else instead.
+ if not isinstance(payload, dict):
+ payload = dict(payload_data=payload)
+ pq3.send(conn, type, **payload)
+
+ resp = pq3.recv1(conn)
+ err.match(resp)
+
+
+def test_oauth_empty_initial_response(conn, oauth_ctx, bearer_token):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send an initial response without data.
+ initial = pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER"))
+ pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+ # The server should respond with an empty challenge so we can send the data
+ # it wants.
+ pkt = pq3.recv1(conn)
+
+ assert pkt.type == pq3.types.AuthnRequest
+ assert pkt.payload.type == pq3.authn.SASLContinue
+ assert not pkt.payload.body
+
+ # Now send the initial data.
+ data = b"n,,\x01auth=Bearer " + bearer_token() + b"\x01\x01"
+ pq3.send(conn, pq3.types.PasswordMessage, data)
+
+ # Server should now complete the handshake.
+ expect_handshake_success(conn)
+
+
+@pytest.fixture()
+def set_validator():
+ """
+ A per-test fixture that allows a test to override the setting of
+ oauth_validator_command for the cluster. The setting will be reverted during
+ teardown.
+
+ Passing None will perform an ALTER SYSTEM RESET.
+ """
+ conn = psycopg2.connect("")
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+
+ # Save the previous value.
+ c.execute("SHOW oauth_validator_command;")
+ prev_cmd = c.fetchone()[0]
+
+ def setter(cmd):
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+ c.execute("SELECT pg_reload_conf();")
+
+ yield setter
+
+ # Restore the previous value.
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (prev_cmd,))
+ c.execute("SELECT pg_reload_conf();")
+
+
+def test_oauth_no_validator(oauth_ctx, set_validator, connect, bearer_token):
+ # Clear out our validator command, then establish a new connection.
+ set_validator("")
+ conn = connect()
+
+ begin_oauth_handshake(conn, oauth_ctx)
+ send_initial_response(conn, bearer=bearer_token())
+
+ # The server should fail the connection.
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+def test_oauth_validator_role(oauth_ctx, set_validator, connect):
+ # Switch the validator implementation. This validator will reflect the
+ # PGUSER as the authenticated identity.
+ path = pathlib.Path(__file__).parent / "validate_reflect.py"
+ path = str(path.absolute())
+
+ set_validator(f"{shlex.quote(path)} '%r' <&%f")
+ conn = connect()
+
+ # Log in. Note that the reflection validator ignores the bearer token.
+ begin_oauth_handshake(conn, oauth_ctx, user=oauth_ctx.user)
+ send_initial_response(conn, bearer=b"dontcare")
+ expect_handshake_success(conn)
+
+ # Check the user identity.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ expected = oauth_ctx.user.encode("utf-8")
+ assert row.columns == [expected]
+
+
+def test_oauth_role_with_shell_unsafe_characters(oauth_ctx, set_validator, connect):
+ """
+ XXX This test pins undesirable behavior. We should be able to handle any
+ valid Postgres role name.
+ """
+ # Switch the validator implementation. This validator will reflect the
+ # PGUSER as the authenticated identity.
+ path = pathlib.Path(__file__).parent / "validate_reflect.py"
+ path = str(path.absolute())
+
+ set_validator(f"{shlex.quote(path)} '%r' <&%f")
+ conn = connect()
+
+ unsafe_username = "hello'there"
+ begin_oauth_handshake(conn, oauth_ctx, user=unsafe_username)
+
+ # The server should reject the handshake.
+ send_initial_response(conn, bearer=b"dontcare")
+ expect_handshake_failure(conn, oauth_ctx)
diff --git a/src/test/python/server/test_server.py b/src/test/python/server/test_server.py
new file mode 100644
index 0000000000..02126dba79
--- /dev/null
+++ b/src/test/python/server/test_server.py
@@ -0,0 +1,21 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import pq3
+
+
+def test_handshake(connect):
+ """Basic sanity check."""
+ conn = connect()
+
+ pq3.handshake(conn, user=pq3.pguser(), database=pq3.pgdatabase())
+
+ pq3.send(conn, pq3.types.Query, query=b"")
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.EmptyQueryResponse
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.ReadyForQuery
diff --git a/src/test/python/server/validate_bearer.py b/src/test/python/server/validate_bearer.py
new file mode 100755
index 0000000000..2cc73ff154
--- /dev/null
+++ b/src/test/python/server/validate_bearer.py
@@ -0,0 +1,101 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It doesn't actually validate
+# anything, and it logs the bearer token data, which is sensitive.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. Memory is shared and communicated from that test module's
+# bearer_token() fixture.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import base64
+import binascii
+import contextlib
+import struct
+import sys
+from multiprocessing import shared_memory
+
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def remove_shm_from_resource_tracker():
+ """
+ Monkey-patch multiprocessing.resource_tracker so SharedMemory won't be
+ tracked. Pulled from this thread, where there are more details:
+
+ https://bugs.python.org/issue38119
+
+ TL;DR: all clients of shared memory segments automatically destroy them on
+ process exit, which makes shared memory segments much less useful. This
+ monkeypatch removes that behavior so that we can defer to the test to manage
+ the segment lifetime.
+
+ Ideally a future Python patch will pull in this fix and then the entire
+ function can go away.
+ """
+ from multiprocessing import resource_tracker
+
+ def fix_register(name, rtype):
+ if rtype == "shared_memory":
+ return
+ return resource_tracker._resource_tracker.register(self, name, rtype)
+
+ resource_tracker.register = fix_register
+
+ def fix_unregister(name, rtype):
+ if rtype == "shared_memory":
+ return
+ return resource_tracker._resource_tracker.unregister(self, name, rtype)
+
+ resource_tracker.unregister = fix_unregister
+
+ if "shared_memory" in resource_tracker._CLEANUP_FUNCS:
+ del resource_tracker._CLEANUP_FUNCS["shared_memory"]
+
+
+def main(args):
+ remove_shm_from_resource_tracker() # XXX remove some day
+
+ # Get the expected token from the currently running test.
+ shared_mem_name = args[0]
+
+ mem = shared_memory.SharedMemory(shared_mem_name)
+ with contextlib.closing(mem):
+ # First two bytes are the token length.
+ size = struct.unpack("H", mem.buf[:2])[0]
+
+ if size == MAX_UINT16:
+ # Special case: the test wants us to accept any token.
+ sys.stderr.write("accepting token without validation\n")
+ return
+
+ # The remainder of the buffer contains the expected token.
+ assert size <= (mem.size - 2)
+ expected_token = mem.buf[2 : size + 2].tobytes()
+
+ mem.buf[:] = b"\0" * mem.size # scribble over the token
+
+ token = sys.stdin.buffer.read()
+ if token != expected_token:
+ sys.exit(f"failed to match Bearer token ({token!r} != {expected_token!r})")
+
+ # See if the test wants us to print anything. If so, it will have encoded
+ # the desired output in the token with an "output=" prefix.
+ try:
+ # altchars="-_" corresponds to the urlsafe alphabet.
+ data = base64.b64decode(token, altchars="-_", validate=True)
+
+ if data.startswith(b"output="):
+ sys.stdout.buffer.write(data[7:])
+
+ except binascii.Error:
+ pass
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/src/test/python/server/validate_reflect.py b/src/test/python/server/validate_reflect.py
new file mode 100755
index 0000000000..24c3a7e715
--- /dev/null
+++ b/src/test/python/server/validate_reflect.py
@@ -0,0 +1,34 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It ignores the bearer token
+# entirely and automatically logs the user in.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. It expects the user's desired role name as an argument; the
+# actual token will be discarded and the user will be logged in with the role
+# name as the authenticated identity.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import sys
+
+
+def main(args):
+ # We have to read the entire token as our first action to unblock the
+ # server, but we won't actually use it.
+ _ = sys.stdin.buffer.read()
+
+ if len(args) != 1:
+ sys.exit("usage: ./validate_reflect.py ROLE")
+
+ # Log the user in as the provided role.
+ role = args[0]
+ print(role)
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/src/test/python/test_internals.py b/src/test/python/test_internals.py
new file mode 100644
index 0000000000..dee4855fc0
--- /dev/null
+++ b/src/test/python/test_internals.py
@@ -0,0 +1,138 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import io
+
+from pq3 import _DebugStream
+
+
+def test_DebugStream_read():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+
+ stream = _DebugStream(under, out)
+
+ res = stream.read(5)
+ assert res == b"abcde"
+
+ res = stream.read(16)
+ assert res == b"fghijklmnopqrstu"
+
+ stream.flush_debug()
+
+ res = stream.read()
+ assert res == b"vwxyz"
+
+ stream.flush_debug()
+
+ expected = (
+ "< 0000:\t61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70\tabcdefghijklmnop\n"
+ "< 0010:\t71 72 73 74 75 \tqrstu\n"
+ "\n"
+ "< 0000:\t76 77 78 79 7a \tvwxyz\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_write():
+ under = io.BytesIO()
+ out = io.StringIO()
+
+ stream = _DebugStream(under, out)
+
+ stream.write(b"\x00\x01\x02")
+ stream.flush()
+
+ assert under.getvalue() == b"\x00\x01\x02"
+
+ stream.write(b"\xc0\xc1\xc2")
+ stream.flush()
+
+ assert under.getvalue() == b"\x00\x01\x02\xc0\xc1\xc2"
+
+ stream.flush_debug()
+
+ expected = "> 0000:\t00 01 02 c0 c1 c2 \t......\n\n"
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_read_write():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+ stream = _DebugStream(under, out)
+
+ res = stream.read(5)
+ assert res == b"abcde"
+
+ stream.write(b"xxxxx")
+ stream.flush()
+
+ assert under.getvalue() == b"abcdexxxxxklmnopqrstuvwxyz"
+
+ res = stream.read(5)
+ assert res == b"klmno"
+
+ stream.write(b"xxxxx")
+ stream.flush()
+
+ assert under.getvalue() == b"abcdexxxxxklmnoxxxxxuvwxyz"
+
+ stream.flush_debug()
+
+ expected = (
+ "< 0000:\t61 62 63 64 65 6b 6c 6d 6e 6f \tabcdeklmno\n"
+ "\n"
+ "> 0000:\t78 78 78 78 78 78 78 78 78 78 \txxxxxxxxxx\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_end_packet():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+ stream = _DebugStream(under, out)
+
+ stream.read(5)
+ stream.end_packet("read description", read=True, indent=" ")
+
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("write description", indent=" ")
+
+ stream.read(5)
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("read/write combo for read", read=True, indent=" ")
+
+ stream.read(5)
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("read/write combo for write", indent=" ")
+
+ expected = (
+ " < 0000:\t61 62 63 64 65 \tabcde\n"
+ "\n"
+ "< read description\n"
+ "\n"
+ "> write description\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ " < 0000:\t6b 6c 6d 6e 6f \tklmno\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ "< read/write combo for read\n"
+ "\n"
+ "> read/write combo for write\n"
+ "\n"
+ " < 0000:\t75 76 77 78 79 \tuvwxy\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
diff --git a/src/test/python/test_pq3.py b/src/test/python/test_pq3.py
new file mode 100644
index 0000000000..e0c0e0568d
--- /dev/null
+++ b/src/test/python/test_pq3.py
@@ -0,0 +1,558 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import struct
+import sys
+
+import pytest
+from construct import Container, PaddingError, StreamError, TerminatedError
+
+import pq3
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"\x00\x00\x00\x10\x00\x04\x00\x00abcdefgh",
+ Container(len=16, proto=0x40000, payload=b"abcdefgh"),
+ b"",
+ id="8-byte payload",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x08\x00\x04\x00\x00",
+ Container(len=8, proto=0x40000, payload=b""),
+ b"",
+ id="no payload",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x09\x00\x04\x00\x00abcde",
+ Container(len=9, proto=0x40000, payload=b"a"),
+ b"bcde",
+ id="1-byte payload and extra padding",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x0B\x00\x03\x00\x00hi\x00",
+ Container(len=11, proto=pq3.protocol(3, 0), payload=[b"hi"]),
+ b"",
+ id="implied parameter list when using proto version 3.0",
+ ),
+ ],
+)
+def test_Startup_parse(raw, expected, extra):
+ with io.BytesIO(raw) as stream:
+ actual = pq3.Startup.parse_stream(stream)
+
+ assert actual == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "packet,expected_bytes",
+ [
+ pytest.param(
+ dict(),
+ b"\x00\x00\x00\x08\x00\x00\x00\x00",
+ id="nothing set",
+ ),
+ pytest.param(
+ dict(len=10, proto=0x12345678),
+ b"\x00\x00\x00\x0A\x12\x34\x56\x78\x00\x00",
+ id="len and proto set explicitly",
+ ),
+ pytest.param(
+ dict(proto=0x12345678),
+ b"\x00\x00\x00\x08\x12\x34\x56\x78",
+ id="implied len with no payload",
+ ),
+ pytest.param(
+ dict(proto=0x12345678, payload=b"abcd"),
+ b"\x00\x00\x00\x0C\x12\x34\x56\x78abcd",
+ id="implied len with payload",
+ ),
+ pytest.param(
+ dict(payload=[b""]),
+ b"\x00\x00\x00\x09\x00\x03\x00\x00\x00",
+ id="implied proto version 3 when sending parameters",
+ ),
+ pytest.param(
+ dict(payload=[b"hi", b""]),
+ b"\x00\x00\x00\x0C\x00\x03\x00\x00hi\x00\x00",
+ id="implied proto version 3 and len when sending more than one parameter",
+ ),
+ pytest.param(
+ dict(payload=dict(user="jsmith", database="postgres")),
+ b"\x00\x00\x00\x27\x00\x03\x00\x00user\x00jsmith\x00database\x00postgres\x00\x00",
+ id="auto-serialization of dict parameters",
+ ),
+ ],
+)
+def test_Startup_build(packet, expected_bytes):
+ actual = pq3.Startup.build(packet)
+ assert actual == expected_bytes
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"*\x00\x00\x00\x08abcd",
+ dict(type=b"*", len=8, payload=b"abcd"),
+ b"",
+ id="4-byte payload",
+ ),
+ pytest.param(
+ b"*\x00\x00\x00\x04",
+ dict(type=b"*", len=4, payload=b""),
+ b"",
+ id="no payload",
+ ),
+ pytest.param(
+ b"*\x00\x00\x00\x05xabcd",
+ dict(type=b"*", len=5, payload=b"x"),
+ b"abcd",
+ id="1-byte payload with extra padding",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=8,
+ payload=dict(type=pq3.authn.OK, body=None),
+ ),
+ b"",
+ id="AuthenticationOk",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x12\x00\x00\x00\x0AEXTERNAL\x00\x00",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=18,
+ payload=dict(type=pq3.authn.SASL, body=[b"EXTERNAL", b""]),
+ ),
+ b"",
+ id="AuthenticationSASL",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=13,
+ payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+ ),
+ b"",
+ id="AuthenticationSASLContinue",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=13,
+ payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+ ),
+ b"",
+ id="AuthenticationSASLFinal",
+ ),
+ pytest.param(
+ b"p\x00\x00\x00\x0Bhunter2",
+ dict(
+ type=pq3.types.PasswordMessage,
+ len=11,
+ payload=b"hunter2",
+ ),
+ b"",
+ id="PasswordMessage",
+ ),
+ pytest.param(
+ b"K\x00\x00\x00\x0C\x00\x00\x00\x00\x12\x34\x56\x78",
+ dict(
+ type=pq3.types.BackendKeyData,
+ len=12,
+ payload=dict(pid=0, key=0x12345678),
+ ),
+ b"",
+ id="BackendKeyData",
+ ),
+ pytest.param(
+ b"C\x00\x00\x00\x08SET\x00",
+ dict(
+ type=pq3.types.CommandComplete,
+ len=8,
+ payload=dict(tag=b"SET"),
+ ),
+ b"",
+ id="CommandComplete",
+ ),
+ pytest.param(
+ b"E\x00\x00\x00\x11Mbad!\x00Mdog!\x00\x00",
+ dict(type=b"E", len=17, payload=dict(fields=[b"Mbad!", b"Mdog!", b""])),
+ b"",
+ id="ErrorResponse",
+ ),
+ pytest.param(
+ b"S\x00\x00\x00\x08a\x00b\x00",
+ dict(
+ type=pq3.types.ParameterStatus,
+ len=8,
+ payload=dict(name=b"a", value=b"b"),
+ ),
+ b"",
+ id="ParameterStatus",
+ ),
+ pytest.param(
+ b"Z\x00\x00\x00\x05x",
+ dict(type=b"Z", len=5, payload=dict(status=b"x")),
+ b"",
+ id="ReadyForQuery",
+ ),
+ pytest.param(
+ b"Q\x00\x00\x00\x06!\x00",
+ dict(type=pq3.types.Query, len=6, payload=dict(query=b"!")),
+ b"",
+ id="Query",
+ ),
+ pytest.param(
+ b"D\x00\x00\x00\x0B\x00\x01\x00\x00\x00\x01!",
+ dict(type=pq3.types.DataRow, len=11, payload=dict(columns=[b"!"])),
+ b"",
+ id="DataRow",
+ ),
+ pytest.param(
+ b"D\x00\x00\x00\x06\x00\x00extra",
+ dict(type=pq3.types.DataRow, len=6, payload=dict(columns=[])),
+ b"extra",
+ id="DataRow with extra data",
+ ),
+ pytest.param(
+ b"I\x00\x00\x00\x04",
+ dict(type=pq3.types.EmptyQueryResponse, len=4, payload=None),
+ b"",
+ id="EmptyQueryResponse",
+ ),
+ pytest.param(
+ b"I\x00\x00\x00\x04\xFF",
+ dict(type=b"I", len=4, payload=None),
+ b"\xFF",
+ id="EmptyQueryResponse with extra bytes",
+ ),
+ pytest.param(
+ b"X\x00\x00\x00\x04",
+ dict(type=pq3.types.Terminate, len=4, payload=None),
+ b"",
+ id="Terminate",
+ ),
+ ],
+)
+def test_Pq3_parse(raw, expected, extra):
+ with io.BytesIO(raw) as stream:
+ actual = pq3.Pq3.parse_stream(stream)
+
+ assert actual == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(type=b"*", len=5),
+ b"*\x00\x00\x00\x05\x00",
+ id="type and len set explicitly",
+ ),
+ pytest.param(
+ dict(type=b"*"),
+ b"*\x00\x00\x00\x04",
+ id="implied len with no payload",
+ ),
+ pytest.param(
+ dict(type=b"*", payload=b"1234"),
+ b"*\x00\x00\x00\x081234",
+ id="implied len with payload",
+ ),
+ pytest.param(
+ dict(type=pq3.types.AuthnRequest, payload=dict(type=pq3.authn.OK)),
+ b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+ id="implied len/type for AuthenticationOK",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(
+ type=pq3.authn.SASL,
+ body=[b"SCRAM-SHA-256-PLUS", b"SCRAM-SHA-256", b""],
+ ),
+ ),
+ b"R\x00\x00\x00\x2A\x00\x00\x00\x0ASCRAM-SHA-256-PLUS\x00SCRAM-SHA-256\x00\x00",
+ id="implied len/type for AuthenticationSASL",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+ ),
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+ id="implied len/type for AuthenticationSASLContinue",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+ ),
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+ id="implied len/type for AuthenticationSASLFinal",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.PasswordMessage,
+ payload=b"hunter2",
+ ),
+ b"p\x00\x00\x00\x0Bhunter2",
+ id="implied len/type for PasswordMessage",
+ ),
+ pytest.param(
+ dict(type=pq3.types.BackendKeyData, payload=dict(pid=1, key=7)),
+ b"K\x00\x00\x00\x0C\x00\x00\x00\x01\x00\x00\x00\x07",
+ id="implied len/type for BackendKeyData",
+ ),
+ pytest.param(
+ dict(type=pq3.types.CommandComplete, payload=dict(tag=b"SET")),
+ b"C\x00\x00\x00\x08SET\x00",
+ id="implied len/type for CommandComplete",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ErrorResponse, payload=dict(fields=[b"error", b""])),
+ b"E\x00\x00\x00\x0Berror\x00\x00",
+ id="implied len/type for ErrorResponse",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ParameterStatus, payload=dict(name=b"a", value=b"b")),
+ b"S\x00\x00\x00\x08a\x00b\x00",
+ id="implied len/type for ParameterStatus",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ReadyForQuery, payload=dict(status=b"I")),
+ b"Z\x00\x00\x00\x05I",
+ id="implied len/type for ReadyForQuery",
+ ),
+ pytest.param(
+ dict(type=pq3.types.Query, payload=dict(query=b"SELECT 1;")),
+ b"Q\x00\x00\x00\x0eSELECT 1;\x00",
+ id="implied len/type for Query",
+ ),
+ pytest.param(
+ dict(type=pq3.types.DataRow, payload=dict(columns=[b"abcd"])),
+ b"D\x00\x00\x00\x0E\x00\x01\x00\x00\x00\x04abcd",
+ id="implied len/type for DataRow",
+ ),
+ pytest.param(
+ dict(type=pq3.types.EmptyQueryResponse),
+ b"I\x00\x00\x00\x04",
+ id="implied len for EmptyQueryResponse",
+ ),
+ pytest.param(
+ dict(type=pq3.types.Terminate),
+ b"X\x00\x00\x00\x04",
+ id="implied len for Terminate",
+ ),
+ ],
+)
+def test_Pq3_build(fields, expected):
+ actual = pq3.Pq3.build(fields)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"\x00\x00",
+ dict(columns=[]),
+ b"",
+ id="no columns",
+ ),
+ pytest.param(
+ b"\x00\x01\x00\x00\x00\x04abcd",
+ dict(columns=[b"abcd"]),
+ b"",
+ id="one column",
+ ),
+ pytest.param(
+ b"\x00\x02\x00\x00\x00\x04abcd\x00\x00\x00\x01x",
+ dict(columns=[b"abcd", b"x"]),
+ b"",
+ id="multiple columns",
+ ),
+ pytest.param(
+ b"\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01x",
+ dict(columns=[b"", b"x"]),
+ b"",
+ id="empty column value",
+ ),
+ pytest.param(
+ b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+ dict(columns=[None, None]),
+ b"",
+ id="null columns",
+ ),
+ ],
+)
+def test_DataRow_parse(raw, expected, extra):
+ pkt = b"D" + struct.pack("!i", len(raw) + 4) + raw
+ with io.BytesIO(pkt) as stream:
+ actual = pq3.Pq3.parse_stream(stream)
+
+ assert actual.type == pq3.types.DataRow
+ assert actual.payload == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(),
+ b"\x00\x00",
+ id="no columns",
+ ),
+ pytest.param(
+ dict(columns=[None, None]),
+ b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+ id="null columns",
+ ),
+ ],
+)
+def test_DataRow_build(fields, expected):
+ actual = pq3.Pq3.build(dict(type=pq3.types.DataRow, payload=fields))
+
+ expected = b"D" + struct.pack("!i", len(expected) + 4) + expected
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "raw,expected,exception",
+ [
+ pytest.param(
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ dict(name=b"EXTERNAL", len=-1, data=None),
+ None,
+ id="no initial response",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\x02me",
+ dict(name=b"EXTERNAL", len=2, data=b"me"),
+ None,
+ id="initial response",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\x02meextra",
+ None,
+ TerminatedError,
+ id="extra data",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\xFFme",
+ None,
+ StreamError,
+ id="underflow",
+ ),
+ ],
+)
+def test_SASLInitialResponse_parse(raw, expected, exception):
+ ctx = contextlib.nullcontext()
+ if exception:
+ ctx = pytest.raises(exception)
+
+ with ctx:
+ actual = pq3.SASLInitialResponse.parse(raw)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(name=b"EXTERNAL"),
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ id="no initial response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=None),
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ id="no initial response (explicit None)",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=b""),
+ b"EXTERNAL\x00\x00\x00\x00\x00",
+ id="empty response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=b"me@example.com"),
+ b"EXTERNAL\x00\x00\x00\x00\x0Eme@example.com",
+ id="initial response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", len=2, data=b"me@example.com"),
+ b"EXTERNAL\x00\x00\x00\x00\x02me@example.com",
+ id="data overflow",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", len=14, data=b"me"),
+ b"EXTERNAL\x00\x00\x00\x00\x0Eme",
+ id="data underflow",
+ ),
+ ],
+)
+def test_SASLInitialResponse_build(fields, expected):
+ actual = pq3.SASLInitialResponse.build(fields)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "version,expected_bytes",
+ [
+ pytest.param((3, 0), b"\x00\x03\x00\x00", id="version 3"),
+ pytest.param((1234, 5679), b"\x04\xd2\x16\x2f", id="SSLRequest"),
+ ],
+)
+def test_protocol(version, expected_bytes):
+ # Make sure the integer returned by protocol is correctly serialized on the
+ # wire.
+ assert struct.pack("!i", pq3.protocol(*version)) == expected_bytes
+
+
+@pytest.mark.parametrize(
+ "envvar,func,expected",
+ [
+ ("PGHOST", pq3.pghost, "localhost"),
+ ("PGPORT", pq3.pgport, 5432),
+ ("PGUSER", pq3.pguser, getpass.getuser()),
+ ("PGDATABASE", pq3.pgdatabase, "postgres"),
+ ],
+)
+def test_env_defaults(monkeypatch, envvar, func, expected):
+ monkeypatch.delenv(envvar, raising=False)
+
+ actual = func()
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "envvars,func,expected",
+ [
+ (dict(PGHOST="otherhost"), pq3.pghost, "otherhost"),
+ (dict(PGPORT="6789"), pq3.pgport, 6789),
+ (dict(PGUSER="postgres"), pq3.pguser, "postgres"),
+ (dict(PGDATABASE="template1"), pq3.pgdatabase, "template1"),
+ ],
+)
+def test_env(monkeypatch, envvars, func, expected):
+ for k, v in envvars.items():
+ monkeypatch.setenv(k, v)
+
+ actual = func()
+ assert actual == expected
diff --git a/src/test/python/tls.py b/src/test/python/tls.py
new file mode 100644
index 0000000000..075c02c1ca
--- /dev/null
+++ b/src/test/python/tls.py
@@ -0,0 +1,195 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+from construct import *
+
+#
+# TLS 1.3
+#
+# Most of the types below are transcribed from RFC 8446:
+#
+# https://tools.ietf.org/html/rfc8446
+#
+
+
+def _Vector(size_field, element):
+ return Prefixed(size_field, GreedyRange(element))
+
+
+# Alerts
+
+AlertLevel = Enum(
+ Byte,
+ warning=1,
+ fatal=2,
+)
+
+AlertDescription = Enum(
+ Byte,
+ close_notify=0,
+ unexpected_message=10,
+ bad_record_mac=20,
+ decryption_failed_RESERVED=21,
+ record_overflow=22,
+ decompression_failure=30,
+ handshake_failure=40,
+ no_certificate_RESERVED=41,
+ bad_certificate=42,
+ unsupported_certificate=43,
+ certificate_revoked=44,
+ certificate_expired=45,
+ certificate_unknown=46,
+ illegal_parameter=47,
+ unknown_ca=48,
+ access_denied=49,
+ decode_error=50,
+ decrypt_error=51,
+ export_restriction_RESERVED=60,
+ protocol_version=70,
+ insufficient_security=71,
+ internal_error=80,
+ user_canceled=90,
+ no_renegotiation=100,
+ unsupported_extension=110,
+)
+
+Alert = Struct(
+ "level" / AlertLevel,
+ "description" / AlertDescription,
+)
+
+
+# Extensions
+
+ExtensionType = Enum(
+ Int16ub,
+ server_name=0,
+ max_fragment_length=1,
+ status_request=5,
+ supported_groups=10,
+ signature_algorithms=13,
+ use_srtp=14,
+ heartbeat=15,
+ application_layer_protocol_negotiation=16,
+ signed_certificate_timestamp=18,
+ client_certificate_type=19,
+ server_certificate_type=20,
+ padding=21,
+ pre_shared_key=41,
+ early_data=42,
+ supported_versions=43,
+ cookie=44,
+ psk_key_exchange_modes=45,
+ certificate_authorities=47,
+ oid_filters=48,
+ post_handshake_auth=49,
+ signature_algorithms_cert=50,
+ key_share=51,
+)
+
+Extension = Struct(
+ "extension_type" / ExtensionType,
+ "extension_data" / Prefixed(Int16ub, GreedyBytes),
+)
+
+
+# ClientHello
+
+
+class _CipherSuiteAdapter(Adapter):
+ class _hextuple(tuple):
+ def __repr__(self):
+ return f"(0x{self[0]:02X}, 0x{self[1]:02X})"
+
+ def _encode(self, obj, context, path):
+ return bytes(obj)
+
+ def _decode(self, obj, context, path):
+ assert len(obj) == 2
+ return self._hextuple(obj)
+
+
+ProtocolVersion = Hex(Int16ub)
+
+Random = Hex(Bytes(32))
+
+CipherSuite = _CipherSuiteAdapter(Byte[2])
+
+ClientHello = Struct(
+ "legacy_version" / ProtocolVersion,
+ "random" / Random,
+ "legacy_session_id" / Prefixed(Byte, Hex(GreedyBytes)),
+ "cipher_suites" / _Vector(Int16ub, CipherSuite),
+ "legacy_compression_methods" / Prefixed(Byte, GreedyBytes),
+ "extensions" / _Vector(Int16ub, Extension),
+)
+
+# ServerHello
+
+ServerHello = Struct(
+ "legacy_version" / ProtocolVersion,
+ "random" / Random,
+ "legacy_session_id_echo" / Prefixed(Byte, Hex(GreedyBytes)),
+ "cipher_suite" / CipherSuite,
+ "legacy_compression_method" / Hex(Byte),
+ "extensions" / _Vector(Int16ub, Extension),
+)
+
+# Handshake
+
+HandshakeType = Enum(
+ Byte,
+ client_hello=1,
+ server_hello=2,
+ new_session_ticket=4,
+ end_of_early_data=5,
+ encrypted_extensions=8,
+ certificate=11,
+ certificate_request=13,
+ certificate_verify=15,
+ finished=20,
+ key_update=24,
+ message_hash=254,
+)
+
+Handshake = Struct(
+ "msg_type" / HandshakeType,
+ "length" / Int24ub,
+ "payload"
+ / Switch(
+ this.msg_type,
+ {
+ HandshakeType.client_hello: ClientHello,
+ HandshakeType.server_hello: ServerHello,
+ # HandshakeType.end_of_early_data: EndOfEarlyData,
+ # HandshakeType.encrypted_extensions: EncryptedExtensions,
+ # HandshakeType.certificate_request: CertificateRequest,
+ # HandshakeType.certificate: Certificate,
+ # HandshakeType.certificate_verify: CertificateVerify,
+ # HandshakeType.finished: Finished,
+ # HandshakeType.new_session_ticket: NewSessionTicket,
+ # HandshakeType.key_update: KeyUpdate,
+ },
+ default=FixedSized(this.length, GreedyBytes),
+ ),
+)
+
+# Records
+
+ContentType = Enum(
+ Byte,
+ invalid=0,
+ change_cipher_spec=20,
+ alert=21,
+ handshake=22,
+ application_data=23,
+)
+
+Plaintext = Struct(
+ "type" / ContentType,
+ "legacy_record_version" / ProtocolVersion,
+ "length" / Int16ub,
+ "fragment" / FixedSized(this.length, GreedyBytes),
+)
--
2.25.1
v3-0009-contrib-oauth-switch-to-pluggable-auth-API.patchtext/x-patch; name=v3-0009-contrib-oauth-switch-to-pluggable-auth-API.patchDownload
From d275b329aaed6c6ef2403e4c313725a1ae88fa40 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Fri, 4 Mar 2022 08:48:47 -0800
Subject: [PATCH v3 9/9] contrib/oauth: switch to pluggable auth API
Move the core server implementation to contrib/oauth as a pluggable
provider, using the RegisterAuthProvider() API. oauth_validator_command
has been moved from core to a custom GUC. Tests have been updated to
handle the new implementations.
Some server modifications remain:
- Adding new HBA options for custom providers
- Registering support for usermaps
- Allowing custom SASL mechanisms to declare their own maximum message
length
This patch is optional; you can apply/revert it to compare the two
approaches.
---
contrib/oauth/Makefile | 16 +++++++
.../auth-oauth.c => contrib/oauth/oauth.c | 47 +++++++++++++++----
src/backend/libpq/Makefile | 1 -
src/backend/libpq/auth.c | 7 ---
src/backend/libpq/hba.c | 21 +++++----
src/backend/utils/misc/guc.c | 12 -----
src/include/libpq/hba.h | 3 +-
src/include/libpq/oauth.h | 24 ----------
src/test/python/server/test_oauth.py | 20 ++++----
9 files changed, 78 insertions(+), 73 deletions(-)
create mode 100644 contrib/oauth/Makefile
rename src/backend/libpq/auth-oauth.c => contrib/oauth/oauth.c (95%)
delete mode 100644 src/include/libpq/oauth.h
diff --git a/contrib/oauth/Makefile b/contrib/oauth/Makefile
new file mode 100644
index 0000000000..880bc1fef3
--- /dev/null
+++ b/contrib/oauth/Makefile
@@ -0,0 +1,16 @@
+# contrib/oauth/Makefile
+
+MODULE_big = oauth
+OBJS = oauth.o
+PGFILEDESC = "oauth - auth provider supporting OAuth 2.0/OIDC"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = contrib/oauth
+top_builddir = ../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/src/backend/libpq/auth-oauth.c b/contrib/oauth/oauth.c
similarity index 95%
rename from src/backend/libpq/auth-oauth.c
rename to contrib/oauth/oauth.c
index c1232a31a0..3a6dab19d9 100644
--- a/src/backend/libpq/auth-oauth.c
+++ b/contrib/oauth/oauth.c
@@ -1,33 +1,39 @@
-/*-------------------------------------------------------------------------
+/* -------------------------------------------------------------------------
*
- * auth-oauth.c
+ * oauth.c
* Server-side implementation of the SASL OAUTHBEARER mechanism.
*
* See the following RFC for more details:
* - RFC 7628: https://tools.ietf.org/html/rfc7628
*
- * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1996-2022, PostgreSQL Global Development Group
* Portions Copyright (c) 1994, Regents of the University of California
*
- * src/backend/libpq/auth-oauth.c
+ * contrib/oauth/oauth.c
*
- *-------------------------------------------------------------------------
+ * -------------------------------------------------------------------------
*/
+
#include "postgres.h"
#include <unistd.h>
#include <fcntl.h>
#include "common/oauth-common.h"
+#include "fmgr.h"
#include "lib/stringinfo.h"
#include "libpq/auth.h"
#include "libpq/hba.h"
-#include "libpq/oauth.h"
#include "libpq/sasl.h"
#include "storage/fd.h"
+#include "utils/guc.h"
+
+PG_MODULE_MAGIC;
+
+void _PG_init(void);
/* GUC */
-char *oauth_validator_command;
+static char *oauth_validator_command;
static void oauth_get_mechanisms(Port *port, StringInfo buf);
static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
@@ -35,7 +41,7 @@ static int oauth_exchange(void *opaq, const char *input, int inputlen,
char **output, int *outputlen, const char **logdetail);
/* Mechanism declaration */
-const pg_be_sasl_mech pg_be_oauth_mech = {
+static const pg_be_sasl_mech oauth_mech = {
oauth_get_mechanisms,
oauth_init,
oauth_exchange,
@@ -795,3 +801,28 @@ username_ok_for_shell(const char *username)
return true;
}
+
+static int CheckOAuth(Port *port)
+{
+ return CheckSASLAuth(&oauth_mech, port, NULL, NULL);
+}
+
+static const char *OAuthError(Port *port)
+{
+ return psprintf("OAuth bearer authentication failed for user \"%s\"",
+ port->user_name);
+}
+
+void
+_PG_init(void)
+{
+ RegisterAuthProvider("oauth", CheckOAuth, OAuthError);
+
+ DefineCustomStringVariable("oauth.validator_command",
+ gettext_noop("Command to validate OAuth v2 bearer tokens."),
+ NULL,
+ &oauth_validator_command,
+ "",
+ PGC_SIGHUP, GUC_SUPERUSER_ONLY,
+ NULL, NULL, NULL);
+}
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 98eb2a8242..6d385fd6a4 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,7 +15,6 @@ include $(top_builddir)/src/Makefile.global
# be-fsstubs is here for historical reasons, probably belongs elsewhere
OBJS = \
- auth-oauth.o \
auth-sasl.o \
auth-scram.o \
auth.o \
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 5c30904e2b..3533b0bc50 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -30,7 +30,6 @@
#include "libpq/auth.h"
#include "libpq/crypt.h"
#include "libpq/libpq.h"
-#include "libpq/oauth.h"
#include "libpq/pqformat.h"
#include "libpq/sasl.h"
#include "libpq/scram.h"
@@ -303,9 +302,6 @@ auth_failed(Port *port, int status, const char *logdetail)
case uaRADIUS:
errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
break;
- case uaOAuth:
- errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
- break;
case uaCustom:
if (CustomAuthenticationError_hook)
errstr = CustomAuthenticationError_hook(port);
@@ -631,9 +627,6 @@ ClientAuthentication(Port *port)
case uaTrust:
status = STATUS_OK;
break;
- case uaOAuth:
- status = CheckSASLAuth(&pg_be_oauth_mech, port, NULL, NULL);
- break;
case uaCustom:
if (CustomAuthenticationCheck_hook)
status = CustomAuthenticationCheck_hook(port);
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index f7f3059927..fb51c53cc0 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -136,7 +136,6 @@ static const char *const UserAuthName[] =
"radius",
"custom",
"peer",
- "oauth",
};
@@ -1401,8 +1400,6 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
#endif
else if (strcmp(token->string, "radius") == 0)
parsedline->auth_method = uaRADIUS;
- else if (strcmp(token->string, "oauth") == 0)
- parsedline->auth_method = uaOAuth;
else if (strcmp(token->string, "custom") == 0)
parsedline->auth_method = uaCustom;
else
@@ -1731,9 +1728,9 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->auth_method != uaPeer &&
hbaline->auth_method != uaGSS &&
hbaline->auth_method != uaSSPI &&
- hbaline->auth_method != uaOAuth &&
- hbaline->auth_method != uaCert)
- INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, oauth, and cert"));
+ hbaline->auth_method != uaCert &&
+ hbaline->auth_method != uaCustom)
+ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, and custom"));
hbaline->usermap = pstrdup(val);
}
else if (strcmp(name, "clientcert") == 0)
@@ -2119,19 +2116,25 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
}
else if (strcmp(name, "issuer") == 0)
{
- if (hbaline->auth_method != uaOAuth)
+ if (hbaline->auth_method != uaCustom
+ && (custom_provider_name != NULL
+ && strcmp(custom_provider_name, "oauth")))
INVALID_AUTH_OPTION("issuer", gettext_noop("oauth"));
hbaline->oauth_issuer = pstrdup(val);
}
else if (strcmp(name, "scope") == 0)
{
- if (hbaline->auth_method != uaOAuth)
+ if (hbaline->auth_method != uaCustom
+ && (custom_provider_name != NULL
+ && strcmp(custom_provider_name, "oauth")))
INVALID_AUTH_OPTION("scope", gettext_noop("oauth"));
hbaline->oauth_scope = pstrdup(val);
}
else if (strcmp(name, "trust_validator_authz") == 0)
{
- if (hbaline->auth_method != uaOAuth)
+ if (hbaline->auth_method != uaCustom
+ && (custom_provider_name != NULL
+ && strcmp(custom_provider_name, "oauth")))
INVALID_AUTH_OPTION("trust_validator_authz", gettext_noop("oauth"));
if (strcmp(val, "1") == 0)
hbaline->oauth_skip_usermap = true;
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index 791c7c83df..1e3650184b 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -58,7 +58,6 @@
#include "libpq/auth.h"
#include "libpq/libpq.h"
#include "libpq/pqformat.h"
-#include "libpq/oauth.h"
#include "miscadmin.h"
#include "optimizer/cost.h"
#include "optimizer/geqo.h"
@@ -4663,17 +4662,6 @@ static struct config_string ConfigureNamesString[] =
check_backtrace_functions, assign_backtrace_functions, NULL
},
- {
- {"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH,
- gettext_noop("Command to validate OAuth v2 bearer tokens."),
- NULL,
- GUC_SUPERUSER_ONLY
- },
- &oauth_validator_command,
- "",
- NULL, NULL, NULL
- },
-
/* End-of-list marker */
{
{NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index d46c2108eb..0c6a7dd823 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -40,8 +40,7 @@ typedef enum UserAuth
uaRADIUS,
uaCustom,
uaPeer,
- uaOAuth
-#define USER_AUTH_LAST uaOAuth /* Must be last value of this enum */
+#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */
} UserAuth;
/*
diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h
deleted file mode 100644
index 870e426af1..0000000000
--- a/src/include/libpq/oauth.h
+++ /dev/null
@@ -1,24 +0,0 @@
-/*-------------------------------------------------------------------------
- *
- * oauth.h
- * Interface to libpq/auth-oauth.c
- *
- * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
- * Portions Copyright (c) 1994, Regents of the University of California
- *
- * src/include/libpq/oauth.h
- *
- *-------------------------------------------------------------------------
- */
-#ifndef PG_OAUTH_H
-#define PG_OAUTH_H
-
-#include "libpq/libpq-be.h"
-#include "libpq/sasl.h"
-
-extern char *oauth_validator_command;
-
-/* Implementation */
-extern const pg_be_sasl_mech pg_be_oauth_mech;
-
-#endif /* PG_OAUTH_H */
diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py
index cb5ca7fa23..07fc25edc2 100644
--- a/src/test/python/server/test_oauth.py
+++ b/src/test/python/server/test_oauth.py
@@ -103,9 +103,9 @@ def oauth_ctx():
ctx = Context()
hba_lines = (
- f'host {ctx.dbname} {ctx.map_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
- f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
- f'host {ctx.dbname} all samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
+ f'host {ctx.dbname} {ctx.map_user} samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
+ f'host {ctx.dbname} {ctx.authz_user} samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
+ f'host {ctx.dbname} all samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
)
ident_lines = (r"oauth /^(.*)@example\.com$ \1",)
@@ -126,12 +126,12 @@ def oauth_ctx():
c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user))
c.execute(sql.SQL("CREATE DATABASE {};").format(dbname))
- # Make this test script the server's oauth_validator.
+ # Make this test script the server's oauth validator.
path = pathlib.Path(__file__).parent / "validate_bearer.py"
path = str(path.absolute())
cmd = f"{shlex.quote(path)} {SHARED_MEM_NAME} <&%f"
- c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+ c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (cmd,))
# Replace pg_hba and pg_ident.
c.execute("SHOW hba_file;")
@@ -149,7 +149,7 @@ def oauth_ctx():
# Put things back the way they were.
c.execute("SELECT pg_reload_conf();")
- c.execute("ALTER SYSTEM RESET oauth_validator_command;")
+ c.execute("ALTER SYSTEM RESET oauth.validator_command;")
c.execute(sql.SQL("DROP DATABASE {};").format(dbname))
c.execute(sql.SQL("DROP ROLE {};").format(authz_user))
c.execute(sql.SQL("DROP ROLE {};").format(map_user))
@@ -930,7 +930,7 @@ def test_oauth_empty_initial_response(conn, oauth_ctx, bearer_token):
def set_validator():
"""
A per-test fixture that allows a test to override the setting of
- oauth_validator_command for the cluster. The setting will be reverted during
+ oauth.validator_command for the cluster. The setting will be reverted during
teardown.
Passing None will perform an ALTER SYSTEM RESET.
@@ -942,17 +942,17 @@ def set_validator():
c = conn.cursor()
# Save the previous value.
- c.execute("SHOW oauth_validator_command;")
+ c.execute("SHOW oauth.validator_command;")
prev_cmd = c.fetchone()[0]
def setter(cmd):
- c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+ c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (cmd,))
c.execute("SELECT pg_reload_conf();")
yield setter
# Restore the previous value.
- c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (prev_cmd,))
+ c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (prev_cmd,))
c.execute("SELECT pg_reload_conf();")
--
2.25.1
Hi Jacob,
Thank you for porting this on top of the pluggable auth methods API. I've
addressed the feedback around other backend changes in my latest patch, but
the client side changes still remain. I had a few questions to understand
them better.
(a) What specifically do the client side changes in the patch implement?
(b) Are the changes you made on the client side specific to OAUTH or are
they about making SASL more generic? As an additional question, if someone
wanted to implement something similar on top of your patch, would they
still have to make client side changes?
Regards,
Samay
On Fri, Mar 4, 2022 at 11:13 AM Jacob Champion <pchampion@vmware.com> wrote:
Show quoted text
Hi all,
v3 rebases this patchset over the top of Samay's pluggable auth
provider API [1], included here as patches 0001-3. The final patch in
the set ports the server implementation from a core feature to a
contrib module; to switch between the two approaches, simply leave out
that final patch.There are still some backend changes that must be made to get this
working, as pointed out in 0009, and obviously libpq support still
requires code changes.--Jacob
[1]
/messages/by-id/CAJxrbyxTRn5P8J-p+wHLwFahK5y56PhK28VOb55jqMO05Y-DJw@mail.gmail.com
On Tue, 2022-03-22 at 14:48 -0700, samay sharma wrote:
Thank you for porting this on top of the pluggable auth methods API.
I've addressed the feedback around other backend changes in my latest
patch, but the client side changes still remain. I had a few
questions to understand them better.(a) What specifically do the client side changes in the patch implement?
Hi Samay,
The client-side changes are an implementation of the OAuth 2.0 Device
Authorization Grant [1]https://datatracker.ietf.org/doc/html/rfc8628 in libpq. The majority of the OAuth logic is
handled by the third-party iddawc library.
The server tells the client what OIDC provider to contact, and then
libpq prompts you to log into that provider on your
smartphone/browser/etc. using a one-time code. After you give libpq
permission to act on your behalf, the Bearer token gets sent to libpq
via a direct connection, and libpq forwards it to the server so that
the server can determine whether you're allowed in.
(b) Are the changes you made on the client side specific to OAUTH or
are they about making SASL more generic?
The original patchset included changes to make SASL more generic. Many
of those changes have since been merged, and the remaining code is
mostly OAuth-specific, but there are still improvements to be made.
(And there's some JSON crud to sift through in the first couple of
patches. I'm still mad that the OAUTHBEARER spec requires clients to
parse JSON in the first place.)
As an additional question,
if someone wanted to implement something similar on top of your
patch, would they still have to make client side changes?
Any new SASL mechanisms require changes to libpq at this point. You
need to implement a new pg_sasl_mech, modify pg_SASL_init() to select
the mechanism correctly, and add whatever connection string options you
need, along with the associated state in pg_conn. Patch 0004 has all
the client-side magic for OAUTHBEARER.
--Jacob
On Fri, 2022-03-04 at 19:13 +0000, Jacob Champion wrote:
v3 rebases this patchset over the top of Samay's pluggable auth
provider API [1], included here as patches 0001-3.
v4 rebases over the latest version of the pluggable auth patchset
(included as 0001-4). Note that there's a recent conflict as
of d4781d887; use an older commit as the base (or wait for the other
thread to be updated).
--Jacob
Attachments:
v4-0007-backend-add-OAUTHBEARER-SASL-mechanism.patchtext/x-patch; name=v4-0007-backend-add-OAUTHBEARER-SASL-mechanism.patchDownload
From b3ceda62e9cc6cbbc24c63c05c5ce072ae771c1b Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 4 May 2021 16:21:11 -0700
Subject: [PATCH v4 07/10] backend: add OAUTHBEARER SASL mechanism
DO NOT USE THIS PROOF OF CONCEPT IN PRODUCTION.
Implement OAUTHBEARER (RFC 7628) on the server side. This adds a new
auth method, oauth, to pg_hba.
Because OAuth implementations vary so wildly, and bearer token
validation is heavily dependent on the issuing party, authn/z is done by
communicating with an external program: the oauth_validator_command.
This command must do the following:
1. Receive the bearer token by reading its contents from a file
descriptor passed from the server. (The numeric value of this
descriptor may be inserted into the oauth_validator_command using the
%f specifier.)
This MUST be the first action the command performs. The server will
not begin reading stdout from the command until the token has been
read in full, so if the command tries to print anything and hits a
buffer limit, the backend will deadlock and time out.
2. Validate the bearer token. The correct way to do this depends on the
issuer, but it generally involves either cryptographic operations to
prove that the token was issued by a trusted party, or the
presentation of the bearer token to some other party so that _it_ can
perform validation.
The command MUST maintain confidentiality of the bearer token, since
in most cases it can be used just like a password. (There are ways to
cryptographically bind tokens to client certificates, but they are
way beyond the scope of this commit message.)
If the token cannot be validated, the command must exit with a
non-zero status. Further authentication/authorization is pointless if
the bearer token wasn't issued by someone you trust.
3. Authenticate the user, authorize the user, or both:
a. To authenticate the user, use the bearer token to retrieve some
trusted identifier string for the end user. The exact process for
this is, again, issuer-dependent. The command should print the
authenticated identity string to stdout, followed by a newline.
If the user cannot be authenticated, the validator should not
print anything to stdout. It should also exit with a non-zero
status, unless the token may be used to authorize the connection
through some other means (see below).
On a success, the command may then exit with a zero success code.
By default, the server will then check to make sure the identity
string matches the role that is being used (or matches a usermap
entry, if one is in use).
b. To optionally authorize the user, in combination with the HBA
option trust_validator_authz=1 (see below), the validator simply
returns a zero exit code if the client should be allowed to
connect with its presented role (which can be passed to the
command using the %r specifier), or a non-zero code otherwise.
The hard part is in determining whether the given token truly
authorizes the client to use the given role, which must
unfortunately be left as an exercise to the reader.
This obviously requires some care, as a poorly implemented token
validator may silently open the entire database to anyone with a
bearer token. But it may be a more portable approach, since OAuth
is designed as an authorization framework, not an authentication
framework. For example, the user's bearer token could carry an
"allow_superuser_access" claim, which would authorize pseudonymous
database access as any role. It's then up to the OAuth system
administrators to ensure that allow_superuser_access is doled out
only to the proper users.
c. It's possible that the user can be successfully authenticated but
isn't authorized to connect. In this case, the command may print
the authenticated ID and then fail with a non-zero exit code.
(This makes it easier to see what's going on in the Postgres
logs.)
4. Token validators may optionally log to stderr. This will be printed
verbatim into the Postgres server logs.
The oauth method supports the following HBA options (but note that two
of them are not optional, since we have no way of choosing sensible
defaults):
issuer: Required. The URL of the OAuth issuing party, which the client
must contact to receive a bearer token.
Some real-world examples as of time of writing:
- https://accounts.google.com
- https://login.microsoft.com/[tenant-id]/v2.0
scope: Required. The OAuth scope(s) required for the server to
authenticate and/or authorize the user. This is heavily
deployment-specific, but a simple example is "openid email".
map: Optional. Specify a standard PostgreSQL user map; this works
the same as with other auth methods such as peer. If a map is
not specified, the user ID returned by the token validator
must exactly match the role that's being requested (but see
trust_validator_authz, below).
trust_validator_authz:
Optional. When set to 1, this allows the token validator to
take full control of the authorization process. Standard user
mapping is skipped: if the validator command succeeds, the
client is allowed to connect under its desired role and no
further checks are done.
Unlike the client, servers support OAuth without needing to be built
against libiddawc (since the responsibility for "speaking" OAuth/OIDC
correctly is delegated entirely to the oauth_validator_command).
Several TODOs:
- port to platforms other than "modern Linux"
- overhaul the communication with oauth_validator_command, which is
currently a bad hack on OpenPipeStream()
- implement more sanity checks on the OAUTHBEARER message format and
tokens sent by the client
- implement more helpful handling of HBA misconfigurations
- properly interpolate JSON when generating error responses
- use logdetail during auth failures
- deal with role names that can't be safely passed to system() without
shell-escaping
- allow passing the configured issuer to the oauth_validator_command, to
deal with multi-issuer setups
- ...and more.
---
src/backend/libpq/Makefile | 1 +
src/backend/libpq/auth-oauth.c | 797 +++++++++++++++++++++++++++++++++
src/backend/libpq/auth-sasl.c | 10 +-
src/backend/libpq/auth-scram.c | 4 +-
src/backend/libpq/auth.c | 7 +
src/backend/libpq/hba.c | 29 +-
src/backend/utils/misc/guc.c | 12 +
src/include/libpq/hba.h | 8 +-
src/include/libpq/oauth.h | 24 +
src/include/libpq/sasl.h | 11 +
10 files changed, 889 insertions(+), 14 deletions(-)
create mode 100644 src/backend/libpq/auth-oauth.c
create mode 100644 src/include/libpq/oauth.h
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 6d385fd6a4..98eb2a8242 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,6 +15,7 @@ include $(top_builddir)/src/Makefile.global
# be-fsstubs is here for historical reasons, probably belongs elsewhere
OBJS = \
+ auth-oauth.o \
auth-sasl.o \
auth-scram.o \
auth.o \
diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c
new file mode 100644
index 0000000000..c1232a31a0
--- /dev/null
+++ b/src/backend/libpq/auth-oauth.c
@@ -0,0 +1,797 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth-oauth.c
+ * Server-side implementation of the SASL OAUTHBEARER mechanism.
+ *
+ * See the following RFC for more details:
+ * - RFC 7628: https://tools.ietf.org/html/rfc7628
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/backend/libpq/auth-oauth.c
+ *
+ *-------------------------------------------------------------------------
+ */
+#include "postgres.h"
+
+#include <unistd.h>
+#include <fcntl.h>
+
+#include "common/oauth-common.h"
+#include "lib/stringinfo.h"
+#include "libpq/auth.h"
+#include "libpq/hba.h"
+#include "libpq/oauth.h"
+#include "libpq/sasl.h"
+#include "storage/fd.h"
+
+/* GUC */
+char *oauth_validator_command;
+
+static void oauth_get_mechanisms(Port *port, StringInfo buf);
+static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
+static int oauth_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, const char **logdetail);
+
+/* Mechanism declaration */
+const pg_be_sasl_mech pg_be_oauth_mech = {
+ oauth_get_mechanisms,
+ oauth_init,
+ oauth_exchange,
+
+ PG_MAX_AUTH_TOKEN_LENGTH,
+};
+
+
+typedef enum
+{
+ OAUTH_STATE_INIT = 0,
+ OAUTH_STATE_ERROR,
+ OAUTH_STATE_FINISHED,
+} oauth_state;
+
+struct oauth_ctx
+{
+ oauth_state state;
+ Port *port;
+ const char *issuer;
+ const char *scope;
+};
+
+static char *sanitize_char(char c);
+static char *parse_kvpairs_for_auth(char **input);
+static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
+static bool validate(Port *port, const char *auth, const char **logdetail);
+static bool run_validator_command(Port *port, const char *token);
+static bool check_exit(FILE **fh, const char *command);
+static bool unset_cloexec(int fd);
+static bool username_ok_for_shell(const char *username);
+
+#define KVSEP 0x01
+#define AUTH_KEY "auth"
+#define BEARER_SCHEME "Bearer "
+
+static void
+oauth_get_mechanisms(Port *port, StringInfo buf)
+{
+ /* Only OAUTHBEARER is supported. */
+ appendStringInfoString(buf, OAUTHBEARER_NAME);
+ appendStringInfoChar(buf, '\0');
+}
+
+static void *
+oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
+{
+ struct oauth_ctx *ctx;
+
+ if (strcmp(selected_mech, OAUTHBEARER_NAME))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("client selected an invalid SASL authentication mechanism")));
+
+ ctx = palloc0(sizeof(*ctx));
+
+ ctx->state = OAUTH_STATE_INIT;
+ ctx->port = port;
+
+ Assert(port->hba);
+ ctx->issuer = port->hba->oauth_issuer;
+ ctx->scope = port->hba->oauth_scope;
+
+ return ctx;
+}
+
+static int
+oauth_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, const char **logdetail)
+{
+ char *p;
+ char cbind_flag;
+ char *auth;
+
+ struct oauth_ctx *ctx = opaq;
+
+ *output = NULL;
+ *outputlen = -1;
+
+ /*
+ * If the client didn't include an "Initial Client Response" in the
+ * SASLInitialResponse message, send an empty challenge, to which the
+ * client will respond with the same data that usually comes in the
+ * Initial Client Response.
+ */
+ if (input == NULL)
+ {
+ Assert(ctx->state == OAUTH_STATE_INIT);
+
+ *output = pstrdup("");
+ *outputlen = 0;
+ return PG_SASL_EXCHANGE_CONTINUE;
+ }
+
+ /*
+ * Check that the input length agrees with the string length of the input.
+ */
+ if (inputlen == 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("The message is empty.")));
+ if (inputlen != strlen(input))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message length does not match input length.")));
+
+ switch (ctx->state)
+ {
+ case OAUTH_STATE_INIT:
+ /* Handle this case below. */
+ break;
+
+ case OAUTH_STATE_ERROR:
+ /*
+ * Only one response is valid for the client during authentication
+ * failure: a single kvsep.
+ */
+ if (inputlen != 1 || *input != KVSEP)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Client did not send a kvsep response.")));
+
+ /* The (failed) handshake is now complete. */
+ ctx->state = OAUTH_STATE_FINISHED;
+ return PG_SASL_EXCHANGE_FAILURE;
+
+ default:
+ elog(ERROR, "invalid OAUTHBEARER exchange state");
+ return PG_SASL_EXCHANGE_FAILURE;
+ }
+
+ /* Handle the client's initial message. */
+ p = pstrdup(input);
+
+ /*
+ * OAUTHBEARER does not currently define a channel binding (so there is no
+ * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a 'y'
+ * specifier purely for the remote chance that a future specification could
+ * define one; then future clients can still interoperate with this server
+ * implementation. 'n' is the expected case.
+ */
+ cbind_flag = *p;
+ switch (cbind_flag)
+ {
+ case 'p':
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data.")));
+ break;
+
+ case 'y': /* fall through */
+ case 'n':
+ p++;
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Comma expected, but found character %s.",
+ sanitize_char(*p))));
+ p++;
+ break;
+
+ default:
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Unexpected channel-binding flag %s.",
+ sanitize_char(cbind_flag))));
+ }
+
+ /*
+ * Forbid optional authzid (authorization identity). We don't support it.
+ */
+ if (*p == 'a')
+ ereport(ERROR,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("client uses authorization identity, but it is not supported")));
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Unexpected attribute %s in client-first-message.",
+ sanitize_char(*p))));
+ p++;
+
+ /* All remaining fields are separated by the RFC's kvsep (\x01). */
+ if (*p != KVSEP)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Key-value separator expected, but found character %s.",
+ sanitize_char(*p))));
+ p++;
+
+ auth = parse_kvpairs_for_auth(&p);
+ if (!auth)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message does not contain an auth value.")));
+
+ /* We should be at the end of our message. */
+ if (*p)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains additional data after the final terminator.")));
+
+ if (!validate(ctx->port, auth, logdetail))
+ {
+ generate_error_response(ctx, output, outputlen);
+
+ ctx->state = OAUTH_STATE_ERROR;
+ return PG_SASL_EXCHANGE_CONTINUE;
+ }
+
+ ctx->state = OAUTH_STATE_FINISHED;
+ return PG_SASL_EXCHANGE_SUCCESS;
+}
+
+/*
+ * Convert an arbitrary byte to printable form. For error messages.
+ *
+ * If it's a printable ASCII character, print it as a single character.
+ * otherwise, print it in hex.
+ *
+ * The returned pointer points to a static buffer.
+ */
+static char *
+sanitize_char(char c)
+{
+ static char buf[5];
+
+ if (c >= 0x21 && c <= 0x7E)
+ snprintf(buf, sizeof(buf), "'%c'", c);
+ else
+ snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
+ return buf;
+}
+
+/*
+ * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is
+ * found, its value is returned.
+ */
+static char *
+parse_kvpairs_for_auth(char **input)
+{
+ char *pos = *input;
+ char *auth = NULL;
+
+ /*
+ * The relevant ABNF, from Sec. 3.1:
+ *
+ * kvsep = %x01
+ * key = 1*(ALPHA)
+ * value = *(VCHAR / SP / HTAB / CR / LF )
+ * kvpair = key "=" value kvsep
+ * ;;gs2-header = See RFC 5801
+ * client-resp = (gs2-header kvsep *kvpair kvsep) / kvsep
+ *
+ * By the time we reach this code, the gs2-header and initial kvsep have
+ * already been validated. We start at the beginning of the first kvpair.
+ */
+
+ while (*pos)
+ {
+ char *end;
+ char *sep;
+ char *key;
+ char *value;
+
+ /*
+ * Find the end of this kvpair. Note that input is null-terminated by
+ * the SASL code, so the strchr() is bounded.
+ */
+ end = strchr(pos, KVSEP);
+ if (!end)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains an unterminated key/value pair.")));
+ *end = '\0';
+
+ if (pos == end)
+ {
+ /* Empty kvpair, signifying the end of the list. */
+ *input = pos + 1;
+ return auth;
+ }
+
+ /*
+ * Find the end of the key name.
+ *
+ * TODO further validate the key/value grammar? empty keys, bad chars...
+ */
+ sep = strchr(pos, '=');
+ if (!sep)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains a key without a value.")));
+ *sep = '\0';
+
+ /* Both key and value are now safely terminated. */
+ key = pos;
+ value = sep + 1;
+
+ if (!strcmp(key, AUTH_KEY))
+ {
+ if (auth)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message contains multiple auth values.")));
+
+ auth = value;
+ }
+ else
+ {
+ /*
+ * The RFC also defines the host and port keys, but they are not
+ * required for OAUTHBEARER and we do not use them. Also, per
+ * Sec. 3.1, any key/value pairs we don't recognize must be ignored.
+ */
+ }
+
+ /* Move to the next pair. */
+ pos = end + 1;
+ }
+
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Message did not contain a final terminator.")));
+
+ return NULL; /* unreachable */
+}
+
+static void
+generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
+{
+ StringInfoData buf;
+
+ /*
+ * The admin needs to set an issuer and scope for OAuth to work. There's not
+ * really a way to hide this from the user, either, because we can't choose
+ * a "default" issuer, so be honest in the failure message.
+ *
+ * TODO: see if there's a better place to fail, earlier than this.
+ */
+ if (!ctx->issuer || !ctx->scope)
+ ereport(FATAL,
+ (errcode(ERRCODE_INTERNAL_ERROR),
+ errmsg("OAuth is not properly configured for this user"),
+ errdetail_log("The issuer and scope parameters must be set in pg_hba.conf.")));
+
+
+ initStringInfo(&buf);
+
+ /*
+ * TODO: JSON escaping
+ */
+ appendStringInfo(&buf,
+ "{ "
+ "\"status\": \"invalid_token\", "
+ "\"openid-configuration\": \"%s/.well-known/openid-configuration\","
+ "\"scope\": \"%s\" "
+ "}",
+ ctx->issuer, ctx->scope);
+
+ *output = buf.data;
+ *outputlen = buf.len;
+}
+
+static bool
+validate(Port *port, const char *auth, const char **logdetail)
+{
+ static const char * const b64_set = "abcdefghijklmnopqrstuvwxyz"
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "0123456789-._~+/";
+
+ const char *token;
+ size_t span;
+ int ret;
+
+ /* TODO: handle logdetail when the test framework can check it */
+
+ /*
+ * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec.
+ * 2.1:
+ *
+ * b64token = 1*( ALPHA / DIGIT /
+ * "-" / "." / "_" / "~" / "+" / "/" ) *"="
+ * credentials = "Bearer" 1*SP b64token
+ *
+ * The "credentials" construction is what we receive in our auth value.
+ *
+ * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization
+ * header format; RFC 7235 Sec. 2), the "Bearer" scheme string must be
+ * compared case-insensitively. (This is not mentioned in RFC 6750, but it's
+ * pointed out in RFC 7628 Sec. 4.)
+ *
+ * TODO: handle the Authorization spec, RFC 7235 Sec. 2.1.
+ */
+ if (strncasecmp(auth, BEARER_SCHEME, strlen(BEARER_SCHEME)))
+ return false;
+
+ /* Pull the bearer token out of the auth value. */
+ token = auth + strlen(BEARER_SCHEME);
+
+ /* Swallow any additional spaces. */
+ while (*token == ' ')
+ token++;
+
+ /*
+ * Before invoking the validator command, sanity-check the token format to
+ * avoid any injection attacks later in the chain. Invalid formats are
+ * technically a protocol violation, but don't reflect any information about
+ * the sensitive Bearer token back to the client; log at COMMERROR instead.
+ */
+
+ /* Tokens must not be empty. */
+ if (!*token)
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Bearer token is empty.")));
+ return false;
+ }
+
+ /*
+ * Make sure the token contains only allowed characters. Tokens may end with
+ * any number of '=' characters.
+ */
+ span = strspn(token, b64_set);
+ while (token[span] == '=')
+ span++;
+
+ if (token[span] != '\0')
+ {
+ /*
+ * This error message could be more helpful by printing the problematic
+ * character(s), but that'd be a bit like printing a piece of someone's
+ * password into the logs.
+ */
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed OAUTHBEARER message"),
+ errdetail("Bearer token is not in the correct format.")));
+ return false;
+ }
+
+ /* Have the validator check the token. */
+ if (!run_validator_command(port, token))
+ return false;
+
+ if (port->hba->oauth_skip_usermap)
+ {
+ /*
+ * If the validator is our authorization authority, we're done.
+ * Authentication may or may not have been performed depending on the
+ * validator implementation; all that matters is that the validator says
+ * the user can log in with the target role.
+ */
+ return true;
+ }
+
+ /* Make sure the validator authenticated the user. */
+ if (!port->authn_id)
+ {
+ /* TODO: use logdetail; reduce message duplication */
+ ereport(LOG,
+ (errmsg("OAuth bearer authentication failed for user \"%s\": validator provided no identity",
+ port->user_name)));
+ return false;
+ }
+
+ /* Finally, check the user map. */
+ ret = check_usermap(port->hba->usermap, port->user_name, port->authn_id,
+ false);
+ return (ret == STATUS_OK);
+}
+
+static bool
+run_validator_command(Port *port, const char *token)
+{
+ bool success = false;
+ int rc;
+ int pipefd[2];
+ int rfd = -1;
+ int wfd = -1;
+
+ StringInfoData command = { 0 };
+ char *p;
+ FILE *fh = NULL;
+
+ ssize_t written;
+ char *line = NULL;
+ size_t size = 0;
+ ssize_t len;
+
+ Assert(oauth_validator_command);
+
+ if (!oauth_validator_command[0])
+ {
+ ereport(COMMERROR,
+ (errmsg("oauth_validator_command is not set"),
+ errhint("To allow OAuth authenticated connections, set "
+ "oauth_validator_command in postgresql.conf.")));
+ return false;
+ }
+
+ /*
+ * Since popen() is unidirectional, open up a pipe for the other direction.
+ * Use CLOEXEC to ensure that our write end doesn't accidentally get copied
+ * into child processes, which would prevent us from closing it cleanly.
+ *
+ * XXX this is ugly. We should just read from the child process's stdout,
+ * but that's a lot more code.
+ * XXX by bypassing the popen API, we open the potential of process
+ * deadlock. Clearly document child process requirements (i.e. the child
+ * MUST read all data off of the pipe before writing anything).
+ * TODO: port to Windows using _pipe().
+ */
+ rc = pipe2(pipefd, O_CLOEXEC);
+ if (rc < 0)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not create child pipe: %m")));
+ return false;
+ }
+
+ rfd = pipefd[0];
+ wfd = pipefd[1];
+
+ /* Allow the read pipe be passed to the child. */
+ if (!unset_cloexec(rfd))
+ {
+ /* error message was already logged */
+ goto cleanup;
+ }
+
+ /*
+ * Construct the command, substituting any recognized %-specifiers:
+ *
+ * %f: the file descriptor of the input pipe
+ * %r: the role that the client wants to assume (port->user_name)
+ * %%: a literal '%'
+ */
+ initStringInfo(&command);
+
+ for (p = oauth_validator_command; *p; p++)
+ {
+ if (p[0] == '%')
+ {
+ switch (p[1])
+ {
+ case 'f':
+ appendStringInfo(&command, "%d", rfd);
+ p++;
+ break;
+ case 'r':
+ /*
+ * TODO: decide how this string should be escaped. The role
+ * is controlled by the client, so if we don't escape it,
+ * command injections are inevitable.
+ *
+ * This is probably an indication that the role name needs
+ * to be communicated to the validator process in some other
+ * way. For this proof of concept, just be incredibly strict
+ * about the characters that are allowed in user names.
+ */
+ if (!username_ok_for_shell(port->user_name))
+ goto cleanup;
+
+ appendStringInfoString(&command, port->user_name);
+ p++;
+ break;
+ case '%':
+ appendStringInfoChar(&command, '%');
+ p++;
+ break;
+ default:
+ appendStringInfoChar(&command, p[0]);
+ }
+ }
+ else
+ appendStringInfoChar(&command, p[0]);
+ }
+
+ /* Execute the command. */
+ fh = OpenPipeStream(command.data, "re");
+ /* TODO: handle failures */
+
+ /* We don't need the read end of the pipe anymore. */
+ close(rfd);
+ rfd = -1;
+
+ /* Give the command the token to validate. */
+ written = write(wfd, token, strlen(token));
+ if (written != strlen(token))
+ {
+ /* TODO must loop for short writes, EINTR et al */
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not write token to child pipe: %m")));
+ goto cleanup;
+ }
+
+ close(wfd);
+ wfd = -1;
+
+ /*
+ * Read the command's response.
+ *
+ * TODO: getline() is probably too new to use, unfortunately.
+ * TODO: loop over all lines
+ */
+ if ((len = getline(&line, &size, fh)) >= 0)
+ {
+ /* TODO: fail if the authn_id doesn't end with a newline */
+ if (len > 0)
+ line[len - 1] = '\0';
+
+ set_authn_id(port, line);
+ }
+ else if (ferror(fh))
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not read from command \"%s\": %m",
+ command.data)));
+ goto cleanup;
+ }
+
+ /* Make sure the command exits cleanly. */
+ if (!check_exit(&fh, command.data))
+ {
+ /* error message already logged */
+ goto cleanup;
+ }
+
+ /* Done. */
+ success = true;
+
+cleanup:
+ if (line)
+ free(line);
+
+ /*
+ * In the successful case, the pipe fds are already closed. For the error
+ * case, always close out the pipe before waiting for the command, to
+ * prevent deadlock.
+ */
+ if (rfd >= 0)
+ close(rfd);
+ if (wfd >= 0)
+ close(wfd);
+
+ if (fh)
+ {
+ Assert(!success);
+ check_exit(&fh, command.data);
+ }
+
+ if (command.data)
+ pfree(command.data);
+
+ return success;
+}
+
+static bool
+check_exit(FILE **fh, const char *command)
+{
+ int rc;
+
+ rc = ClosePipeStream(*fh);
+ *fh = NULL;
+
+ if (rc == -1)
+ {
+ /* pclose() itself failed. */
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not close pipe to command \"%s\": %m",
+ command)));
+ }
+ else if (rc != 0)
+ {
+ char *reason = wait_result_to_str(rc);
+
+ ereport(COMMERROR,
+ (errmsg("failed to execute command \"%s\": %s",
+ command, reason)));
+
+ pfree(reason);
+ }
+
+ return (rc == 0);
+}
+
+static bool
+unset_cloexec(int fd)
+{
+ int flags;
+ int rc;
+
+ flags = fcntl(fd, F_GETFD);
+ if (flags == -1)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not get fd flags for child pipe: %m")));
+ return false;
+ }
+
+ rc = fcntl(fd, F_SETFD, flags & ~FD_CLOEXEC);
+ if (rc < 0)
+ {
+ ereport(COMMERROR,
+ (errcode_for_file_access(),
+ errmsg("could not unset FD_CLOEXEC for child pipe: %m")));
+ return false;
+ }
+
+ return true;
+}
+
+/*
+ * XXX This should go away eventually and be replaced with either a proper
+ * escape or a different strategy for communication with the validator command.
+ */
+static bool
+username_ok_for_shell(const char *username)
+{
+ /* This set is borrowed from fe_utils' appendShellStringNoError(). */
+ static const char * const allowed = "abcdefghijklmnopqrstuvwxyz"
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "0123456789-_./:";
+ size_t span;
+
+ Assert(username && username[0]); /* should have already been checked */
+
+ span = strspn(username, allowed);
+ if (username[span] != '\0')
+ {
+ ereport(COMMERROR,
+ (errmsg("PostgreSQL user name contains unsafe characters and cannot be passed to the OAuth validator")));
+ return false;
+ }
+
+ return true;
+}
diff --git a/src/backend/libpq/auth-sasl.c b/src/backend/libpq/auth-sasl.c
index a1d7dbb6d5..0f461a6696 100644
--- a/src/backend/libpq/auth-sasl.c
+++ b/src/backend/libpq/auth-sasl.c
@@ -20,14 +20,6 @@
#include "libpq/pqformat.h"
#include "libpq/sasl.h"
-/*
- * Maximum accepted size of SASL messages.
- *
- * The messages that the server or libpq generate are much smaller than this,
- * but have some headroom.
- */
-#define PG_MAX_SASL_MESSAGE_LENGTH 1024
-
/*
* Perform a SASL exchange with a libpq client, using a specific mechanism
* implementation.
@@ -103,7 +95,7 @@ CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
/* Get the actual SASL message */
initStringInfo(&buf);
- if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
+ if (pq_getmessage(&buf, mech->max_message_length))
{
/* EOF - pq_getmessage already logged error */
pfree(buf.data);
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index ee7f52218a..4049ace470 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -118,7 +118,9 @@ static int scram_exchange(void *opaq, const char *input, int inputlen,
const pg_be_sasl_mech pg_be_scram_mech = {
scram_get_mechanisms,
scram_init,
- scram_exchange
+ scram_exchange,
+
+ PG_MAX_SASL_MESSAGE_LENGTH
};
/*
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 4a8a63922a..17042d84ad 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -30,6 +30,7 @@
#include "libpq/auth.h"
#include "libpq/crypt.h"
#include "libpq/libpq.h"
+#include "libpq/oauth.h"
#include "libpq/pqformat.h"
#include "libpq/sasl.h"
#include "libpq/scram.h"
@@ -298,6 +299,9 @@ auth_failed(Port *port, int status, const char *logdetail)
case uaRADIUS:
errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
break;
+ case uaOAuth:
+ errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
+ break;
case uaCustom:
{
CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
@@ -626,6 +630,9 @@ ClientAuthentication(Port *port)
case uaTrust:
status = STATUS_OK;
break;
+ case uaOAuth:
+ status = CheckSASLAuth(&pg_be_oauth_mech, port, NULL, NULL);
+ break;
case uaCustom:
{
CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 42cb1ce51d..cd3b1cc140 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -136,7 +136,8 @@ static const char *const UserAuthName[] =
"cert",
"radius",
"custom",
- "peer"
+ "peer",
+ "oauth",
};
@@ -1401,6 +1402,8 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
#endif
else if (strcmp(token->string, "radius") == 0)
parsedline->auth_method = uaRADIUS;
+ else if (strcmp(token->string, "oauth") == 0)
+ parsedline->auth_method = uaOAuth;
else if (strcmp(token->string, "custom") == 0)
parsedline->auth_method = uaCustom;
else
@@ -1730,8 +1733,9 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->auth_method != uaGSS &&
hbaline->auth_method != uaSSPI &&
hbaline->auth_method != uaCert &&
+ hbaline->auth_method != uaOAuth &&
hbaline->auth_method != uaCustom)
- INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert and custom"));
+ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, oauth, and custom"));
hbaline->usermap = pstrdup(val);
}
else if (strcmp(name, "clientcert") == 0)
@@ -2115,6 +2119,27 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->radiusidentifiers = parsed_identifiers;
hbaline->radiusidentifiers_s = pstrdup(val);
}
+ else if (strcmp(name, "issuer") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("issuer", gettext_noop("oauth"));
+ hbaline->oauth_issuer = pstrdup(val);
+ }
+ else if (strcmp(name, "scope") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("scope", gettext_noop("oauth"));
+ hbaline->oauth_scope = pstrdup(val);
+ }
+ else if (strcmp(name, "trust_validator_authz") == 0)
+ {
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("trust_validator_authz", gettext_noop("oauth"));
+ if (strcmp(val, "1") == 0)
+ hbaline->oauth_skip_usermap = true;
+ else
+ hbaline->oauth_skip_usermap = false;
+ }
else if (strcmp(name, "provider") == 0)
{
REQUIRE_AUTH_OPTION(uaCustom, "provider", "custom");
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index f70f7f5c01..9a5b2aa496 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -59,6 +59,7 @@
#include "libpq/auth.h"
#include "libpq/libpq.h"
#include "libpq/pqformat.h"
+#include "libpq/oauth.h"
#include "miscadmin.h"
#include "optimizer/cost.h"
#include "optimizer/geqo.h"
@@ -4666,6 +4667,17 @@ static struct config_string ConfigureNamesString[] =
check_backtrace_functions, assign_backtrace_functions, NULL
},
+ {
+ {"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH,
+ gettext_noop("Command to validate OAuth v2 bearer tokens."),
+ NULL,
+ GUC_SUPERUSER_ONLY
+ },
+ &oauth_validator_command,
+ "",
+ NULL, NULL, NULL
+ },
+
/* End-of-list marker */
{
{NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 31a00c4b71..e405103a2e 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -39,8 +39,9 @@ typedef enum UserAuth
uaCert,
uaRADIUS,
uaCustom,
- uaPeer
-#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */
+ uaPeer,
+ uaOAuth
+#define USER_AUTH_LAST uaOAuth /* Must be last value of this enum */
} UserAuth;
/*
@@ -128,6 +129,9 @@ typedef struct HbaLine
char *radiusidentifiers_s;
List *radiusports;
char *radiusports_s;
+ char *oauth_issuer;
+ char *oauth_scope;
+ bool oauth_skip_usermap;
char *custom_provider;
List *custom_auth_options;
} HbaLine;
diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h
new file mode 100644
index 0000000000..870e426af1
--- /dev/null
+++ b/src/include/libpq/oauth.h
@@ -0,0 +1,24 @@
+/*-------------------------------------------------------------------------
+ *
+ * oauth.h
+ * Interface to libpq/auth-oauth.c
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/libpq/oauth.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef PG_OAUTH_H
+#define PG_OAUTH_H
+
+#include "libpq/libpq-be.h"
+#include "libpq/sasl.h"
+
+extern char *oauth_validator_command;
+
+/* Implementation */
+extern const pg_be_sasl_mech pg_be_oauth_mech;
+
+#endif /* PG_OAUTH_H */
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
index 39ccf8f0e3..f7d905591a 100644
--- a/src/include/libpq/sasl.h
+++ b/src/include/libpq/sasl.h
@@ -26,6 +26,14 @@
#define PG_SASL_EXCHANGE_SUCCESS 1
#define PG_SASL_EXCHANGE_FAILURE 2
+/*
+ * Maximum accepted size of SASL messages.
+ *
+ * The messages that the server or libpq generate are much smaller than this,
+ * but have some headroom.
+ */
+#define PG_MAX_SASL_MESSAGE_LENGTH 1024
+
/*
* Backend SASL mechanism callbacks.
*
@@ -127,6 +135,9 @@ typedef struct pg_be_sasl_mech
const char *input, int inputlen,
char **output, int *outputlen,
const char **logdetail);
+
+ /* The maximum size allowed for client SASLResponses. */
+ int max_message_length;
} pg_be_sasl_mech;
/* Common implementation for auth.c */
--
2.25.1
v4-0008-Add-a-very-simple-authn_id-extension.patchtext/x-patch; name=v4-0008-Add-a-very-simple-authn_id-extension.patchDownload
From 9dd8e024fde29239829e822b8f2b82028044cd8b Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 18 May 2021 15:01:29 -0700
Subject: [PATCH v4 08/10] Add a very simple authn_id extension
...for retrieving the authn_id from the server in tests.
---
contrib/authn_id/Makefile | 19 +++++++++++++++++++
contrib/authn_id/authn_id--1.0.sql | 8 ++++++++
contrib/authn_id/authn_id.c | 28 ++++++++++++++++++++++++++++
contrib/authn_id/authn_id.control | 5 +++++
4 files changed, 60 insertions(+)
create mode 100644 contrib/authn_id/Makefile
create mode 100644 contrib/authn_id/authn_id--1.0.sql
create mode 100644 contrib/authn_id/authn_id.c
create mode 100644 contrib/authn_id/authn_id.control
diff --git a/contrib/authn_id/Makefile b/contrib/authn_id/Makefile
new file mode 100644
index 0000000000..46026358e0
--- /dev/null
+++ b/contrib/authn_id/Makefile
@@ -0,0 +1,19 @@
+# contrib/authn_id/Makefile
+
+MODULE_big = authn_id
+OBJS = authn_id.o
+
+EXTENSION = authn_id
+DATA = authn_id--1.0.sql
+PGFILEDESC = "authn_id - information about the authenticated user"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = contrib/authn_id
+top_builddir = ../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/contrib/authn_id/authn_id--1.0.sql b/contrib/authn_id/authn_id--1.0.sql
new file mode 100644
index 0000000000..af2a4d3991
--- /dev/null
+++ b/contrib/authn_id/authn_id--1.0.sql
@@ -0,0 +1,8 @@
+/* contrib/authn_id/authn_id--1.0.sql */
+
+-- complain if script is sourced in psql, rather than via CREATE EXTENSION
+\echo Use "CREATE EXTENSION authn_id" to load this file. \quit
+
+CREATE FUNCTION authn_id() RETURNS text
+AS 'MODULE_PATHNAME', 'authn_id'
+LANGUAGE C IMMUTABLE;
diff --git a/contrib/authn_id/authn_id.c b/contrib/authn_id/authn_id.c
new file mode 100644
index 0000000000..0fecac36a8
--- /dev/null
+++ b/contrib/authn_id/authn_id.c
@@ -0,0 +1,28 @@
+/*
+ * Extension to expose the current user's authn_id.
+ *
+ * contrib/authn_id/authn_id.c
+ */
+
+#include "postgres.h"
+
+#include "fmgr.h"
+#include "libpq/libpq-be.h"
+#include "miscadmin.h"
+#include "utils/builtins.h"
+
+PG_MODULE_MAGIC;
+
+PG_FUNCTION_INFO_V1(authn_id);
+
+/*
+ * Returns the current user's authenticated identity.
+ */
+Datum
+authn_id(PG_FUNCTION_ARGS)
+{
+ if (!MyProcPort->authn_id)
+ PG_RETURN_NULL();
+
+ PG_RETURN_TEXT_P(cstring_to_text(MyProcPort->authn_id));
+}
diff --git a/contrib/authn_id/authn_id.control b/contrib/authn_id/authn_id.control
new file mode 100644
index 0000000000..e0f9e06bed
--- /dev/null
+++ b/contrib/authn_id/authn_id.control
@@ -0,0 +1,5 @@
+# authn_id extension
+comment = 'current user identity'
+default_version = '1.0'
+module_pathname = '$libdir/authn_id'
+relocatable = true
--
2.25.1
v4-0001-Add-support-for-custom-authentication-methods.patchtext/x-patch; name=v4-0001-Add-support-for-custom-authentication-methods.patchDownload
From 575431b4e035c266b55a25414f802fbf8ba16b97 Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Tue, 15 Feb 2022 22:23:29 -0800
Subject: [PATCH v4 01/10] Add support for custom authentication methods
Currently, PostgreSQL supports only a set of pre-defined authentication
methods. This patch adds support for 2 hooks which allow users to add
their custom authentication methods by defining a check function and an
error function. Users can then use these methods by using a new "custom"
keyword in pg_hba.conf and specifying the authentication provider they
want to use.
---
src/backend/libpq/auth.c | 108 ++++++++++++++++++++++++++++++++-------
src/backend/libpq/hba.c | 44 ++++++++++++++++
src/include/libpq/auth.h | 37 ++++++++++++++
src/include/libpq/hba.h | 2 +
4 files changed, 172 insertions(+), 19 deletions(-)
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index efc53f3135..375ee33892 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -47,8 +47,6 @@
*----------------------------------------------------------------
*/
static void auth_failed(Port *port, int status, const char *logdetail);
-static char *recv_password_packet(Port *port);
-static void set_authn_id(Port *port, const char *id);
/*----------------------------------------------------------------
@@ -206,22 +204,11 @@ static int pg_SSPI_make_upn(char *accountname,
static int CheckRADIUSAuth(Port *port);
static int PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd);
-
-/*
- * Maximum accepted size of GSS and SSPI authentication tokens.
- * We also use this as a limit on ordinary password packet lengths.
- *
- * Kerberos tickets are usually quite small, but the TGTs issued by Windows
- * domain controllers include an authorization field known as the Privilege
- * Attribute Certificate (PAC), which contains the user's Windows permissions
- * (group memberships etc.). The PAC is copied into all tickets obtained on
- * the basis of this TGT (even those issued by Unix realms which the Windows
- * realm trusts), and can be several kB in size. The maximum token size
- * accepted by Windows systems is determined by the MaxAuthToken Windows
- * registry setting. Microsoft recommends that it is not set higher than
- * 65535 bytes, so that seems like a reasonable limit for us as well.
+/*----------------------------------------------------------------
+ * Custom Authentication
+ *----------------------------------------------------------------
*/
-#define PG_MAX_AUTH_TOKEN_LENGTH 65535
+static List *custom_auth_providers = NIL;
/*----------------------------------------------------------------
* Global authentication functions
@@ -311,6 +298,15 @@ auth_failed(Port *port, int status, const char *logdetail)
case uaRADIUS:
errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
break;
+ case uaCustom:
+ {
+ CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
+ if (provider->auth_error_hook)
+ errstr = provider->auth_error_hook(port);
+ else
+ errstr = gettext_noop("Custom authentication failed for user \"%s\"");
+ break;
+ }
default:
errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method");
break;
@@ -345,7 +341,7 @@ auth_failed(Port *port, int status, const char *logdetail)
* lifetime of the Port, so it is safe to pass a string that is managed by an
* external library.
*/
-static void
+void
set_authn_id(Port *port, const char *id)
{
Assert(id);
@@ -630,6 +626,13 @@ ClientAuthentication(Port *port)
case uaTrust:
status = STATUS_OK;
break;
+ case uaCustom:
+ {
+ CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
+ if (provider->auth_check_hook)
+ status = provider->auth_check_hook(port);
+ break;
+ }
}
if ((status == STATUS_OK && port->hba->clientcert == clientCertFull)
@@ -689,7 +692,7 @@ sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extrale
*
* Returns NULL if couldn't get password, else palloc'd string.
*/
-static char *
+char *
recv_password_packet(Port *port)
{
StringInfoData buf;
@@ -3343,3 +3346,70 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
}
} /* while (true) */
}
+
+/*----------------------------------------------------------------
+ * Custom authentication
+ *----------------------------------------------------------------
+ */
+
+/*
+ * RegisterAuthProvider registers a custom authentication provider to be
+ * used for authentication. It validates the inputs and adds the provider
+ * name and it's hooks to a list of loaded providers. The right provider's
+ * hooks can then be called based on the provider name specified in
+ * pg_hba.conf.
+ *
+ * This function should be called in _PG_init() by any extension looking to
+ * add a custom authentication method.
+ */
+void RegisterAuthProvider(const char *provider_name,
+ CustomAuthenticationCheck_hook_type AuthenticationCheckFunction,
+ CustomAuthenticationError_hook_type AuthenticationErrorFunction)
+{
+ CustomAuthProvider *provider = NULL;
+ MemoryContext old_context;
+
+ if (provider_name == NULL)
+ {
+ ereport(ERROR,
+ (errmsg("cannot register authentication provider without name")));
+ }
+
+ if (AuthenticationCheckFunction == NULL)
+ {
+ ereport(ERROR,
+ (errmsg("cannot register authentication provider without a check function")));
+ }
+
+ /*
+ * Allocate in top memory context as we need to read this whenever
+ * we parse pg_hba.conf
+ */
+ old_context = MemoryContextSwitchTo(TopMemoryContext);
+ provider = palloc(sizeof(CustomAuthProvider));
+ provider->name = MemoryContextStrdup(TopMemoryContext,provider_name);
+ provider->auth_check_hook = AuthenticationCheckFunction;
+ provider->auth_error_hook = AuthenticationErrorFunction;
+ custom_auth_providers = lappend(custom_auth_providers, provider);
+ MemoryContextSwitchTo(old_context);
+}
+
+/*
+ * Returns the authentication provider (which includes it's
+ * callback functions) based on name specified.
+ */
+CustomAuthProvider *get_provider_by_name(const char *name)
+{
+ ListCell *lc;
+
+ foreach(lc, custom_auth_providers)
+ {
+ CustomAuthProvider *provider = (CustomAuthProvider *) lfirst(lc);
+ if (strcmp(provider->name,name) == 0)
+ {
+ return provider;
+ }
+ }
+
+ return NULL;
+}
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 90953c38f3..9f15252789 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -31,6 +31,7 @@
#include "common/ip.h"
#include "common/string.h"
#include "funcapi.h"
+#include "libpq/auth.h"
#include "libpq/ifaddr.h"
#include "libpq/libpq.h"
#include "miscadmin.h"
@@ -134,6 +135,7 @@ static const char *const UserAuthName[] =
"ldap",
"cert",
"radius",
+ "custom",
"peer"
};
@@ -1399,6 +1401,8 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
#endif
else if (strcmp(token->string, "radius") == 0)
parsedline->auth_method = uaRADIUS;
+ else if (strcmp(token->string, "custom") == 0)
+ parsedline->auth_method = uaCustom;
else
{
ereport(elevel,
@@ -1691,6 +1695,14 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
parsedline->clientcert = clientCertFull;
}
+ /*
+ * Ensure that the provider name is specified for custom authentication method.
+ */
+ if (parsedline->auth_method == uaCustom)
+ {
+ MANDATORY_AUTH_ARG(parsedline->custom_provider, "provider", "custom");
+ }
+
return parsedline;
}
@@ -2102,6 +2114,31 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->radiusidentifiers = parsed_identifiers;
hbaline->radiusidentifiers_s = pstrdup(val);
}
+ else if (strcmp(name, "provider") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaCustom, "provider", "custom");
+
+ /*
+ * Verify that the provider mentioned is loaded via shared_preload_libraries.
+ */
+
+ if (get_provider_by_name(val) == NULL)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("cannot use authentication provider %s",val),
+ errhint("Load authentication provider via shared_preload_libraries."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("cannot use authentication provider %s", val);
+
+ return false;
+ }
+ else
+ {
+ hbaline->custom_provider = pstrdup(val);
+ }
+ }
else
{
ereport(elevel,
@@ -2442,6 +2479,13 @@ gethba_options(HbaLine *hba)
CStringGetTextDatum(psprintf("radiusports=%s", hba->radiusports_s));
}
+ if (hba->auth_method == uaCustom)
+ {
+ if (hba->custom_provider)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("provider=%s",hba->custom_provider));
+ }
+
/* If you add more options, consider increasing MAX_HBA_OPTIONS. */
Assert(noptions <= MAX_HBA_OPTIONS);
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 6d7ee1acb9..7aff98d919 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -23,9 +23,46 @@ extern char *pg_krb_realm;
extern void ClientAuthentication(Port *port);
extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
int extralen);
+extern void set_authn_id(Port *port, const char *id);
+extern char *recv_password_packet(Port *port);
/* Hook for plugins to get control in ClientAuthentication() */
+typedef int (*CustomAuthenticationCheck_hook_type) (Port *);
typedef void (*ClientAuthentication_hook_type) (Port *, int);
extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook;
+/* Hook for plugins to report error messages in auth_failed() */
+typedef const char * (*CustomAuthenticationError_hook_type) (Port *);
+
+extern void RegisterAuthProvider
+ (const char *provider_name,
+ CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook,
+ CustomAuthenticationError_hook_type CustomAuthenticationError_hook);
+
+/* Declarations for custom authentication providers */
+typedef struct CustomAuthProvider
+{
+ const char *name;
+ CustomAuthenticationCheck_hook_type auth_check_hook;
+ CustomAuthenticationError_hook_type auth_error_hook;
+} CustomAuthProvider;
+
+extern CustomAuthProvider *get_provider_by_name(const char *name);
+
+/*
+ * Maximum accepted size of GSS and SSPI authentication tokens.
+ * We also use this as a limit on ordinary password packet lengths.
+ *
+ * Kerberos tickets are usually quite small, but the TGTs issued by Windows
+ * domain controllers include an authorization field known as the Privilege
+ * Attribute Certificate (PAC), which contains the user's Windows permissions
+ * (group memberships etc.). The PAC is copied into all tickets obtained on
+ * the basis of this TGT (even those issued by Unix realms which the Windows
+ * realm trusts), and can be several kB in size. The maximum token size
+ * accepted by Windows systems is determined by the MaxAuthToken Windows
+ * registry setting. Microsoft recommends that it is not set higher than
+ * 65535 bytes, so that seems like a reasonable limit for us as well.
+ */
+#define PG_MAX_AUTH_TOKEN_LENGTH 65535
+
#endif /* AUTH_H */
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 8d9f3821b1..48490c44ed 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -38,6 +38,7 @@ typedef enum UserAuth
uaLDAP,
uaCert,
uaRADIUS,
+ uaCustom,
uaPeer
#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */
} UserAuth;
@@ -120,6 +121,7 @@ typedef struct HbaLine
char *radiusidentifiers_s;
List *radiusports;
char *radiusports_s;
+ char *custom_provider;
} HbaLine;
typedef struct IdentLine
--
2.25.1
v4-0002-Add-sample-extension-to-test-custom-auth-provider.patchtext/x-patch; name=v4-0002-Add-sample-extension-to-test-custom-auth-provider.patchDownload
From cb4131d8424861826e443708bc0f4e6baa76c871 Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Tue, 15 Feb 2022 22:28:40 -0800
Subject: [PATCH v4 02/10] Add sample extension to test custom auth provider
hooks
This change adds a new extension to src/test/modules to
test the custom authentication provider hooks. In this
extension, we use an array to define which users to
authenticate and what passwords to use. We then get
encrypted passwords from the client and match them with
the encrypted version of the password in the array.
---
src/include/libpq/scram.h | 2 +-
src/test/modules/test_auth_provider/Makefile | 16 ++++
.../test_auth_provider/test_auth_provider.c | 86 +++++++++++++++++++
3 files changed, 103 insertions(+), 1 deletion(-)
create mode 100644 src/test/modules/test_auth_provider/Makefile
create mode 100644 src/test/modules/test_auth_provider/test_auth_provider.c
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index e60992a0d2..c51e848c24 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -18,7 +18,7 @@
#include "libpq/sasl.h"
/* SASL implementation callbacks */
-extern const pg_be_sasl_mech pg_be_scram_mech;
+extern PGDLLIMPORT const pg_be_sasl_mech pg_be_scram_mech;
/* Routines to handle and check SCRAM-SHA-256 secret */
extern char *pg_be_scram_build_secret(const char *password);
diff --git a/src/test/modules/test_auth_provider/Makefile b/src/test/modules/test_auth_provider/Makefile
new file mode 100644
index 0000000000..17971a5c7a
--- /dev/null
+++ b/src/test/modules/test_auth_provider/Makefile
@@ -0,0 +1,16 @@
+# src/test/modules/test_auth_provider/Makefile
+
+MODULE_big = test_auth_provider
+OBJS = test_auth_provider.o
+PGFILEDESC = "test_auth_provider - provider to test auth hooks"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = src/test/modules/test_auth_provider
+top_builddir = ../../../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/src/test/modules/test_auth_provider/test_auth_provider.c b/src/test/modules/test_auth_provider/test_auth_provider.c
new file mode 100644
index 0000000000..7c4b1f3500
--- /dev/null
+++ b/src/test/modules/test_auth_provider/test_auth_provider.c
@@ -0,0 +1,86 @@
+/* -------------------------------------------------------------------------
+ *
+ * test_auth_provider.c
+ * example authentication provider plugin
+ *
+ * Copyright (c) 2022, PostgreSQL Global Development Group
+ *
+ * IDENTIFICATION
+ * contrib/test_auth_provider/test_auth_provider.c
+ *
+ * -------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+#include "fmgr.h"
+#include "libpq/auth.h"
+#include "libpq/libpq.h"
+#include "libpq/scram.h"
+
+PG_MODULE_MAGIC;
+
+void _PG_init(void);
+
+static char *get_encrypted_password_for_user(char *user_name);
+
+/*
+ * List of usernames / passwords to approve. Here we are not
+ * getting passwords from Postgres but from this list. In a more real-life
+ * extension, you can fetch valid credentials and authentication tokens /
+ * passwords from an external authentication provider.
+ */
+char credentials[3][3][50] = {
+ {"bob","alice","carol"},
+ {"bob123","alice123","carol123"}
+};
+
+static int TestAuthenticationCheck(Port *port)
+{
+ int result = STATUS_ERROR;
+ char *real_pass;
+ const char *logdetail = NULL;
+
+ real_pass = get_encrypted_password_for_user(port->user_name);
+ if (real_pass)
+ {
+ result = CheckSASLAuth(&pg_be_scram_mech, port, real_pass, &logdetail);
+ pfree(real_pass);
+ }
+
+ if (result == STATUS_OK)
+ set_authn_id(port, port->user_name);
+
+ return result;
+}
+
+/*
+ * Get SCRAM encrypted version of the password for user.
+ */
+static char *
+get_encrypted_password_for_user(char *user_name)
+{
+ char *password = NULL;
+ int i;
+ for (i=0; i<3; i++)
+ {
+ if (strcmp(user_name, credentials[0][i]) == 0)
+ {
+ password = pstrdup(pg_be_scram_build_secret(credentials[1][i]));
+ }
+ }
+
+ return password;
+}
+
+static const char *TestAuthenticationError(Port *port)
+{
+ char *error_message = (char *)palloc (100);
+ sprintf(error_message, "Test authentication failed for user %s", port->user_name);
+ return error_message;
+}
+
+void
+_PG_init(void)
+{
+ RegisterAuthProvider("test", TestAuthenticationCheck, TestAuthenticationError);
+}
--
2.25.1
v4-0003-Add-tests-for-test_auth_provider-extension.patchtext/x-patch; name=v4-0003-Add-tests-for-test_auth_provider-extension.patchDownload
From 24702486bfaca691d6ca9388544fb23e6d765055 Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Wed, 16 Feb 2022 12:28:36 -0800
Subject: [PATCH v4 03/10] Add tests for test_auth_provider extension
Add tap tests for test_auth_provider extension allow make check in
src/test/modules to run them.
---
src/test/modules/Makefile | 1 +
src/test/modules/test_auth_provider/Makefile | 2 +
.../test_auth_provider/t/001_custom_auth.pl | 125 ++++++++++++++++++
3 files changed, 128 insertions(+)
create mode 100644 src/test/modules/test_auth_provider/t/001_custom_auth.pl
diff --git a/src/test/modules/Makefile b/src/test/modules/Makefile
index 9090226daa..d0d461ef9e 100644
--- a/src/test/modules/Makefile
+++ b/src/test/modules/Makefile
@@ -14,6 +14,7 @@ SUBDIRS = \
plsample \
snapshot_too_old \
spgist_name_ops \
+ test_auth_provider \
test_bloomfilter \
test_ddl_deparse \
test_extensions \
diff --git a/src/test/modules/test_auth_provider/Makefile b/src/test/modules/test_auth_provider/Makefile
index 17971a5c7a..7d601cf7d5 100644
--- a/src/test/modules/test_auth_provider/Makefile
+++ b/src/test/modules/test_auth_provider/Makefile
@@ -4,6 +4,8 @@ MODULE_big = test_auth_provider
OBJS = test_auth_provider.o
PGFILEDESC = "test_auth_provider - provider to test auth hooks"
+TAP_TESTS = 1
+
ifdef USE_PGXS
PG_CONFIG = pg_config
PGXS := $(shell $(PG_CONFIG) --pgxs)
diff --git a/src/test/modules/test_auth_provider/t/001_custom_auth.pl b/src/test/modules/test_auth_provider/t/001_custom_auth.pl
new file mode 100644
index 0000000000..3b7472dc7f
--- /dev/null
+++ b/src/test/modules/test_auth_provider/t/001_custom_auth.pl
@@ -0,0 +1,125 @@
+
+# Copyright (c) 2021-2022, PostgreSQL Global Development Group
+
+# Set of tests for testing custom authentication hooks.
+
+use strict;
+use warnings;
+use PostgreSQL::Test::Cluster;
+use PostgreSQL::Test::Utils;
+use Test::More;
+
+# Delete pg_hba.conf from the given node, add a new entry to it
+# and then execute a reload to refresh it.
+sub reset_pg_hba
+{
+ my $node = shift;
+ my $hba_method = shift;
+
+ unlink($node->data_dir . '/pg_hba.conf');
+ # just for testing purposes, use a continuation line
+ $node->append_conf('pg_hba.conf', "local all all\\\n $hba_method");
+ $node->reload;
+ return;
+}
+
+# Test if you get expected results in pg_hba_file_rules error column after
+# changing pg_hba.conf and reloading it.
+sub test_hba_reload
+{
+ my ($node, $method, $expected_res) = @_;
+ my $status_string = 'failed';
+ $status_string = 'success' if ($expected_res eq 0);
+ my $testname = "pg_hba.conf reload $status_string for method $method";
+
+ reset_pg_hba($node, $method);
+
+ my ($cmdret, $stdout, $stderr) = $node->psql("postgres",
+ "select count(*) from pg_hba_file_rules where error is not null",extra_params => ['-U','bob']);
+
+ is($stdout, $expected_res, $testname);
+}
+
+# Test access for a single role, useful to wrap all tests into one. Extra
+# named parameters are passed to connect_ok/fails as-is.
+sub test_role
+{
+ local $Test::Builder::Level = $Test::Builder::Level + 1;
+
+ my ($node, $role, $method, $expected_res, %params) = @_;
+ my $status_string = 'failed';
+ $status_string = 'success' if ($expected_res eq 0);
+
+ my $connstr = "user=$role";
+ my $testname =
+ "authentication $status_string for method $method, role $role";
+
+ if ($expected_res eq 0)
+ {
+ $node->connect_ok($connstr, $testname, %params);
+ }
+ else
+ {
+ # No checks of the error message, only the status code.
+ $node->connect_fails($connstr, $testname, %params);
+ }
+}
+
+# Initialize server node
+my $node = PostgreSQL::Test::Cluster->new('server');
+$node->init;
+$node->append_conf('postgresql.conf', "log_connections = on\n");
+$node->append_conf('postgresql.conf', "shared_preload_libraries = 'test_auth_provider.so'\n");
+$node->start;
+
+$node->safe_psql('postgres', "CREATE ROLE bob SUPERUSER LOGIN;");
+$node->safe_psql('postgres', "CREATE ROLE alice LOGIN;");
+$node->safe_psql('postgres', "CREATE ROLE test LOGIN;");
+
+# Add custom auth method to pg_hba.conf
+reset_pg_hba($node, 'custom provider=test');
+
+# Test that users are able to login with correct passwords.
+$ENV{"PGPASSWORD"} = 'bob123';
+test_role($node, 'bob', 'custom', 0, log_like => [qr/connection authorized: user=bob/]);
+$ENV{"PGPASSWORD"} = 'alice123';
+test_role($node, 'alice', 'custom', 0, log_like => [qr/connection authorized: user=alice/]);
+
+# Test that bad passwords are rejected.
+$ENV{"PGPASSWORD"} = 'badpassword';
+test_role($node, 'bob', 'custom', 2, log_unlike => [qr/connection authorized:/]);
+test_role($node, 'alice', 'custom', 2, log_unlike => [qr/connection authorized:/]);
+
+# Test that users not in authentication list are rejected.
+test_role($node, 'test', 'custom', 2, log_unlike => [qr/connection authorized:/]);
+
+$ENV{"PGPASSWORD"} = 'bob123';
+
+# Tests for invalid auth options
+
+# Test that an incorrect provider name is not accepted.
+test_hba_reload($node, 'custom provider=wrong', 1);
+
+# Test that specifying provider option with different auth method is not allowed.
+test_hba_reload($node, 'trust provider=test', 1);
+
+# Test that provider name is a mandatory option for custom auth.
+test_hba_reload($node, 'custom', 1);
+
+# Test that correct provider name allows reload to succeed.
+test_hba_reload($node, 'custom provider=test', 0);
+
+# Custom auth modules require mentioning extension in shared_preload_libraries.
+
+# Remove extension from shared_preload_libraries and try to restart.
+$node->adjust_conf('postgresql.conf', 'shared_preload_libraries', "''");
+command_fails(['pg_ctl', '-w', '-D', $node->data_dir, '-l', $node->logfile, 'restart'],'restart with empty shared_preload_libraries failed');
+
+# Fix shared_preload_libraries and confirm that you can now restart.
+$node->adjust_conf('postgresql.conf', 'shared_preload_libraries', "'test_auth_provider.so'");
+command_ok(['pg_ctl', '-w', '-D', $node->data_dir, '-l', $node->logfile,'start'],'restart with correct shared_preload_libraries succeeded');
+
+# Test that we can connect again
+test_role($node, 'bob', 'custom', 0, log_like => [qr/connection authorized: user=bob/]);
+
+done_testing();
--
2.25.1
v4-0004-Add-support-for-map-and-custom-auth-options.patchtext/x-patch; name=v4-0004-Add-support-for-map-and-custom-auth-options.patchDownload
From c30970a354b23f26eaf3e1db7c7d7759f2f828b3 Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Mon, 14 Mar 2022 14:54:08 -0700
Subject: [PATCH v4 04/10] Add support for "map" and custom auth options
This commit allows extensions to now specify, validate and use
custom options for their custom auth methods. This is done by
exposing a validation function hook which can be defined by
extensions. The valid options are then stored as key / value
pairs which can be used while checking authentication. We also
allow custom auth providers to use the "map" option to use
usermaps.
The test module was updated to use custom options and new tests
were added.
---
src/backend/libpq/auth.c | 4 +-
src/backend/libpq/hba.c | 76 +++++++++++++++----
src/include/libpq/auth.h | 17 +++--
src/include/libpq/hba.h | 8 ++
.../test_auth_provider/t/001_custom_auth.pl | 22 ++++++
.../test_auth_provider/test_auth_provider.c | 50 +++++++++++-
6 files changed, 157 insertions(+), 20 deletions(-)
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 375ee33892..4a8a63922a 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -3364,7 +3364,8 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
*/
void RegisterAuthProvider(const char *provider_name,
CustomAuthenticationCheck_hook_type AuthenticationCheckFunction,
- CustomAuthenticationError_hook_type AuthenticationErrorFunction)
+ CustomAuthenticationError_hook_type AuthenticationErrorFunction,
+ CustomAuthenticationValidateOptions_hook_type AuthenticationOptionsFunction)
{
CustomAuthProvider *provider = NULL;
MemoryContext old_context;
@@ -3390,6 +3391,7 @@ void RegisterAuthProvider(const char *provider_name,
provider->name = MemoryContextStrdup(TopMemoryContext,provider_name);
provider->auth_check_hook = AuthenticationCheckFunction;
provider->auth_error_hook = AuthenticationErrorFunction;
+ provider->auth_options_hook = AuthenticationOptionsFunction;
custom_auth_providers = lappend(custom_auth_providers, provider);
MemoryContextSwitchTo(old_context);
}
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 9f15252789..42cb1ce51d 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -1729,8 +1729,9 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->auth_method != uaPeer &&
hbaline->auth_method != uaGSS &&
hbaline->auth_method != uaSSPI &&
- hbaline->auth_method != uaCert)
- INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, and cert"));
+ hbaline->auth_method != uaCert &&
+ hbaline->auth_method != uaCustom)
+ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert and custom"));
hbaline->usermap = pstrdup(val);
}
else if (strcmp(name, "clientcert") == 0)
@@ -2121,7 +2122,6 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
/*
* Verify that the provider mentioned is loaded via shared_preload_libraries.
*/
-
if (get_provider_by_name(val) == NULL)
{
ereport(elevel,
@@ -2129,7 +2129,7 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
errmsg("cannot use authentication provider %s",val),
errhint("Load authentication provider via shared_preload_libraries."),
errcontext("line %d of configuration file \"%s\"",
- line_num, HbaFileName)));
+ line_num, HbaFileName)));
*err_msg = psprintf("cannot use authentication provider %s", val);
return false;
@@ -2141,15 +2141,55 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
}
else
{
- ereport(elevel,
- (errcode(ERRCODE_CONFIG_FILE_ERROR),
- errmsg("unrecognized authentication option name: \"%s\"",
- name),
- errcontext("line %d of configuration file \"%s\"",
- line_num, HbaFileName)));
- *err_msg = psprintf("unrecognized authentication option name: \"%s\"",
- name);
- return false;
+ /*
+ * Allow custom providers to validate their options if they have an
+ * option validation function defined.
+ */
+ if (hbaline->auth_method == uaCustom && (hbaline->custom_provider != NULL))
+ {
+ bool valid_option = false;
+ CustomAuthProvider *provider = get_provider_by_name(hbaline->custom_provider);
+ if (provider->auth_options_hook)
+ {
+ valid_option = provider->auth_options_hook(name, val, hbaline, err_msg);
+ if (valid_option)
+ {
+ CustomOption *option = palloc(sizeof(CustomOption));
+ option->name = pstrdup(name);
+ option->value = pstrdup(val);
+ hbaline->custom_auth_options = lappend(hbaline->custom_auth_options,
+ option);
+ }
+ }
+ else
+ {
+ *err_msg = psprintf("unrecognized authentication option name: \"%s\"",
+ name);
+ }
+
+ /* Report the error returned by the provider as it is */
+ if (!valid_option)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("%s", *err_msg),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ return false;
+ }
+ }
+ else
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("unrecognized authentication option name: \"%s\"",
+ name),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("unrecognized authentication option name: \"%s\"",
+ name);
+ return false;
+ }
}
return true;
}
@@ -2484,6 +2524,16 @@ gethba_options(HbaLine *hba)
if (hba->custom_provider)
options[noptions++] =
CStringGetTextDatum(psprintf("provider=%s",hba->custom_provider));
+ if (hba->custom_auth_options)
+ {
+ ListCell *lc;
+ foreach(lc, hba->custom_auth_options)
+ {
+ CustomOption *option = (CustomOption *)lfirst(lc);
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("%s=%s",option->name, option->value));
+ }
+ }
}
/* If you add more options, consider increasing MAX_HBA_OPTIONS. */
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 7aff98d919..cbdc63b4df 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -31,22 +31,29 @@ typedef int (*CustomAuthenticationCheck_hook_type) (Port *);
typedef void (*ClientAuthentication_hook_type) (Port *, int);
extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook;
+/* Declarations for custom authentication providers */
+
/* Hook for plugins to report error messages in auth_failed() */
typedef const char * (*CustomAuthenticationError_hook_type) (Port *);
-extern void RegisterAuthProvider
- (const char *provider_name,
- CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook,
- CustomAuthenticationError_hook_type CustomAuthenticationError_hook);
+/* Hook for plugins to validate custom authentication options */
+typedef bool (*CustomAuthenticationValidateOptions_hook_type)
+ (char *, char *, HbaLine *, char **);
-/* Declarations for custom authentication providers */
typedef struct CustomAuthProvider
{
const char *name;
CustomAuthenticationCheck_hook_type auth_check_hook;
CustomAuthenticationError_hook_type auth_error_hook;
+ CustomAuthenticationValidateOptions_hook_type auth_options_hook;
} CustomAuthProvider;
+extern void RegisterAuthProvider
+ (const char *provider_name,
+ CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook,
+ CustomAuthenticationError_hook_type CustomAuthenticationError_hook,
+ CustomAuthenticationValidateOptions_hook_type CustomAuthenticationOptions_hook);
+
extern CustomAuthProvider *get_provider_by_name(const char *name);
/*
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 48490c44ed..31a00c4b71 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -78,6 +78,13 @@ typedef enum ClientCertName
clientCertDN
} ClientCertName;
+/* Struct for custom options defined by custom auth plugins */
+typedef struct CustomOption
+{
+ char *name;
+ char *value;
+}CustomOption;
+
typedef struct HbaLine
{
int linenumber;
@@ -122,6 +129,7 @@ typedef struct HbaLine
List *radiusports;
char *radiusports_s;
char *custom_provider;
+ List *custom_auth_options;
} HbaLine;
typedef struct IdentLine
diff --git a/src/test/modules/test_auth_provider/t/001_custom_auth.pl b/src/test/modules/test_auth_provider/t/001_custom_auth.pl
index 3b7472dc7f..e964c2f723 100644
--- a/src/test/modules/test_auth_provider/t/001_custom_auth.pl
+++ b/src/test/modules/test_auth_provider/t/001_custom_auth.pl
@@ -109,6 +109,28 @@ test_hba_reload($node, 'custom', 1);
# Test that correct provider name allows reload to succeed.
test_hba_reload($node, 'custom provider=test', 0);
+# Tests for custom auth options
+
+# Test that a custom option doesn't work without a provider.
+test_hba_reload($node, 'custom allow=bob', 1);
+
+# Test that options other than allowed ones are not accepted.
+test_hba_reload($node, 'custom provider=test wrong=true', 1);
+
+# Test that only valid values are accepted for allowed options.
+test_hba_reload($node, 'custom provider=test allow=wrong', 1);
+
+# Test that setting allow option for a user doesn't look at the password.
+test_hba_reload($node, 'custom provider=test allow=bob', 0);
+$ENV{"PGPASSWORD"} = 'bad123';
+test_role($node, 'bob', 'custom', 0, log_like => [qr/connection authorized: user=bob/]);
+
+# Password is still checked for other users.
+test_role($node, 'alice', 'custom', 2, log_unlike => [qr/connection authorized:/]);
+
+# Reset the password for future tests.
+$ENV{"PGPASSWORD"} = 'bob123';
+
# Custom auth modules require mentioning extension in shared_preload_libraries.
# Remove extension from shared_preload_libraries and try to restart.
diff --git a/src/test/modules/test_auth_provider/test_auth_provider.c b/src/test/modules/test_auth_provider/test_auth_provider.c
index 7c4b1f3500..5ac425f5b6 100644
--- a/src/test/modules/test_auth_provider/test_auth_provider.c
+++ b/src/test/modules/test_auth_provider/test_auth_provider.c
@@ -39,7 +39,27 @@ static int TestAuthenticationCheck(Port *port)
int result = STATUS_ERROR;
char *real_pass;
const char *logdetail = NULL;
+ ListCell *lc;
+ /*
+ * If user's name is in the the "allow" list, do not request password
+ * for them and allow them to authenticate.
+ */
+ foreach(lc,port->hba->custom_auth_options)
+ {
+ CustomOption *option = (CustomOption *) lfirst(lc);
+ if (strcmp(option->name, "allow") == 0 &&
+ strcmp(option->value, port->user_name) == 0)
+ {
+ set_authn_id(port, port->user_name);
+ return STATUS_OK;
+ }
+ }
+
+ /*
+ * Encrypt the password and validate that it's the same as the one
+ * returned by the client.
+ */
real_pass = get_encrypted_password_for_user(port->user_name);
if (real_pass)
{
@@ -79,8 +99,36 @@ static const char *TestAuthenticationError(Port *port)
return error_message;
}
+/*
+ * Returns if the options passed are supported by the extension
+ * and are valid. Currently only "allow" is supported.
+ */
+static bool TestAuthenticationOptions(char *name, char *val, HbaLine *hbaline, char **err_msg)
+{
+ /* Validate that an actual user is in the "allow" list. */
+ if (strcmp(name,"allow") == 0)
+ {
+ for (int i=0;i<3;i++)
+ {
+ if (strcmp(val,credentials[0][i]) == 0)
+ {
+ return true;
+ }
+ }
+
+ *err_msg = psprintf("\"%s\" is not valid value for option \"%s\"", val, name);
+ return false;
+ }
+ else
+ {
+ *err_msg = psprintf("option \"%s\" not recognized by \"%s\" provider", val, hbaline->custom_provider);
+ return false;
+ }
+}
+
void
_PG_init(void)
{
- RegisterAuthProvider("test", TestAuthenticationCheck, TestAuthenticationError);
+ RegisterAuthProvider("test", TestAuthenticationCheck,
+ TestAuthenticationError,TestAuthenticationOptions);
}
--
2.25.1
v4-0005-common-jsonapi-support-FRONTEND-clients.patchtext/x-patch; name=v4-0005-common-jsonapi-support-FRONTEND-clients.patchDownload
From 0ca324b52c94760e799974e5661fe29c3912d2d8 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Mon, 3 May 2021 15:38:26 -0700
Subject: [PATCH v4 05/10] common/jsonapi: support FRONTEND clients
Based on a patch by Michael Paquier.
For frontend code, use PQExpBuffer instead of StringInfo. This requires
us to track allocation failures so that we can return JSON_OUT_OF_MEMORY
as needed. json_errdetail() now allocates its error message inside
memory owned by the JsonLexContext, so clients don't need to worry about
freeing it.
For convenience, the backend now has destroyJsonLexContext() to mirror
other create/destroy APIs. The frontend has init/term versions of the
API to handle stack-allocated JsonLexContexts.
We can now partially revert b44669b2ca, now that json_errdetail() works
correctly.
---
src/backend/utils/adt/jsonfuncs.c | 4 +-
src/bin/pg_verifybackup/parse_manifest.c | 13 +-
src/bin/pg_verifybackup/t/005_bad_manifest.pl | 2 +-
src/common/Makefile | 2 +-
src/common/jsonapi.c | 290 +++++++++++++-----
src/include/common/jsonapi.h | 47 ++-
6 files changed, 270 insertions(+), 88 deletions(-)
diff --git a/src/backend/utils/adt/jsonfuncs.c b/src/backend/utils/adt/jsonfuncs.c
index 29664aa6e4..7d32a99d8c 100644
--- a/src/backend/utils/adt/jsonfuncs.c
+++ b/src/backend/utils/adt/jsonfuncs.c
@@ -723,9 +723,7 @@ json_object_keys(PG_FUNCTION_ARGS)
pg_parse_json_or_ereport(lex, sem);
/* keys are now in state->result */
- pfree(lex->strval->data);
- pfree(lex->strval);
- pfree(lex);
+ destroyJsonLexContext(lex);
pfree(sem);
MemoryContextSwitchTo(oldcontext);
diff --git a/src/bin/pg_verifybackup/parse_manifest.c b/src/bin/pg_verifybackup/parse_manifest.c
index 6364b01282..4b38fd3963 100644
--- a/src/bin/pg_verifybackup/parse_manifest.c
+++ b/src/bin/pg_verifybackup/parse_manifest.c
@@ -119,7 +119,7 @@ void
json_parse_manifest(JsonManifestParseContext *context, char *buffer,
size_t size)
{
- JsonLexContext *lex;
+ JsonLexContext lex = {0};
JsonParseErrorType json_error;
JsonSemAction sem;
JsonManifestParseState parse;
@@ -129,8 +129,8 @@ json_parse_manifest(JsonManifestParseContext *context, char *buffer,
parse.state = JM_EXPECT_TOPLEVEL_START;
parse.saw_version_field = false;
- /* Create a JSON lexing context. */
- lex = makeJsonLexContextCstringLen(buffer, size, PG_UTF8, true);
+ /* Initialize a JSON lexing context. */
+ initJsonLexContextCstringLen(&lex, buffer, size, PG_UTF8, true);
/* Set up semantic actions. */
sem.semstate = &parse;
@@ -145,14 +145,17 @@ json_parse_manifest(JsonManifestParseContext *context, char *buffer,
sem.scalar = json_manifest_scalar;
/* Run the actual JSON parser. */
- json_error = pg_parse_json(lex, &sem);
+ json_error = pg_parse_json(&lex, &sem);
if (json_error != JSON_SUCCESS)
- json_manifest_parse_failure(context, "parsing failed");
+ json_manifest_parse_failure(context, json_errdetail(json_error, &lex));
if (parse.state != JM_EXPECT_EOF)
json_manifest_parse_failure(context, "manifest ended unexpectedly");
/* Verify the manifest checksum. */
verify_manifest_checksum(&parse, buffer, size);
+
+ /* Clean up. */
+ termJsonLexContext(&lex);
}
/*
diff --git a/src/bin/pg_verifybackup/t/005_bad_manifest.pl b/src/bin/pg_verifybackup/t/005_bad_manifest.pl
index 118beb53d7..f2692972fe 100644
--- a/src/bin/pg_verifybackup/t/005_bad_manifest.pl
+++ b/src/bin/pg_verifybackup/t/005_bad_manifest.pl
@@ -16,7 +16,7 @@ my $tempdir = PostgreSQL::Test::Utils::tempdir;
test_bad_manifest(
'input string ended unexpectedly',
- qr/could not parse backup manifest: parsing failed/,
+ qr/could not parse backup manifest: The input string ended unexpectedly/,
<<EOM);
{
EOM
diff --git a/src/common/Makefile b/src/common/Makefile
index f627349835..694da03658 100644
--- a/src/common/Makefile
+++ b/src/common/Makefile
@@ -40,7 +40,7 @@ override CPPFLAGS += -DVAL_LDFLAGS_EX="\"$(LDFLAGS_EX)\""
override CPPFLAGS += -DVAL_LDFLAGS_SL="\"$(LDFLAGS_SL)\""
override CPPFLAGS += -DVAL_LIBS="\"$(LIBS)\""
-override CPPFLAGS := -DFRONTEND -I. -I$(top_srcdir)/src/common $(CPPFLAGS)
+override CPPFLAGS := -DFRONTEND -I. -I$(top_srcdir)/src/common -I$(libpq_srcdir) $(CPPFLAGS)
LIBS += $(PTHREAD_LIBS)
# If you add objects here, see also src/tools/msvc/Mkvcbuild.pm
diff --git a/src/common/jsonapi.c b/src/common/jsonapi.c
index 6666077a93..7fc5eaf460 100644
--- a/src/common/jsonapi.c
+++ b/src/common/jsonapi.c
@@ -20,10 +20,39 @@
#include "common/jsonapi.h"
#include "mb/pg_wchar.h"
-#ifndef FRONTEND
+#ifdef FRONTEND
+#include "pqexpbuffer.h"
+#else
+#include "lib/stringinfo.h"
#include "miscadmin.h"
#endif
+/*
+ * In backend, we will use palloc/pfree along with StringInfo. In frontend, use
+ * malloc and PQExpBuffer, and return JSON_OUT_OF_MEMORY on out-of-memory.
+ */
+#ifdef FRONTEND
+
+#define STRDUP(s) strdup(s)
+#define ALLOC(size) malloc(size)
+
+#define appendStrVal appendPQExpBuffer
+#define appendStrValChar appendPQExpBufferChar
+#define createStrVal createPQExpBuffer
+#define resetStrVal resetPQExpBuffer
+
+#else /* !FRONTEND */
+
+#define STRDUP(s) pstrdup(s)
+#define ALLOC(size) palloc(size)
+
+#define appendStrVal appendStringInfo
+#define appendStrValChar appendStringInfoChar
+#define createStrVal makeStringInfo
+#define resetStrVal resetStringInfo
+
+#endif
+
/*
* The context of the parser is maintained by the recursive descent
* mechanism, but is passed explicitly to the error reporting routine
@@ -132,10 +161,12 @@ IsValidJsonNumber(const char *str, int len)
return (!numeric_error) && (total_len == dummy_lex.input_length);
}
+#ifndef FRONTEND
+
/*
* makeJsonLexContextCstringLen
*
- * lex constructor, with or without StringInfo object for de-escaped lexemes.
+ * lex constructor, with or without a string object for de-escaped lexemes.
*
* Without is better as it makes the processing faster, so only make one
* if really required.
@@ -145,13 +176,66 @@ makeJsonLexContextCstringLen(char *json, int len, int encoding, bool need_escape
{
JsonLexContext *lex = palloc0(sizeof(JsonLexContext));
+ initJsonLexContextCstringLen(lex, json, len, encoding, need_escapes);
+
+ return lex;
+}
+
+void
+destroyJsonLexContext(JsonLexContext *lex)
+{
+ termJsonLexContext(lex);
+ pfree(lex);
+}
+
+#endif /* !FRONTEND */
+
+void
+initJsonLexContextCstringLen(JsonLexContext *lex, char *json, int len, int encoding, bool need_escapes)
+{
lex->input = lex->token_terminator = lex->line_start = json;
lex->line_number = 1;
lex->input_length = len;
lex->input_encoding = encoding;
- if (need_escapes)
- lex->strval = makeStringInfo();
- return lex;
+ lex->parse_strval = need_escapes;
+ if (lex->parse_strval)
+ {
+ /*
+ * This call can fail in FRONTEND code. We defer error handling to time
+ * of use (json_lex_string()) since there's no way to signal failure
+ * here, and we might not need to parse any strings anyway.
+ */
+ lex->strval = createStrVal();
+ }
+ lex->errormsg = NULL;
+}
+
+void
+termJsonLexContext(JsonLexContext *lex)
+{
+ static const JsonLexContext empty = {0};
+
+ if (lex->strval)
+ {
+#ifdef FRONTEND
+ destroyPQExpBuffer(lex->strval);
+#else
+ pfree(lex->strval->data);
+ pfree(lex->strval);
+#endif
+ }
+
+ if (lex->errormsg)
+ {
+#ifdef FRONTEND
+ destroyPQExpBuffer(lex->errormsg);
+#else
+ pfree(lex->errormsg->data);
+ pfree(lex->errormsg);
+#endif
+ }
+
+ *lex = empty;
}
/*
@@ -217,7 +301,7 @@ json_count_array_elements(JsonLexContext *lex, int *elements)
* etc, so doing this with a copy makes that safe.
*/
memcpy(©lex, lex, sizeof(JsonLexContext));
- copylex.strval = NULL; /* not interested in values here */
+ copylex.parse_strval = false; /* not interested in values here */
copylex.lex_level++;
count = 0;
@@ -279,14 +363,21 @@ parse_scalar(JsonLexContext *lex, JsonSemAction *sem)
/* extract the de-escaped string value, or the raw lexeme */
if (lex_peek(lex) == JSON_TOKEN_STRING)
{
- if (lex->strval != NULL)
- val = pstrdup(lex->strval->data);
+ if (lex->parse_strval)
+ {
+ val = STRDUP(lex->strval->data);
+ if (val == NULL)
+ return JSON_OUT_OF_MEMORY;
+ }
}
else
{
int len = (lex->token_terminator - lex->token_start);
- val = palloc(len + 1);
+ val = ALLOC(len + 1);
+ if (val == NULL)
+ return JSON_OUT_OF_MEMORY;
+
memcpy(val, lex->token_start, len);
val[len] = '\0';
}
@@ -320,8 +411,12 @@ parse_object_field(JsonLexContext *lex, JsonSemAction *sem)
if (lex_peek(lex) != JSON_TOKEN_STRING)
return report_parse_error(JSON_PARSE_STRING, lex);
- if ((ostart != NULL || oend != NULL) && lex->strval != NULL)
- fname = pstrdup(lex->strval->data);
+ if ((ostart != NULL || oend != NULL) && lex->parse_strval)
+ {
+ fname = STRDUP(lex->strval->data);
+ if (fname == NULL)
+ return JSON_OUT_OF_MEMORY;
+ }
result = json_lex(lex);
if (result != JSON_SUCCESS)
return result;
@@ -368,6 +463,10 @@ parse_object(JsonLexContext *lex, JsonSemAction *sem)
JsonParseErrorType result;
#ifndef FRONTEND
+ /*
+ * TODO: clients need some way to put a bound on stack growth. Parse level
+ * limits maybe?
+ */
check_stack_depth();
#endif
@@ -676,8 +775,15 @@ json_lex_string(JsonLexContext *lex)
int len;
int hi_surrogate = -1;
- if (lex->strval != NULL)
- resetStringInfo(lex->strval);
+ if (lex->parse_strval)
+ {
+#ifdef FRONTEND
+ /* make sure initialization succeeded */
+ if (lex->strval == NULL)
+ return JSON_OUT_OF_MEMORY;
+#endif
+ resetStrVal(lex->strval);
+ }
Assert(lex->input_length > 0);
s = lex->token_start;
@@ -737,7 +843,7 @@ json_lex_string(JsonLexContext *lex)
return JSON_UNICODE_ESCAPE_FORMAT;
}
}
- if (lex->strval != NULL)
+ if (lex->parse_strval)
{
/*
* Combine surrogate pairs.
@@ -797,19 +903,19 @@ json_lex_string(JsonLexContext *lex)
unicode_to_utf8(ch, (unsigned char *) utf8str);
utf8len = pg_utf_mblen((unsigned char *) utf8str);
- appendBinaryStringInfo(lex->strval, utf8str, utf8len);
+ appendBinaryPQExpBuffer(lex->strval, utf8str, utf8len);
}
else if (ch <= 0x007f)
{
/* The ASCII range is the same in all encodings */
- appendStringInfoChar(lex->strval, (char) ch);
+ appendPQExpBufferChar(lex->strval, (char) ch);
}
else
return JSON_UNICODE_HIGH_ESCAPE;
#endif /* FRONTEND */
}
}
- else if (lex->strval != NULL)
+ else if (lex->parse_strval)
{
if (hi_surrogate != -1)
return JSON_UNICODE_LOW_SURROGATE;
@@ -819,22 +925,22 @@ json_lex_string(JsonLexContext *lex)
case '"':
case '\\':
case '/':
- appendStringInfoChar(lex->strval, *s);
+ appendStrValChar(lex->strval, *s);
break;
case 'b':
- appendStringInfoChar(lex->strval, '\b');
+ appendStrValChar(lex->strval, '\b');
break;
case 'f':
- appendStringInfoChar(lex->strval, '\f');
+ appendStrValChar(lex->strval, '\f');
break;
case 'n':
- appendStringInfoChar(lex->strval, '\n');
+ appendStrValChar(lex->strval, '\n');
break;
case 'r':
- appendStringInfoChar(lex->strval, '\r');
+ appendStrValChar(lex->strval, '\r');
break;
case 't':
- appendStringInfoChar(lex->strval, '\t');
+ appendStrValChar(lex->strval, '\t');
break;
default:
/* Not a valid string escape, so signal error. */
@@ -858,12 +964,12 @@ json_lex_string(JsonLexContext *lex)
}
}
- else if (lex->strval != NULL)
+ else if (lex->parse_strval)
{
if (hi_surrogate != -1)
return JSON_UNICODE_LOW_SURROGATE;
- appendStringInfoChar(lex->strval, *s);
+ appendStrValChar(lex->strval, *s);
}
}
@@ -871,6 +977,11 @@ json_lex_string(JsonLexContext *lex)
if (hi_surrogate != -1)
return JSON_UNICODE_LOW_SURROGATE;
+#ifdef FRONTEND
+ if (lex->parse_strval && PQExpBufferBroken(lex->strval))
+ return JSON_OUT_OF_MEMORY;
+#endif
+
/* Hooray, we found the end of the string! */
lex->prev_token_terminator = lex->token_terminator;
lex->token_terminator = s + 1;
@@ -1043,72 +1154,93 @@ report_parse_error(JsonParseContext ctx, JsonLexContext *lex)
return JSON_SUCCESS; /* silence stupider compilers */
}
-
-#ifndef FRONTEND
-/*
- * Extract the current token from a lexing context, for error reporting.
- */
-static char *
-extract_token(JsonLexContext *lex)
-{
- int toklen = lex->token_terminator - lex->token_start;
- char *token = palloc(toklen + 1);
-
- memcpy(token, lex->token_start, toklen);
- token[toklen] = '\0';
- return token;
-}
-
/*
* Construct a detail message for a JSON error.
*
- * Note that the error message generated by this routine may not be
- * palloc'd, making it unsafe for frontend code as there is no way to
- * know if this can be safery pfree'd or not.
+ * The returned allocation is either static or owned by the JsonLexContext and
+ * should not be freed.
*/
char *
json_errdetail(JsonParseErrorType error, JsonLexContext *lex)
{
+ int toklen = lex->token_terminator - lex->token_start;
+
+ if (error == JSON_OUT_OF_MEMORY)
+ {
+ /* Short circuit. Allocating anything for this case is unhelpful. */
+ return _("out of memory");
+ }
+
+ if (lex->errormsg)
+ resetStrVal(lex->errormsg);
+ else
+ lex->errormsg = createStrVal();
+
switch (error)
{
case JSON_SUCCESS:
/* fall through to the error code after switch */
break;
case JSON_ESCAPING_INVALID:
- return psprintf(_("Escape sequence \"\\%s\" is invalid."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Escape sequence \"\\%.*s\" is invalid."),
+ toklen, lex->token_start);
+ break;
case JSON_ESCAPING_REQUIRED:
- return psprintf(_("Character with value 0x%02x must be escaped."),
- (unsigned char) *(lex->token_terminator));
+ appendStrVal(lex->errormsg,
+ _("Character with value 0x%02x must be escaped."),
+ (unsigned char) *(lex->token_terminator));
+ break;
case JSON_EXPECTED_END:
- return psprintf(_("Expected end of input, but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected end of input, but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_ARRAY_FIRST:
- return psprintf(_("Expected array element or \"]\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected array element or \"]\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_ARRAY_NEXT:
- return psprintf(_("Expected \",\" or \"]\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected \",\" or \"]\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_COLON:
- return psprintf(_("Expected \":\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected \":\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_JSON:
- return psprintf(_("Expected JSON value, but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected JSON value, but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_MORE:
return _("The input string ended unexpectedly.");
case JSON_EXPECTED_OBJECT_FIRST:
- return psprintf(_("Expected string or \"}\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected string or \"}\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_OBJECT_NEXT:
- return psprintf(_("Expected \",\" or \"}\", but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected \",\" or \"}\", but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_EXPECTED_STRING:
- return psprintf(_("Expected string, but found \"%s\"."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Expected string, but found \"%.*s\"."),
+ toklen, lex->token_start);
+ break;
case JSON_INVALID_TOKEN:
- return psprintf(_("Token \"%s\" is invalid."),
- extract_token(lex));
+ appendStrVal(lex->errormsg,
+ _("Token \"%.*s\" is invalid."),
+ toklen, lex->token_start);
+ break;
+ case JSON_OUT_OF_MEMORY:
+ /* should have been handled above; use the error path */
+ break;
case JSON_UNICODE_CODE_POINT_ZERO:
return _("\\u0000 cannot be converted to text.");
case JSON_UNICODE_ESCAPE_FORMAT:
@@ -1122,12 +1254,22 @@ json_errdetail(JsonParseErrorType error, JsonLexContext *lex)
return _("Unicode low surrogate must follow a high surrogate.");
}
- /*
- * We don't use a default: case, so that the compiler will warn about
- * unhandled enum values. But this needs to be here anyway to cover the
- * possibility of an incorrect input.
- */
- elog(ERROR, "unexpected json parse error type: %d", (int) error);
- return NULL;
-}
+ /* Note that lex->errormsg can be NULL in FRONTEND code. */
+ if (lex->errormsg && !lex->errormsg->data[0])
+ {
+ /*
+ * We don't use a default: case, so that the compiler will warn about
+ * unhandled enum values. But this needs to be here anyway to cover the
+ * possibility of an incorrect input.
+ */
+ appendStrVal(lex->errormsg,
+ "unexpected json parse error type: %d", (int) error);
+ }
+
+#ifdef FRONTEND
+ if (PQExpBufferBroken(lex->errormsg))
+ return _("out of memory while constructing error description");
#endif
+
+ return lex->errormsg->data;
+}
diff --git a/src/include/common/jsonapi.h b/src/include/common/jsonapi.h
index 52cb4a9339..d7cafc84fe 100644
--- a/src/include/common/jsonapi.h
+++ b/src/include/common/jsonapi.h
@@ -14,8 +14,6 @@
#ifndef JSONAPI_H
#define JSONAPI_H
-#include "lib/stringinfo.h"
-
typedef enum
{
JSON_TOKEN_INVALID,
@@ -48,6 +46,7 @@ typedef enum
JSON_EXPECTED_OBJECT_NEXT,
JSON_EXPECTED_STRING,
JSON_INVALID_TOKEN,
+ JSON_OUT_OF_MEMORY,
JSON_UNICODE_CODE_POINT_ZERO,
JSON_UNICODE_ESCAPE_FORMAT,
JSON_UNICODE_HIGH_ESCAPE,
@@ -55,6 +54,17 @@ typedef enum
JSON_UNICODE_LOW_SURROGATE
} JsonParseErrorType;
+/*
+ * Don't depend on the internal type header for strval; if callers need access
+ * then they can include the appropriate header themselves.
+ */
+#ifdef FRONTEND
+#define StrValType PQExpBufferData
+#else
+#define StrValType StringInfoData
+#endif
+
+typedef struct StrValType StrValType;
/*
* All the fields in this structure should be treated as read-only.
@@ -81,7 +91,9 @@ typedef struct JsonLexContext
int lex_level;
int line_number; /* line number, starting from 1 */
char *line_start; /* where that line starts within input */
- StringInfo strval;
+ bool parse_strval;
+ StrValType *strval; /* only used if parse_strval == true */
+ StrValType *errormsg;
} JsonLexContext;
typedef void (*json_struct_action) (void *state);
@@ -141,9 +153,10 @@ extern JsonSemAction nullSemAction;
*/
extern JsonParseErrorType json_count_array_elements(JsonLexContext *lex,
int *elements);
+#ifndef FRONTEND
/*
- * constructor for JsonLexContext, with or without strval element.
+ * allocating constructor for JsonLexContext, with or without strval element.
* If supplied, the strval element will contain a de-escaped version of
* the lexeme. However, doing this imposes a performance penalty, so
* it should be avoided if the de-escaped lexeme is not required.
@@ -153,6 +166,32 @@ extern JsonLexContext *makeJsonLexContextCstringLen(char *json,
int encoding,
bool need_escapes);
+/*
+ * Counterpart to makeJsonLexContextCstringLen(): clears and deallocates lex.
+ * The context pointer should not be used after this call.
+ */
+extern void destroyJsonLexContext(JsonLexContext *lex);
+
+#endif /* !FRONTEND */
+
+/*
+ * stack constructor for JsonLexContext, with or without strval element.
+ * If supplied, the strval element will contain a de-escaped version of
+ * the lexeme. However, doing this imposes a performance penalty, so
+ * it should be avoided if the de-escaped lexeme is not required.
+ */
+extern void initJsonLexContextCstringLen(JsonLexContext *lex,
+ char *json,
+ int len,
+ int encoding,
+ bool need_escapes);
+
+/*
+ * Counterpart to initJsonLexContextCstringLen(): clears the contents of lex,
+ * but does not deallocate lex itself.
+ */
+extern void termJsonLexContext(JsonLexContext *lex);
+
/* lex one token */
extern JsonParseErrorType json_lex(JsonLexContext *lex);
--
2.25.1
v4-0006-libpq-add-OAUTHBEARER-SASL-mechanism.patchtext/x-patch; name=v4-0006-libpq-add-OAUTHBEARER-SASL-mechanism.patchDownload
From f4752b6daea9b519ff3482094a1f445a4731cc15 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Tue, 13 Apr 2021 10:27:27 -0700
Subject: [PATCH v4 06/10] libpq: add OAUTHBEARER SASL mechanism
DO NOT USE THIS PROOF OF CONCEPT IN PRODUCTION.
Implement OAUTHBEARER (RFC 7628) and OAuth 2.0 Device Authorization
Grants (RFC 8628) on the client side. When speaking to a OAuth-enabled
server, it looks a bit like this:
$ psql 'host=example.org oauth_client_id=f02c6361-0635-...'
Visit https://oauth.example.org/login and enter the code: FPQ2-M4BG
The OAuth issuer must support device authorization. No other OAuth flows
are currently implemented.
The client implementation requires libiddawc and its development
headers. Configure --with-oauth (and --with-includes/--with-libraries to
point at the iddawc installation, if it's in a custom location).
Several TODOs:
- don't retry forever if the server won't accept our token
- perform several sanity checks on the OAuth issuer's responses
- handle cases where the client has been set up with an issuer and
scope, but the Postgres server wants to use something different
- improve error debuggability during the OAuth handshake
- ...and more.
---
configure | 100 ++++
configure.ac | 19 +
src/Makefile.global.in | 1 +
src/include/common/oauth-common.h | 19 +
src/include/pg_config.h.in | 6 +
src/interfaces/libpq/Makefile | 7 +-
src/interfaces/libpq/fe-auth-oauth.c | 744 +++++++++++++++++++++++++++
src/interfaces/libpq/fe-auth-sasl.h | 5 +-
src/interfaces/libpq/fe-auth-scram.c | 6 +-
src/interfaces/libpq/fe-auth.c | 42 +-
src/interfaces/libpq/fe-auth.h | 3 +
src/interfaces/libpq/fe-connect.c | 38 ++
src/interfaces/libpq/libpq-int.h | 8 +
13 files changed, 979 insertions(+), 19 deletions(-)
create mode 100644 src/include/common/oauth-common.h
create mode 100644 src/interfaces/libpq/fe-auth-oauth.c
diff --git a/configure b/configure
index e066cbe2c8..42a3304681 100755
--- a/configure
+++ b/configure
@@ -718,6 +718,7 @@ with_uuid
with_readline
with_systemd
with_selinux
+with_oauth
with_ldap
with_krb_srvnam
krb_srvtab
@@ -861,6 +862,7 @@ with_krb_srvnam
with_pam
with_bsd_auth
with_ldap
+with_oauth
with_bonjour
with_selinux
with_systemd
@@ -1570,6 +1572,7 @@ Optional Packages:
--with-pam build with PAM support
--with-bsd-auth build with BSD Authentication support
--with-ldap build with LDAP support
+ --with-oauth build with OAuth 2.0 support
--with-bonjour build with Bonjour support
--with-selinux build with SELinux support
--with-systemd build with systemd support
@@ -8377,6 +8380,42 @@ $as_echo "$with_ldap" >&6; }
+#
+# OAuth 2.0
+#
+{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether to build with OAuth support" >&5
+$as_echo_n "checking whether to build with OAuth support... " >&6; }
+
+
+
+# Check whether --with-oauth was given.
+if test "${with_oauth+set}" = set; then :
+ withval=$with_oauth;
+ case $withval in
+ yes)
+
+$as_echo "#define USE_OAUTH 1" >>confdefs.h
+
+ ;;
+ no)
+ :
+ ;;
+ *)
+ as_fn_error $? "no argument expected for --with-oauth option" "$LINENO" 5
+ ;;
+ esac
+
+else
+ with_oauth=no
+
+fi
+
+
+{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $with_oauth" >&5
+$as_echo "$with_oauth" >&6; }
+
+
+
#
# Bonjour
#
@@ -13503,6 +13542,56 @@ fi
+if test "$with_oauth" = yes ; then
+ { $as_echo "$as_me:${as_lineno-$LINENO}: checking for i_init_session in -liddawc" >&5
+$as_echo_n "checking for i_init_session in -liddawc... " >&6; }
+if ${ac_cv_lib_iddawc_i_init_session+:} false; then :
+ $as_echo_n "(cached) " >&6
+else
+ ac_check_lib_save_LIBS=$LIBS
+LIBS="-liddawc $LIBS"
+cat confdefs.h - <<_ACEOF >conftest.$ac_ext
+/* end confdefs.h. */
+
+/* Override any GCC internal prototype to avoid an error.
+ Use char because int might match the return type of a GCC
+ builtin and then its argument prototype would still apply. */
+#ifdef __cplusplus
+extern "C"
+#endif
+char i_init_session ();
+int
+main ()
+{
+return i_init_session ();
+ ;
+ return 0;
+}
+_ACEOF
+if ac_fn_c_try_link "$LINENO"; then :
+ ac_cv_lib_iddawc_i_init_session=yes
+else
+ ac_cv_lib_iddawc_i_init_session=no
+fi
+rm -f core conftest.err conftest.$ac_objext \
+ conftest$ac_exeext conftest.$ac_ext
+LIBS=$ac_check_lib_save_LIBS
+fi
+{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_iddawc_i_init_session" >&5
+$as_echo "$ac_cv_lib_iddawc_i_init_session" >&6; }
+if test "x$ac_cv_lib_iddawc_i_init_session" = xyes; then :
+ cat >>confdefs.h <<_ACEOF
+#define HAVE_LIBIDDAWC 1
+_ACEOF
+
+ LIBS="-liddawc $LIBS"
+
+else
+ as_fn_error $? "library 'iddawc' is required for OAuth support" "$LINENO" 5
+fi
+
+fi
+
# for contrib/sepgsql
if test "$with_selinux" = yes; then
{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for security_compute_create_name in -lselinux" >&5
@@ -14516,6 +14605,17 @@ fi
done
+fi
+
+if test "$with_oauth" != no; then
+ ac_fn_c_check_header_mongrel "$LINENO" "iddawc.h" "ac_cv_header_iddawc_h" "$ac_includes_default"
+if test "x$ac_cv_header_iddawc_h" = xyes; then :
+
+else
+ as_fn_error $? "header file <iddawc.h> is required for OAuth" "$LINENO" 5
+fi
+
+
fi
if test "$PORTNAME" = "win32" ; then
diff --git a/configure.ac b/configure.ac
index 078381e568..4050f91dbd 100644
--- a/configure.ac
+++ b/configure.ac
@@ -887,6 +887,17 @@ AC_MSG_RESULT([$with_ldap])
AC_SUBST(with_ldap)
+#
+# OAuth 2.0
+#
+AC_MSG_CHECKING([whether to build with OAuth support])
+PGAC_ARG_BOOL(with, oauth, no,
+ [build with OAuth 2.0 support],
+ [AC_DEFINE([USE_OAUTH], 1, [Define to 1 to build with OAuth 2.0 support. (--with-oauth)])])
+AC_MSG_RESULT([$with_oauth])
+AC_SUBST(with_oauth)
+
+
#
# Bonjour
#
@@ -1388,6 +1399,10 @@ fi
AC_SUBST(LDAP_LIBS_FE)
AC_SUBST(LDAP_LIBS_BE)
+if test "$with_oauth" = yes ; then
+ AC_CHECK_LIB(iddawc, i_init_session, [], [AC_MSG_ERROR([library 'iddawc' is required for OAuth support])])
+fi
+
# for contrib/sepgsql
if test "$with_selinux" = yes; then
AC_CHECK_LIB(selinux, security_compute_create_name, [],
@@ -1606,6 +1621,10 @@ elif test "$with_uuid" = ossp ; then
[AC_MSG_ERROR([header file <ossp/uuid.h> or <uuid.h> is required for OSSP UUID])])])
fi
+if test "$with_oauth" != no; then
+ AC_CHECK_HEADER(iddawc.h, [], [AC_MSG_ERROR([header file <iddawc.h> is required for OAuth])])
+fi
+
if test "$PORTNAME" = "win32" ; then
AC_CHECK_HEADERS(crtdefs.h)
fi
diff --git a/src/Makefile.global.in b/src/Makefile.global.in
index bbdc1c4bda..c9c61a9c99 100644
--- a/src/Makefile.global.in
+++ b/src/Makefile.global.in
@@ -193,6 +193,7 @@ with_ldap = @with_ldap@
with_libxml = @with_libxml@
with_libxslt = @with_libxslt@
with_llvm = @with_llvm@
+with_oauth = @with_oauth@
with_system_tzdata = @with_system_tzdata@
with_uuid = @with_uuid@
with_zlib = @with_zlib@
diff --git a/src/include/common/oauth-common.h b/src/include/common/oauth-common.h
new file mode 100644
index 0000000000..3fa95ac7e8
--- /dev/null
+++ b/src/include/common/oauth-common.h
@@ -0,0 +1,19 @@
+/*-------------------------------------------------------------------------
+ *
+ * oauth-common.h
+ * Declarations for helper functions used for OAuth/OIDC authentication
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/common/oauth-common.h
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef OAUTH_COMMON_H
+#define OAUTH_COMMON_H
+
+/* Name of SASL mechanism per IANA */
+#define OAUTHBEARER_NAME "OAUTHBEARER"
+
+#endif /* OAUTH_COMMON_H */
diff --git a/src/include/pg_config.h.in b/src/include/pg_config.h.in
index 635fbb2181..1b3332601e 100644
--- a/src/include/pg_config.h.in
+++ b/src/include/pg_config.h.in
@@ -319,6 +319,9 @@
/* Define to 1 if you have the `crypto' library (-lcrypto). */
#undef HAVE_LIBCRYPTO
+/* Define to 1 if you have the `iddawc' library (-liddawc). */
+#undef HAVE_LIBIDDAWC
+
/* Define to 1 if you have the `ldap' library (-lldap). */
#undef HAVE_LIBLDAP
@@ -922,6 +925,9 @@
/* Define to select named POSIX semaphores. */
#undef USE_NAMED_POSIX_SEMAPHORES
+/* Define to 1 to build with OAuth 2.0 support. (--with-oauth) */
+#undef USE_OAUTH
+
/* Define to 1 to build with OpenSSL support. (--with-ssl=openssl) */
#undef USE_OPENSSL
diff --git a/src/interfaces/libpq/Makefile b/src/interfaces/libpq/Makefile
index 3c53393fa4..727305c578 100644
--- a/src/interfaces/libpq/Makefile
+++ b/src/interfaces/libpq/Makefile
@@ -62,6 +62,11 @@ OBJS += \
fe-secure-gssapi.o
endif
+ifeq ($(with_oauth),yes)
+OBJS += \
+ fe-auth-oauth.o
+endif
+
ifeq ($(PORTNAME), cygwin)
override shlib = cyg$(NAME)$(DLSUFFIX)
endif
@@ -83,7 +88,7 @@ endif
# that are built correctly for use in a shlib.
SHLIB_LINK_INTERNAL = -lpgcommon_shlib -lpgport_shlib
ifneq ($(PORTNAME), win32)
-SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi_krb5 -lgss -lgssapi -lssl -lsocket -lnsl -lresolv -lintl -lm, $(LIBS)) $(LDAP_LIBS_FE) $(PTHREAD_LIBS)
+SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi_krb5 -lgss -lgssapi -lssl -liddawc -lsocket -lnsl -lresolv -lintl -lm, $(LIBS)) $(LDAP_LIBS_FE) $(PTHREAD_LIBS)
else
SHLIB_LINK += $(filter -lcrypt -ldes -lcom_err -lcrypto -lk5crypto -lkrb5 -lgssapi32 -lssl -lsocket -lnsl -lresolv -lintl -lm $(PTHREAD_LIBS), $(LIBS)) $(LDAP_LIBS_FE)
endif
diff --git a/src/interfaces/libpq/fe-auth-oauth.c b/src/interfaces/libpq/fe-auth-oauth.c
new file mode 100644
index 0000000000..383c9d4bdb
--- /dev/null
+++ b/src/interfaces/libpq/fe-auth-oauth.c
@@ -0,0 +1,744 @@
+/*-------------------------------------------------------------------------
+ *
+ * fe-auth-oauth.c
+ * The front-end (client) implementation of OAuth/OIDC authentication.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * IDENTIFICATION
+ * src/interfaces/libpq/fe-auth-oauth.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include <iddawc.h>
+
+#include "postgres_fe.h"
+
+#include "common/base64.h"
+#include "common/hmac.h"
+#include "common/jsonapi.h"
+#include "common/oauth-common.h"
+#include "fe-auth.h"
+#include "mb/pg_wchar.h"
+
+/* The exported OAuth callback mechanism. */
+static void *oauth_init(PGconn *conn, const char *password,
+ const char *sasl_mechanism);
+static void oauth_exchange(void *opaq, bool final,
+ char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success);
+static bool oauth_channel_bound(void *opaq);
+static void oauth_free(void *opaq);
+
+const pg_fe_sasl_mech pg_oauth_mech = {
+ oauth_init,
+ oauth_exchange,
+ oauth_channel_bound,
+ oauth_free,
+};
+
+typedef enum
+{
+ FE_OAUTH_INIT,
+ FE_OAUTH_BEARER_SENT,
+ FE_OAUTH_SERVER_ERROR,
+} fe_oauth_state_enum;
+
+typedef struct
+{
+ fe_oauth_state_enum state;
+
+ PGconn *conn;
+} fe_oauth_state;
+
+static void *
+oauth_init(PGconn *conn, const char *password,
+ const char *sasl_mechanism)
+{
+ fe_oauth_state *state;
+
+ /*
+ * We only support one SASL mechanism here; anything else is programmer
+ * error.
+ */
+ Assert(sasl_mechanism != NULL);
+ Assert(!strcmp(sasl_mechanism, OAUTHBEARER_NAME));
+
+ state = malloc(sizeof(*state));
+ if (!state)
+ return NULL;
+
+ state->state = FE_OAUTH_INIT;
+ state->conn = conn;
+
+ return state;
+}
+
+static const char *
+iddawc_error_string(int errcode)
+{
+ switch (errcode)
+ {
+ case I_OK:
+ return "I_OK";
+
+ case I_ERROR:
+ return "I_ERROR";
+
+ case I_ERROR_PARAM:
+ return "I_ERROR_PARAM";
+
+ case I_ERROR_MEMORY:
+ return "I_ERROR_MEMORY";
+
+ case I_ERROR_UNAUTHORIZED:
+ return "I_ERROR_UNAUTHORIZED";
+
+ case I_ERROR_SERVER:
+ return "I_ERROR_SERVER";
+ }
+
+ return "<unknown>";
+}
+
+static void
+iddawc_error(PGconn *conn, int errcode, const char *msg)
+{
+ appendPQExpBufferStr(&conn->errorMessage, libpq_gettext(msg));
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext(" (iddawc error %s)\n"),
+ iddawc_error_string(errcode));
+}
+
+static void
+iddawc_request_error(PGconn *conn, struct _i_session *i, int err, const char *msg)
+{
+ const char *error_code;
+ const char *desc;
+
+ appendPQExpBuffer(&conn->errorMessage, "%s: ", libpq_gettext(msg));
+
+ error_code = i_get_str_parameter(i, I_OPT_ERROR);
+ if (!error_code)
+ {
+ /*
+ * The server didn't give us any useful information, so just print the
+ * error code.
+ */
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("(iddawc error %s)\n"),
+ iddawc_error_string(err));
+ return;
+ }
+
+ /* If the server gave a string description, print that too. */
+ desc = i_get_str_parameter(i, I_OPT_ERROR_DESCRIPTION);
+ if (desc)
+ appendPQExpBuffer(&conn->errorMessage, "%s ", desc);
+
+ appendPQExpBuffer(&conn->errorMessage, "(%s)\n", error_code);
+}
+
+static char *
+get_auth_token(PGconn *conn)
+{
+ PQExpBuffer token_buf = NULL;
+ struct _i_session session;
+ int err;
+ int auth_method;
+ bool user_prompted = false;
+ const char *verification_uri;
+ const char *user_code;
+ const char *access_token;
+ const char *token_type;
+ char *token = NULL;
+
+ if (!conn->oauth_discovery_uri)
+ return strdup(""); /* ask the server for one */
+
+ if (!conn->oauth_client_id)
+ {
+ /* We can't talk to a server without a client identifier. */
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("no oauth_client_id is set for the connection"));
+ return NULL;
+ }
+
+ i_init_session(&session);
+
+ token_buf = createPQExpBuffer();
+ if (!token_buf)
+ goto cleanup;
+
+ err = i_set_str_parameter(&session, I_OPT_OPENID_CONFIG_ENDPOINT, conn->oauth_discovery_uri);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set OpenID config endpoint");
+ goto cleanup;
+ }
+
+ err = i_get_openid_config(&session);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to fetch OpenID discovery document");
+ goto cleanup;
+ }
+
+ if (!i_get_str_parameter(&session, I_OPT_TOKEN_ENDPOINT))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer has no token endpoint"));
+ goto cleanup;
+ }
+
+ if (!i_get_str_parameter(&session, I_OPT_DEVICE_AUTHORIZATION_ENDPOINT))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer does not support device authorization"));
+ goto cleanup;
+ }
+
+ err = i_set_response_type(&session, I_RESPONSE_TYPE_DEVICE_CODE);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set device code response type");
+ goto cleanup;
+ }
+
+ auth_method = I_TOKEN_AUTH_METHOD_NONE;
+ if (conn->oauth_client_secret && *conn->oauth_client_secret)
+ auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
+
+ err = i_set_parameter_list(&session,
+ I_OPT_CLIENT_ID, conn->oauth_client_id,
+ I_OPT_CLIENT_SECRET, conn->oauth_client_secret,
+ I_OPT_TOKEN_METHOD, auth_method,
+ I_OPT_SCOPE, conn->oauth_scope,
+ I_OPT_NONE
+ );
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set client identifier");
+ goto cleanup;
+ }
+
+ err = i_run_device_auth_request(&session);
+ if (err)
+ {
+ iddawc_request_error(conn, &session, err,
+ "failed to obtain device authorization");
+ goto cleanup;
+ }
+
+ verification_uri = i_get_str_parameter(&session, I_OPT_DEVICE_AUTH_VERIFICATION_URI);
+ if (!verification_uri)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a verification URI"));
+ goto cleanup;
+ }
+
+ user_code = i_get_str_parameter(&session, I_OPT_DEVICE_AUTH_USER_CODE);
+ if (!user_code)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a user code"));
+ goto cleanup;
+ }
+
+ /*
+ * Poll the token endpoint until either the user logs in and authorizes the
+ * use of a token, or a hard failure occurs. We perform one ping _before_
+ * prompting the user, so that we don't make them do the work of logging in
+ * only to find that the token endpoint is completely unreachable.
+ */
+ err = i_run_token_request(&session);
+ while (err)
+ {
+ const char *error_code;
+ uint interval;
+
+ error_code = i_get_str_parameter(&session, I_OPT_ERROR);
+
+ /*
+ * authorization_pending and slow_down are the only acceptable errors;
+ * anything else and we bail.
+ */
+ if (!error_code || (strcmp(error_code, "authorization_pending")
+ && strcmp(error_code, "slow_down")))
+ {
+ iddawc_request_error(conn, &session, err,
+ "OAuth token retrieval failed");
+ goto cleanup;
+ }
+
+ if (!user_prompted)
+ {
+ /*
+ * Now that we know the token endpoint isn't broken, give the user
+ * the login instructions.
+ */
+ pqInternalNotice(&conn->noticeHooks,
+ "Visit %s and enter the code: %s",
+ verification_uri, user_code);
+
+ user_prompted = true;
+ }
+
+ /*
+ * We are required to wait between polls; the server tells us how long.
+ * TODO: if interval's not set, we need to default to five seconds
+ * TODO: sanity check the interval
+ */
+ interval = i_get_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL);
+
+ /*
+ * A slow_down error requires us to permanently increase our retry
+ * interval by five seconds. RFC 8628, Sec. 3.5.
+ */
+ if (!strcmp(error_code, "slow_down"))
+ {
+ interval += 5;
+ i_set_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL, interval);
+ }
+
+ sleep(interval);
+
+ /*
+ * XXX Reset the error code before every call, because iddawc won't do
+ * that for us. This matters if the server first sends a "pending" error
+ * code, then later hard-fails without sending an error code to
+ * overwrite the first one.
+ *
+ * That we have to do this at all seems like a bug in iddawc.
+ */
+ i_set_str_parameter(&session, I_OPT_ERROR, NULL);
+
+ err = i_run_token_request(&session);
+ }
+
+ access_token = i_get_str_parameter(&session, I_OPT_ACCESS_TOKEN);
+ token_type = i_get_str_parameter(&session, I_OPT_TOKEN_TYPE);
+
+ if (!access_token || !token_type || strcasecmp(token_type, "Bearer"))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a bearer token"));
+ goto cleanup;
+ }
+
+ appendPQExpBufferStr(token_buf, "Bearer ");
+ appendPQExpBufferStr(token_buf, access_token);
+
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+
+ token = strdup(token_buf->data);
+
+cleanup:
+ if (token_buf)
+ destroyPQExpBuffer(token_buf);
+ i_clean_session(&session);
+
+ return token;
+}
+
+#define kvsep "\x01"
+
+static char *
+client_initial_response(PGconn *conn)
+{
+ static const char * const resp_format = "n,," kvsep "auth=%s" kvsep kvsep;
+
+ PQExpBuffer token_buf;
+ PQExpBuffer discovery_buf = NULL;
+ char *token = NULL;
+ char *response = NULL;
+
+ token_buf = createPQExpBuffer();
+ if (!token_buf)
+ goto cleanup;
+
+ /*
+ * If we don't yet have a discovery URI, but the user gave us an explicit
+ * issuer, use the .well-known discovery URI for that issuer.
+ */
+ if (!conn->oauth_discovery_uri && conn->oauth_issuer)
+ {
+ discovery_buf = createPQExpBuffer();
+ if (!discovery_buf)
+ goto cleanup;
+
+ appendPQExpBufferStr(discovery_buf, conn->oauth_issuer);
+ appendPQExpBufferStr(discovery_buf, "/.well-known/openid-configuration");
+
+ if (PQExpBufferBroken(discovery_buf))
+ goto cleanup;
+
+ conn->oauth_discovery_uri = strdup(discovery_buf->data);
+ }
+
+ token = get_auth_token(conn);
+ if (!token)
+ goto cleanup;
+
+ appendPQExpBuffer(token_buf, resp_format, token);
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+
+ response = strdup(token_buf->data);
+
+cleanup:
+ if (token)
+ free(token);
+ if (discovery_buf)
+ destroyPQExpBuffer(discovery_buf);
+ if (token_buf)
+ destroyPQExpBuffer(token_buf);
+
+ return response;
+}
+
+#define ERROR_STATUS_FIELD "status"
+#define ERROR_SCOPE_FIELD "scope"
+#define ERROR_OPENID_CONFIGURATION_FIELD "openid-configuration"
+
+struct json_ctx
+{
+ char *errmsg; /* any non-NULL value stops all processing */
+ PQExpBufferData errbuf; /* backing memory for errmsg */
+ int nested; /* nesting level (zero is the top) */
+
+ const char *target_field_name; /* points to a static allocation */
+ char **target_field; /* see below */
+
+ /* target_field, if set, points to one of the following: */
+ char *status;
+ char *scope;
+ char *discovery_uri;
+};
+
+#define oauth_json_has_error(ctx) \
+ (PQExpBufferDataBroken((ctx)->errbuf) || (ctx)->errmsg)
+
+#define oauth_json_set_error(ctx, ...) \
+ do { \
+ appendPQExpBuffer(&(ctx)->errbuf, __VA_ARGS__); \
+ (ctx)->errmsg = (ctx)->errbuf.data; \
+ } while (0)
+
+static void
+oauth_json_object_start(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ return; /* short-circuit */
+
+ if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ oauth_json_set_error(ctx,
+ libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+
+ ++ctx->nested;
+}
+
+static void
+oauth_json_object_end(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ return; /* short-circuit */
+
+ --ctx->nested;
+}
+
+static void
+oauth_json_object_field_start(void *state, char *name, bool isnull)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ {
+ /* short-circuit */
+ free(name);
+ return;
+ }
+
+ if (ctx->nested == 1)
+ {
+ if (!strcmp(name, ERROR_STATUS_FIELD))
+ {
+ ctx->target_field_name = ERROR_STATUS_FIELD;
+ ctx->target_field = &ctx->status;
+ }
+ else if (!strcmp(name, ERROR_SCOPE_FIELD))
+ {
+ ctx->target_field_name = ERROR_SCOPE_FIELD;
+ ctx->target_field = &ctx->scope;
+ }
+ else if (!strcmp(name, ERROR_OPENID_CONFIGURATION_FIELD))
+ {
+ ctx->target_field_name = ERROR_OPENID_CONFIGURATION_FIELD;
+ ctx->target_field = &ctx->discovery_uri;
+ }
+ }
+
+ free(name);
+}
+
+static void
+oauth_json_array_start(void *state)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ return; /* short-circuit */
+
+ if (!ctx->nested)
+ {
+ ctx->errmsg = libpq_gettext("top-level element must be an object");
+ }
+ else if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ oauth_json_set_error(ctx,
+ libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+}
+
+static void
+oauth_json_scalar(void *state, char *token, JsonTokenType type)
+{
+ struct json_ctx *ctx = state;
+
+ if (oauth_json_has_error(ctx))
+ {
+ /* short-circuit */
+ free(token);
+ return;
+ }
+
+ if (!ctx->nested)
+ {
+ ctx->errmsg = libpq_gettext("top-level element must be an object");
+ }
+ else if (ctx->target_field)
+ {
+ Assert(ctx->nested == 1);
+
+ if (type == JSON_TOKEN_STRING)
+ {
+ *ctx->target_field = token;
+
+ ctx->target_field = NULL;
+ ctx->target_field_name = NULL;
+
+ return; /* don't free the token we're using */
+ }
+
+ oauth_json_set_error(ctx,
+ libpq_gettext("field \"%s\" must be a string"),
+ ctx->target_field_name);
+ }
+
+ free(token);
+}
+
+static bool
+handle_oauth_sasl_error(PGconn *conn, char *msg, int msglen)
+{
+ JsonLexContext lex = {0};
+ JsonSemAction sem = {0};
+ JsonParseErrorType err;
+ struct json_ctx ctx = {0};
+ char *errmsg = NULL;
+
+ /* Sanity check. */
+ if (strlen(msg) != msglen)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("server's error message contained an embedded NULL"));
+ return false;
+ }
+
+ initJsonLexContextCstringLen(&lex, msg, msglen, PG_UTF8, true);
+
+ initPQExpBuffer(&ctx.errbuf);
+ sem.semstate = &ctx;
+
+ sem.object_start = oauth_json_object_start;
+ sem.object_end = oauth_json_object_end;
+ sem.object_field_start = oauth_json_object_field_start;
+ sem.array_start = oauth_json_array_start;
+ sem.scalar = oauth_json_scalar;
+
+ err = pg_parse_json(&lex, &sem);
+
+ if (err != JSON_SUCCESS)
+ {
+ errmsg = json_errdetail(err, &lex);
+ }
+ else if (PQExpBufferDataBroken(ctx.errbuf))
+ {
+ errmsg = libpq_gettext("out of memory");
+ }
+ else if (ctx.errmsg)
+ {
+ errmsg = ctx.errmsg;
+ }
+
+ if (errmsg)
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("failed to parse server's error response: %s"),
+ errmsg);
+
+ /* Don't need the error buffer or the JSON lexer anymore. */
+ termPQExpBuffer(&ctx.errbuf);
+ termJsonLexContext(&lex);
+
+ if (errmsg)
+ return false;
+
+ /* TODO: what if these override what the user already specified? */
+ if (ctx.discovery_uri)
+ {
+ if (conn->oauth_discovery_uri)
+ free(conn->oauth_discovery_uri);
+
+ conn->oauth_discovery_uri = ctx.discovery_uri;
+ }
+
+ if (ctx.scope)
+ {
+ if (conn->oauth_scope)
+ free(conn->oauth_scope);
+
+ conn->oauth_scope = ctx.scope;
+ }
+ /* TODO: missing error scope should clear any existing connection scope */
+
+ if (!ctx.status)
+ {
+ appendPQExpBuffer(&conn->errorMessage,
+ libpq_gettext("server sent error response without a status"));
+ return false;
+ }
+
+ if (!strcmp(ctx.status, "invalid_token"))
+ {
+ /*
+ * invalid_token is the only error code we'll automatically retry for,
+ * but only if we have enough information to do so.
+ */
+ if (conn->oauth_discovery_uri)
+ conn->oauth_want_retry = true;
+ }
+ /* TODO: include status in hard failure message */
+
+ return true;
+}
+
+static void
+oauth_exchange(void *opaq, bool final,
+ char *input, int inputlen,
+ char **output, int *outputlen,
+ bool *done, bool *success)
+{
+ fe_oauth_state *state = opaq;
+ PGconn *conn = state->conn;
+
+ *done = false;
+ *success = false;
+ *output = NULL;
+ *outputlen = 0;
+
+ switch (state->state)
+ {
+ case FE_OAUTH_INIT:
+ Assert(inputlen == -1);
+
+ *output = client_initial_response(conn);
+ if (!*output)
+ goto error;
+
+ *outputlen = strlen(*output);
+ state->state = FE_OAUTH_BEARER_SENT;
+
+ break;
+
+ case FE_OAUTH_BEARER_SENT:
+ if (final)
+ {
+ /* TODO: ensure there is no message content here. */
+ *done = true;
+ *success = true;
+
+ break;
+ }
+
+ /*
+ * Error message sent by the server.
+ */
+ if (!handle_oauth_sasl_error(conn, input, inputlen))
+ goto error;
+
+ /*
+ * Respond with the required dummy message (RFC 7628, sec. 3.2.3).
+ */
+ *output = strdup(kvsep);
+ *outputlen = strlen(*output); /* == 1 */
+
+ state->state = FE_OAUTH_SERVER_ERROR;
+ break;
+
+ case FE_OAUTH_SERVER_ERROR:
+ /*
+ * After an error, the server should send an error response to fail
+ * the SASL handshake, which is handled in higher layers.
+ *
+ * If we get here, the server either sent *another* challenge which
+ * isn't defined in the RFC, or completed the handshake successfully
+ * after telling us it was going to fail. Neither is acceptable.
+ */
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("server sent additional OAuth data after error\n"));
+ goto error;
+
+ default:
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("invalid OAuth exchange state\n"));
+ goto error;
+ }
+
+ return;
+
+error:
+ *done = true;
+ *success = false;
+}
+
+static bool
+oauth_channel_bound(void *opaq)
+{
+ /* This mechanism does not support channel binding. */
+ return false;
+}
+
+static void
+oauth_free(void *opaq)
+{
+ fe_oauth_state *state = opaq;
+
+ free(state);
+}
diff --git a/src/interfaces/libpq/fe-auth-sasl.h b/src/interfaces/libpq/fe-auth-sasl.h
index da3c30b87b..b1bb382f70 100644
--- a/src/interfaces/libpq/fe-auth-sasl.h
+++ b/src/interfaces/libpq/fe-auth-sasl.h
@@ -65,6 +65,8 @@ typedef struct pg_fe_sasl_mech
*
* state: The opaque mechanism state returned by init()
*
+ * final: true if the server has sent a final exchange outcome
+ *
* input: The challenge data sent by the server, or NULL when
* generating a client-first initial response (that is, when
* the server expects the client to send a message to start
@@ -92,7 +94,8 @@ typedef struct pg_fe_sasl_mech
* Ignored if *done is false.
*--------
*/
- void (*exchange) (void *state, char *input, int inputlen,
+ void (*exchange) (void *state, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success);
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index e616200704..681b76adbe 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -24,7 +24,8 @@
/* The exported SCRAM callback mechanism. */
static void *scram_init(PGconn *conn, const char *password,
const char *sasl_mechanism);
-static void scram_exchange(void *opaq, char *input, int inputlen,
+static void scram_exchange(void *opaq, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success);
static bool scram_channel_bound(void *opaq);
@@ -206,7 +207,8 @@ scram_free(void *opaq)
* Exchange a SCRAM message with backend.
*/
static void
-scram_exchange(void *opaq, char *input, int inputlen,
+scram_exchange(void *opaq, bool final,
+ char *input, int inputlen,
char **output, int *outputlen,
bool *done, bool *success)
{
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 6fceff561b..2567a34023 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -38,6 +38,7 @@
#endif
#include "common/md5.h"
+#include "common/oauth-common.h"
#include "common/scram-common.h"
#include "fe-auth.h"
#include "fe-auth-sasl.h"
@@ -422,7 +423,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
bool success;
const char *selected_mechanism;
PQExpBufferData mechanism_buf;
- char *password;
+ char *password = NULL;
initPQExpBuffer(&mechanism_buf);
@@ -444,8 +445,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
/*
* Parse the list of SASL authentication mechanisms in the
* AuthenticationSASL message, and select the best mechanism that we
- * support. SCRAM-SHA-256-PLUS and SCRAM-SHA-256 are the only ones
- * supported at the moment, listed by order of decreasing importance.
+ * support. Mechanisms are listed by order of decreasing importance.
*/
selected_mechanism = NULL;
for (;;)
@@ -485,6 +485,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
{
selected_mechanism = SCRAM_SHA_256_PLUS_NAME;
conn->sasl = &pg_scram_mech;
+ conn->password_needed = true;
}
#else
/*
@@ -522,7 +523,17 @@ pg_SASL_init(PGconn *conn, int payloadlen)
{
selected_mechanism = SCRAM_SHA_256_NAME;
conn->sasl = &pg_scram_mech;
+ conn->password_needed = true;
}
+#ifdef USE_OAUTH
+ else if (strcmp(mechanism_buf.data, OAUTHBEARER_NAME) == 0 &&
+ !selected_mechanism)
+ {
+ selected_mechanism = OAUTHBEARER_NAME;
+ conn->sasl = &pg_oauth_mech;
+ conn->password_needed = false;
+ }
+#endif
}
if (!selected_mechanism)
@@ -547,18 +558,19 @@ pg_SASL_init(PGconn *conn, int payloadlen)
/*
* First, select the password to use for the exchange, complaining if
- * there isn't one. Currently, all supported SASL mechanisms require a
- * password, so we can just go ahead here without further distinction.
+ * there isn't one and the SASL mechanism needs it.
*/
- conn->password_needed = true;
- password = conn->connhost[conn->whichhost].password;
- if (password == NULL)
- password = conn->pgpass;
- if (password == NULL || password[0] == '\0')
+ if (conn->password_needed)
{
- appendPQExpBufferStr(&conn->errorMessage,
- PQnoPasswordSupplied);
- goto error;
+ password = conn->connhost[conn->whichhost].password;
+ if (password == NULL)
+ password = conn->pgpass;
+ if (password == NULL || password[0] == '\0')
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ PQnoPasswordSupplied);
+ goto error;
+ }
}
Assert(conn->sasl);
@@ -576,7 +588,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
goto oom_error;
/* Get the mechanism-specific Initial Client Response, if any */
- conn->sasl->exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state, false,
NULL, -1,
&initialresponse, &initialresponselen,
&done, &success);
@@ -657,7 +669,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
/* For safety and convenience, ensure the buffer is NULL-terminated. */
challenge[payloadlen] = '\0';
- conn->sasl->exchange(conn->sasl_state,
+ conn->sasl->exchange(conn->sasl_state, final,
challenge, payloadlen,
&output, &outputlen,
&done, &success);
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 049a8bb1a1..2a56774019 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -28,4 +28,7 @@ extern const pg_fe_sasl_mech pg_scram_mech;
extern char *pg_fe_scram_build_secret(const char *password,
const char **errstr);
+/* Mechanisms in fe-auth-oauth.c */
+extern const pg_fe_sasl_mech pg_oauth_mech;
+
#endif /* FE_AUTH_H */
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index cf554d389f..fdd30d71de 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -344,6 +344,23 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
"Target-Session-Attrs", "", 15, /* sizeof("prefer-standby") = 15 */
offsetof(struct pg_conn, target_session_attrs)},
+ /* OAuth v2 */
+ {"oauth_issuer", NULL, NULL, NULL,
+ "OAuth-Issuer", "", 40,
+ offsetof(struct pg_conn, oauth_issuer)},
+
+ {"oauth_client_id", NULL, NULL, NULL,
+ "OAuth-Client-ID", "", 40,
+ offsetof(struct pg_conn, oauth_client_id)},
+
+ {"oauth_client_secret", NULL, NULL, NULL,
+ "OAuth-Client-Secret", "", 40,
+ offsetof(struct pg_conn, oauth_client_secret)},
+
+ {"oauth_scope", NULL, NULL, NULL,
+ "OAuth-Scope", "", 15,
+ offsetof(struct pg_conn, oauth_scope)},
+
/* Terminating entry --- MUST BE LAST */
{NULL, NULL, NULL, NULL,
NULL, NULL, 0}
@@ -606,6 +623,7 @@ pqDropServerData(PGconn *conn)
conn->write_err_msg = NULL;
conn->be_pid = 0;
conn->be_key = 0;
+ /* conn->oauth_want_retry = false; TODO */
}
@@ -3386,6 +3404,16 @@ keep_going: /* We will come back to here until there is
/* Check to see if we should mention pgpassfile */
pgpassfileWarning(conn);
+#ifdef USE_OAUTH
+ if (conn->sasl == &pg_oauth_mech
+ && conn->oauth_want_retry)
+ {
+ /* TODO: only allow retry once */
+ need_new_connection = true;
+ goto keep_going;
+ }
+#endif
+
#ifdef ENABLE_GSS
/*
@@ -4166,6 +4194,16 @@ freePGconn(PGconn *conn)
free(conn->rowBuf);
if (conn->target_session_attrs)
free(conn->target_session_attrs);
+ if (conn->oauth_issuer)
+ free(conn->oauth_issuer);
+ if (conn->oauth_discovery_uri)
+ free(conn->oauth_discovery_uri);
+ if (conn->oauth_client_id)
+ free(conn->oauth_client_id);
+ if (conn->oauth_client_secret)
+ free(conn->oauth_client_secret);
+ if (conn->oauth_scope)
+ free(conn->oauth_scope);
termPQExpBuffer(&conn->errorMessage);
termPQExpBuffer(&conn->workBuffer);
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index e0cee4b142..0dff13505a 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -394,6 +394,14 @@ struct pg_conn
char *ssl_max_protocol_version; /* maximum TLS protocol version */
char *target_session_attrs; /* desired session properties */
+ /* OAuth v2 */
+ char *oauth_issuer; /* token issuer URL */
+ char *oauth_discovery_uri; /* URI of the issuer's discovery document */
+ char *oauth_client_id; /* client identifier */
+ char *oauth_client_secret; /* client secret */
+ char *oauth_scope; /* access token scope */
+ bool oauth_want_retry; /* should we retry on failure? */
+
/* Optional file to write trace info to */
FILE *Pfdebug;
int traceFlags;
--
2.25.1
v4-0009-Add-pytest-suite-for-OAuth.patchtext/x-patch; name=v4-0009-Add-pytest-suite-for-OAuth.patchDownload
From 6d8fd9e5b352fd0847c9454ced2b763a6b11e73f Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Fri, 4 Jun 2021 09:06:38 -0700
Subject: [PATCH v4 09/10] Add pytest suite for OAuth
Requires Python 3; on the first run of `make installcheck` the
dependencies will be installed into ./venv for you. See the README for
more details.
---
src/test/python/.gitignore | 2 +
src/test/python/Makefile | 38 +
src/test/python/README | 54 ++
src/test/python/client/__init__.py | 0
src/test/python/client/conftest.py | 126 +++
src/test/python/client/test_client.py | 180 ++++
src/test/python/client/test_oauth.py | 936 ++++++++++++++++++
src/test/python/pq3.py | 727 ++++++++++++++
src/test/python/pytest.ini | 4 +
src/test/python/requirements.txt | 7 +
src/test/python/server/__init__.py | 0
src/test/python/server/conftest.py | 45 +
src/test/python/server/test_oauth.py | 1012 ++++++++++++++++++++
src/test/python/server/test_server.py | 21 +
src/test/python/server/validate_bearer.py | 101 ++
src/test/python/server/validate_reflect.py | 34 +
src/test/python/test_internals.py | 138 +++
src/test/python/test_pq3.py | 558 +++++++++++
src/test/python/tls.py | 195 ++++
19 files changed, 4178 insertions(+)
create mode 100644 src/test/python/.gitignore
create mode 100644 src/test/python/Makefile
create mode 100644 src/test/python/README
create mode 100644 src/test/python/client/__init__.py
create mode 100644 src/test/python/client/conftest.py
create mode 100644 src/test/python/client/test_client.py
create mode 100644 src/test/python/client/test_oauth.py
create mode 100644 src/test/python/pq3.py
create mode 100644 src/test/python/pytest.ini
create mode 100644 src/test/python/requirements.txt
create mode 100644 src/test/python/server/__init__.py
create mode 100644 src/test/python/server/conftest.py
create mode 100644 src/test/python/server/test_oauth.py
create mode 100644 src/test/python/server/test_server.py
create mode 100755 src/test/python/server/validate_bearer.py
create mode 100755 src/test/python/server/validate_reflect.py
create mode 100644 src/test/python/test_internals.py
create mode 100644 src/test/python/test_pq3.py
create mode 100644 src/test/python/tls.py
diff --git a/src/test/python/.gitignore b/src/test/python/.gitignore
new file mode 100644
index 0000000000..0e8f027b2e
--- /dev/null
+++ b/src/test/python/.gitignore
@@ -0,0 +1,2 @@
+__pycache__/
+/venv/
diff --git a/src/test/python/Makefile b/src/test/python/Makefile
new file mode 100644
index 0000000000..b0695b6287
--- /dev/null
+++ b/src/test/python/Makefile
@@ -0,0 +1,38 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+# Only Python 3 is supported, but if it's named something different on your
+# system you can override it with the PYTHON3 variable.
+PYTHON3 := python3
+
+# All dependencies are placed into this directory. The default is .gitignored
+# for you, but you can override it if you'd like.
+VENV := ./venv
+
+override VBIN := $(VENV)/bin
+override PIP := $(VBIN)/pip
+override PYTEST := $(VBIN)/py.test
+override ISORT := $(VBIN)/isort
+override BLACK := $(VBIN)/black
+
+.PHONY: installcheck indent
+
+installcheck: $(PYTEST)
+ $(PYTEST) -v -rs
+
+indent: $(ISORT) $(BLACK)
+ $(ISORT) --profile black *.py client/*.py server/*.py
+ $(BLACK) *.py client/*.py server/*.py
+
+$(PYTEST) $(ISORT) $(BLACK) &: requirements.txt | $(PIP)
+ $(PIP) install --force-reinstall -r $<
+
+$(PIP):
+ $(PYTHON3) -m venv $(VENV)
+
+# A convenience recipe to rebuild psycopg2 against the local libpq.
+.PHONY: rebuild-psycopg2
+rebuild-psycopg2: | $(PIP)
+ $(PIP) install --force-reinstall --no-binary :all: $(shell grep psycopg2 requirements.txt)
diff --git a/src/test/python/README b/src/test/python/README
new file mode 100644
index 0000000000..0bda582c4b
--- /dev/null
+++ b/src/test/python/README
@@ -0,0 +1,54 @@
+A test suite for exercising both the libpq client and the server backend at the
+protocol level, based on pytest and Construct.
+
+The test suite currently assumes that the standard PG* environment variables
+point to the database under test and are sufficient to log in a superuser on
+that system. In other words, a bare `psql` needs to Just Work before the test
+suite can do its thing. For a newly built dev cluster, typically all that I need
+to do is a
+
+ export PGDATABASE=postgres
+
+but you can adjust as needed for your setup.
+
+## Requirements
+
+A supported version (3.6+) of Python.
+
+The first run of
+
+ make installcheck
+
+will install a local virtual environment and all needed dependencies. During
+development, if libpq changes incompatibly, you can issue
+
+ $ make rebuild-psycopg2
+
+to force a rebuild of the client library.
+
+## Hacking
+
+The code style is enforced by a _very_ opinionated autoformatter. Running the
+
+ make indent
+
+recipe will invoke it for you automatically. Don't fight the tool; part of the
+zen is in knowing that if the formatter makes your code ugly, there's probably a
+cleaner way to write your code.
+
+## Advanced Usage
+
+The Makefile is there for convenience, but you don't have to use it. Activate
+the virtualenv to be able to use pytest directly:
+
+ $ source venv/bin/activate
+ $ py.test -k oauth
+ ...
+ $ py.test ./server/test_server.py
+ ...
+ $ deactivate # puts the PATH et al back the way it was before
+
+To make quick smoke tests possible, slow tests have been marked explicitly. You
+can skip them by saying e.g.
+
+ $ py.test -m 'not slow'
diff --git a/src/test/python/client/__init__.py b/src/test/python/client/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/client/conftest.py b/src/test/python/client/conftest.py
new file mode 100644
index 0000000000..f38da7a138
--- /dev/null
+++ b/src/test/python/client/conftest.py
@@ -0,0 +1,126 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import socket
+import sys
+import threading
+
+import psycopg2
+import pytest
+
+import pq3
+
+BLOCKING_TIMEOUT = 2 # the number of seconds to wait for blocking calls
+
+
+@pytest.fixture
+def server_socket(unused_tcp_port_factory):
+ """
+ Returns a listening socket bound to an ephemeral port.
+ """
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("127.0.0.1", unused_tcp_port_factory()))
+ s.listen(1)
+ s.settimeout(BLOCKING_TIMEOUT)
+ yield s
+
+
+class ClientHandshake(threading.Thread):
+ """
+ A thread that connects to a local Postgres server using psycopg2. Once the
+ opening handshake completes, the connection will be immediately closed.
+ """
+
+ def __init__(self, *, port, **kwargs):
+ super().__init__()
+
+ kwargs["port"] = port
+ self._kwargs = kwargs
+
+ self.exception = None
+
+ def run(self):
+ try:
+ conn = psycopg2.connect(host="127.0.0.1", **self._kwargs)
+ conn.close()
+ except Exception as e:
+ self.exception = e
+
+ def check_completed(self, timeout=BLOCKING_TIMEOUT):
+ """
+ Joins the client thread. Raises an exception if the thread could not be
+ joined, or if it threw an exception itself. (The exception will be
+ cleared, so future calls to check_completed will succeed.)
+ """
+ self.join(timeout)
+
+ if self.is_alive():
+ raise TimeoutError("client thread did not handshake within the timeout")
+ elif self.exception:
+ e = self.exception
+ self.exception = None
+ raise e
+
+
+@pytest.fixture
+def accept(server_socket):
+ """
+ Returns a factory function that, when called, returns a pair (sock, client)
+ where sock is a server socket that has accepted a connection from client,
+ and client is an instance of ClientHandshake. Clients will complete their
+ handshakes and cleanly disconnect.
+
+ The default connstring options may be extended or overridden by passing
+ arbitrary keyword arguments. Keep in mind that you generally should not
+ override the host or port, since they point to the local test server.
+
+ For situations where a client needs to connect more than once to complete a
+ handshake, the accept function may be called more than once. (The client
+ returned for subsequent calls will always be the same client that was
+ returned for the first call.)
+
+ Tests must either complete the handshake so that the client thread can be
+ automatically joined during teardown, or else call client.check_completed()
+ and manually handle any expected errors.
+ """
+ _, port = server_socket.getsockname()
+
+ client = None
+ default_opts = dict(
+ port=port,
+ user=pq3.pguser(),
+ sslmode="disable",
+ )
+
+ def factory(**kwargs):
+ nonlocal client
+
+ if client is None:
+ opts = dict(default_opts)
+ opts.update(kwargs)
+
+ # The server_socket is already listening, so the client thread can
+ # be safely started; it'll block on the connection until we accept.
+ client = ClientHandshake(**opts)
+ client.start()
+
+ sock, _ = server_socket.accept()
+ return sock, client
+
+ yield factory
+ client.check_completed()
+
+
+@pytest.fixture
+def conn(accept):
+ """
+ Returns an accepted, wrapped pq3 connection to a psycopg2 client. The socket
+ will be closed when the test finishes, and the client will be checked for a
+ cleanly completed handshake.
+ """
+ sock, client = accept()
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ yield conn
diff --git a/src/test/python/client/test_client.py b/src/test/python/client/test_client.py
new file mode 100644
index 0000000000..c4c946fda4
--- /dev/null
+++ b/src/test/python/client/test_client.py
@@ -0,0 +1,180 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import sys
+
+import psycopg2
+import pytest
+from cryptography.hazmat.primitives import hashes, hmac
+
+import pq3
+
+
+def finish_handshake(conn):
+ """
+ Sends the AuthenticationOK message and the standard opening salvo of server
+ messages, then asserts that the client immediately sends a Terminate message
+ to close the connection cleanly.
+ """
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+ pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.Terminate
+
+
+def test_handshake(conn):
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ finish_handshake(conn)
+
+
+def test_aborted_connection(accept):
+ """
+ Make sure the client correctly reports an early close during handshakes.
+ """
+ sock, client = accept()
+ sock.close()
+
+ expected = "server closed the connection unexpectedly"
+ with pytest.raises(psycopg2.OperationalError, match=expected):
+ client.check_completed()
+
+
+#
+# SCRAM-SHA-256 (see RFC 5802: https://tools.ietf.org/html/rfc5802)
+#
+
+
+@pytest.fixture
+def password():
+ """
+ Returns a password for use by both client and server.
+ """
+ # TODO: parameterize this with passwords that require SASLprep.
+ return "secret"
+
+
+@pytest.fixture
+def pwconn(accept, password):
+ """
+ Like the conn fixture, but uses a password in the connection.
+ """
+ sock, client = accept(password=password)
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ yield conn
+
+
+def sha256(data):
+ """The H(str) function from Section 2.2."""
+ digest = hashes.Hash(hashes.SHA256())
+ digest.update(data)
+ return digest.finalize()
+
+
+def hmac_256(key, data):
+ """The HMAC(key, str) function from Section 2.2."""
+ h = hmac.HMAC(key, hashes.SHA256())
+ h.update(data)
+ return h.finalize()
+
+
+def xor(a, b):
+ """The XOR operation from Section 2.2."""
+ res = bytearray(a)
+ for i, byte in enumerate(b):
+ res[i] ^= byte
+ return bytes(res)
+
+
+def h_i(data, salt, i):
+ """The Hi(str, salt, i) function from Section 2.2."""
+ assert i > 0
+
+ acc = hmac_256(data, salt + b"\x00\x00\x00\x01")
+ last = acc
+ i -= 1
+
+ while i:
+ u = hmac_256(data, last)
+ acc = xor(acc, u)
+
+ last = u
+ i -= 1
+
+ return acc
+
+
+def test_scram(pwconn, password):
+ startup = pq3.recv1(pwconn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ pwconn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"SCRAM-SHA-256", b""],
+ )
+
+ # Get the client-first-message.
+ pkt = pq3.recv1(pwconn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ initial = pq3.SASLInitialResponse.parse(pkt.payload)
+ assert initial.name == b"SCRAM-SHA-256"
+
+ c_bind, authzid, c_name, c_nonce = initial.data.split(b",")
+ assert c_bind == b"n" # no channel bindings on a plaintext connection
+ assert authzid == b"" # we don't support authzid currently
+ assert c_name == b"n=" # libpq doesn't honor the GS2 username
+ assert c_nonce.startswith(b"r=")
+
+ # Send the server-first-message.
+ salt = b"12345"
+ iterations = 2
+
+ s_nonce = c_nonce + b"somenonce"
+ s_salt = b"s=" + base64.b64encode(salt)
+ s_iterations = b"i=%d" % iterations
+
+ msg = b",".join([s_nonce, s_salt, s_iterations])
+ pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=msg)
+
+ # Get the client-final-message.
+ pkt = pq3.recv1(pwconn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ c_bind_final, c_nonce_final, c_proof = pkt.payload.split(b",")
+ assert c_bind_final == b"c=" + base64.b64encode(c_bind + b"," + authzid + b",")
+ assert c_nonce_final == s_nonce
+
+ # Calculate what the client proof should be.
+ salted_password = h_i(password.encode("ascii"), salt, iterations)
+ client_key = hmac_256(salted_password, b"Client Key")
+ stored_key = sha256(client_key)
+
+ auth_message = b",".join(
+ [c_name, c_nonce, s_nonce, s_salt, s_iterations, c_bind_final, c_nonce_final]
+ )
+ client_signature = hmac_256(stored_key, auth_message)
+ client_proof = xor(client_key, client_signature)
+
+ expected = b"p=" + base64.b64encode(client_proof)
+ assert c_proof == expected
+
+ # Send the correct server signature.
+ server_key = hmac_256(salted_password, b"Server Key")
+ server_signature = hmac_256(server_key, auth_message)
+
+ s_verify = b"v=" + base64.b64encode(server_signature)
+ pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal, body=s_verify)
+
+ # Done!
+ finish_handshake(pwconn)
diff --git a/src/test/python/client/test_oauth.py b/src/test/python/client/test_oauth.py
new file mode 100644
index 0000000000..a754a9c0b6
--- /dev/null
+++ b/src/test/python/client/test_oauth.py
@@ -0,0 +1,936 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import http.server
+import json
+import secrets
+import sys
+import threading
+import time
+import urllib.parse
+
+import psycopg2
+import pytest
+
+import pq3
+
+from .conftest import BLOCKING_TIMEOUT
+
+
+def finish_handshake(conn):
+ """
+ Sends the AuthenticationOK message and the standard opening salvo of server
+ messages, then asserts that the client immediately sends a Terminate message
+ to close the connection cleanly.
+ """
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+ pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+ pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.Terminate
+
+
+#
+# OAUTHBEARER (see RFC 7628: https://tools.ietf.org/html/rfc7628)
+#
+
+
+def start_oauth_handshake(conn):
+ """
+ Negotiates an OAUTHBEARER SASL challenge. Returns the client's initial
+ response data.
+ """
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn, pq3.types.AuthnRequest, type=pq3.authn.SASL, body=[b"OAUTHBEARER", b""]
+ )
+
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+
+ initial = pq3.SASLInitialResponse.parse(pkt.payload)
+ assert initial.name == b"OAUTHBEARER"
+
+ return initial.data
+
+
+def get_auth_value(initial):
+ """
+ Finds the auth value (e.g. "Bearer somedata..." in the client's initial SASL
+ response.
+ """
+ kvpairs = initial.split(b"\x01")
+ assert kvpairs[0] == b"n,," # no channel binding or authzid
+ assert kvpairs[2] == b"" # ends with an empty kvpair
+ assert kvpairs[3] == b"" # ...and there's nothing after it
+ assert len(kvpairs) == 4
+
+ key, value = kvpairs[1].split(b"=", 2)
+ assert key == b"auth"
+
+ return value
+
+
+def xtest_oauth_success(conn): # TODO
+ initial = start_oauth_handshake(conn)
+
+ auth = get_auth_value(initial)
+ assert auth.startswith(b"Bearer ")
+
+ # Accept the token. TODO actually validate
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+class OpenIDProvider(threading.Thread):
+ """
+ A thread that runs a mock OpenID provider server.
+ """
+
+ def __init__(self, *, port):
+ super().__init__()
+
+ self.exception = None
+
+ addr = ("", port)
+ self.server = self._Server(addr, self._Handler)
+
+ # TODO: allow HTTPS only, somehow
+ oauth = self._OAuthState()
+ oauth.host = f"localhost:{port}"
+ oauth.issuer = f"http://localhost:{port}"
+
+ # The following endpoints are required to be advertised by providers,
+ # even though our chosen client implementation does not actually make
+ # use of them.
+ oauth.register_endpoint(
+ "authorization_endpoint", "POST", "/authorize", self._authorization_handler
+ )
+ oauth.register_endpoint("jwks_uri", "GET", "/keys", self._jwks_handler)
+
+ self.server.oauth = oauth
+
+ def run(self):
+ try:
+ self.server.serve_forever()
+ except Exception as e:
+ self.exception = e
+
+ def stop(self, timeout=BLOCKING_TIMEOUT):
+ """
+ Shuts down the server and joins its thread. Raises an exception if the
+ thread could not be joined, or if it threw an exception itself. Must
+ only be called once, after start().
+ """
+ self.server.shutdown()
+ self.join(timeout)
+
+ if self.is_alive():
+ raise TimeoutError("client thread did not handshake within the timeout")
+ elif self.exception:
+ e = self.exception
+ raise e
+
+ class _OAuthState(object):
+ def __init__(self):
+ self.endpoint_paths = {}
+ self._endpoints = {}
+
+ def register_endpoint(self, name, method, path, func):
+ if method not in self._endpoints:
+ self._endpoints[method] = {}
+
+ self._endpoints[method][path] = func
+ self.endpoint_paths[name] = path
+
+ def endpoint(self, method, path):
+ if method not in self._endpoints:
+ return None
+
+ return self._endpoints[method].get(path)
+
+ class _Server(http.server.HTTPServer):
+ def handle_error(self, request, addr):
+ self.shutdown_request(request)
+ raise
+
+ @staticmethod
+ def _jwks_handler(headers, params):
+ return 200, {"keys": []}
+
+ @staticmethod
+ def _authorization_handler(headers, params):
+ # We don't actually want this to be called during these tests -- we
+ # should be using the device authorization endpoint instead.
+ assert (
+ False
+ ), "authorization handler called instead of device authorization handler"
+
+ class _Handler(http.server.BaseHTTPRequestHandler):
+ timeout = BLOCKING_TIMEOUT
+
+ def _discovery_handler(self, headers, params):
+ oauth = self.server.oauth
+
+ doc = {
+ "issuer": oauth.issuer,
+ "response_types_supported": ["token"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+
+ for name, path in oauth.endpoint_paths.items():
+ doc[name] = oauth.issuer + path
+
+ return 200, doc
+
+ def _handle(self, *, params=None, handler=None):
+ oauth = self.server.oauth
+ assert self.headers["Host"] == oauth.host
+
+ if handler is None:
+ handler = oauth.endpoint(self.command, self.path)
+ assert (
+ handler is not None
+ ), f"no registered endpoint for {self.command} {self.path}"
+
+ code, resp = handler(self.headers, params)
+
+ self.send_response(code)
+ self.send_header("Content-Type", "application/json")
+ self.end_headers()
+
+ resp = json.dumps(resp)
+ resp = resp.encode("utf-8")
+ self.wfile.write(resp)
+
+ self.close_connection = True
+
+ def do_GET(self):
+ if self.path == "/.well-known/openid-configuration":
+ self._handle(handler=self._discovery_handler)
+ return
+
+ self._handle()
+
+ def _request_body(self):
+ length = self.headers["Content-Length"]
+
+ # Handle only an explicit content-length.
+ assert length is not None
+ length = int(length)
+
+ return self.rfile.read(length).decode("utf-8")
+
+ def do_POST(self):
+ assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
+
+ body = self._request_body()
+ params = urllib.parse.parse_qs(body)
+
+ self._handle(params=params)
+
+
+@pytest.fixture
+def openid_provider(unused_tcp_port_factory):
+ """
+ A fixture that returns the OAuth state of a running OpenID provider server. The
+ server will be stopped when the fixture is torn down.
+ """
+ thread = OpenIDProvider(port=unused_tcp_port_factory())
+ thread.start()
+
+ try:
+ yield thread.server.oauth
+ finally:
+ thread.stop()
+
+
+@pytest.mark.parametrize("secret", [None, "", "hunter2"])
+@pytest.mark.parametrize("scope", [None, "", "openid email"])
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_with_explicit_issuer(
+ capfd, accept, openid_provider, retries, scope, secret
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ oauth_client_secret=secret,
+ oauth_scope=scope,
+ )
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+ verification_url = "https://example.com/device"
+
+ access_token = secrets.token_urlsafe()
+
+ def check_client_authn(headers, params):
+ if not secret:
+ assert params["client_id"] == [client_id]
+ return
+
+ # Require the client to use Basic authn; request-body credentials are
+ # NOT RECOMMENDED (RFC 6749, Sec. 2.3.1).
+ assert "Authorization" in headers
+
+ method, creds = headers["Authorization"].split()
+ assert method == "Basic"
+
+ expected = f"{client_id}:{secret}"
+ assert base64.b64decode(creds) == expected.encode("ascii")
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ check_client_authn(headers, params)
+
+ if scope:
+ assert params["scope"] == [scope]
+ else:
+ assert "scope" not in params
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": verification_url,
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ attempts = 0
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ check_client_authn(headers, params)
+
+ assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+ assert params["device_code"] == [device_code]
+
+ now = time.monotonic()
+
+ with retry_lock:
+ nonlocal attempts
+
+ # If the test wants to force the client to retry, return an
+ # authorization_pending response and decrement the retry count.
+ if attempts < retries:
+ attempts += 1
+ return 400, {"error": "authorization_pending"}
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+ if retries:
+ # Finally, make sure that the client prompted the user with the expected
+ # authorization URL and user code.
+ expected = f"Visit {verification_url} and enter the code: {user_code}"
+ _, stderr = capfd.readouterr()
+ assert expected in stderr
+
+
+def test_oauth_requires_client_id(accept, openid_provider):
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ # Do not set a client ID; this should cause a client error after the
+ # server asks for OAUTHBEARER and the client tries to contact the
+ # issuer.
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should disconnect at this point.
+ assert not conn.read()
+
+ expected_error = "no oauth_client_id is set"
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("error_code", ["authorization_pending", "slow_down"])
+@pytest.mark.parametrize("retries", [1, 2])
+def test_oauth_retry_interval(accept, openid_provider, retries, error_code):
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id="some-id",
+ )
+
+ expected_retry_interval = 1
+ access_token = secrets.token_urlsafe()
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ resp = {
+ "device_code": "my-device-code",
+ "user_code": "my-user-code",
+ "interval": expected_retry_interval,
+ "verification_uri": "https://example.com",
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ attempts = 0
+ last_retry = None
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ now = time.monotonic()
+
+ with retry_lock:
+ nonlocal attempts, last_retry, expected_retry_interval
+
+ # Make sure the retry interval is being respected by the client.
+ if last_retry is not None:
+ interval = now - last_retry
+ assert interval >= expected_retry_interval
+
+ last_retry = now
+
+ # If the test wants to force the client to retry, return the desired
+ # error response and decrement the retry count.
+ if attempts < retries:
+ attempts += 1
+
+ # A slow_down code requires the client to additionally increase
+ # its interval by five seconds.
+ if error_code == "slow_down":
+ expected_retry_interval += 5
+
+ return 400, {"error": error_code}
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+ "failure_mode, error_pattern",
+ [
+ pytest.param(
+ {
+ "error": "invalid_client",
+ "error_description": "client authentication failed",
+ },
+ r"client authentication failed \(invalid_client\)",
+ id="authentication failure with description",
+ ),
+ pytest.param(
+ {"error": "invalid_request"},
+ r"\(invalid_request\)",
+ id="invalid request without description",
+ ),
+ pytest.param(
+ {},
+ r"failed to obtain device authorization",
+ id="broken error response",
+ ),
+ ],
+)
+def test_oauth_device_authorization_failures(
+ accept, openid_provider, failure_mode, error_pattern
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ )
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ return 400, failure_mode
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ def token_endpoint(headers, params):
+ assert False, "token endpoint was invoked unexpectedly"
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should not continue the connection due to the hardcoded
+ # provider failure; we disconnect here.
+
+ # Now make sure the client correctly failed.
+ with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+ client.check_completed()
+
+
+@pytest.mark.parametrize(
+ "failure_mode, error_pattern",
+ [
+ pytest.param(
+ {
+ "error": "expired_token",
+ "error_description": "the device code has expired",
+ },
+ r"the device code has expired \(expired_token\)",
+ id="expired token with description",
+ ),
+ pytest.param(
+ {"error": "access_denied"},
+ r"\(access_denied\)",
+ id="access denied without description",
+ ),
+ pytest.param(
+ {},
+ r"OAuth token retrieval failed",
+ id="broken error response",
+ ),
+ ],
+)
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_token_failures(
+ accept, openid_provider, retries, failure_mode, error_pattern
+):
+ client_id = secrets.token_hex()
+
+ sock, client = accept(
+ oauth_issuer=openid_provider.issuer,
+ oauth_client_id=client_id,
+ )
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ assert params["client_id"] == [client_id]
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": "https://example.com/device",
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ retry_lock = threading.Lock()
+
+ def token_endpoint(headers, params):
+ with retry_lock:
+ nonlocal retries
+
+ # If the test wants to force the client to retry, return an
+ # authorization_pending response and decrement the retry count.
+ if retries > 0:
+ retries -= 1
+ return 400, {"error": "authorization_pending"}
+
+ return 400, failure_mode
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ # Initiate a handshake, which should result in the above endpoints
+ # being called.
+ startup = pq3.recv1(conn, cls=pq3.Startup)
+ assert startup.proto == pq3.protocol(3, 0)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASL,
+ body=[b"OAUTHBEARER", b""],
+ )
+
+ # The client should not continue the connection due to the hardcoded
+ # provider failure; we disconnect here.
+
+ # Now make sure the client correctly failed.
+ with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+ client.check_completed()
+
+
+@pytest.mark.parametrize("scope", [None, "openid email"])
+@pytest.mark.parametrize(
+ "base_response",
+ [
+ {"status": "invalid_token"},
+ {"extra_object": {"key": "value"}, "status": "invalid_token"},
+ {"extra_object": {"status": 1}, "status": "invalid_token"},
+ ],
+)
+def test_oauth_discovery(accept, openid_provider, base_response, scope):
+ sock, client = accept(oauth_client_id=secrets.token_hex())
+
+ device_code = secrets.token_hex()
+ user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+ verification_url = "https://example.com/device"
+
+ access_token = secrets.token_urlsafe()
+
+ # Set up our provider callbacks.
+ # NOTE that these callbacks will be called on a background thread. Don't do
+ # any unprotected state mutation here.
+
+ def authorization_endpoint(headers, params):
+ if scope:
+ assert params["scope"] == [scope]
+ else:
+ assert "scope" not in params
+
+ resp = {
+ "device_code": device_code,
+ "user_code": user_code,
+ "interval": 0,
+ "verification_uri": verification_url,
+ "expires_in": 5,
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+ )
+
+ def token_endpoint(headers, params):
+ assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+ assert params["device_code"] == [device_code]
+
+ # Successfully finish the request by sending the access bearer token.
+ resp = {
+ "access_token": access_token,
+ "token_type": "bearer",
+ }
+
+ return 200, resp
+
+ openid_provider.register_endpoint(
+ "token_endpoint", "POST", "/token", token_endpoint
+ )
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # For discovery, the client should send an empty auth header. See
+ # RFC 7628, Sec. 4.3.
+ auth = get_auth_value(initial)
+ assert auth == b""
+
+ # We will fail the first SASL exchange. First return a link to the
+ # discovery document, pointing to the test provider server.
+ resp = dict(base_response)
+
+ discovery_uri = f"{openid_provider.issuer}/.well-known/openid-configuration"
+ resp["openid-configuration"] = discovery_uri
+
+ if scope:
+ resp["scope"] = scope
+
+ resp = json.dumps(resp)
+
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASLContinue,
+ body=resp.encode("ascii"),
+ )
+
+ # Per RFC, the client is required to send a dummy ^A response.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+ assert pkt.payload == b"\x01"
+
+ # Now fail the SASL exchange.
+ pq3.send(
+ conn,
+ pq3.types.ErrorResponse,
+ fields=[
+ b"SFATAL",
+ b"C28000",
+ b"Mdoesn't matter",
+ b"",
+ ],
+ )
+
+ # The client will connect to us a second time, using the parameters we sent
+ # it.
+ sock, _ = accept()
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # Validate and accept the token.
+ auth = get_auth_value(initial)
+ assert auth == f"Bearer {access_token}".encode("ascii")
+
+ pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+ finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+ "response,expected_error",
+ [
+ pytest.param(
+ "abcde",
+ 'Token "abcde" is invalid',
+ id="bad JSON: invalid syntax",
+ ),
+ pytest.param(
+ '"abcde"',
+ "top-level element must be an object",
+ id="bad JSON: top-level element is a string",
+ ),
+ pytest.param(
+ "[]",
+ "top-level element must be an object",
+ id="bad JSON: top-level element is an array",
+ ),
+ pytest.param(
+ "{}",
+ "server sent error response without a status",
+ id="bad JSON: no status member",
+ ),
+ pytest.param(
+ '{ "status": null }',
+ 'field "status" must be a string',
+ id="bad JSON: null status member",
+ ),
+ pytest.param(
+ '{ "status": 0 }',
+ 'field "status" must be a string',
+ id="bad JSON: int status member",
+ ),
+ pytest.param(
+ '{ "status": [ "bad" ] }',
+ 'field "status" must be a string',
+ id="bad JSON: array status member",
+ ),
+ pytest.param(
+ '{ "status": { "bad": "bad" } }',
+ 'field "status" must be a string',
+ id="bad JSON: object status member",
+ ),
+ pytest.param(
+ '{ "nested": { "status": "bad" } }',
+ "server sent error response without a status",
+ id="bad JSON: nested status",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token" ',
+ "The input string ended unexpectedly",
+ id="bad JSON: unterminated object",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token" } { }',
+ 'Expected end of input, but found "{"',
+ id="bad JSON: trailing data",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "openid-configuration": 1 }',
+ 'field "openid-configuration" must be a string',
+ id="bad JSON: int openid-configuration member",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "openid-configuration": 1 }',
+ 'field "openid-configuration" must be a string',
+ id="bad JSON: int openid-configuration member",
+ ),
+ pytest.param(
+ '{ "status": "invalid_token", "scope": 1 }',
+ 'field "scope" must be a string',
+ id="bad JSON: int scope member",
+ ),
+ ],
+)
+def test_oauth_discovery_server_error(accept, response, expected_error):
+ sock, client = accept(oauth_client_id=secrets.token_hex())
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ initial = start_oauth_handshake(conn)
+
+ # Fail the SASL exchange with an invalid JSON response.
+ pq3.send(
+ conn,
+ pq3.types.AuthnRequest,
+ type=pq3.authn.SASLContinue,
+ body=response.encode("utf-8"),
+ )
+
+ # The client should disconnect, so the socket is closed here. (If
+ # the client doesn't disconnect, it will report a different error
+ # below and the test will fail.)
+
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
+
+
+@pytest.mark.parametrize(
+ "sasl_err,resp_type,resp_payload,expected_error",
+ [
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.ErrorResponse,
+ dict(
+ fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+ ),
+ "expected error message",
+ id="standard server error: invalid_request",
+ ),
+ pytest.param(
+ {"status": "invalid_token"},
+ pq3.types.ErrorResponse,
+ dict(
+ fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+ ),
+ "expected error message",
+ id="standard server error: invalid_token without discovery URI",
+ ),
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.AuthnRequest,
+ dict(type=pq3.authn.SASLContinue, body=b""),
+ "server sent additional OAuth data",
+ id="broken server: additional challenge after error",
+ ),
+ pytest.param(
+ {"status": "invalid_request"},
+ pq3.types.AuthnRequest,
+ dict(type=pq3.authn.SASLFinal),
+ "server sent additional OAuth data",
+ id="broken server: SASL success after error",
+ ),
+ ],
+)
+def test_oauth_server_error(accept, sasl_err, resp_type, resp_payload, expected_error):
+ sock, client = accept()
+
+ with sock:
+ with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+ start_oauth_handshake(conn)
+
+ # Ignore the client data. Return an error "challenge".
+ resp = json.dumps(sasl_err)
+ resp = resp.encode("utf-8")
+
+ pq3.send(
+ conn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=resp
+ )
+
+ # Per RFC, the client is required to send a dummy ^A response.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.PasswordMessage
+ assert pkt.payload == b"\x01"
+
+ # Now fail the SASL exchange (in either a valid way, or an invalid
+ # one, depending on the test).
+ pq3.send(conn, resp_type, **resp_payload)
+
+ with pytest.raises(psycopg2.OperationalError, match=expected_error):
+ client.check_completed()
diff --git a/src/test/python/pq3.py b/src/test/python/pq3.py
new file mode 100644
index 0000000000..3a22dad0b6
--- /dev/null
+++ b/src/test/python/pq3.py
@@ -0,0 +1,727 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import os
+import ssl
+import sys
+import textwrap
+
+from construct import *
+
+import tls
+
+
+def protocol(major, minor):
+ """
+ Returns the protocol version, in integer format, corresponding to the given
+ major and minor version numbers.
+ """
+ return (major << 16) | minor
+
+
+# Startup
+
+StringList = GreedyRange(NullTerminated(GreedyBytes))
+
+
+class KeyValueAdapter(Adapter):
+ """
+ Turns a key-value store into a null-terminated list of null-terminated
+ strings, as presented on the wire in the startup packet.
+ """
+
+ def _encode(self, obj, context, path):
+ if isinstance(obj, list):
+ return obj
+
+ l = []
+
+ for k, v in obj.items():
+ if isinstance(k, str):
+ k = k.encode("utf-8")
+ l.append(k)
+
+ if isinstance(v, str):
+ v = v.encode("utf-8")
+ l.append(v)
+
+ l.append(b"")
+ return l
+
+ def _decode(self, obj, context, path):
+ # TODO: turn a list back into a dict
+ return obj
+
+
+KeyValues = KeyValueAdapter(StringList)
+
+_startup_payload = Switch(
+ this.proto,
+ {
+ protocol(3, 0): KeyValues,
+ },
+ default=GreedyBytes,
+)
+
+
+def _default_protocol(this):
+ try:
+ if isinstance(this.payload, (list, dict)):
+ return protocol(3, 0)
+ except AttributeError:
+ pass # no payload passed during build
+
+ return 0
+
+
+def _startup_payload_len(this):
+ """
+ The payload field has a fixed size based on the length of the packet. But
+ if the caller hasn't supplied an explicit length at build time, we have to
+ build the payload to figure out how long it is, which requires us to know
+ the length first... This function exists solely to break the cycle.
+ """
+ assert this._building, "_startup_payload_len() cannot be called during parsing"
+
+ try:
+ payload = this.payload
+ except AttributeError:
+ return 0 # no payload
+
+ if isinstance(payload, bytes):
+ # already serialized; just use the given length
+ return len(payload)
+
+ try:
+ proto = this.proto
+ except AttributeError:
+ proto = _default_protocol(this)
+
+ data = _startup_payload.build(payload, proto=proto)
+ return len(data)
+
+
+Startup = Struct(
+ "len" / Default(Int32sb, lambda this: _startup_payload_len(this) + 8),
+ "proto" / Default(Hex(Int32sb), _default_protocol),
+ "payload" / FixedSized(this.len - 8, Default(_startup_payload, b"")),
+)
+
+# Pq3
+
+# Adapted from construct.core.EnumIntegerString
+class EnumNamedByte:
+ def __init__(self, val, name):
+ self._val = val
+ self._name = name
+
+ def __int__(self):
+ return ord(self._val)
+
+ def __str__(self):
+ return "(enum) %s %r" % (self._name, self._val)
+
+ def __repr__(self):
+ return "EnumNamedByte(%r)" % self._val
+
+ def __eq__(self, other):
+ if isinstance(other, EnumNamedByte):
+ other = other._val
+ if not isinstance(other, bytes):
+ return NotImplemented
+
+ return self._val == other
+
+ def __hash__(self):
+ return hash(self._val)
+
+
+# Adapted from construct.core.Enum
+class ByteEnum(Adapter):
+ def __init__(self, **mapping):
+ super(ByteEnum, self).__init__(Byte)
+ self.namemapping = {k: EnumNamedByte(v, k) for k, v in mapping.items()}
+ self.decmapping = {v: EnumNamedByte(v, k) for k, v in mapping.items()}
+
+ def __getattr__(self, name):
+ if name in self.namemapping:
+ return self.decmapping[self.namemapping[name]]
+ raise AttributeError
+
+ def _decode(self, obj, context, path):
+ b = bytes([obj])
+ try:
+ return self.decmapping[b]
+ except KeyError:
+ return EnumNamedByte(b, "(unknown)")
+
+ def _encode(self, obj, context, path):
+ if isinstance(obj, int):
+ return obj
+ elif isinstance(obj, bytes):
+ return ord(obj)
+ return int(obj)
+
+
+types = ByteEnum(
+ ErrorResponse=b"E",
+ ReadyForQuery=b"Z",
+ Query=b"Q",
+ EmptyQueryResponse=b"I",
+ AuthnRequest=b"R",
+ PasswordMessage=b"p",
+ BackendKeyData=b"K",
+ CommandComplete=b"C",
+ ParameterStatus=b"S",
+ DataRow=b"D",
+ Terminate=b"X",
+)
+
+
+authn = Enum(
+ Int32ub,
+ OK=0,
+ SASL=10,
+ SASLContinue=11,
+ SASLFinal=12,
+)
+
+
+_authn_body = Switch(
+ this.type,
+ {
+ authn.OK: Terminated,
+ authn.SASL: StringList,
+ },
+ default=GreedyBytes,
+)
+
+
+def _data_len(this):
+ assert this._building, "_data_len() cannot be called during parsing"
+
+ if not hasattr(this, "data") or this.data is None:
+ return -1
+
+ return len(this.data)
+
+
+# The protocol reuses the PasswordMessage for several authentication response
+# types, and there's no good way to figure out which is which without keeping
+# state for the entire stream. So this is a separate Construct that can be
+# explicitly parsed/built by code that knows it's needed.
+SASLInitialResponse = Struct(
+ "name" / NullTerminated(GreedyBytes),
+ "len" / Default(Int32sb, lambda this: _data_len(this)),
+ "data"
+ / IfThenElse(
+ # Allow tests to explicitly pass an incorrect length during testing, by
+ # not enforcing a FixedSized during build. (The len calculation above
+ # defaults to the correct size.)
+ this._building,
+ Optional(GreedyBytes),
+ If(this.len != -1, Default(FixedSized(this.len, GreedyBytes), b"")),
+ ),
+ Terminated, # make sure the entire response is consumed
+)
+
+
+_column = FocusedSeq(
+ "data",
+ "len" / Default(Int32sb, lambda this: _data_len(this)),
+ "data" / If(this.len != -1, FixedSized(this.len, GreedyBytes)),
+)
+
+
+_payload_map = {
+ types.ErrorResponse: Struct("fields" / StringList),
+ types.ReadyForQuery: Struct("status" / Bytes(1)),
+ types.Query: Struct("query" / NullTerminated(GreedyBytes)),
+ types.EmptyQueryResponse: Terminated,
+ types.AuthnRequest: Struct("type" / authn, "body" / Default(_authn_body, b"")),
+ types.BackendKeyData: Struct("pid" / Int32ub, "key" / Hex(Int32ub)),
+ types.CommandComplete: Struct("tag" / NullTerminated(GreedyBytes)),
+ types.ParameterStatus: Struct(
+ "name" / NullTerminated(GreedyBytes), "value" / NullTerminated(GreedyBytes)
+ ),
+ types.DataRow: Struct("columns" / Default(PrefixedArray(Int16sb, _column), b"")),
+ types.Terminate: Terminated,
+}
+
+
+_payload = FocusedSeq(
+ "_payload",
+ "_payload"
+ / Switch(
+ this._.type,
+ _payload_map,
+ default=GreedyBytes,
+ ),
+ Terminated, # make sure every payload consumes the entire packet
+)
+
+
+def _payload_len(this):
+ """
+ See _startup_payload_len() for an explanation.
+ """
+ assert this._building, "_payload_len() cannot be called during parsing"
+
+ try:
+ payload = this.payload
+ except AttributeError:
+ return 0 # no payload
+
+ if isinstance(payload, bytes):
+ # already serialized; just use the given length
+ return len(payload)
+
+ data = _payload.build(payload, type=this.type)
+ return len(data)
+
+
+Pq3 = Struct(
+ "type" / types,
+ "len" / Default(Int32ub, lambda this: _payload_len(this) + 4),
+ "payload" / FixedSized(this.len - 4, Default(_payload, b"")),
+)
+
+
+# Environment
+
+
+def pghost():
+ return os.environ.get("PGHOST", default="localhost")
+
+
+def pgport():
+ return int(os.environ.get("PGPORT", default=5432))
+
+
+def pguser():
+ try:
+ return os.environ["PGUSER"]
+ except KeyError:
+ return getpass.getuser()
+
+
+def pgdatabase():
+ return os.environ.get("PGDATABASE", default="postgres")
+
+
+# Connections
+
+
+def _hexdump_translation_map():
+ """
+ For hexdumps. Translates any unprintable or non-ASCII bytes into '.'.
+ """
+ input = bytearray()
+
+ for i in range(128):
+ c = chr(i)
+
+ if not c.isprintable():
+ input += bytes([i])
+
+ input += bytes(range(128, 256))
+
+ return bytes.maketrans(input, b"." * len(input))
+
+
+class _DebugStream(object):
+ """
+ Wraps a file-like object and adds hexdumps of the read and write data. Call
+ end_packet() on a _DebugStream to write the accumulated hexdumps to the
+ output stream, along with the packet that was sent.
+ """
+
+ _translation_map = _hexdump_translation_map()
+
+ def __init__(self, stream, out=sys.stdout):
+ """
+ Creates a new _DebugStream wrapping the given stream (which must have
+ been created by wrap()). All attributes not provided by the _DebugStream
+ are delegated to the wrapped stream. out is the text stream to which
+ hexdumps are written.
+ """
+ self.raw = stream
+ self._out = out
+ self._rbuf = io.BytesIO()
+ self._wbuf = io.BytesIO()
+
+ def __getattr__(self, name):
+ return getattr(self.raw, name)
+
+ def __setattr__(self, name, value):
+ if name in ("raw", "_out", "_rbuf", "_wbuf"):
+ return object.__setattr__(self, name, value)
+
+ setattr(self.raw, name, value)
+
+ def read(self, *args, **kwargs):
+ buf = self.raw.read(*args, **kwargs)
+
+ self._rbuf.write(buf)
+ return buf
+
+ def write(self, b):
+ self._wbuf.write(b)
+ return self.raw.write(b)
+
+ def recv(self, *args):
+ buf = self.raw.recv(*args)
+
+ self._rbuf.write(buf)
+ return buf
+
+ def _flush(self, buf, prefix):
+ width = 16
+ hexwidth = width * 3 - 1
+
+ count = 0
+ buf.seek(0)
+
+ while True:
+ line = buf.read(16)
+
+ if not line:
+ if count:
+ self._out.write("\n") # separate the output block with a newline
+ return
+
+ self._out.write("%s %04X:\t" % (prefix, count))
+ self._out.write("%*s\t" % (-hexwidth, line.hex(" ")))
+ self._out.write(line.translate(self._translation_map).decode("ascii"))
+ self._out.write("\n")
+
+ count += 16
+
+ def print_debug(self, obj, *, prefix=""):
+ contents = ""
+ if obj is not None:
+ contents = str(obj)
+
+ for line in contents.splitlines():
+ self._out.write("%s%s\n" % (prefix, line))
+
+ self._out.write("\n")
+
+ def flush_debug(self, *, prefix=""):
+ self._flush(self._rbuf, prefix + "<")
+ self._rbuf = io.BytesIO()
+
+ self._flush(self._wbuf, prefix + ">")
+ self._wbuf = io.BytesIO()
+
+ def end_packet(self, pkt, *, read=False, prefix="", indent=" "):
+ """
+ Marks the end of a logical "packet" of data. A string representation of
+ pkt will be printed, and the debug buffers will be flushed with an
+ indent. All lines can be optionally prefixed.
+
+ If read is True, the packet representation is written after the debug
+ buffers; otherwise the default of False (meaning write) causes the
+ packet representation to be dumped first. This is meant to capture the
+ logical flow of layer translation.
+ """
+ write = not read
+
+ if write:
+ self.print_debug(pkt, prefix=prefix + "> ")
+
+ self.flush_debug(prefix=prefix + indent)
+
+ if read:
+ self.print_debug(pkt, prefix=prefix + "< ")
+
+
+@contextlib.contextmanager
+def wrap(socket, *, debug_stream=None):
+ """
+ Transforms a raw socket into a connection that can be used for Construct
+ building and parsing. The return value is a context manager and can be used
+ in a with statement.
+ """
+ # It is critical that buffering be disabled here, so that we can still
+ # manipulate the raw socket without desyncing the stream.
+ with socket.makefile("rwb", buffering=0) as sfile:
+ # Expose the original socket's recv() on the SocketIO object we return.
+ def recv(self, *args):
+ return socket.recv(*args)
+
+ sfile.recv = recv.__get__(sfile)
+
+ conn = sfile
+ if debug_stream:
+ conn = _DebugStream(conn, debug_stream)
+
+ try:
+ yield conn
+ finally:
+ if debug_stream:
+ conn.flush_debug(prefix="? ")
+
+
+def _send(stream, cls, obj):
+ debugging = hasattr(stream, "flush_debug")
+ out = io.BytesIO()
+
+ # Ideally we would build directly to the passed stream, but because we need
+ # to reparse the generated output for the debugging case, build to an
+ # intermediate BytesIO and send it instead.
+ cls.build_stream(obj, out)
+ buf = out.getvalue()
+
+ stream.write(buf)
+ if debugging:
+ pkt = cls.parse(buf)
+ stream.end_packet(pkt)
+
+ stream.flush()
+
+
+def send(stream, packet_type, payload_data=None, **payloadkw):
+ """
+ Sends a packet on the given pq3 connection. type is the pq3.types member
+ that should be assigned to the packet. If payload_data is given, it will be
+ used as the packet payload; otherwise the key/value pairs in payloadkw will
+ be the payload contents.
+ """
+ data = payloadkw
+
+ if payload_data is not None:
+ if payloadkw:
+ raise ValueError(
+ "payload_data and payload keywords may not be used simultaneously"
+ )
+
+ data = payload_data
+
+ _send(stream, Pq3, dict(type=packet_type, payload=data))
+
+
+def send_startup(stream, proto=None, **kwargs):
+ """
+ Sends a startup packet on the given pq3 connection. In most cases you should
+ use the handshake functions instead, which will do this for you.
+
+ By default, a protocol version 3 packet will be sent. This can be overridden
+ with the proto parameter.
+ """
+ pkt = {}
+
+ if proto is not None:
+ pkt["proto"] = proto
+ if kwargs:
+ pkt["payload"] = kwargs
+
+ _send(stream, Startup, pkt)
+
+
+def recv1(stream, *, cls=Pq3):
+ """
+ Receives a single pq3 packet from the given stream and returns it.
+ """
+ resp = cls.parse_stream(stream)
+
+ debugging = hasattr(stream, "flush_debug")
+ if debugging:
+ stream.end_packet(resp, read=True)
+
+ return resp
+
+
+def handshake(stream, **kwargs):
+ """
+ Performs a libpq v3 startup handshake. kwargs should contain the key/value
+ parameters to send to the server in the startup packet.
+ """
+ # Send our startup parameters.
+ send_startup(stream, **kwargs)
+
+ # Receive and dump packets until the server indicates it's ready for our
+ # first query.
+ while True:
+ resp = recv1(stream)
+ if resp is None:
+ raise RuntimeError("server closed connection during handshake")
+
+ if resp.type == types.ReadyForQuery:
+ return
+ elif resp.type == types.ErrorResponse:
+ raise RuntimeError(
+ f"received error response from peer: {resp.payload.fields!r}"
+ )
+
+
+# TLS
+
+
+class _TLSStream(object):
+ """
+ A file-like object that performs TLS encryption/decryption on a wrapped
+ stream. Differs from ssl.SSLSocket in that we have full visibility and
+ control over the TLS layer.
+ """
+
+ def __init__(self, stream, context):
+ self._stream = stream
+ self._debugging = hasattr(stream, "flush_debug")
+
+ self._in = ssl.MemoryBIO()
+ self._out = ssl.MemoryBIO()
+ self._ssl = context.wrap_bio(self._in, self._out)
+
+ def handshake(self):
+ try:
+ self._pump(lambda: self._ssl.do_handshake())
+ finally:
+ self._flush_debug(prefix="? ")
+
+ def read(self, *args):
+ return self._pump(lambda: self._ssl.read(*args))
+
+ def write(self, *args):
+ return self._pump(lambda: self._ssl.write(*args))
+
+ def _decode(self, buf):
+ """
+ Attempts to decode a buffer of TLS data into a packet representation
+ that can be printed.
+
+ TODO: handle buffers (and record fragments) that don't align with packet
+ boundaries.
+ """
+ end = len(buf)
+ bio = io.BytesIO(buf)
+
+ ret = io.StringIO()
+
+ while bio.tell() < end:
+ record = tls.Plaintext.parse_stream(bio)
+
+ if ret.tell() > 0:
+ ret.write("\n")
+ ret.write("[Record] ")
+ ret.write(str(record))
+ ret.write("\n")
+
+ if record.type == tls.ContentType.handshake:
+ record_cls = tls.Handshake
+ else:
+ continue
+
+ innerlen = len(record.fragment)
+ inner = io.BytesIO(record.fragment)
+
+ while inner.tell() < innerlen:
+ msg = record_cls.parse_stream(inner)
+
+ indented = "[Message] " + str(msg)
+ indented = textwrap.indent(indented, " ")
+
+ ret.write("\n")
+ ret.write(indented)
+ ret.write("\n")
+
+ return ret.getvalue()
+
+ def flush(self):
+ if not self._out.pending:
+ self._stream.flush()
+ return
+
+ buf = self._out.read()
+ self._stream.write(buf)
+
+ if self._debugging:
+ pkt = self._decode(buf)
+ self._stream.end_packet(pkt, prefix=" ")
+
+ self._stream.flush()
+
+ def _pump(self, operation):
+ while True:
+ try:
+ return operation()
+ except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
+ want = e
+ self._read_write(want)
+
+ def _recv(self, maxsize):
+ buf = self._stream.recv(4096)
+ if not buf:
+ self._in.write_eof()
+ return
+
+ self._in.write(buf)
+
+ if not self._debugging:
+ return
+
+ pkt = self._decode(buf)
+ self._stream.end_packet(pkt, read=True, prefix=" ")
+
+ def _read_write(self, want):
+ # XXX This needs work. So many corner cases yet to handle. For one,
+ # doing blocking writes in flush may lead to distributed deadlock if the
+ # peer is already blocking on its writes.
+
+ if isinstance(want, ssl.SSLWantWriteError):
+ assert self._out.pending, "SSL backend wants write without data"
+
+ self.flush()
+
+ if isinstance(want, ssl.SSLWantReadError):
+ self._recv(4096)
+
+ def _flush_debug(self, prefix):
+ if not self._debugging:
+ return
+
+ self._stream.flush_debug(prefix=prefix)
+
+
+@contextlib.contextmanager
+def tls_handshake(stream, context):
+ """
+ Performs a TLS handshake over the given stream (which must have been created
+ via a call to wrap()), and returns a new stream which transparently tunnels
+ data over the TLS connection.
+
+ If the passed stream has debugging enabled, the returned stream will also
+ have debugging, using the same output IO.
+ """
+ debugging = hasattr(stream, "flush_debug")
+
+ # Send our startup parameters.
+ send_startup(stream, proto=protocol(1234, 5679))
+
+ # Look at the SSL response.
+ resp = stream.read(1)
+ if debugging:
+ stream.flush_debug(prefix=" ")
+
+ if resp == b"N":
+ raise RuntimeError("server does not support SSLRequest")
+ if resp != b"S":
+ raise RuntimeError(f"unexpected response of type {resp!r} during TLS startup")
+
+ tls = _TLSStream(stream, context)
+ tls.handshake()
+
+ if debugging:
+ tls = _DebugStream(tls, stream._out)
+
+ try:
+ yield tls
+ # TODO: teardown/unwrap the connection?
+ finally:
+ if debugging:
+ tls.flush_debug(prefix="? ")
diff --git a/src/test/python/pytest.ini b/src/test/python/pytest.ini
new file mode 100644
index 0000000000..ab7a6e7fb9
--- /dev/null
+++ b/src/test/python/pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+
+markers =
+ slow: mark test as slow
diff --git a/src/test/python/requirements.txt b/src/test/python/requirements.txt
new file mode 100644
index 0000000000..32f105ea84
--- /dev/null
+++ b/src/test/python/requirements.txt
@@ -0,0 +1,7 @@
+black
+cryptography~=3.4.6
+construct~=2.10.61
+isort~=5.6
+psycopg2~=2.8.6
+pytest~=6.1
+pytest-asyncio~=0.14.0
diff --git a/src/test/python/server/__init__.py b/src/test/python/server/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/server/conftest.py b/src/test/python/server/conftest.py
new file mode 100644
index 0000000000..ba7342a453
--- /dev/null
+++ b/src/test/python/server/conftest.py
@@ -0,0 +1,45 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import socket
+import sys
+
+import pytest
+
+import pq3
+
+
+@pytest.fixture
+def connect():
+ """
+ A factory fixture that, when called, returns a socket connected to a
+ Postgres server, wrapped in a pq3 connection. The calling test will be
+ skipped automatically if a server is not running at PGHOST:PGPORT, so it's
+ best to connect as soon as possible after the test case begins, to avoid
+ doing unnecessary work.
+ """
+ # Set up an ExitStack to handle safe cleanup of all of the moving pieces.
+ with contextlib.ExitStack() as stack:
+
+ def conn_factory():
+ addr = (pq3.pghost(), pq3.pgport())
+
+ try:
+ sock = socket.create_connection(addr, timeout=2)
+ except ConnectionError as e:
+ pytest.skip(f"unable to connect to {addr}: {e}")
+
+ # Have ExitStack close our socket.
+ stack.enter_context(sock)
+
+ # Wrap the connection in a pq3 layer and have ExitStack clean it up
+ # too.
+ wrap_ctx = pq3.wrap(sock, debug_stream=sys.stdout)
+ conn = stack.enter_context(wrap_ctx)
+
+ return conn
+
+ yield conn_factory
diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py
new file mode 100644
index 0000000000..cb5ca7fa23
--- /dev/null
+++ b/src/test/python/server/test_oauth.py
@@ -0,0 +1,1012 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import contextlib
+import json
+import os
+import pathlib
+import secrets
+import shlex
+import shutil
+import socket
+import struct
+from multiprocessing import shared_memory
+
+import psycopg2
+import pytest
+from psycopg2 import sql
+
+import pq3
+
+MAX_SASL_MESSAGE_LENGTH = 65535
+
+INVALID_AUTHORIZATION_ERRCODE = b"28000"
+PROTOCOL_VIOLATION_ERRCODE = b"08P01"
+FEATURE_NOT_SUPPORTED_ERRCODE = b"0A000"
+
+SHARED_MEM_NAME = "oauth-pytest"
+MAX_TOKEN_SIZE = 4096
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def skip_if_no_postgres():
+ """
+ Used by the oauth_ctx fixture to skip this test module if no Postgres server
+ is running.
+
+ This logic is nearly duplicated with the conn fixture. Ideally oauth_ctx
+ would depend on that, but a module-scope fixture can't depend on a
+ test-scope fixture, and we haven't reached the rule of three yet.
+ """
+ addr = (pq3.pghost(), pq3.pgport())
+
+ try:
+ with socket.create_connection(addr, timeout=2):
+ pass
+ except ConnectionError as e:
+ pytest.skip(f"unable to connect to {addr}: {e}")
+
+
+@contextlib.contextmanager
+def prepend_file(path, lines):
+ """
+ A context manager that prepends a file on disk with the desired lines of
+ text. When the context manager is exited, the file will be restored to its
+ original contents.
+ """
+ # First make a backup of the original file.
+ bak = path + ".bak"
+ shutil.copy2(path, bak)
+
+ try:
+ # Write the new lines, followed by the original file content.
+ with open(path, "w") as new, open(bak, "r") as orig:
+ new.writelines(lines)
+ shutil.copyfileobj(orig, new)
+
+ # Return control to the calling code.
+ yield
+
+ finally:
+ # Put the backup back into place.
+ os.replace(bak, path)
+
+
+@pytest.fixture(scope="module")
+def oauth_ctx():
+ """
+ Creates a database and user that use the oauth auth method. The context
+ object contains the dbname and user attributes as strings to be used during
+ connection, as well as the issuer and scope that have been set in the HBA
+ configuration.
+
+ This fixture assumes that the standard PG* environment variables point to a
+ server running on a local machine, and that the PGUSER has rights to create
+ databases and roles.
+ """
+ skip_if_no_postgres() # don't bother running these tests without a server
+
+ id = secrets.token_hex(4)
+
+ class Context:
+ dbname = "oauth_test_" + id
+
+ user = "oauth_user_" + id
+ map_user = "oauth_map_user_" + id
+ authz_user = "oauth_authz_user_" + id
+
+ issuer = "https://example.com/" + id
+ scope = "openid " + id
+
+ ctx = Context()
+ hba_lines = (
+ f'host {ctx.dbname} {ctx.map_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
+ f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
+ f'host {ctx.dbname} all samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
+ )
+ ident_lines = (r"oauth /^(.*)@example\.com$ \1",)
+
+ conn = psycopg2.connect("")
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+
+ # Create our roles and database.
+ user = sql.Identifier(ctx.user)
+ map_user = sql.Identifier(ctx.map_user)
+ authz_user = sql.Identifier(ctx.authz_user)
+ dbname = sql.Identifier(ctx.dbname)
+
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(user))
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(map_user))
+ c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user))
+ c.execute(sql.SQL("CREATE DATABASE {};").format(dbname))
+
+ # Make this test script the server's oauth_validator.
+ path = pathlib.Path(__file__).parent / "validate_bearer.py"
+ path = str(path.absolute())
+
+ cmd = f"{shlex.quote(path)} {SHARED_MEM_NAME} <&%f"
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+
+ # Replace pg_hba and pg_ident.
+ c.execute("SHOW hba_file;")
+ hba = c.fetchone()[0]
+
+ c.execute("SHOW ident_file;")
+ ident = c.fetchone()[0]
+
+ with prepend_file(hba, hba_lines), prepend_file(ident, ident_lines):
+ c.execute("SELECT pg_reload_conf();")
+
+ # Use the new database and user.
+ yield ctx
+
+ # Put things back the way they were.
+ c.execute("SELECT pg_reload_conf();")
+
+ c.execute("ALTER SYSTEM RESET oauth_validator_command;")
+ c.execute(sql.SQL("DROP DATABASE {};").format(dbname))
+ c.execute(sql.SQL("DROP ROLE {};").format(authz_user))
+ c.execute(sql.SQL("DROP ROLE {};").format(map_user))
+ c.execute(sql.SQL("DROP ROLE {};").format(user))
+
+
+@pytest.fixture()
+def conn(oauth_ctx, connect):
+ """
+ A convenience wrapper for connect(). The main purpose of this fixture is to
+ make sure oauth_ctx runs its setup code before the connection is made.
+ """
+ return connect()
+
+
+@pytest.fixture(scope="module", autouse=True)
+def authn_id_extension(oauth_ctx):
+ """
+ Performs a `CREATE EXTENSION authn_id` in the test database. This fixture is
+ autoused, so tests don't need to rely on it.
+ """
+ conn = psycopg2.connect(database=oauth_ctx.dbname)
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+ c.execute("CREATE EXTENSION authn_id;")
+
+
+@pytest.fixture(scope="session")
+def shared_mem():
+ """
+ Yields a shared memory segment that can be used for communication between
+ the bearer_token fixture and ./validate_bearer.py.
+ """
+ size = MAX_TOKEN_SIZE + 2 # two byte length prefix
+ mem = shared_memory.SharedMemory(SHARED_MEM_NAME, create=True, size=size)
+
+ try:
+ with contextlib.closing(mem):
+ yield mem
+ finally:
+ mem.unlink()
+
+
+@pytest.fixture()
+def bearer_token(shared_mem):
+ """
+ Returns a factory function that, when called, will store a Bearer token in
+ shared_mem. If token is None (the default), a new token will be generated
+ using secrets.token_urlsafe() and returned; otherwise the passed token will
+ be used as-is.
+
+ When token is None, the generated token size in bytes may be specified as an
+ argument; if unset, a small 16-byte token will be generated. The token size
+ may not exceed MAX_TOKEN_SIZE in any case.
+
+ The return value is the token, converted to a bytes object.
+
+ As a special case for testing failure modes, accept_any may be set to True.
+ This signals to the validator command that any bearer token should be
+ accepted. The returned token in this case may be used or discarded as needed
+ by the test.
+ """
+
+ def set_token(token=None, *, size=16, accept_any=False):
+ if token is not None:
+ size = len(token)
+
+ if size > MAX_TOKEN_SIZE:
+ raise ValueError(f"token size {size} exceeds maximum size {MAX_TOKEN_SIZE}")
+
+ if token is None:
+ if size % 4:
+ raise ValueError(f"requested token size {size} is not a multiple of 4")
+
+ token = secrets.token_urlsafe(size // 4 * 3)
+ assert len(token) == size
+
+ try:
+ token = token.encode("ascii")
+ except AttributeError:
+ pass # already encoded
+
+ if accept_any:
+ # Two-byte magic value.
+ shared_mem.buf[:2] = struct.pack("H", MAX_UINT16)
+ else:
+ # Two-byte length prefix, then the token data.
+ shared_mem.buf[:2] = struct.pack("H", len(token))
+ shared_mem.buf[2 : size + 2] = token
+
+ return token
+
+ return set_token
+
+
+def begin_oauth_handshake(conn, oauth_ctx, *, user=None):
+ if user is None:
+ user = oauth_ctx.authz_user
+
+ pq3.send_startup(conn, user=user, database=oauth_ctx.dbname)
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.AuthnRequest
+
+ # The server should advertise exactly one mechanism.
+ assert resp.payload.type == pq3.authn.SASL
+ assert resp.payload.body == [b"OAUTHBEARER", b""]
+
+
+def send_initial_response(conn, *, auth=None, bearer=None):
+ """
+ Sends the OAUTHBEARER initial response on the connection, using the given
+ bearer token. Alternatively to a bearer token, the initial response's auth
+ field may be explicitly specified to test corner cases.
+ """
+ if bearer is not None and auth is not None:
+ raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+ if bearer is not None:
+ auth = b"Bearer " + bearer
+
+ if auth is None:
+ raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+ initial = pq3.SASLInitialResponse.build(
+ dict(
+ name=b"OAUTHBEARER",
+ data=b"n,,\x01auth=" + auth + b"\x01\x01",
+ )
+ )
+ pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+
+def expect_handshake_success(conn):
+ """
+ Validates that the server responds with an AuthnOK message, and then drains
+ the connection until a ReadyForQuery message is received.
+ """
+ resp = pq3.recv1(conn)
+
+ assert resp.type == pq3.types.AuthnRequest
+ assert resp.payload.type == pq3.authn.OK
+ assert not resp.payload.body
+
+ receive_until(conn, pq3.types.ReadyForQuery)
+
+
+def expect_handshake_failure(conn, oauth_ctx):
+ """
+ Performs the OAUTHBEARER SASL failure "handshake" and validates the server's
+ side of the conversation, including the final ErrorResponse.
+ """
+
+ # We expect a discovery "challenge" back from the server before the authn
+ # failure message.
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.AuthnRequest
+
+ req = resp.payload
+ assert req.type == pq3.authn.SASLContinue
+
+ body = json.loads(req.body)
+ assert body["status"] == "invalid_token"
+ assert body["scope"] == oauth_ctx.scope
+
+ expected_config = oauth_ctx.issuer + "/.well-known/openid-configuration"
+ assert body["openid-configuration"] == expected_config
+
+ # Send the dummy response to complete the failed handshake.
+ pq3.send(conn, pq3.types.PasswordMessage, b"\x01")
+ resp = pq3.recv1(conn)
+
+ err = ExpectedError(INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed")
+ err.match(resp)
+
+
+def receive_until(conn, type):
+ """
+ receive_until pulls packets off the pq3 connection until a packet with the
+ desired type is found, or an error response is received.
+ """
+ while True:
+ pkt = pq3.recv1(conn)
+
+ if pkt.type == type:
+ return pkt
+ elif pkt.type == pq3.types.ErrorResponse:
+ raise RuntimeError(
+ f"received error response from peer: {pkt.payload.fields!r}"
+ )
+
+
+@pytest.mark.parametrize("token_len", [16, 1024, 4096])
+@pytest.mark.parametrize(
+ "auth_prefix",
+ [
+ b"Bearer ",
+ b"bearer ",
+ b"Bearer ",
+ ],
+)
+def test_oauth(conn, oauth_ctx, bearer_token, auth_prefix, token_len):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Generate our bearer token with the desired length.
+ token = bearer_token(size=token_len)
+ auth = auth_prefix + token
+
+ send_initial_response(conn, auth=auth)
+ expect_handshake_success(conn)
+
+ # Make sure that the server has not set an authenticated ID.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ assert row.columns == [None]
+
+
+@pytest.mark.parametrize(
+ "token_value",
+ [
+ "abcdzA==",
+ "123456M=",
+ "x-._~+/x",
+ ],
+)
+def test_oauth_bearer_corner_cases(conn, oauth_ctx, bearer_token, token_value):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ send_initial_response(conn, bearer=bearer_token(token_value))
+
+ expect_handshake_success(conn)
+
+
+@pytest.mark.parametrize(
+ "user,authn_id,should_succeed",
+ [
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: ctx.user,
+ True,
+ id="validator authn: succeeds when authn_id == username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: None,
+ False,
+ id="validator authn: fails when authn_id is not set",
+ ),
+ pytest.param(
+ lambda ctx: ctx.user,
+ lambda ctx: ctx.authz_user,
+ False,
+ id="validator authn: fails when authn_id != username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: ctx.map_user + "@example.com",
+ True,
+ id="validator with map: succeeds when authn_id matches map",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: None,
+ False,
+ id="validator with map: fails when authn_id is not set",
+ ),
+ pytest.param(
+ lambda ctx: ctx.map_user,
+ lambda ctx: ctx.map_user + "@example.net",
+ False,
+ id="validator with map: fails when authn_id doesn't match map",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: None,
+ True,
+ id="validator authz: succeeds with no authn_id",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "",
+ True,
+ id="validator authz: succeeds with empty authn_id",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "postgres",
+ True,
+ id="validator authz: succeeds with basic username",
+ ),
+ pytest.param(
+ lambda ctx: ctx.authz_user,
+ lambda ctx: "me@example.com",
+ True,
+ id="validator authz: succeeds with email address",
+ ),
+ ],
+)
+def test_oauth_authn_id(conn, oauth_ctx, bearer_token, user, authn_id, should_succeed):
+ token = None
+
+ authn_id = authn_id(oauth_ctx)
+ if authn_id is not None:
+ authn_id = authn_id.encode("ascii")
+
+ # As a hack to get the validator to reflect arbitrary output from this
+ # test, encode the desired output as a base64 token. The validator will
+ # key on the leading "output=" to differentiate this from the random
+ # tokens generated by secrets.token_urlsafe().
+ output = b"output=" + authn_id + b"\n"
+ token = base64.urlsafe_b64encode(output)
+
+ token = bearer_token(token)
+ username = user(oauth_ctx)
+
+ begin_oauth_handshake(conn, oauth_ctx, user=username)
+ send_initial_response(conn, bearer=token)
+
+ if not should_succeed:
+ expect_handshake_failure(conn, oauth_ctx)
+ return
+
+ expect_handshake_success(conn)
+
+ # Check the reported authn_id.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ assert row.columns == [authn_id]
+
+
+class ExpectedError(object):
+ def __init__(self, code, msg=None, detail=None):
+ self.code = code
+ self.msg = msg
+ self.detail = detail
+
+ # Protect against the footgun of an accidental empty string, which will
+ # "match" anything. If you don't want to match message or detail, just
+ # don't pass them.
+ if self.msg == "":
+ raise ValueError("msg must be non-empty or None")
+ if self.detail == "":
+ raise ValueError("detail must be non-empty or None")
+
+ def _getfield(self, resp, type):
+ """
+ Searches an ErrorResponse for a single field of the given type (e.g.
+ "M", "C", "D") and returns its value. Asserts if it doesn't find exactly
+ one field.
+ """
+ prefix = type.encode("ascii")
+ fields = [f for f in resp.payload.fields if f.startswith(prefix)]
+
+ assert len(fields) == 1
+ return fields[0][1:] # strip off the type byte
+
+ def match(self, resp):
+ """
+ Checks that the given response matches the expected code, message, and
+ detail (if given). The error code must match exactly. The expected
+ message and detail must be contained within the actual strings.
+ """
+ assert resp.type == pq3.types.ErrorResponse
+
+ code = self._getfield(resp, "C")
+ assert code == self.code
+
+ if self.msg:
+ msg = self._getfield(resp, "M")
+ expected = self.msg.encode("utf-8")
+ assert expected in msg
+
+ if self.detail:
+ detail = self._getfield(resp, "D")
+ expected = self.detail.encode("utf-8")
+ assert expected in detail
+
+
+def test_oauth_rejected_bearer(conn, oauth_ctx, bearer_token):
+ # Generate a new bearer token, which we will proceed not to use.
+ _ = bearer_token()
+
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send a bearer token that doesn't match what the validator expects. It
+ # should fail the connection.
+ send_initial_response(conn, bearer=b"xxxxxx")
+
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.parametrize(
+ "bad_bearer",
+ [
+ b"Bearer ",
+ b"Bearer a===b",
+ b"Bearer hello!",
+ b"Bearer me@example.com",
+ b'OAuth realm="Example"',
+ b"",
+ ],
+)
+def test_oauth_invalid_bearer(conn, oauth_ctx, bearer_token, bad_bearer):
+ # Tell the validator to accept any token. This ensures that the invalid
+ # bearer tokens are rejected before the validation step.
+ _ = bearer_token(accept_any=True)
+
+ begin_oauth_handshake(conn, oauth_ctx)
+ send_initial_response(conn, auth=bad_bearer)
+
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "resp_type,resp,err",
+ [
+ pytest.param(
+ None,
+ None,
+ None,
+ marks=pytest.mark.slow,
+ id="no response (expect timeout)",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"hello",
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not send a kvsep response",
+ ),
+ id="bad dummy response",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"\x01\x01",
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not send a kvsep response",
+ ),
+ id="multiple kvseps",
+ ),
+ pytest.param(
+ pq3.types.Query,
+ dict(query=b""),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+ id="bad response message type",
+ ),
+ ],
+)
+def test_oauth_bad_response_to_error_challenge(conn, oauth_ctx, resp_type, resp, err):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send an empty auth initial response, which will force an authn failure.
+ send_initial_response(conn, auth=b"")
+
+ # We expect a discovery "challenge" back from the server before the authn
+ # failure message.
+ pkt = pq3.recv1(conn)
+ assert pkt.type == pq3.types.AuthnRequest
+
+ req = pkt.payload
+ assert req.type == pq3.authn.SASLContinue
+
+ body = json.loads(req.body)
+ assert body["status"] == "invalid_token"
+
+ if resp_type is None:
+ # Do not send the dummy response. We should time out and not get a
+ # response from the server.
+ with pytest.raises(socket.timeout):
+ conn.read(1)
+
+ # Done with the test.
+ return
+
+ # Send the bad response.
+ pq3.send(conn, resp_type, resp)
+
+ # Make sure the server fails the connection correctly.
+ pkt = pq3.recv1(conn)
+ err.match(pkt)
+
+
+@pytest.mark.parametrize(
+ "type,payload,err",
+ [
+ pytest.param(
+ pq3.types.ErrorResponse,
+ dict(fields=[b""]),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+ id="error response in initial message",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ b"x" * (MAX_SASL_MESSAGE_LENGTH + 1),
+ ExpectedError(
+ INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed"
+ ),
+ id="overlong initial response data",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"SCRAM-SHA-256")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE, "invalid SASL authentication mechanism"
+ ),
+ id="bad SASL mechanism selection",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=2, data=b"x")),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "insufficient data"),
+ id="SASL data underflow",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=0, data=b"x")),
+ ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "invalid message format"),
+ id="SASL data overflow",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "message is empty",
+ ),
+ id="empty",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"n,,\x01auth=\x01\x01\0")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "length does not match input length",
+ ),
+ id="contains null byte",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected channel-binding flag", # XXX this is a bit strange
+ ),
+ id="initial error response",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"p=tls-server-end-point,,\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "server does not support channel binding",
+ ),
+ id="uses channel binding",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"x,,\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected channel-binding flag",
+ ),
+ id="invalid channel binding specifier",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Comma expected",
+ ),
+ id="bad GS2 header: missing channel binding terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,a")),
+ ExpectedError(
+ FEATURE_NOT_SUPPORTED_ERRCODE,
+ "client uses authorization identity",
+ ),
+ id="bad GS2 header: authzid in use",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,b,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected attribute",
+ ),
+ id="bad GS2 header: extra attribute",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Unexpected attribute 0x00", # XXX this is a bit strange
+ ),
+ id="bad GS2 header: missing authzid terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Key-value separator expected",
+ ),
+ id="missing initial kvsep",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "Key-value separator expected",
+ ),
+ id="missing initial kvsep",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "does not contain an auth value",
+ ),
+ id="missing auth value: empty key-value list",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "does not contain an auth value",
+ ),
+ id="missing auth value: other keys present",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "unterminated key/value pair",
+ ),
+ id="missing value terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,\x01")),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not contain a final terminator",
+ ),
+ id="missing list terminator: empty list",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "did not contain a final terminator",
+ ),
+ id="missing list terminator: with auth value",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01\x01blah")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "additional data after the final terminator",
+ ),
+ id="additional key after terminator",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(name=b"OAUTHBEARER", data=b"y,,\x01key\x01\x01")
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "key without a value",
+ ),
+ id="key without value",
+ ),
+ pytest.param(
+ pq3.types.PasswordMessage,
+ pq3.SASLInitialResponse.build(
+ dict(
+ name=b"OAUTHBEARER",
+ data=b"y,,\x01auth=Bearer 0\x01auth=Bearer 1\x01\x01",
+ )
+ ),
+ ExpectedError(
+ PROTOCOL_VIOLATION_ERRCODE,
+ "malformed OAUTHBEARER message",
+ "contains multiple auth values",
+ ),
+ id="multiple auth values",
+ ),
+ ],
+)
+def test_oauth_bad_initial_response(conn, oauth_ctx, type, payload, err):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # The server expects a SASL response; give it something else instead.
+ if not isinstance(payload, dict):
+ payload = dict(payload_data=payload)
+ pq3.send(conn, type, **payload)
+
+ resp = pq3.recv1(conn)
+ err.match(resp)
+
+
+def test_oauth_empty_initial_response(conn, oauth_ctx, bearer_token):
+ begin_oauth_handshake(conn, oauth_ctx)
+
+ # Send an initial response without data.
+ initial = pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER"))
+ pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+ # The server should respond with an empty challenge so we can send the data
+ # it wants.
+ pkt = pq3.recv1(conn)
+
+ assert pkt.type == pq3.types.AuthnRequest
+ assert pkt.payload.type == pq3.authn.SASLContinue
+ assert not pkt.payload.body
+
+ # Now send the initial data.
+ data = b"n,,\x01auth=Bearer " + bearer_token() + b"\x01\x01"
+ pq3.send(conn, pq3.types.PasswordMessage, data)
+
+ # Server should now complete the handshake.
+ expect_handshake_success(conn)
+
+
+@pytest.fixture()
+def set_validator():
+ """
+ A per-test fixture that allows a test to override the setting of
+ oauth_validator_command for the cluster. The setting will be reverted during
+ teardown.
+
+ Passing None will perform an ALTER SYSTEM RESET.
+ """
+ conn = psycopg2.connect("")
+ conn.autocommit = True
+
+ with contextlib.closing(conn):
+ c = conn.cursor()
+
+ # Save the previous value.
+ c.execute("SHOW oauth_validator_command;")
+ prev_cmd = c.fetchone()[0]
+
+ def setter(cmd):
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+ c.execute("SELECT pg_reload_conf();")
+
+ yield setter
+
+ # Restore the previous value.
+ c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (prev_cmd,))
+ c.execute("SELECT pg_reload_conf();")
+
+
+def test_oauth_no_validator(oauth_ctx, set_validator, connect, bearer_token):
+ # Clear out our validator command, then establish a new connection.
+ set_validator("")
+ conn = connect()
+
+ begin_oauth_handshake(conn, oauth_ctx)
+ send_initial_response(conn, bearer=bearer_token())
+
+ # The server should fail the connection.
+ expect_handshake_failure(conn, oauth_ctx)
+
+
+def test_oauth_validator_role(oauth_ctx, set_validator, connect):
+ # Switch the validator implementation. This validator will reflect the
+ # PGUSER as the authenticated identity.
+ path = pathlib.Path(__file__).parent / "validate_reflect.py"
+ path = str(path.absolute())
+
+ set_validator(f"{shlex.quote(path)} '%r' <&%f")
+ conn = connect()
+
+ # Log in. Note that the reflection validator ignores the bearer token.
+ begin_oauth_handshake(conn, oauth_ctx, user=oauth_ctx.user)
+ send_initial_response(conn, bearer=b"dontcare")
+ expect_handshake_success(conn)
+
+ # Check the user identity.
+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+ resp = receive_until(conn, pq3.types.DataRow)
+
+ row = resp.payload
+ expected = oauth_ctx.user.encode("utf-8")
+ assert row.columns == [expected]
+
+
+def test_oauth_role_with_shell_unsafe_characters(oauth_ctx, set_validator, connect):
+ """
+ XXX This test pins undesirable behavior. We should be able to handle any
+ valid Postgres role name.
+ """
+ # Switch the validator implementation. This validator will reflect the
+ # PGUSER as the authenticated identity.
+ path = pathlib.Path(__file__).parent / "validate_reflect.py"
+ path = str(path.absolute())
+
+ set_validator(f"{shlex.quote(path)} '%r' <&%f")
+ conn = connect()
+
+ unsafe_username = "hello'there"
+ begin_oauth_handshake(conn, oauth_ctx, user=unsafe_username)
+
+ # The server should reject the handshake.
+ send_initial_response(conn, bearer=b"dontcare")
+ expect_handshake_failure(conn, oauth_ctx)
diff --git a/src/test/python/server/test_server.py b/src/test/python/server/test_server.py
new file mode 100644
index 0000000000..02126dba79
--- /dev/null
+++ b/src/test/python/server/test_server.py
@@ -0,0 +1,21 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import pq3
+
+
+def test_handshake(connect):
+ """Basic sanity check."""
+ conn = connect()
+
+ pq3.handshake(conn, user=pq3.pguser(), database=pq3.pgdatabase())
+
+ pq3.send(conn, pq3.types.Query, query=b"")
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.EmptyQueryResponse
+
+ resp = pq3.recv1(conn)
+ assert resp.type == pq3.types.ReadyForQuery
diff --git a/src/test/python/server/validate_bearer.py b/src/test/python/server/validate_bearer.py
new file mode 100755
index 0000000000..2cc73ff154
--- /dev/null
+++ b/src/test/python/server/validate_bearer.py
@@ -0,0 +1,101 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It doesn't actually validate
+# anything, and it logs the bearer token data, which is sensitive.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. Memory is shared and communicated from that test module's
+# bearer_token() fixture.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import base64
+import binascii
+import contextlib
+import struct
+import sys
+from multiprocessing import shared_memory
+
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def remove_shm_from_resource_tracker():
+ """
+ Monkey-patch multiprocessing.resource_tracker so SharedMemory won't be
+ tracked. Pulled from this thread, where there are more details:
+
+ https://bugs.python.org/issue38119
+
+ TL;DR: all clients of shared memory segments automatically destroy them on
+ process exit, which makes shared memory segments much less useful. This
+ monkeypatch removes that behavior so that we can defer to the test to manage
+ the segment lifetime.
+
+ Ideally a future Python patch will pull in this fix and then the entire
+ function can go away.
+ """
+ from multiprocessing import resource_tracker
+
+ def fix_register(name, rtype):
+ if rtype == "shared_memory":
+ return
+ return resource_tracker._resource_tracker.register(self, name, rtype)
+
+ resource_tracker.register = fix_register
+
+ def fix_unregister(name, rtype):
+ if rtype == "shared_memory":
+ return
+ return resource_tracker._resource_tracker.unregister(self, name, rtype)
+
+ resource_tracker.unregister = fix_unregister
+
+ if "shared_memory" in resource_tracker._CLEANUP_FUNCS:
+ del resource_tracker._CLEANUP_FUNCS["shared_memory"]
+
+
+def main(args):
+ remove_shm_from_resource_tracker() # XXX remove some day
+
+ # Get the expected token from the currently running test.
+ shared_mem_name = args[0]
+
+ mem = shared_memory.SharedMemory(shared_mem_name)
+ with contextlib.closing(mem):
+ # First two bytes are the token length.
+ size = struct.unpack("H", mem.buf[:2])[0]
+
+ if size == MAX_UINT16:
+ # Special case: the test wants us to accept any token.
+ sys.stderr.write("accepting token without validation\n")
+ return
+
+ # The remainder of the buffer contains the expected token.
+ assert size <= (mem.size - 2)
+ expected_token = mem.buf[2 : size + 2].tobytes()
+
+ mem.buf[:] = b"\0" * mem.size # scribble over the token
+
+ token = sys.stdin.buffer.read()
+ if token != expected_token:
+ sys.exit(f"failed to match Bearer token ({token!r} != {expected_token!r})")
+
+ # See if the test wants us to print anything. If so, it will have encoded
+ # the desired output in the token with an "output=" prefix.
+ try:
+ # altchars="-_" corresponds to the urlsafe alphabet.
+ data = base64.b64decode(token, altchars="-_", validate=True)
+
+ if data.startswith(b"output="):
+ sys.stdout.buffer.write(data[7:])
+
+ except binascii.Error:
+ pass
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/src/test/python/server/validate_reflect.py b/src/test/python/server/validate_reflect.py
new file mode 100755
index 0000000000..24c3a7e715
--- /dev/null
+++ b/src/test/python/server/validate_reflect.py
@@ -0,0 +1,34 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It ignores the bearer token
+# entirely and automatically logs the user in.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. It expects the user's desired role name as an argument; the
+# actual token will be discarded and the user will be logged in with the role
+# name as the authenticated identity.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import sys
+
+
+def main(args):
+ # We have to read the entire token as our first action to unblock the
+ # server, but we won't actually use it.
+ _ = sys.stdin.buffer.read()
+
+ if len(args) != 1:
+ sys.exit("usage: ./validate_reflect.py ROLE")
+
+ # Log the user in as the provided role.
+ role = args[0]
+ print(role)
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/src/test/python/test_internals.py b/src/test/python/test_internals.py
new file mode 100644
index 0000000000..dee4855fc0
--- /dev/null
+++ b/src/test/python/test_internals.py
@@ -0,0 +1,138 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import io
+
+from pq3 import _DebugStream
+
+
+def test_DebugStream_read():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+
+ stream = _DebugStream(under, out)
+
+ res = stream.read(5)
+ assert res == b"abcde"
+
+ res = stream.read(16)
+ assert res == b"fghijklmnopqrstu"
+
+ stream.flush_debug()
+
+ res = stream.read()
+ assert res == b"vwxyz"
+
+ stream.flush_debug()
+
+ expected = (
+ "< 0000:\t61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70\tabcdefghijklmnop\n"
+ "< 0010:\t71 72 73 74 75 \tqrstu\n"
+ "\n"
+ "< 0000:\t76 77 78 79 7a \tvwxyz\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_write():
+ under = io.BytesIO()
+ out = io.StringIO()
+
+ stream = _DebugStream(under, out)
+
+ stream.write(b"\x00\x01\x02")
+ stream.flush()
+
+ assert under.getvalue() == b"\x00\x01\x02"
+
+ stream.write(b"\xc0\xc1\xc2")
+ stream.flush()
+
+ assert under.getvalue() == b"\x00\x01\x02\xc0\xc1\xc2"
+
+ stream.flush_debug()
+
+ expected = "> 0000:\t00 01 02 c0 c1 c2 \t......\n\n"
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_read_write():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+ stream = _DebugStream(under, out)
+
+ res = stream.read(5)
+ assert res == b"abcde"
+
+ stream.write(b"xxxxx")
+ stream.flush()
+
+ assert under.getvalue() == b"abcdexxxxxklmnopqrstuvwxyz"
+
+ res = stream.read(5)
+ assert res == b"klmno"
+
+ stream.write(b"xxxxx")
+ stream.flush()
+
+ assert under.getvalue() == b"abcdexxxxxklmnoxxxxxuvwxyz"
+
+ stream.flush_debug()
+
+ expected = (
+ "< 0000:\t61 62 63 64 65 6b 6c 6d 6e 6f \tabcdeklmno\n"
+ "\n"
+ "> 0000:\t78 78 78 78 78 78 78 78 78 78 \txxxxxxxxxx\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
+
+
+def test_DebugStream_end_packet():
+ under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+ out = io.StringIO()
+ stream = _DebugStream(under, out)
+
+ stream.read(5)
+ stream.end_packet("read description", read=True, indent=" ")
+
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("write description", indent=" ")
+
+ stream.read(5)
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("read/write combo for read", read=True, indent=" ")
+
+ stream.read(5)
+ stream.write(b"xxxxx")
+ stream.flush()
+ stream.end_packet("read/write combo for write", indent=" ")
+
+ expected = (
+ " < 0000:\t61 62 63 64 65 \tabcde\n"
+ "\n"
+ "< read description\n"
+ "\n"
+ "> write description\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ " < 0000:\t6b 6c 6d 6e 6f \tklmno\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ "< read/write combo for read\n"
+ "\n"
+ "> read/write combo for write\n"
+ "\n"
+ " < 0000:\t75 76 77 78 79 \tuvwxy\n"
+ "\n"
+ " > 0000:\t78 78 78 78 78 \txxxxx\n"
+ "\n"
+ )
+ assert out.getvalue() == expected
diff --git a/src/test/python/test_pq3.py b/src/test/python/test_pq3.py
new file mode 100644
index 0000000000..e0c0e0568d
--- /dev/null
+++ b/src/test/python/test_pq3.py
@@ -0,0 +1,558 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import struct
+import sys
+
+import pytest
+from construct import Container, PaddingError, StreamError, TerminatedError
+
+import pq3
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"\x00\x00\x00\x10\x00\x04\x00\x00abcdefgh",
+ Container(len=16, proto=0x40000, payload=b"abcdefgh"),
+ b"",
+ id="8-byte payload",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x08\x00\x04\x00\x00",
+ Container(len=8, proto=0x40000, payload=b""),
+ b"",
+ id="no payload",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x09\x00\x04\x00\x00abcde",
+ Container(len=9, proto=0x40000, payload=b"a"),
+ b"bcde",
+ id="1-byte payload and extra padding",
+ ),
+ pytest.param(
+ b"\x00\x00\x00\x0B\x00\x03\x00\x00hi\x00",
+ Container(len=11, proto=pq3.protocol(3, 0), payload=[b"hi"]),
+ b"",
+ id="implied parameter list when using proto version 3.0",
+ ),
+ ],
+)
+def test_Startup_parse(raw, expected, extra):
+ with io.BytesIO(raw) as stream:
+ actual = pq3.Startup.parse_stream(stream)
+
+ assert actual == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "packet,expected_bytes",
+ [
+ pytest.param(
+ dict(),
+ b"\x00\x00\x00\x08\x00\x00\x00\x00",
+ id="nothing set",
+ ),
+ pytest.param(
+ dict(len=10, proto=0x12345678),
+ b"\x00\x00\x00\x0A\x12\x34\x56\x78\x00\x00",
+ id="len and proto set explicitly",
+ ),
+ pytest.param(
+ dict(proto=0x12345678),
+ b"\x00\x00\x00\x08\x12\x34\x56\x78",
+ id="implied len with no payload",
+ ),
+ pytest.param(
+ dict(proto=0x12345678, payload=b"abcd"),
+ b"\x00\x00\x00\x0C\x12\x34\x56\x78abcd",
+ id="implied len with payload",
+ ),
+ pytest.param(
+ dict(payload=[b""]),
+ b"\x00\x00\x00\x09\x00\x03\x00\x00\x00",
+ id="implied proto version 3 when sending parameters",
+ ),
+ pytest.param(
+ dict(payload=[b"hi", b""]),
+ b"\x00\x00\x00\x0C\x00\x03\x00\x00hi\x00\x00",
+ id="implied proto version 3 and len when sending more than one parameter",
+ ),
+ pytest.param(
+ dict(payload=dict(user="jsmith", database="postgres")),
+ b"\x00\x00\x00\x27\x00\x03\x00\x00user\x00jsmith\x00database\x00postgres\x00\x00",
+ id="auto-serialization of dict parameters",
+ ),
+ ],
+)
+def test_Startup_build(packet, expected_bytes):
+ actual = pq3.Startup.build(packet)
+ assert actual == expected_bytes
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"*\x00\x00\x00\x08abcd",
+ dict(type=b"*", len=8, payload=b"abcd"),
+ b"",
+ id="4-byte payload",
+ ),
+ pytest.param(
+ b"*\x00\x00\x00\x04",
+ dict(type=b"*", len=4, payload=b""),
+ b"",
+ id="no payload",
+ ),
+ pytest.param(
+ b"*\x00\x00\x00\x05xabcd",
+ dict(type=b"*", len=5, payload=b"x"),
+ b"abcd",
+ id="1-byte payload with extra padding",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=8,
+ payload=dict(type=pq3.authn.OK, body=None),
+ ),
+ b"",
+ id="AuthenticationOk",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x12\x00\x00\x00\x0AEXTERNAL\x00\x00",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=18,
+ payload=dict(type=pq3.authn.SASL, body=[b"EXTERNAL", b""]),
+ ),
+ b"",
+ id="AuthenticationSASL",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=13,
+ payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+ ),
+ b"",
+ id="AuthenticationSASLContinue",
+ ),
+ pytest.param(
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+ dict(
+ type=pq3.types.AuthnRequest,
+ len=13,
+ payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+ ),
+ b"",
+ id="AuthenticationSASLFinal",
+ ),
+ pytest.param(
+ b"p\x00\x00\x00\x0Bhunter2",
+ dict(
+ type=pq3.types.PasswordMessage,
+ len=11,
+ payload=b"hunter2",
+ ),
+ b"",
+ id="PasswordMessage",
+ ),
+ pytest.param(
+ b"K\x00\x00\x00\x0C\x00\x00\x00\x00\x12\x34\x56\x78",
+ dict(
+ type=pq3.types.BackendKeyData,
+ len=12,
+ payload=dict(pid=0, key=0x12345678),
+ ),
+ b"",
+ id="BackendKeyData",
+ ),
+ pytest.param(
+ b"C\x00\x00\x00\x08SET\x00",
+ dict(
+ type=pq3.types.CommandComplete,
+ len=8,
+ payload=dict(tag=b"SET"),
+ ),
+ b"",
+ id="CommandComplete",
+ ),
+ pytest.param(
+ b"E\x00\x00\x00\x11Mbad!\x00Mdog!\x00\x00",
+ dict(type=b"E", len=17, payload=dict(fields=[b"Mbad!", b"Mdog!", b""])),
+ b"",
+ id="ErrorResponse",
+ ),
+ pytest.param(
+ b"S\x00\x00\x00\x08a\x00b\x00",
+ dict(
+ type=pq3.types.ParameterStatus,
+ len=8,
+ payload=dict(name=b"a", value=b"b"),
+ ),
+ b"",
+ id="ParameterStatus",
+ ),
+ pytest.param(
+ b"Z\x00\x00\x00\x05x",
+ dict(type=b"Z", len=5, payload=dict(status=b"x")),
+ b"",
+ id="ReadyForQuery",
+ ),
+ pytest.param(
+ b"Q\x00\x00\x00\x06!\x00",
+ dict(type=pq3.types.Query, len=6, payload=dict(query=b"!")),
+ b"",
+ id="Query",
+ ),
+ pytest.param(
+ b"D\x00\x00\x00\x0B\x00\x01\x00\x00\x00\x01!",
+ dict(type=pq3.types.DataRow, len=11, payload=dict(columns=[b"!"])),
+ b"",
+ id="DataRow",
+ ),
+ pytest.param(
+ b"D\x00\x00\x00\x06\x00\x00extra",
+ dict(type=pq3.types.DataRow, len=6, payload=dict(columns=[])),
+ b"extra",
+ id="DataRow with extra data",
+ ),
+ pytest.param(
+ b"I\x00\x00\x00\x04",
+ dict(type=pq3.types.EmptyQueryResponse, len=4, payload=None),
+ b"",
+ id="EmptyQueryResponse",
+ ),
+ pytest.param(
+ b"I\x00\x00\x00\x04\xFF",
+ dict(type=b"I", len=4, payload=None),
+ b"\xFF",
+ id="EmptyQueryResponse with extra bytes",
+ ),
+ pytest.param(
+ b"X\x00\x00\x00\x04",
+ dict(type=pq3.types.Terminate, len=4, payload=None),
+ b"",
+ id="Terminate",
+ ),
+ ],
+)
+def test_Pq3_parse(raw, expected, extra):
+ with io.BytesIO(raw) as stream:
+ actual = pq3.Pq3.parse_stream(stream)
+
+ assert actual == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(type=b"*", len=5),
+ b"*\x00\x00\x00\x05\x00",
+ id="type and len set explicitly",
+ ),
+ pytest.param(
+ dict(type=b"*"),
+ b"*\x00\x00\x00\x04",
+ id="implied len with no payload",
+ ),
+ pytest.param(
+ dict(type=b"*", payload=b"1234"),
+ b"*\x00\x00\x00\x081234",
+ id="implied len with payload",
+ ),
+ pytest.param(
+ dict(type=pq3.types.AuthnRequest, payload=dict(type=pq3.authn.OK)),
+ b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+ id="implied len/type for AuthenticationOK",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(
+ type=pq3.authn.SASL,
+ body=[b"SCRAM-SHA-256-PLUS", b"SCRAM-SHA-256", b""],
+ ),
+ ),
+ b"R\x00\x00\x00\x2A\x00\x00\x00\x0ASCRAM-SHA-256-PLUS\x00SCRAM-SHA-256\x00\x00",
+ id="implied len/type for AuthenticationSASL",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+ ),
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+ id="implied len/type for AuthenticationSASLContinue",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.AuthnRequest,
+ payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+ ),
+ b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+ id="implied len/type for AuthenticationSASLFinal",
+ ),
+ pytest.param(
+ dict(
+ type=pq3.types.PasswordMessage,
+ payload=b"hunter2",
+ ),
+ b"p\x00\x00\x00\x0Bhunter2",
+ id="implied len/type for PasswordMessage",
+ ),
+ pytest.param(
+ dict(type=pq3.types.BackendKeyData, payload=dict(pid=1, key=7)),
+ b"K\x00\x00\x00\x0C\x00\x00\x00\x01\x00\x00\x00\x07",
+ id="implied len/type for BackendKeyData",
+ ),
+ pytest.param(
+ dict(type=pq3.types.CommandComplete, payload=dict(tag=b"SET")),
+ b"C\x00\x00\x00\x08SET\x00",
+ id="implied len/type for CommandComplete",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ErrorResponse, payload=dict(fields=[b"error", b""])),
+ b"E\x00\x00\x00\x0Berror\x00\x00",
+ id="implied len/type for ErrorResponse",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ParameterStatus, payload=dict(name=b"a", value=b"b")),
+ b"S\x00\x00\x00\x08a\x00b\x00",
+ id="implied len/type for ParameterStatus",
+ ),
+ pytest.param(
+ dict(type=pq3.types.ReadyForQuery, payload=dict(status=b"I")),
+ b"Z\x00\x00\x00\x05I",
+ id="implied len/type for ReadyForQuery",
+ ),
+ pytest.param(
+ dict(type=pq3.types.Query, payload=dict(query=b"SELECT 1;")),
+ b"Q\x00\x00\x00\x0eSELECT 1;\x00",
+ id="implied len/type for Query",
+ ),
+ pytest.param(
+ dict(type=pq3.types.DataRow, payload=dict(columns=[b"abcd"])),
+ b"D\x00\x00\x00\x0E\x00\x01\x00\x00\x00\x04abcd",
+ id="implied len/type for DataRow",
+ ),
+ pytest.param(
+ dict(type=pq3.types.EmptyQueryResponse),
+ b"I\x00\x00\x00\x04",
+ id="implied len for EmptyQueryResponse",
+ ),
+ pytest.param(
+ dict(type=pq3.types.Terminate),
+ b"X\x00\x00\x00\x04",
+ id="implied len for Terminate",
+ ),
+ ],
+)
+def test_Pq3_build(fields, expected):
+ actual = pq3.Pq3.build(fields)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "raw,expected,extra",
+ [
+ pytest.param(
+ b"\x00\x00",
+ dict(columns=[]),
+ b"",
+ id="no columns",
+ ),
+ pytest.param(
+ b"\x00\x01\x00\x00\x00\x04abcd",
+ dict(columns=[b"abcd"]),
+ b"",
+ id="one column",
+ ),
+ pytest.param(
+ b"\x00\x02\x00\x00\x00\x04abcd\x00\x00\x00\x01x",
+ dict(columns=[b"abcd", b"x"]),
+ b"",
+ id="multiple columns",
+ ),
+ pytest.param(
+ b"\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01x",
+ dict(columns=[b"", b"x"]),
+ b"",
+ id="empty column value",
+ ),
+ pytest.param(
+ b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+ dict(columns=[None, None]),
+ b"",
+ id="null columns",
+ ),
+ ],
+)
+def test_DataRow_parse(raw, expected, extra):
+ pkt = b"D" + struct.pack("!i", len(raw) + 4) + raw
+ with io.BytesIO(pkt) as stream:
+ actual = pq3.Pq3.parse_stream(stream)
+
+ assert actual.type == pq3.types.DataRow
+ assert actual.payload == expected
+ assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(),
+ b"\x00\x00",
+ id="no columns",
+ ),
+ pytest.param(
+ dict(columns=[None, None]),
+ b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+ id="null columns",
+ ),
+ ],
+)
+def test_DataRow_build(fields, expected):
+ actual = pq3.Pq3.build(dict(type=pq3.types.DataRow, payload=fields))
+
+ expected = b"D" + struct.pack("!i", len(expected) + 4) + expected
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "raw,expected,exception",
+ [
+ pytest.param(
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ dict(name=b"EXTERNAL", len=-1, data=None),
+ None,
+ id="no initial response",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\x02me",
+ dict(name=b"EXTERNAL", len=2, data=b"me"),
+ None,
+ id="initial response",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\x02meextra",
+ None,
+ TerminatedError,
+ id="extra data",
+ ),
+ pytest.param(
+ b"EXTERNAL\x00\x00\x00\x00\xFFme",
+ None,
+ StreamError,
+ id="underflow",
+ ),
+ ],
+)
+def test_SASLInitialResponse_parse(raw, expected, exception):
+ ctx = contextlib.nullcontext()
+ if exception:
+ ctx = pytest.raises(exception)
+
+ with ctx:
+ actual = pq3.SASLInitialResponse.parse(raw)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "fields,expected",
+ [
+ pytest.param(
+ dict(name=b"EXTERNAL"),
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ id="no initial response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=None),
+ b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+ id="no initial response (explicit None)",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=b""),
+ b"EXTERNAL\x00\x00\x00\x00\x00",
+ id="empty response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", data=b"me@example.com"),
+ b"EXTERNAL\x00\x00\x00\x00\x0Eme@example.com",
+ id="initial response",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", len=2, data=b"me@example.com"),
+ b"EXTERNAL\x00\x00\x00\x00\x02me@example.com",
+ id="data overflow",
+ ),
+ pytest.param(
+ dict(name=b"EXTERNAL", len=14, data=b"me"),
+ b"EXTERNAL\x00\x00\x00\x00\x0Eme",
+ id="data underflow",
+ ),
+ ],
+)
+def test_SASLInitialResponse_build(fields, expected):
+ actual = pq3.SASLInitialResponse.build(fields)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "version,expected_bytes",
+ [
+ pytest.param((3, 0), b"\x00\x03\x00\x00", id="version 3"),
+ pytest.param((1234, 5679), b"\x04\xd2\x16\x2f", id="SSLRequest"),
+ ],
+)
+def test_protocol(version, expected_bytes):
+ # Make sure the integer returned by protocol is correctly serialized on the
+ # wire.
+ assert struct.pack("!i", pq3.protocol(*version)) == expected_bytes
+
+
+@pytest.mark.parametrize(
+ "envvar,func,expected",
+ [
+ ("PGHOST", pq3.pghost, "localhost"),
+ ("PGPORT", pq3.pgport, 5432),
+ ("PGUSER", pq3.pguser, getpass.getuser()),
+ ("PGDATABASE", pq3.pgdatabase, "postgres"),
+ ],
+)
+def test_env_defaults(monkeypatch, envvar, func, expected):
+ monkeypatch.delenv(envvar, raising=False)
+
+ actual = func()
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ "envvars,func,expected",
+ [
+ (dict(PGHOST="otherhost"), pq3.pghost, "otherhost"),
+ (dict(PGPORT="6789"), pq3.pgport, 6789),
+ (dict(PGUSER="postgres"), pq3.pguser, "postgres"),
+ (dict(PGDATABASE="template1"), pq3.pgdatabase, "template1"),
+ ],
+)
+def test_env(monkeypatch, envvars, func, expected):
+ for k, v in envvars.items():
+ monkeypatch.setenv(k, v)
+
+ actual = func()
+ assert actual == expected
diff --git a/src/test/python/tls.py b/src/test/python/tls.py
new file mode 100644
index 0000000000..075c02c1ca
--- /dev/null
+++ b/src/test/python/tls.py
@@ -0,0 +1,195 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+from construct import *
+
+#
+# TLS 1.3
+#
+# Most of the types below are transcribed from RFC 8446:
+#
+# https://tools.ietf.org/html/rfc8446
+#
+
+
+def _Vector(size_field, element):
+ return Prefixed(size_field, GreedyRange(element))
+
+
+# Alerts
+
+AlertLevel = Enum(
+ Byte,
+ warning=1,
+ fatal=2,
+)
+
+AlertDescription = Enum(
+ Byte,
+ close_notify=0,
+ unexpected_message=10,
+ bad_record_mac=20,
+ decryption_failed_RESERVED=21,
+ record_overflow=22,
+ decompression_failure=30,
+ handshake_failure=40,
+ no_certificate_RESERVED=41,
+ bad_certificate=42,
+ unsupported_certificate=43,
+ certificate_revoked=44,
+ certificate_expired=45,
+ certificate_unknown=46,
+ illegal_parameter=47,
+ unknown_ca=48,
+ access_denied=49,
+ decode_error=50,
+ decrypt_error=51,
+ export_restriction_RESERVED=60,
+ protocol_version=70,
+ insufficient_security=71,
+ internal_error=80,
+ user_canceled=90,
+ no_renegotiation=100,
+ unsupported_extension=110,
+)
+
+Alert = Struct(
+ "level" / AlertLevel,
+ "description" / AlertDescription,
+)
+
+
+# Extensions
+
+ExtensionType = Enum(
+ Int16ub,
+ server_name=0,
+ max_fragment_length=1,
+ status_request=5,
+ supported_groups=10,
+ signature_algorithms=13,
+ use_srtp=14,
+ heartbeat=15,
+ application_layer_protocol_negotiation=16,
+ signed_certificate_timestamp=18,
+ client_certificate_type=19,
+ server_certificate_type=20,
+ padding=21,
+ pre_shared_key=41,
+ early_data=42,
+ supported_versions=43,
+ cookie=44,
+ psk_key_exchange_modes=45,
+ certificate_authorities=47,
+ oid_filters=48,
+ post_handshake_auth=49,
+ signature_algorithms_cert=50,
+ key_share=51,
+)
+
+Extension = Struct(
+ "extension_type" / ExtensionType,
+ "extension_data" / Prefixed(Int16ub, GreedyBytes),
+)
+
+
+# ClientHello
+
+
+class _CipherSuiteAdapter(Adapter):
+ class _hextuple(tuple):
+ def __repr__(self):
+ return f"(0x{self[0]:02X}, 0x{self[1]:02X})"
+
+ def _encode(self, obj, context, path):
+ return bytes(obj)
+
+ def _decode(self, obj, context, path):
+ assert len(obj) == 2
+ return self._hextuple(obj)
+
+
+ProtocolVersion = Hex(Int16ub)
+
+Random = Hex(Bytes(32))
+
+CipherSuite = _CipherSuiteAdapter(Byte[2])
+
+ClientHello = Struct(
+ "legacy_version" / ProtocolVersion,
+ "random" / Random,
+ "legacy_session_id" / Prefixed(Byte, Hex(GreedyBytes)),
+ "cipher_suites" / _Vector(Int16ub, CipherSuite),
+ "legacy_compression_methods" / Prefixed(Byte, GreedyBytes),
+ "extensions" / _Vector(Int16ub, Extension),
+)
+
+# ServerHello
+
+ServerHello = Struct(
+ "legacy_version" / ProtocolVersion,
+ "random" / Random,
+ "legacy_session_id_echo" / Prefixed(Byte, Hex(GreedyBytes)),
+ "cipher_suite" / CipherSuite,
+ "legacy_compression_method" / Hex(Byte),
+ "extensions" / _Vector(Int16ub, Extension),
+)
+
+# Handshake
+
+HandshakeType = Enum(
+ Byte,
+ client_hello=1,
+ server_hello=2,
+ new_session_ticket=4,
+ end_of_early_data=5,
+ encrypted_extensions=8,
+ certificate=11,
+ certificate_request=13,
+ certificate_verify=15,
+ finished=20,
+ key_update=24,
+ message_hash=254,
+)
+
+Handshake = Struct(
+ "msg_type" / HandshakeType,
+ "length" / Int24ub,
+ "payload"
+ / Switch(
+ this.msg_type,
+ {
+ HandshakeType.client_hello: ClientHello,
+ HandshakeType.server_hello: ServerHello,
+ # HandshakeType.end_of_early_data: EndOfEarlyData,
+ # HandshakeType.encrypted_extensions: EncryptedExtensions,
+ # HandshakeType.certificate_request: CertificateRequest,
+ # HandshakeType.certificate: Certificate,
+ # HandshakeType.certificate_verify: CertificateVerify,
+ # HandshakeType.finished: Finished,
+ # HandshakeType.new_session_ticket: NewSessionTicket,
+ # HandshakeType.key_update: KeyUpdate,
+ },
+ default=FixedSized(this.length, GreedyBytes),
+ ),
+)
+
+# Records
+
+ContentType = Enum(
+ Byte,
+ invalid=0,
+ change_cipher_spec=20,
+ alert=21,
+ handshake=22,
+ application_data=23,
+)
+
+Plaintext = Struct(
+ "type" / ContentType,
+ "legacy_record_version" / ProtocolVersion,
+ "length" / Int16ub,
+ "fragment" / FixedSized(this.length, GreedyBytes),
+)
--
2.25.1
v4-0010-contrib-oauth-switch-to-pluggable-auth-API.patchtext/x-patch; name=v4-0010-contrib-oauth-switch-to-pluggable-auth-API.patchDownload
From f520c08e1aee5239051e304c8a8faf5cb25bdbf2 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Fri, 25 Mar 2022 16:35:30 -0700
Subject: [PATCH v4 10/10] contrib/oauth: switch to pluggable auth API
Move the core server implementation to contrib/oauth as a pluggable
provider, using the RegisterAuthProvider() API. oauth_validator_command
has been moved from core to a custom GUC. HBA options are handled
using the new hook. Tests have been updated to handle the new
implementations.
One server modification remains: allowing custom SASL mechanisms to
declare their own maximum message length.
This patch is optional; you can apply/revert it to compare the two
approaches.
---
contrib/oauth/Makefile | 16 ++++
.../auth-oauth.c => contrib/oauth/oauth.c | 88 ++++++++++++++++---
src/backend/libpq/Makefile | 1 -
src/backend/libpq/auth.c | 7 --
src/backend/libpq/hba.c | 27 +-----
src/backend/utils/misc/guc.c | 12 ---
src/include/libpq/hba.h | 6 +-
src/include/libpq/oauth.h | 24 -----
src/test/python/README | 3 +-
src/test/python/server/test_oauth.py | 20 ++---
10 files changed, 104 insertions(+), 100 deletions(-)
create mode 100644 contrib/oauth/Makefile
rename src/backend/libpq/auth-oauth.c => contrib/oauth/oauth.c (90%)
delete mode 100644 src/include/libpq/oauth.h
diff --git a/contrib/oauth/Makefile b/contrib/oauth/Makefile
new file mode 100644
index 0000000000..880bc1fef3
--- /dev/null
+++ b/contrib/oauth/Makefile
@@ -0,0 +1,16 @@
+# contrib/oauth/Makefile
+
+MODULE_big = oauth
+OBJS = oauth.o
+PGFILEDESC = "oauth - auth provider supporting OAuth 2.0/OIDC"
+
+ifdef USE_PGXS
+PG_CONFIG = pg_config
+PGXS := $(shell $(PG_CONFIG) --pgxs)
+include $(PGXS)
+else
+subdir = contrib/oauth
+top_builddir = ../..
+include $(top_builddir)/src/Makefile.global
+include $(top_srcdir)/contrib/contrib-global.mk
+endif
diff --git a/src/backend/libpq/auth-oauth.c b/contrib/oauth/oauth.c
similarity index 90%
rename from src/backend/libpq/auth-oauth.c
rename to contrib/oauth/oauth.c
index c1232a31a0..e83f3c5d99 100644
--- a/src/backend/libpq/auth-oauth.c
+++ b/contrib/oauth/oauth.c
@@ -1,33 +1,39 @@
-/*-------------------------------------------------------------------------
+/* -------------------------------------------------------------------------
*
- * auth-oauth.c
+ * oauth.c
* Server-side implementation of the SASL OAUTHBEARER mechanism.
*
* See the following RFC for more details:
* - RFC 7628: https://tools.ietf.org/html/rfc7628
*
- * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1996-2022, PostgreSQL Global Development Group
* Portions Copyright (c) 1994, Regents of the University of California
*
- * src/backend/libpq/auth-oauth.c
+ * contrib/oauth/oauth.c
*
- *-------------------------------------------------------------------------
+ * -------------------------------------------------------------------------
*/
+
#include "postgres.h"
#include <unistd.h>
#include <fcntl.h>
#include "common/oauth-common.h"
+#include "fmgr.h"
#include "lib/stringinfo.h"
#include "libpq/auth.h"
#include "libpq/hba.h"
-#include "libpq/oauth.h"
#include "libpq/sasl.h"
#include "storage/fd.h"
+#include "utils/guc.h"
+
+PG_MODULE_MAGIC;
+
+void _PG_init(void);
/* GUC */
-char *oauth_validator_command;
+static char *oauth_validator_command;
static void oauth_get_mechanisms(Port *port, StringInfo buf);
static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
@@ -35,7 +41,7 @@ static int oauth_exchange(void *opaq, const char *input, int inputlen,
char **output, int *outputlen, const char **logdetail);
/* Mechanism declaration */
-const pg_be_sasl_mech pg_be_oauth_mech = {
+static const pg_be_sasl_mech oauth_mech = {
oauth_get_mechanisms,
oauth_init,
oauth_exchange,
@@ -57,12 +63,13 @@ struct oauth_ctx
Port *port;
const char *issuer;
const char *scope;
+ bool skip_usermap;
};
static char *sanitize_char(char c);
static char *parse_kvpairs_for_auth(char **input);
static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
-static bool validate(Port *port, const char *auth, const char **logdetail);
+static bool validate(struct oauth_ctx *ctx, const char *auth, const char **logdetail);
static bool run_validator_command(Port *port, const char *token);
static bool check_exit(FILE **fh, const char *command);
static bool unset_cloexec(int fd);
@@ -84,6 +91,7 @@ static void *
oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
{
struct oauth_ctx *ctx;
+ ListCell *lc;
if (strcmp(selected_mech, OAUTHBEARER_NAME))
ereport(ERROR,
@@ -96,8 +104,21 @@ oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
ctx->port = port;
Assert(port->hba);
- ctx->issuer = port->hba->oauth_issuer;
- ctx->scope = port->hba->oauth_scope;
+
+ foreach (lc, port->hba->custom_auth_options)
+ {
+ CustomOption *option = lfirst(lc);
+
+ if (strcmp(option->name, "issuer") == 0)
+ ctx->issuer = option->value;
+ else if (strcmp(option->name, "scope") == 0)
+ ctx->scope = option->value;
+ else if (strcmp(option->name, "trust_validator_authz") == 0)
+ {
+ if (strcmp(option->value, "1") == 0)
+ ctx->skip_usermap = true;
+ }
+ }
return ctx;
}
@@ -248,7 +269,7 @@ oauth_exchange(void *opaq, const char *input, int inputlen,
errmsg("malformed OAUTHBEARER message"),
errdetail("Message contains additional data after the final terminator.")));
- if (!validate(ctx->port, auth, logdetail))
+ if (!validate(ctx, auth, logdetail))
{
generate_error_response(ctx, output, outputlen);
@@ -415,12 +436,13 @@ generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
}
static bool
-validate(Port *port, const char *auth, const char **logdetail)
+validate(struct oauth_ctx *ctx, const char *auth, const char **logdetail)
{
static const char * const b64_set = "abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"0123456789-._~+/";
+ Port *port = ctx->port;
const char *token;
size_t span;
int ret;
@@ -497,7 +519,7 @@ validate(Port *port, const char *auth, const char **logdetail)
if (!run_validator_command(port, token))
return false;
- if (port->hba->oauth_skip_usermap)
+ if (ctx->skip_usermap)
{
/*
* If the validator is our authorization authority, we're done.
@@ -795,3 +817,41 @@ username_ok_for_shell(const char *username)
return true;
}
+
+static int CheckOAuth(Port *port)
+{
+ return CheckSASLAuth(&oauth_mech, port, NULL, NULL);
+}
+
+static const char *OAuthError(Port *port)
+{
+ return psprintf("OAuth bearer authentication failed for user \"%s\"",
+ port->user_name);
+}
+
+static bool OAuthCheckOption(char *name, char *val,
+ struct HbaLine *hbaline, char **errmsg)
+{
+ if (!strcmp(name, "issuer"))
+ return true;
+ if (!strcmp(name, "scope"))
+ return true;
+ if (!strcmp(name, "trust_validator_authz"))
+ return true;
+
+ return false;
+}
+
+void
+_PG_init(void)
+{
+ RegisterAuthProvider("oauth", CheckOAuth, OAuthError, OAuthCheckOption);
+
+ DefineCustomStringVariable("oauth.validator_command",
+ gettext_noop("Command to validate OAuth v2 bearer tokens."),
+ NULL,
+ &oauth_validator_command,
+ "",
+ PGC_SIGHUP, GUC_SUPERUSER_ONLY,
+ NULL, NULL, NULL);
+}
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 98eb2a8242..6d385fd6a4 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,7 +15,6 @@ include $(top_builddir)/src/Makefile.global
# be-fsstubs is here for historical reasons, probably belongs elsewhere
OBJS = \
- auth-oauth.o \
auth-sasl.o \
auth-scram.o \
auth.o \
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 17042d84ad..4a8a63922a 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -30,7 +30,6 @@
#include "libpq/auth.h"
#include "libpq/crypt.h"
#include "libpq/libpq.h"
-#include "libpq/oauth.h"
#include "libpq/pqformat.h"
#include "libpq/sasl.h"
#include "libpq/scram.h"
@@ -299,9 +298,6 @@ auth_failed(Port *port, int status, const char *logdetail)
case uaRADIUS:
errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
break;
- case uaOAuth:
- errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
- break;
case uaCustom:
{
CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
@@ -630,9 +626,6 @@ ClientAuthentication(Port *port)
case uaTrust:
status = STATUS_OK;
break;
- case uaOAuth:
- status = CheckSASLAuth(&pg_be_oauth_mech, port, NULL, NULL);
- break;
case uaCustom:
{
CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index cd3b1cc140..6bf986d5b3 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -137,7 +137,6 @@ static const char *const UserAuthName[] =
"radius",
"custom",
"peer",
- "oauth",
};
@@ -1402,8 +1401,6 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
#endif
else if (strcmp(token->string, "radius") == 0)
parsedline->auth_method = uaRADIUS;
- else if (strcmp(token->string, "oauth") == 0)
- parsedline->auth_method = uaOAuth;
else if (strcmp(token->string, "custom") == 0)
parsedline->auth_method = uaCustom;
else
@@ -1733,9 +1730,8 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->auth_method != uaGSS &&
hbaline->auth_method != uaSSPI &&
hbaline->auth_method != uaCert &&
- hbaline->auth_method != uaOAuth &&
hbaline->auth_method != uaCustom)
- INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, oauth, and custom"));
+ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, and custom"));
hbaline->usermap = pstrdup(val);
}
else if (strcmp(name, "clientcert") == 0)
@@ -2119,27 +2115,6 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
hbaline->radiusidentifiers = parsed_identifiers;
hbaline->radiusidentifiers_s = pstrdup(val);
}
- else if (strcmp(name, "issuer") == 0)
- {
- if (hbaline->auth_method != uaOAuth)
- INVALID_AUTH_OPTION("issuer", gettext_noop("oauth"));
- hbaline->oauth_issuer = pstrdup(val);
- }
- else if (strcmp(name, "scope") == 0)
- {
- if (hbaline->auth_method != uaOAuth)
- INVALID_AUTH_OPTION("scope", gettext_noop("oauth"));
- hbaline->oauth_scope = pstrdup(val);
- }
- else if (strcmp(name, "trust_validator_authz") == 0)
- {
- if (hbaline->auth_method != uaOAuth)
- INVALID_AUTH_OPTION("trust_validator_authz", gettext_noop("oauth"));
- if (strcmp(val, "1") == 0)
- hbaline->oauth_skip_usermap = true;
- else
- hbaline->oauth_skip_usermap = false;
- }
else if (strcmp(name, "provider") == 0)
{
REQUIRE_AUTH_OPTION(uaCustom, "provider", "custom");
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index 9a5b2aa496..f70f7f5c01 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -59,7 +59,6 @@
#include "libpq/auth.h"
#include "libpq/libpq.h"
#include "libpq/pqformat.h"
-#include "libpq/oauth.h"
#include "miscadmin.h"
#include "optimizer/cost.h"
#include "optimizer/geqo.h"
@@ -4667,17 +4666,6 @@ static struct config_string ConfigureNamesString[] =
check_backtrace_functions, assign_backtrace_functions, NULL
},
- {
- {"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH,
- gettext_noop("Command to validate OAuth v2 bearer tokens."),
- NULL,
- GUC_SUPERUSER_ONLY
- },
- &oauth_validator_command,
- "",
- NULL, NULL, NULL
- },
-
/* End-of-list marker */
{
{NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index e405103a2e..bbc94363cb 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -40,8 +40,7 @@ typedef enum UserAuth
uaRADIUS,
uaCustom,
uaPeer,
- uaOAuth
-#define USER_AUTH_LAST uaOAuth /* Must be last value of this enum */
+#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */
} UserAuth;
/*
@@ -129,9 +128,6 @@ typedef struct HbaLine
char *radiusidentifiers_s;
List *radiusports;
char *radiusports_s;
- char *oauth_issuer;
- char *oauth_scope;
- bool oauth_skip_usermap;
char *custom_provider;
List *custom_auth_options;
} HbaLine;
diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h
deleted file mode 100644
index 870e426af1..0000000000
--- a/src/include/libpq/oauth.h
+++ /dev/null
@@ -1,24 +0,0 @@
-/*-------------------------------------------------------------------------
- *
- * oauth.h
- * Interface to libpq/auth-oauth.c
- *
- * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
- * Portions Copyright (c) 1994, Regents of the University of California
- *
- * src/include/libpq/oauth.h
- *
- *-------------------------------------------------------------------------
- */
-#ifndef PG_OAUTH_H
-#define PG_OAUTH_H
-
-#include "libpq/libpq-be.h"
-#include "libpq/sasl.h"
-
-extern char *oauth_validator_command;
-
-/* Implementation */
-extern const pg_be_sasl_mech pg_be_oauth_mech;
-
-#endif /* PG_OAUTH_H */
diff --git a/src/test/python/README b/src/test/python/README
index 0bda582c4b..0fbc1046cf 100644
--- a/src/test/python/README
+++ b/src/test/python/README
@@ -13,7 +13,8 @@ but you can adjust as needed for your setup.
## Requirements
-A supported version (3.6+) of Python.
+- A supported version (3.6+) of Python.
+- The oauth extension must be installed and loaded via shared_preload_libraries.
The first run of
diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py
index cb5ca7fa23..07fc25edc2 100644
--- a/src/test/python/server/test_oauth.py
+++ b/src/test/python/server/test_oauth.py
@@ -103,9 +103,9 @@ def oauth_ctx():
ctx = Context()
hba_lines = (
- f'host {ctx.dbname} {ctx.map_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
- f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
- f'host {ctx.dbname} all samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
+ f'host {ctx.dbname} {ctx.map_user} samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
+ f'host {ctx.dbname} {ctx.authz_user} samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
+ f'host {ctx.dbname} all samehost custom provider=oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
)
ident_lines = (r"oauth /^(.*)@example\.com$ \1",)
@@ -126,12 +126,12 @@ def oauth_ctx():
c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user))
c.execute(sql.SQL("CREATE DATABASE {};").format(dbname))
- # Make this test script the server's oauth_validator.
+ # Make this test script the server's oauth validator.
path = pathlib.Path(__file__).parent / "validate_bearer.py"
path = str(path.absolute())
cmd = f"{shlex.quote(path)} {SHARED_MEM_NAME} <&%f"
- c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+ c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (cmd,))
# Replace pg_hba and pg_ident.
c.execute("SHOW hba_file;")
@@ -149,7 +149,7 @@ def oauth_ctx():
# Put things back the way they were.
c.execute("SELECT pg_reload_conf();")
- c.execute("ALTER SYSTEM RESET oauth_validator_command;")
+ c.execute("ALTER SYSTEM RESET oauth.validator_command;")
c.execute(sql.SQL("DROP DATABASE {};").format(dbname))
c.execute(sql.SQL("DROP ROLE {};").format(authz_user))
c.execute(sql.SQL("DROP ROLE {};").format(map_user))
@@ -930,7 +930,7 @@ def test_oauth_empty_initial_response(conn, oauth_ctx, bearer_token):
def set_validator():
"""
A per-test fixture that allows a test to override the setting of
- oauth_validator_command for the cluster. The setting will be reverted during
+ oauth.validator_command for the cluster. The setting will be reverted during
teardown.
Passing None will perform an ALTER SYSTEM RESET.
@@ -942,17 +942,17 @@ def set_validator():
c = conn.cursor()
# Save the previous value.
- c.execute("SHOW oauth_validator_command;")
+ c.execute("SHOW oauth.validator_command;")
prev_cmd = c.fetchone()[0]
def setter(cmd):
- c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+ c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (cmd,))
c.execute("SELECT pg_reload_conf();")
yield setter
# Restore the previous value.
- c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (prev_cmd,))
+ c.execute("ALTER SYSTEM SET oauth.validator_command TO %s;", (prev_cmd,))
c.execute("SELECT pg_reload_conf();")
--
2.25.1
Hi Hackers,
We are trying to implement AAD(Azure AD) support in PostgreSQL and it
can be achieved with support of the OAuth method. To support AAD on
top of OAuth in a generic fashion (i.e for all other OAuth providers),
we are proposing this patch. It basically exposes two new hooks (one
for error reporting and one for OAuth provider specific token
validation) and passing OAuth bearer token to backend. It also adds
support for client credentials flow of OAuth additional to device code
flow which Jacob has proposed.
The changes for each component are summarized below.
1. Provider-specific extension:
Each OAuth provider implements their own token validator as an
extension. Extension registers an OAuth provider hook which is matched
to a line in the HBA file.
2. Add support to pass on the OAuth bearer token. In this
obtaining the bearer token is left to 3rd party application or user.
./psql -U <username> -d 'dbname=postgres
oauth_client_id=<client_id> oauth_bearer_token=<token>
3. HBA: An additional param ‘provider’ is added for the oauth method.
Defining "oauth" as method + passing provider, issuer endpoint
and expected audience
* * * * oauth provider=<token validation extension>
issuer=.... scope=....
4. Engine Backend:
Support for generic OAUTHBEARER type, requesting client to
provide token and passing to token for provider-specific extension.
5. Engine Frontend: Two-tiered approach.
a) libpq transparently passes on the token received
from 3rd party client as is to the backend.
b) libpq optionally compiled for the clients which
explicitly need libpq to orchestrate OAuth communication with the
issuer (it depends heavily on 3rd party library iddawc as Jacob
already pointed out. The library seems to be supporting all the OAuth
flows.)
Please let us know your thoughts as the proposed method supports
different OAuth flows with the use of provider specific hooks. We
think that the proposal would be useful for various OAuth providers.
Thanks,
Mahendrakar.
Show quoted text
On Tue, 20 Sept 2022 at 10:18, Jacob Champion <pchampion@vmware.com> wrote:
On Tue, 2021-06-22 at 23:22 +0000, Jacob Champion wrote:
On Fri, 2021-06-18 at 11:31 +0300, Heikki Linnakangas wrote:
A few small things caught my eye in the backend oauth_exchange function:
+ /* Handle the client's initial message. */ + p = strdup(input);this strdup() should be pstrdup().
Thanks, I'll fix that in the next re-roll.
In the same function, there are a bunch of reports like this:
ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Comma expected, but found character \"%s\".", + sanitize_char(*p))));I don't think the double quotes are needed here, because sanitize_char
will return quotes if it's a single character. So it would end up
looking like this: ... found character "'x'".I'll fix this too. Thanks!
v2, attached, incorporates Heikki's suggested fixes and also rebases on
top of latest HEAD, which had the SASL refactoring changes committed
last month.The biggest change from the last patchset is 0001, an attempt at
enabling jsonapi in the frontend without the use of palloc(), based on
suggestions by Michael and Tom from last commitfest. I've also made
some improvements to the pytest suite. No major changes to the OAuth
implementation yet.--Jacob
Attachments:
v1-0001-oauth-provider-support.patchapplication/x-patch; name=v1-0001-oauth-provider-support.patchDownload
diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c
index c47211132c..86f820482b 100644
--- a/src/backend/libpq/auth-oauth.c
+++ b/src/backend/libpq/auth-oauth.c
@@ -24,7 +24,9 @@
#include "libpq/hba.h"
#include "libpq/oauth.h"
#include "libpq/sasl.h"
+#include "miscadmin.h"
#include "storage/fd.h"
+#include "utils/memutils.h"
/* GUC */
char *oauth_validator_command;
@@ -34,6 +36,13 @@ static void *oauth_init(Port *port, const char *selected_mech, const char *shado
static int oauth_exchange(void *opaq, const char *input, int inputlen,
char **output, int *outputlen, char **logdetail);
+/*----------------------------------------------------------------
+ * OAuth Authentication
+ *----------------------------------------------------------------
+ */
+static List *oauth_providers = NIL;
+static OAuthProvider* oauth_provider = NULL;
+
/* Mechanism declaration */
const pg_be_sasl_mech pg_be_oauth_mech = {
oauth_get_mechanisms,
@@ -63,15 +72,90 @@ static char *sanitize_char(char c);
static char *parse_kvpairs_for_auth(char **input);
static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
static bool validate(Port *port, const char *auth, char **logdetail);
-static bool run_validator_command(Port *port, const char *token);
+static const char* run_validator_command(Port *port, const char *token);
static bool check_exit(FILE **fh, const char *command);
static bool unset_cloexec(int fd);
-static bool username_ok_for_shell(const char *username);
#define KVSEP 0x01
#define AUTH_KEY "auth"
#define BEARER_SCHEME "Bearer "
+/*----------------------------------------------------------------
+ * OAuth Token Validator
+ *----------------------------------------------------------------
+ */
+
+/*
+ * RegisterOAuthProvider registers a OAuth Token Validator to be
+ * used for oauth token validation. It validates the token and adds the valiator
+ * name and it's hooks to a list of loaded token validator. The right validator's
+ * hooks can then be called based on the validator name specified in
+ * pg_hba.conf.
+ *
+ * This function should be called in _PG_init() by any extension looking to
+ * add a custom authentication method.
+ */
+void
+RegisterOAuthProvider(
+ const char *provider_name,
+ OAuthProviderCheck_hook_type OAuthProviderCheck_hook,
+ OAuthProviderError_hook_type OAuthProviderError_hook
+)
+{
+ if (!process_shared_preload_libraries_in_progress)
+ {
+ ereport(ERROR,
+ (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
+ errmsg("RegisterOAuthProvider can only be called by a shared_preload_library")));
+ return;
+ }
+
+ MemoryContext oldcxt;
+ if (oauth_provider == NULL)
+ {
+ oldcxt = MemoryContextSwitchTo(TopMemoryContext);
+ oauth_provider = palloc(sizeof(OAuthProvider));
+ oauth_provider->name = pstrdup(provider_name);
+ oauth_provider->oauth_provider_hook = OAuthProviderCheck_hook;
+ oauth_provider->oauth_error_hook = OAuthProviderError_hook;
+ oauth_providers = lappend(oauth_providers, oauth_provider);
+ MemoryContextSwitchTo(oldcxt);
+ }
+ else
+ {
+ if (oauth_provider && oauth_provider->name)
+ {
+ ereport(ERROR,
+ (errmsg("OAuth provider \"%s\" is already loaded.",
+ oauth_provider->name)));
+ }
+ else
+ {
+ ereport(ERROR,
+ (errmsg("OAuth provider is already loaded.")));
+ }
+ }
+}
+
+/*
+ * Returns the oauth provider (which includes it's
+ * callback functions) based on name specified.
+ */
+OAuthProvider *get_provider_by_name(const char *name)
+{
+ ListCell *lc;
+ foreach(lc, oauth_providers)
+ {
+ OAuthProvider *provider = (OAuthProvider *) lfirst(lc);
+ if (strcmp(provider->name, name) == 0)
+ {
+ return provider;
+ }
+ }
+
+ return NULL;
+}
+
static void
oauth_get_mechanisms(Port *port, StringInfo buf)
{
@@ -494,17 +578,17 @@ validate(Port *port, const char *auth, char **logdetail)
}
/* Have the validator check the token. */
- if (!run_validator_command(port, token))
+ if (run_validator_command(port, token) == NULL)
return false;
-
+
if (port->hba->oauth_skip_usermap)
{
/*
- * If the validator is our authorization authority, we're done.
- * Authentication may or may not have been performed depending on the
- * validator implementation; all that matters is that the validator says
- * the user can log in with the target role.
- */
+ * If the validator is our authorization authority, we're done.
+ * Authentication may or may not have been performed depending on the
+ * validator implementation; all that matters is that the validator says
+ * the user can log in with the target role.
+ */
return true;
}
@@ -524,193 +608,26 @@ validate(Port *port, const char *auth, char **logdetail)
return (ret == STATUS_OK);
}
-static bool
+static const char*
run_validator_command(Port *port, const char *token)
{
- bool success = false;
- int rc;
- int pipefd[2];
- int rfd = -1;
- int wfd = -1;
-
- StringInfoData command = { 0 };
- char *p;
- FILE *fh = NULL;
-
- ssize_t written;
- char *line = NULL;
- size_t size = 0;
- ssize_t len;
-
- Assert(oauth_validator_command);
-
- if (!oauth_validator_command[0])
- {
- ereport(COMMERROR,
- (errmsg("oauth_validator_command is not set"),
- errhint("To allow OAuth authenticated connections, set "
- "oauth_validator_command in postgresql.conf.")));
- return false;
- }
-
- /*
- * Since popen() is unidirectional, open up a pipe for the other direction.
- * Use CLOEXEC to ensure that our write end doesn't accidentally get copied
- * into child processes, which would prevent us from closing it cleanly.
- *
- * XXX this is ugly. We should just read from the child process's stdout,
- * but that's a lot more code.
- * XXX by bypassing the popen API, we open the potential of process
- * deadlock. Clearly document child process requirements (i.e. the child
- * MUST read all data off of the pipe before writing anything).
- * TODO: port to Windows using _pipe().
- */
- rc = pipe2(pipefd, O_CLOEXEC);
- if (rc < 0)
+ if(oauth_provider->oauth_provider_hook == NULL)
{
- ereport(COMMERROR,
- (errcode_for_file_access(),
- errmsg("could not create child pipe: %m")));
return false;
}
- rfd = pipefd[0];
- wfd = pipefd[1];
-
- /* Allow the read pipe be passed to the child. */
- if (!unset_cloexec(rfd))
+ char *id = oauth_provider->
+ oauth_provider_hook(port, token);
+ if(id == NULL)
{
- /* error message was already logged */
- goto cleanup;
- }
-
- /*
- * Construct the command, substituting any recognized %-specifiers:
- *
- * %f: the file descriptor of the input pipe
- * %r: the role that the client wants to assume (port->user_name)
- * %%: a literal '%'
- */
- initStringInfo(&command);
-
- for (p = oauth_validator_command; *p; p++)
- {
- if (p[0] == '%')
- {
- switch (p[1])
- {
- case 'f':
- appendStringInfo(&command, "%d", rfd);
- p++;
- break;
- case 'r':
- /*
- * TODO: decide how this string should be escaped. The role
- * is controlled by the client, so if we don't escape it,
- * command injections are inevitable.
- *
- * This is probably an indication that the role name needs
- * to be communicated to the validator process in some other
- * way. For this proof of concept, just be incredibly strict
- * about the characters that are allowed in user names.
- */
- if (!username_ok_for_shell(port->user_name))
- goto cleanup;
-
- appendStringInfoString(&command, port->user_name);
- p++;
- break;
- case '%':
- appendStringInfoChar(&command, '%');
- p++;
- break;
- default:
- appendStringInfoChar(&command, p[0]);
- }
- }
- else
- appendStringInfoChar(&command, p[0]);
- }
-
- /* Execute the command. */
- fh = OpenPipeStream(command.data, "re");
- /* TODO: handle failures */
-
- /* We don't need the read end of the pipe anymore. */
- close(rfd);
- rfd = -1;
-
- /* Give the command the token to validate. */
- written = write(wfd, token, strlen(token));
- if (written != strlen(token))
- {
- /* TODO must loop for short writes, EINTR et al */
- ereport(COMMERROR,
- (errcode_for_file_access(),
- errmsg("could not write token to child pipe: %m")));
- goto cleanup;
- }
-
- close(wfd);
- wfd = -1;
-
- /*
- * Read the command's response.
- *
- * TODO: getline() is probably too new to use, unfortunately.
- * TODO: loop over all lines
- */
- if ((len = getline(&line, &size, fh)) >= 0)
- {
- /* TODO: fail if the authn_id doesn't end with a newline */
- if (len > 0)
- line[len - 1] = '\0';
-
- set_authn_id(port, line);
- }
- else if (ferror(fh))
- {
- ereport(COMMERROR,
- (errcode_for_file_access(),
- errmsg("could not read from command \"%s\": %m",
- command.data)));
- goto cleanup;
- }
-
- /* Make sure the command exits cleanly. */
- if (!check_exit(&fh, command.data))
- {
- /* error message already logged */
- goto cleanup;
- }
-
- /* Done. */
- success = true;
-
-cleanup:
- if (line)
- free(line);
-
- /*
- * In the successful case, the pipe fds are already closed. For the error
- * case, always close out the pipe before waiting for the command, to
- * prevent deadlock.
- */
- if (rfd >= 0)
- close(rfd);
- if (wfd >= 0)
- close(wfd);
-
- if (fh)
- {
- Assert(!success);
- check_exit(&fh, command.data);
+ ereport(LOG,
+ (errmsg("OAuth bearer token validation failed" )));
+ return NULL;
}
- if (command.data)
- pfree(command.data);
-
- return success;
+ set_authn_id(port, id);
+
+ return id;
}
static bool
@@ -769,29 +686,3 @@ unset_cloexec(int fd)
return true;
}
-
-/*
- * XXX This should go away eventually and be replaced with either a proper
- * escape or a different strategy for communication with the validator command.
- */
-static bool
-username_ok_for_shell(const char *username)
-{
- /* This set is borrowed from fe_utils' appendShellStringNoError(). */
- static const char * const allowed = "abcdefghijklmnopqrstuvwxyz"
- "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- "0123456789-_./:";
- size_t span;
-
- Assert(username && username[0]); /* should have already been checked */
-
- span = strspn(username, allowed);
- if (username[span] != '\0')
- {
- ereport(COMMERROR,
- (errmsg("PostgreSQL user name contains unsafe characters and cannot be passed to the OAuth validator")));
- return false;
- }
-
- return true;
-}
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 333051ad3c..0bbcf231d2 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -296,8 +296,14 @@ auth_failed(Port *port, int status, const char *logdetail)
errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
break;
case uaOAuth:
- errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
- break;
+ {
+ OAuthProvider *provider = get_provider_by_name(port->hba->oauth_provider);
+ if(provider->oauth_error_hook)
+ errstr = provider->oauth_error_hook(port);
+ else
+ errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\"");
+ break;
+ }
default:
errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method");
break;
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 943e78ddff..94fb5d434d 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -1663,6 +1663,14 @@ parse_hba_line(TokenizedAuthLine *tok_line, int elevel)
parsedline->clientcert = clientCertFull;
}
+ /*
+ * Ensure that the token validation provider name is specified as provider for oauth method.
+ */
+ if (parsedline->auth_method == uaOAuth)
+ {
+ MANDATORY_AUTH_ARG(parsedline->oauth_provider, "provider", "oauth");
+ }
+
return parsedline;
}
@@ -2095,6 +2103,31 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
else
hbaline->oauth_skip_usermap = false;
}
+ else if (strcmp(name, "provider") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaOAuth, "provider", "oauth");
+ if (hbaline->auth_method != uaOAuth)
+ INVALID_AUTH_OPTION("provider", gettext_noop("oauth"));
+ /*
+ * Verify that the token validation mentioned is loaded via shared_preload_libraries.
+ */
+ if (get_provider_by_name(val) == NULL)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("cannot use oauth provider %s",val),
+ errhint("Load provider token validation via shared_preload_libraries."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("cannot use oauth provider %s", val);
+
+ return false;
+ }
+ else
+ {
+ hbaline->oauth_provider = pstrdup(val);
+ }
+ }
else
{
ereport(elevel,
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 485e48970e..938ac399dc 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -44,4 +44,29 @@ extern void set_authn_id(Port *port, const char *id);
typedef void (*ClientAuthentication_hook_type) (Port *, int);
extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook;
+/* Declarations for oAuth authentication providers */
+typedef const char* (*OAuthProviderCheck_hook_type) (Port *, const char*);
+
+/* Hook for plugins to report error messages in validation_failed() */
+typedef const char * (*OAuthProviderError_hook_type) (Port *);
+
+/* Hook for plugins to validate oauth provider options */
+typedef bool (*OAuthProviderValidateOptions_hook_type)
+ (char *, char *, HbaLine *, char **);
+
+typedef struct OAuthProvider
+{
+ const char *name;
+ OAuthProviderCheck_hook_type oauth_provider_hook;
+ OAuthProviderError_hook_type oauth_error_hook;
+} OAuthProvider;
+
+extern void RegisterOAuthProvider
+ (const char *provider_name,
+ OAuthProviderCheck_hook_type OAuthProviderCheck_hook,
+ OAuthProviderError_hook_type OAuthProviderError_hook
+ );
+
+extern OAuthProvider *get_provider_by_name(const char *name);
+
#endif /* AUTH_H */
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index c1b1313989..d65395cc22 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -123,6 +123,7 @@ typedef struct HbaLine
char *radiusports_s;
char *oauth_issuer;
char *oauth_scope;
+ char *oauth_provider;
bool oauth_skip_usermap;
} HbaLine;
diff --git a/src/interfaces/libpq/fe-auth-oauth.c b/src/interfaces/libpq/fe-auth-oauth.c
index 91d2c69f16..61a0b80b7e 100644
--- a/src/interfaces/libpq/fe-auth-oauth.c
+++ b/src/interfaces/libpq/fe-auth-oauth.c
@@ -174,6 +174,16 @@ get_auth_token(PGconn *conn)
if (!token_buf)
goto cleanup;
+ if(conn->oauth_bearer_token)
+ {
+ appendPQExpBufferStr(token_buf, "Bearer ");
+ appendPQExpBufferStr(token_buf, conn->oauth_bearer_token);
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+ token = strdup(token_buf->data);
+ goto cleanup;
+ }
+
err = i_set_str_parameter(&session, I_OPT_OPENID_CONFIG_ENDPOINT, conn->oauth_discovery_uri);
if (err)
{
@@ -201,18 +211,22 @@ get_auth_token(PGconn *conn)
libpq_gettext("issuer does not support device authorization"));
goto cleanup;
}
+
+ //default device flow
+ int session_response_type = I_RESPONSE_TYPE_DEVICE_CODE;
+ auth_method = I_TOKEN_AUTH_METHOD_NONE;
+ if (conn->oauth_client_secret && *conn->oauth_client_secret)
+ {
+ auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
+ }
- err = i_set_response_type(&session, I_RESPONSE_TYPE_DEVICE_CODE);
+ err = i_set_response_type(&session, session_response_type);
if (err)
{
iddawc_error(conn, err, "failed to set device code response type");
goto cleanup;
}
- auth_method = I_TOKEN_AUTH_METHOD_NONE;
- if (conn->oauth_client_secret && *conn->oauth_client_secret)
- auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
-
err = i_set_parameter_list(&session,
I_OPT_CLIENT_ID, conn->oauth_client_id,
I_OPT_CLIENT_SECRET, conn->oauth_client_secret,
@@ -250,6 +264,18 @@ get_auth_token(PGconn *conn)
goto cleanup;
}
+ if (conn->oauth_client_secret && *conn->oauth_client_secret)
+ {
+ session_response_type = I_RESPONSE_TYPE_CLIENT_CREDENTIALS;
+ }
+
+ err = i_set_response_type(&session, session_response_type);
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set session response type");
+ goto cleanup;
+ }
+
/*
* Poll the token endpoint until either the user logs in and authorizes the
* use of a token, or a hard failure occurs. We perform one ping _before_
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 2ff450ce05..5d804c8c0d 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -361,6 +361,10 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
"OAuth-Scope", "", 15,
offsetof(struct pg_conn, oauth_scope)},
+ {"oauth_bearer_token", NULL, NULL, NULL,
+ "OAuth-Bearer", "", 20,
+ offsetof(struct pg_conn, oauth_bearer_token)},
+
/* Terminating entry --- MUST BE LAST */
{NULL, NULL, NULL, NULL,
NULL, NULL, 0}
@@ -4200,6 +4204,8 @@ freePGconn(PGconn *conn)
free(conn->oauth_discovery_uri);
if (conn->oauth_client_id)
free(conn->oauth_client_id);
+ if(conn->oauth_bearer_token)
+ free(conn->oauth_bearer_token);
if (conn->oauth_client_secret)
free(conn->oauth_client_secret);
if (conn->oauth_scope)
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 1b4de3dff0..91e71afe14 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -402,6 +402,7 @@ struct pg_conn
char *oauth_client_id; /* client identifier */
char *oauth_client_secret; /* client secret */
char *oauth_scope; /* access token scope */
+ char *oauth_bearer_token; /* oauth token */
bool oauth_want_retry; /* should we retry on failure? */
/* Optional file to write trace info to */
Hi Mahendrakar, thanks for your interest and for the patch!
On Mon, Sep 19, 2022 at 10:03 PM mahendrakar s
<mahendrakarforpg@gmail.com> wrote:
The changes for each component are summarized below.
1. Provider-specific extension:
Each OAuth provider implements their own token validator as an
extension. Extension registers an OAuth provider hook which is matched
to a line in the HBA file.
How easy is it to write a Bearer validator using C? My limited
understanding was that most providers were publishing libraries in
higher-level languages.
Along those lines, sample validators will need to be provided, both to
help in review and to get the pytest suite green again. (And coverage
for the new code is important, too.)
2. Add support to pass on the OAuth bearer token. In this
obtaining the bearer token is left to 3rd party application or user../psql -U <username> -d 'dbname=postgres
oauth_client_id=<client_id> oauth_bearer_token=<token>
This hurts, but I think people are definitely going to ask for it, given
the frightening practice of copy-pasting these (incredibly sensitive
secret) tokens all over the place... Ideally I'd like to implement
sender constraints for the Bearer token, to *prevent* copy-pasting (or,
you know, outright theft). But I'm not sure that sender constraints are
well-implemented yet for the major providers.
3. HBA: An additional param ‘provider’ is added for the oauth method.
Defining "oauth" as method + passing provider, issuer endpoint
and expected audience* * * * oauth provider=<token validation extension>
issuer=.... scope=....
Naming aside (this conflicts with Samay's previous proposal, I think), I
have concerns about the implementation. There's this code:
+ if (oauth_provider && oauth_provider->name) + { + ereport(ERROR, + (errmsg("OAuth provider \"%s\" is already loaded.", + oauth_provider->name))); + }
which appears to prevent loading more than one global provider. But
there's also code that deals with a provider list? (Again, it'd help to
have test code covering the new stuff.)
b) libpq optionally compiled for the clients which
explicitly need libpq to orchestrate OAuth communication with the
issuer (it depends heavily on 3rd party library iddawc as Jacob
already pointed out. The library seems to be supporting all the OAuth
flows.)
Speaking of iddawc, I don't think it's a dependency we should choose to
rely on. For all the code that it has, it doesn't seem to provide
compatibility with several real-world providers.
Google, for one, chose not to follow the IETF spec it helped author, and
iddawc doesn't support its flavor of Device Authorization. At another
point, I think iddawc tried to decode Azure's Bearer tokens, which is
incorrect...
I haven't been able to check if those problems have been fixed in a
recent version, but if we're going to tie ourselves to a huge
dependency, I'd at least like to believe that said dependency is
battle-tested and solid, and personally I don't feel like iddawc is.
- auth_method = I_TOKEN_AUTH_METHOD_NONE;
- if (conn->oauth_client_secret && *conn->oauth_client_secret)
- auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
This code got moved, but I'm not sure why? It doesn't appear to have
made a change to the logic.
+ if (conn->oauth_client_secret && *conn->oauth_client_secret) + { + session_response_type = I_RESPONSE_TYPE_CLIENT_CREDENTIALS; + }
Is this an Azure-specific requirement? Ideally a public client (which
psql is) shouldn't have to provide a secret to begin with, if I
understand that bit of the protocol correctly. I think Google also
required provider-specific changes in this part of the code, and
unfortunately I don't think they looked the same as yours.
We'll have to figure all that out... Standards are great; everyone has
one of their own. :)
Thanks,
--Jacob
On Tue, Sep 20, 2022 at 4:19 PM Jacob Champion <jchampion@timescale.com> wrote:
2. Add support to pass on the OAuth bearer token. In this
obtaining the bearer token is left to 3rd party application or user../psql -U <username> -d 'dbname=postgres
oauth_client_id=<client_id> oauth_bearer_token=<token>This hurts, but I think people are definitely going to ask for it, given
the frightening practice of copy-pasting these (incredibly sensitive
secret) tokens all over the place...
After some further thought -- in this case, you already have an opaque
Bearer token (and therefore you already know, out of band, which
provider needs to be used), you're willing to copy-paste it from
whatever service you got it from, and you have an extension plugged
into Postgres on the backend that verifies this Bearer blob using some
procedure that Postgres knows nothing about.
Why do you need the OAUTHBEARER mechanism logic at that point? Isn't
that identical to a custom password scheme? It seems like that could
be handled completely by Samay's pluggable auth proposal.
--Jacob
We can support both passing the token from an upstream client and libpq implementing OAUTH2 protocol to obtaining one.
Libpq implementing OAUTHBEARER is needed for community/3rd party tools to have user-friendly authentication experience:
1. For community client tools, like pg_admin, psql etc.
Example experience: pg_admin would be able to open a popup dialog to authenticate customer and keep refresh token to avoid asking the user frequently.
2. For 3rd party connectors supporting generic OAUTH with any provider. Useful for datawiz clients, like Tableau or ETL tools. Those can support both user and client OAUTH flows.
Libpq passing toked directly from an upstream client is useful in other scenarios:
1. Enterprise clients, built with .Net / Java and using provider-specific authentication libraries, like MSAL for AAD. Those can also support more advance provider-specific token acquisition flows.
2. Resource-tight (like IoT) clients. Those can be compiled without optional libpq flag not including the iddawc or other dependency.
Thanks!
Andrey.
-----Original Message-----
From: Jacob Champion <jchampion@timescale.com>
Sent: Wednesday, September 21, 2022 9:03 AM
To: mahendrakar s <mahendrakarforpg@gmail.com>
Cc: pgsql-hackers@postgresql.org; smilingsamay@gmail.com; andres@anarazel.de; Andrey Chudnovskiy <Andrey.Chudnovskiy@microsoft.com>; Mahendrakar Srinivasarao <mahendrakars@microsoft.com>
Subject: [EXTERNAL] Re: [PoC] Federated Authn/z with OAUTHBEARER
[You don't often get email from jchampion@timescale.com. Learn why this is important at https://aka.ms/LearnAboutSenderIdentification ]
On Tue, Sep 20, 2022 at 4:19 PM Jacob Champion <jchampion@timescale.com> wrote:
2. Add support to pass on the OAuth bearer token. In this
obtaining the bearer token is left to 3rd party application or user../psql -U <username> -d 'dbname=postgres
oauth_client_id=<client_id> oauth_bearer_token=<token>This hurts, but I think people are definitely going to ask for it,
given the frightening practice of copy-pasting these (incredibly
sensitive
secret) tokens all over the place...
After some further thought -- in this case, you already have an opaque Bearer token (and therefore you already know, out of band, which provider needs to be used), you're willing to copy-paste it from whatever service you got it from, and you have an extension plugged into Postgres on the backend that verifies this Bearer blob using some procedure that Postgres knows nothing about.
Why do you need the OAUTHBEARER mechanism logic at that point? Isn't that identical to a custom password scheme? It seems like that could be handled completely by Samay's pluggable auth proposal.
--Jacob
On Wed, Sep 21, 2022 at 3:10 PM Andrey Chudnovskiy
<Andrey.Chudnovskiy@microsoft.com> wrote:
We can support both passing the token from an upstream client and libpq implementing OAUTH2 protocol to obtaining one.
Right, I agree that we could potentially do both.
Libpq passing toked directly from an upstream client is useful in other scenarios:
1. Enterprise clients, built with .Net / Java and using provider-specific authentication libraries, like MSAL for AAD. Those can also support more advance provider-specific token acquisition flows.
2. Resource-tight (like IoT) clients. Those can be compiled without optional libpq flag not including the iddawc or other dependency.
What I don't understand is how the OAUTHBEARER mechanism helps you in
this case. You're short-circuiting the negotiation where the server
tells the client what provider to use and what scopes to request, and
instead you're saying "here's a secret string, just take it and
validate it with magic."
I realize the ability to pass an opaque token may be useful, but from
the server's perspective, I don't see what differentiates it from the
password auth method plus a custom authenticator plugin. Why pay for
the additional complexity of OAUTHBEARER if you're not going to use
it?
--Jacob
First, My message from corp email wasn't displayed in the thread,
That is what Jacob replied to, let me post it here for context:
We can support both passing the token from an upstream client and libpq implementing OAUTH2 protocol to obtain one.
Libpq implementing OAUTHBEARER is needed for community/3rd party tools to have user-friendly authentication experience:
1. For community client tools, like pg_admin, psql etc.
Example experience: pg_admin would be able to open a popup dialog to authenticate customers and keep refresh tokens to avoid asking the user frequently.
2. For 3rd party connectors supporting generic OAUTH with any provider. Useful for datawiz clients, like Tableau or ETL tools. Those can support both user and client OAUTH flows.Libpq passing toked directly from an upstream client is useful in other scenarios:
1. Enterprise clients, built with .Net / Java and using provider-specific authentication libraries, like MSAL for AAD. Those can also support more advanced provider-specific token acquisition flows.
2. Resource-tight (like IoT) clients. Those can be compiled without the optional libpq flag not including the iddawc or other dependency.
-----------------------------------------------------------------------------------------------------
On this:
What I don't understand is how the OAUTHBEARER mechanism helps you in
this case. You're short-circuiting the negotiation where the server
tells the client what provider to use and what scopes to request, and
instead you're saying "here's a secret string, just take it and
validate it with magic."I realize the ability to pass an opaque token may be useful, but from
the server's perspective, I don't see what differentiates it from the
password auth method plus a custom authenticator plugin. Why pay for
the additional complexity of OAUTHBEARER if you're not going to use
it?
Yes, passing a token as a new auth method won't make much sense in
isolation. However:
1. Since OAUTHBEARER is supported in the ecosystem, passing a token as
a way to authenticate with OAUTHBEARER is more consistent (IMO), then
passing it as a password.
2. Validation on the backend side doesn't depend on whether the token
is obtained by libpq or transparently passed by the upstream client.
3. Single OAUTH auth method on the server side for both scenarios,
would allow both enterprise clients with their own Token acquisition
and community clients using libpq flows to connect as the same PG
users/roles.
Show quoted text
On Wed, Sep 21, 2022 at 8:36 PM Jacob Champion <jchampion@timescale.com> wrote:
On Wed, Sep 21, 2022 at 3:10 PM Andrey Chudnovskiy
<Andrey.Chudnovskiy@microsoft.com> wrote:We can support both passing the token from an upstream client and libpq implementing OAUTH2 protocol to obtaining one.
Right, I agree that we could potentially do both.
Libpq passing toked directly from an upstream client is useful in other scenarios:
1. Enterprise clients, built with .Net / Java and using provider-specific authentication libraries, like MSAL for AAD. Those can also support more advance provider-specific token acquisition flows.
2. Resource-tight (like IoT) clients. Those can be compiled without optional libpq flag not including the iddawc or other dependency.What I don't understand is how the OAUTHBEARER mechanism helps you in
this case. You're short-circuiting the negotiation where the server
tells the client what provider to use and what scopes to request, and
instead you're saying "here's a secret string, just take it and
validate it with magic."I realize the ability to pass an opaque token may be useful, but from
the server's perspective, I don't see what differentiates it from the
password auth method plus a custom authenticator plugin. Why pay for
the additional complexity of OAUTHBEARER if you're not going to use
it?--Jacob
On 9/21/22 21:55, Andrey Chudnovsky wrote:
First, My message from corp email wasn't displayed in the thread,
I see it on the public archives [1]/messages/by-id/MN0PR21MB31694BAC193ECE1807FD45358F4F9@MN0PR21MB3169.namprd21.prod.outlook.com. Your client is choosing some pretty
confusing quoting tactics, though, which you may want to adjust. :D
I have what I'll call some "skeptical curiosity" here -- you don't need
to defend your use cases to me by any means, but I'd love to understand
more about them.
Yes, passing a token as a new auth method won't make much sense in
isolation. However:
1. Since OAUTHBEARER is supported in the ecosystem, passing a token as
a way to authenticate with OAUTHBEARER is more consistent (IMO), then
passing it as a password.
Agreed. It's probably not a very strong argument for the new mechanism,
though, especially if you're not using the most expensive code inside it.
2. Validation on the backend side doesn't depend on whether the token
is obtained by libpq or transparently passed by the upstream client.
Sure.
3. Single OAUTH auth method on the server side for both scenarios,
would allow both enterprise clients with their own Token acquisition
and community clients using libpq flows to connect as the same PG
users/roles.
Okay, this is a stronger argument. With that in mind, I want to revisit
your examples and maybe provide some counterproposals:
Libpq passing toked directly from an upstream client is useful in other scenarios:
1. Enterprise clients, built with .Net / Java and using provider-specific authentication libraries, like MSAL for AAD. Those can also support more advanced provider-specific token acquisition flows.
I can see that providing a token directly would help you work around
limitations in libpq's "standard" OAuth flows, whether we use iddawc or
not. And it's cheap in terms of implementation. But I have a feeling it
would fall apart rapidly with error cases, where the server is giving
libpq information via the OAUTHBEARER mechanism, but libpq can only
communicate to your wrapper through human-readable error messages on stderr.
This seems like clear motivation for client-side SASL plugins (which
were also discussed on Samay's proposal thread). That's a lot more
expensive to implement in libpq, but if it were hypothetically
available, wouldn't you rather your provider-specific code be able to
speak OAUTHBEARER directly with the server?
2. Resource-tight (like IoT) clients. Those can be compiled without the optional libpq flag not including the iddawc or other dependency.
I want to dig into this much more; resource-constrained systems are near
and dear to me. I can see two cases here:
Case 1: The device is an IoT client that wants to connect on its own
behalf. Why would you want to use OAuth in that case? And how would the
IoT device get its Bearer token to begin with? I'm much more used to
architectures that provision high-entropy secrets for this, whether
they're incredibly long passwords per device (in which case,
channel-bound SCRAM should be a fairly strong choice?) or client certs
(which can be better decentralized, but make for a lot of bookkeeping).
If the answer to that is, "we want an IoT client to be able to connect
using the same role as a person", then I think that illustrates a clear
need for SASL negotiation. That would let the IoT client choose
SCRAM-*-PLUS or EXTERNAL, and the person at the keyboard can choose
OAUTHBEARER. Then we have incredible flexibility, because you don't have
to engineer one mechanism to handle them all.
Case 2: The constrained device is being used as a jump point. So there's
an actual person at a keyboard, trying to get into a backend server
(maybe behind a firewall layer, etc.), and the middlebox is either not
web-connected or is incredibly tiny for some reason. That might be a
good use case for a copy-pasted Bearer token, but is there actual demand
for that use case? What motivation would you (or your end user) have for
choosing a fairly heavy, web-centric authentication method in such a
constrained environment?
Are there other resource-constrained use cases I've missed?
Thanks,
--Jacob
[1]: /messages/by-id/MN0PR21MB31694BAC193ECE1807FD45358F4F9@MN0PR21MB3169.namprd21.prod.outlook.com
/messages/by-id/MN0PR21MB31694BAC193ECE1807FD45358F4F9@MN0PR21MB3169.namprd21.prod.outlook.com
On Fri, Mar 25, 2022 at 5:00 PM Jacob Champion <pchampion@vmware.com> wrote:
v4 rebases over the latest version of the pluggable auth patchset
(included as 0001-4). Note that there's a recent conflict as
of d4781d887; use an older commit as the base (or wait for the other
thread to be updated).
Here's a newly rebased v5. (They're all zipped now, which I probably
should have done a while back, sorry.)
- As before, 0001-4 are the pluggable auth set; they've now diverged
from the official version over on the other thread [1]/messages/by-id/CAJxrbyxgFzfqby+VRCkeAhJnwVZE50+ZLPx0JT2TDg9LbZtkCg@mail.gmail.com.
- I'm not sure that 0005 is still completely coherent after the
rebase, given the recent changes to jsonapi.c. But for now, the tests
are green, and that should be enough to keep the conversation going.
- 0008 will hopefully be obsoleted when the SYSTEM_USER proposal [2]/messages/by-id/7e692b8c-0b11-45db-1cad-3afc5b57409f@amazon.com lands.
Thanks,
--Jacob
[1]: /messages/by-id/CAJxrbyxgFzfqby+VRCkeAhJnwVZE50+ZLPx0JT2TDg9LbZtkCg@mail.gmail.com
[2]: /messages/by-id/7e692b8c-0b11-45db-1cad-3afc5b57409f@amazon.com
Attachments:
Libpq passing toked directly from an upstream client is useful in other scenarios:
1. Enterprise clients, built with .Net / Java and using provider-specific authentication libraries, like MSAL for AAD. Those can also support more advanced provider-specific token acquisition flows.
I can see that providing a token directly would help you work around
limitations in libpq's "standard" OAuth flows, whether we use iddawc or
not. And it's cheap in terms of implementation. But I have a feeling it
would fall apart rapidly with error cases, where the server is giving
libpq information via the OAUTHBEARER mechanism, but libpq can only
communicate to your wrapper through human-readable error messages on stderr.
For the providing token directly, that would be primarily used for
scenarios where the same party controls both the server and the client
side wrapper.
I.e. The client knows how to get a token for a particular principal
and doesn't need any additional information other than human readable
messages.
Please clarify the scenarios where you see this falling apart.
I can provide an example in the cloud world. We (Azure) as well as
other providers offer ways to obtain OAUTH tokens for
Service-to-Service communication at IAAS / PAAS level.
on Azure "Managed Identity" feature integrated in Compute VM allows a
client to make a local http call to get a token. VM itself manages the
certificate livecycle, as well as implements the corresponding OAUTH
flow.
This capability is used by both our 1st party PAAS offerings, as well
as 3rd party services deploying on VMs or managed K8S clusters.
Here, the client doesn't need libpq assistance in obtaining the token.
This seems like clear motivation for client-side SASL plugins (which
were also discussed on Samay's proposal thread). That's a lot more
expensive to implement in libpq, but if it were hypothetically
available, wouldn't you rather your provider-specific code be able to
speak OAUTHBEARER directly with the server?
I generally agree that pluggable auth layers in libpq could be
beneficial. However, as you pointed out in Samay's thread, that would
require a new distribution model for libpq / clients to optionally
include provider-specific logic.
My optimistic plan here would be to implement several core OAUTH flows
in libpq core which would be generic enough to support major
enterprise OAUTH providers:
1. Client Credentials flow (Client_id + Client_secret) for backend applications.
2. Authorization Code Flow with PKCE and/or Device code flow for GUI
applications.
(2.) above would require a protocol between libpq and upstream clients
to exchange several messages.
Your patch includes a way for libpq to deliver to the client a message
about the next authentication steps, so planned to build on top of
that.
A little about scenarios, we look at.
What we're trying to achieve here is an easy integration path for
multiple players in the ecosystem:
- Managed PaaS Postgres providers (both us and multi-cloud solutions)
- SaaS providers deploying postgres on IaaS/PaaS providers' clouds
- Tools - pg_admin, psql and other ones.
- BI, ETL, Federation and other scenarios where postgres is used as
the data source.
If we can offer a provider agnostic solution for Backend <=> libpq <=>
Upstreal client path, we can have all players above build support for
OAUTH credentials, managed by the cloud provider of their choice.
For us, that would mean:
- Better administrator experience with pg_admin / psql handling of the
AAD (Azure Active Directory) authentication flows.
- Path for integration solutions using Postgres to build AAD
authentication in their management experience.
- Ability to use AAD identity provider for any Postgres deployments
other than our 1st party PaaS offering.
- Ability to offer github as the identity provider for PaaS Postgres offering.
Other players in the ecosystem above would be able to get the same benefits.
Does that make sense and possible without provider specific libpq plugin?
-------------------------
On resource constrained scenarios.
I want to dig into this much more; resource-constrained systems are near
and dear to me. I can see two cases here:
I just referred to the ability to compile libpq without extra
dependencies to save some kilobytes.
Not sure if OAUTH is widely used in those cases. It involves overhead
anyway, and requires the device to talk to an additional party (OAUTH
provider).
Likely Cert authentication is easier.
If needed, it can get libpq with full OAUTH support and use a client
code. But I didn't think about this scenario.
Show quoted text
On Fri, Sep 23, 2022 at 3:39 PM Jacob Champion <jchampion@timescale.com> wrote:
On Fri, Mar 25, 2022 at 5:00 PM Jacob Champion <pchampion@vmware.com> wrote:
v4 rebases over the latest version of the pluggable auth patchset
(included as 0001-4). Note that there's a recent conflict as
of d4781d887; use an older commit as the base (or wait for the other
thread to be updated).Here's a newly rebased v5. (They're all zipped now, which I probably
should have done a while back, sorry.)- As before, 0001-4 are the pluggable auth set; they've now diverged
from the official version over on the other thread [1].
- I'm not sure that 0005 is still completely coherent after the
rebase, given the recent changes to jsonapi.c. But for now, the tests
are green, and that should be enough to keep the conversation going.
- 0008 will hopefully be obsoleted when the SYSTEM_USER proposal [2] lands.Thanks,
--Jacob[1] /messages/by-id/CAJxrbyxgFzfqby+VRCkeAhJnwVZE50+ZLPx0JT2TDg9LbZtkCg@mail.gmail.com
[2] /messages/by-id/7e692b8c-0b11-45db-1cad-3afc5b57409f@amazon.com
On Mon, Sep 26, 2022 at 6:39 PM Andrey Chudnovsky
<achudnovskij@gmail.com> wrote:
For the providing token directly, that would be primarily used for
scenarios where the same party controls both the server and the client
side wrapper.
I.e. The client knows how to get a token for a particular principal
and doesn't need any additional information other than human readable
messages.
Please clarify the scenarios where you see this falling apart.
The most concrete example I can see is with the OAUTHBEARER error
response. If you want to eventually handle differing scopes per role,
or different error statuses (which the proof-of-concept currently
hardcodes as `invalid_token`), then the client can't assume it knows
what the server is going to say there. I think that's true even if you
control both sides and are hardcoding the provider.
How should we communicate those pieces to a custom client when it's
passing a token directly? The easiest way I can see is for the custom
client to speak the OAUTHBEARER protocol directly (e.g. SASL plugin).
If you had to parse the libpq error message, I don't think that'd be
particularly maintainable.
I can provide an example in the cloud world. We (Azure) as well as
other providers offer ways to obtain OAUTH tokens for
Service-to-Service communication at IAAS / PAAS level.
on Azure "Managed Identity" feature integrated in Compute VM allows a
client to make a local http call to get a token. VM itself manages the
certificate livecycle, as well as implements the corresponding OAUTH
flow.
This capability is used by both our 1st party PAAS offerings, as well
as 3rd party services deploying on VMs or managed K8S clusters.
Here, the client doesn't need libpq assistance in obtaining the token.
Cool. To me that's the strongest argument yet for directly providing
tokens to libpq.
My optimistic plan here would be to implement several core OAUTH flows
in libpq core which would be generic enough to support major
enterprise OAUTH providers:
1. Client Credentials flow (Client_id + Client_secret) for backend applications.
2. Authorization Code Flow with PKCE and/or Device code flow for GUI
applications.
As long as it's clear to DBAs when to use which flow (because existing
documentation for that is hit-and-miss), I think it's reasonable to
eventually support multiple flows. Personally my preference would be
to start with one or two core flows, and expand outward once we're
sure that we do those perfectly. Otherwise the explosion of knobs and
buttons might be overwhelming, both to users and devs.
Related to the question of flows is the client implementation library.
I've mentioned that I don't think iddawc is production-ready. As far
as I'm aware, there is only one certified OpenID relying party written
in C, and that's... an Apache server plugin. That leaves us either
choosing an untested library, scouring the web for a "tested" library
(and hoping we're right in our assessment), or implementing our own
(which is going to tamp down enthusiasm for supporting many flows,
though that has its own set of benefits). If you know of any reliable
implementations with a C API, please let me know.
(2.) above would require a protocol between libpq and upstream clients
to exchange several messages.
Your patch includes a way for libpq to deliver to the client a message
about the next authentication steps, so planned to build on top of
that.
Specifically it delivers that message to an end user. If you want a
generic machine client to be able to use that, then we'll need to talk
about how.
A little about scenarios, we look at.
What we're trying to achieve here is an easy integration path for
multiple players in the ecosystem:
- Managed PaaS Postgres providers (both us and multi-cloud solutions)
- SaaS providers deploying postgres on IaaS/PaaS providers' clouds
- Tools - pg_admin, psql and other ones.
- BI, ETL, Federation and other scenarios where postgres is used as
the data source.If we can offer a provider agnostic solution for Backend <=> libpq <=>
Upstreal client path, we can have all players above build support for
OAUTH credentials, managed by the cloud provider of their choice.
Well... I don't quite understand why we'd go to the trouble of
providing a provider-agnostic communication solution only to have
everyone write their own provider-specific client support. Unless
you're saying Microsoft would provide an officially blessed plugin for
the *server* side only, and Google would provide one of their own, and
so on.
The server side authorization is the only place where I think it makes
sense to specialize by default. libpq should remain agnostic, with the
understanding that we'll need to make hard decisions when a major
provider decides not to follow a spec.
For us, that would mean:
- Better administrator experience with pg_admin / psql handling of the
AAD (Azure Active Directory) authentication flows.
- Path for integration solutions using Postgres to build AAD
authentication in their management experience.
- Ability to use AAD identity provider for any Postgres deployments
other than our 1st party PaaS offering.
- Ability to offer github as the identity provider for PaaS Postgres offering.
GitHub is unfortunately a bit tricky, unless they've started
supporting OpenID recently?
Other players in the ecosystem above would be able to get the same benefits.
Does that make sense and possible without provider specific libpq plugin?
If the players involved implement the flows and follow the specs, yes.
That's a big "if", unfortunately. I think GitHub and Google are two
major players who are currently doing things their own way.
I just referred to the ability to compile libpq without extra
dependencies to save some kilobytes.
Not sure if OAUTH is widely used in those cases. It involves overhead
anyway, and requires the device to talk to an additional party (OAUTH
provider).
Likely Cert authentication is easier.
If needed, it can get libpq with full OAUTH support and use a client
code. But I didn't think about this scenario.
Makes sense. Thanks!
--Jacob
The most concrete example I can see is with the OAUTHBEARER error
response. If you want to eventually handle differing scopes per role,
or different error statuses (which the proof-of-concept currently
hardcodes as `invalid_token`), then the client can't assume it knows
what the server is going to say there. I think that's true even if you
control both sides and are hardcoding the provider.
Ok, I see the point. It's related to the topic of communication
between libpq and the upstream client.
How should we communicate those pieces to a custom client when it's
passing a token directly? The easiest way I can see is for the custom
client to speak the OAUTHBEARER protocol directly (e.g. SASL plugin).
If you had to parse the libpq error message, I don't think that'd be
particularly maintainable.
I agree that parsing the message is not a sustainable way.
Could you provide more details on the SASL plugin approach you propose?
Specifically, is this basically a set of extension hooks for the client
side?
With the need for the client to be compiled with the plugins based on
the set of providers it needs.
Well... I don't quite understand why we'd go to the trouble of
providing a provider-agnostic communication solution only to have
everyone write their own provider-specific client support. Unless
you're saying Microsoft would provide an officially blessed plugin for
the *server* side only, and Google would provide one of their own, and
so on.
Yes, via extensions. Identity providers can open source extensions to
use their auth services outside of first party PaaS offerings.
For 3rd party Postgres PaaS or on premise deployments.
The server side authorization is the only place where I think it makes
sense to specialize by default. libpq should remain agnostic, with the
understanding that we'll need to make hard decisions when a major
provider decides not to follow a spec.
Completely agree with agnostic libpq. Though needs validation with
several major providers to know if this is possible.
Specifically it delivers that message to an end user. If you want a
generic machine client to be able to use that, then we'll need to talk
about how.
Yes, that's what needs to be decided.
In both Device code and Authorization code scenarios, libpq and the
client would need to exchange a couple of pieces of metadata.
Plus, after success, the client should be able to access a refresh token
for further use.
Can we implement a generic protocol like for this between libpq and the
clients?
On Fri, Sep 30, 2022 at 7:47 AM Andrey Chudnovsky
<achudnovskij@gmail.com> wrote:
How should we communicate those pieces to a custom client when it's
passing a token directly? The easiest way I can see is for the custom
client to speak the OAUTHBEARER protocol directly (e.g. SASL plugin).
If you had to parse the libpq error message, I don't think that'd be
particularly maintainable.I agree that parsing the message is not a sustainable way.
Could you provide more details on the SASL plugin approach you propose?Specifically, is this basically a set of extension hooks for the client side?
With the need for the client to be compiled with the plugins based on
the set of providers it needs.
That's a good question. I can see two broad approaches, with maybe
some ability to combine them into a hybrid:
1. If there turns out to be serious interest in having libpq itself
handle OAuth natively (with all of the web-facing code that implies,
and all of the questions still left to answer), then we might be able
to provide a "token hook" in the same way that we currently provide a
passphrase hook for OpenSSL keys. By default, libpq would use its
internal machinery to take the provider details, navigate its builtin
flow, and return the Bearer token. If you wanted to override that
behavior as a client, you could replace the builtin flow with your
own, by registering a set of callbacks.
2. Alternatively, OAuth support could be provided via a mechanism
plugin for some third-party SASL library (GNU libgsasl, Cyrus
libsasl2). We could provide an OAuth plugin in contrib that handles
the default flow. Other providers could publish their alternative
plugins to completely replace the OAUTHBEARER mechanism handling.
Approach (2) would make for some duplicated effort since every
provider has to write code to speak the OAUTHBEARER protocol. It might
simplify provider-specific distribution, since (at least for Cyrus) I
think you could build a single plugin that supports both the client
and server side. But it would be a lot easier to unknowingly (or
knowingly) break the spec, since you'd control both the client and
server sides. There would be less incentive to interoperate.
Finally, we could potentially take pieces from both, by having an
official OAuth mechanism plugin that provides a client-side hook to
override the flow. I have no idea if the benefits would offset the
costs of a plugin-for-a-plugin style architecture. And providers would
still be free to ignore it and just provide a full mechanism plugin
anyway.
Well... I don't quite understand why we'd go to the trouble of
providing a provider-agnostic communication solution only to have
everyone write their own provider-specific client support. Unless
you're saying Microsoft would provide an officially blessed plugin for
the *server* side only, and Google would provide one of their own, and
so on.Yes, via extensions. Identity providers can open source extensions to
use their auth services outside of first party PaaS offerings.
For 3rd party Postgres PaaS or on premise deployments.
Sounds reasonable.
The server side authorization is the only place where I think it makes
sense to specialize by default. libpq should remain agnostic, with the
understanding that we'll need to make hard decisions when a major
provider decides not to follow a spec.Completely agree with agnostic libpq. Though needs validation with
several major providers to know if this is possible.
Agreed.
Specifically it delivers that message to an end user. If you want a
generic machine client to be able to use that, then we'll need to talk
about how.Yes, that's what needs to be decided.
In both Device code and Authorization code scenarios, libpq and the
client would need to exchange a couple of pieces of metadata.
Plus, after success, the client should be able to access a refresh token for further use.Can we implement a generic protocol like for this between libpq and the clients?
I think we can probably prototype a callback hook for approach (1)
pretty quickly. (2) is a lot more work and investigation, but it's
work that I'm interested in doing (when I get the time). I think there
are other very good reasons to consider a third-party SASL library,
and some good lessons to be learned, even if the community decides not
to go down that road.
Thanks,
--Jacob
I think we can probably prototype a callback hook for approach (1)
pretty quickly. (2) is a lot more work and investigation, but it's
work that I'm interested in doing (when I get the time). I think there
are other very good reasons to consider a third-party SASL library,
and some good lessons to be learned, even if the community decides not
to go down that road.
Makes sense. We will work on (1.) and do some check if there are any
blockers for a shared solution to support github and google.
Show quoted text
On Fri, Sep 30, 2022 at 1:45 PM Jacob Champion <jchampion@timescale.com> wrote:
On Fri, Sep 30, 2022 at 7:47 AM Andrey Chudnovsky
<achudnovskij@gmail.com> wrote:How should we communicate those pieces to a custom client when it's
passing a token directly? The easiest way I can see is for the custom
client to speak the OAUTHBEARER protocol directly (e.g. SASL plugin).
If you had to parse the libpq error message, I don't think that'd be
particularly maintainable.I agree that parsing the message is not a sustainable way.
Could you provide more details on the SASL plugin approach you propose?Specifically, is this basically a set of extension hooks for the client side?
With the need for the client to be compiled with the plugins based on
the set of providers it needs.That's a good question. I can see two broad approaches, with maybe
some ability to combine them into a hybrid:1. If there turns out to be serious interest in having libpq itself
handle OAuth natively (with all of the web-facing code that implies,
and all of the questions still left to answer), then we might be able
to provide a "token hook" in the same way that we currently provide a
passphrase hook for OpenSSL keys. By default, libpq would use its
internal machinery to take the provider details, navigate its builtin
flow, and return the Bearer token. If you wanted to override that
behavior as a client, you could replace the builtin flow with your
own, by registering a set of callbacks.2. Alternatively, OAuth support could be provided via a mechanism
plugin for some third-party SASL library (GNU libgsasl, Cyrus
libsasl2). We could provide an OAuth plugin in contrib that handles
the default flow. Other providers could publish their alternative
plugins to completely replace the OAUTHBEARER mechanism handling.Approach (2) would make for some duplicated effort since every
provider has to write code to speak the OAUTHBEARER protocol. It might
simplify provider-specific distribution, since (at least for Cyrus) I
think you could build a single plugin that supports both the client
and server side. But it would be a lot easier to unknowingly (or
knowingly) break the spec, since you'd control both the client and
server sides. There would be less incentive to interoperate.Finally, we could potentially take pieces from both, by having an
official OAuth mechanism plugin that provides a client-side hook to
override the flow. I have no idea if the benefits would offset the
costs of a plugin-for-a-plugin style architecture. And providers would
still be free to ignore it and just provide a full mechanism plugin
anyway.Well... I don't quite understand why we'd go to the trouble of
providing a provider-agnostic communication solution only to have
everyone write their own provider-specific client support. Unless
you're saying Microsoft would provide an officially blessed plugin for
the *server* side only, and Google would provide one of their own, and
so on.Yes, via extensions. Identity providers can open source extensions to
use their auth services outside of first party PaaS offerings.
For 3rd party Postgres PaaS or on premise deployments.Sounds reasonable.
The server side authorization is the only place where I think it makes
sense to specialize by default. libpq should remain agnostic, with the
understanding that we'll need to make hard decisions when a major
provider decides not to follow a spec.Completely agree with agnostic libpq. Though needs validation with
several major providers to know if this is possible.Agreed.
Specifically it delivers that message to an end user. If you want a
generic machine client to be able to use that, then we'll need to talk
about how.Yes, that's what needs to be decided.
In both Device code and Authorization code scenarios, libpq and the
client would need to exchange a couple of pieces of metadata.
Plus, after success, the client should be able to access a refresh token for further use.Can we implement a generic protocol like for this between libpq and the clients?
I think we can probably prototype a callback hook for approach (1)
pretty quickly. (2) is a lot more work and investigation, but it's
work that I'm interested in doing (when I get the time). I think there
are other very good reasons to consider a third-party SASL library,
and some good lessons to be learned, even if the community decides not
to go down that road.Thanks,
--Jacob
Hi,
We validated on libpq handling OAuth natively with different flows
with different OIDC certified providers.
Flows: Device Code, Client Credentials and Refresh Token.
Providers: Microsoft, Google and Okta.
Also validated with OAuth provider Github.
We propose using OpenID Connect (OIDC) as the protocol, instead of
OAuth, as it is:
- Discovery mechanism to bridge the differences and provide metadata.
- Stricter protocol and certification process to reliably identify
which providers can be supported.
- OIDC is designed for authentication, while the main purpose of OAUTH is to
authorize applications on behalf of the user.
Github is not OIDC certified, so won’t be supported with this proposal.
However, it may be supported in the future through the ability for the
extension to provide custom discovery document content.
OpenID configuration has a well-known discovery mechanism
for the provider configuration URI which is
defined in OpenID Connect. It allows libpq to fetch
metadata about provider (i.e endpoints, supported grants, response types, etc).
In the attached patch (based on V2 patch in the thread and does not
contain Samay's changes):
- Provider can configure issuer url and scope through the options hook.)
- Server passes on an open discovery url and scope to libpq.
- Libpq handles OAuth flow based on the flow_type sent in the
connection string [1]connection string for refresh token flow: ./psql -U <user> -d 'dbname=postgres oauth_client_id=<client_id> oauth_flow_type=<flowtype> oauth_refresh_token=<refresh token>'.
- Added callbacks to notify a structure to client tools if OAuth flow
requires user interaction.
- Pg backend uses hooks to validate bearer token.
Note that authentication code flow with PKCE for GUI clients is not
implemented yet.
Proposed next steps:
- Broaden discussion to reach agreement on the approach.
- Implement libpq changes without iddawc
- Prototype GUI flow with pgAdmin
Thanks,
Mahendrakar.
[1]: connection string for refresh token flow: ./psql -U <user> -d 'dbname=postgres oauth_client_id=<client_id> oauth_flow_type=<flowtype> oauth_refresh_token=<refresh token>'
connection string for refresh token flow:
./psql -U <user> -d 'dbname=postgres oauth_client_id=<client_id>
oauth_flow_type=<flowtype> oauth_refresh_token=<refresh token>'
Show quoted text
On Mon, 3 Oct 2022 at 23:34, Andrey Chudnovsky <achudnovskij@gmail.com> wrote:
I think we can probably prototype a callback hook for approach (1)
pretty quickly. (2) is a lot more work and investigation, but it's
work that I'm interested in doing (when I get the time). I think there
are other very good reasons to consider a third-party SASL library,
and some good lessons to be learned, even if the community decides not
to go down that road.Makes sense. We will work on (1.) and do some check if there are any
blockers for a shared solution to support github and google.On Fri, Sep 30, 2022 at 1:45 PM Jacob Champion <jchampion@timescale.com> wrote:
On Fri, Sep 30, 2022 at 7:47 AM Andrey Chudnovsky
<achudnovskij@gmail.com> wrote:How should we communicate those pieces to a custom client when it's
passing a token directly? The easiest way I can see is for the custom
client to speak the OAUTHBEARER protocol directly (e.g. SASL plugin).
If you had to parse the libpq error message, I don't think that'd be
particularly maintainable.I agree that parsing the message is not a sustainable way.
Could you provide more details on the SASL plugin approach you propose?Specifically, is this basically a set of extension hooks for the client side?
With the need for the client to be compiled with the plugins based on
the set of providers it needs.That's a good question. I can see two broad approaches, with maybe
some ability to combine them into a hybrid:1. If there turns out to be serious interest in having libpq itself
handle OAuth natively (with all of the web-facing code that implies,
and all of the questions still left to answer), then we might be able
to provide a "token hook" in the same way that we currently provide a
passphrase hook for OpenSSL keys. By default, libpq would use its
internal machinery to take the provider details, navigate its builtin
flow, and return the Bearer token. If you wanted to override that
behavior as a client, you could replace the builtin flow with your
own, by registering a set of callbacks.2. Alternatively, OAuth support could be provided via a mechanism
plugin for some third-party SASL library (GNU libgsasl, Cyrus
libsasl2). We could provide an OAuth plugin in contrib that handles
the default flow. Other providers could publish their alternative
plugins to completely replace the OAUTHBEARER mechanism handling.Approach (2) would make for some duplicated effort since every
provider has to write code to speak the OAUTHBEARER protocol. It might
simplify provider-specific distribution, since (at least for Cyrus) I
think you could build a single plugin that supports both the client
and server side. But it would be a lot easier to unknowingly (or
knowingly) break the spec, since you'd control both the client and
server sides. There would be less incentive to interoperate.Finally, we could potentially take pieces from both, by having an
official OAuth mechanism plugin that provides a client-side hook to
override the flow. I have no idea if the benefits would offset the
costs of a plugin-for-a-plugin style architecture. And providers would
still be free to ignore it and just provide a full mechanism plugin
anyway.Well... I don't quite understand why we'd go to the trouble of
providing a provider-agnostic communication solution only to have
everyone write their own provider-specific client support. Unless
you're saying Microsoft would provide an officially blessed plugin for
the *server* side only, and Google would provide one of their own, and
so on.Yes, via extensions. Identity providers can open source extensions to
use their auth services outside of first party PaaS offerings.
For 3rd party Postgres PaaS or on premise deployments.Sounds reasonable.
The server side authorization is the only place where I think it makes
sense to specialize by default. libpq should remain agnostic, with the
understanding that we'll need to make hard decisions when a major
provider decides not to follow a spec.Completely agree with agnostic libpq. Though needs validation with
several major providers to know if this is possible.Agreed.
Specifically it delivers that message to an end user. If you want a
generic machine client to be able to use that, then we'll need to talk
about how.Yes, that's what needs to be decided.
In both Device code and Authorization code scenarios, libpq and the
client would need to exchange a couple of pieces of metadata.
Plus, after success, the client should be able to access a refresh token for further use.Can we implement a generic protocol like for this between libpq and the clients?
I think we can probably prototype a callback hook for approach (1)
pretty quickly. (2) is a lot more work and investigation, but it's
work that I'm interested in doing (when I get the time). I think there
are other very good reasons to consider a third-party SASL library,
and some good lessons to be learned, even if the community decides not
to go down that road.Thanks,
--Jacob
Attachments:
v1-0001-oauth-flows-validation-hook-approach.patchapplication/octet-stream; name=v1-0001-oauth-flows-validation-hook-approach.patchDownload
diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c
index 3a625847f3..f213a40b65 100644
--- a/src/backend/libpq/auth-oauth.c
+++ b/src/backend/libpq/auth-oauth.c
@@ -24,15 +24,23 @@
#include "libpq/hba.h"
#include "libpq/oauth.h"
#include "libpq/sasl.h"
+#include "miscadmin.h"
#include "storage/fd.h"
/* GUC */
char *oauth_validator_command;
+static OAuthProvider* oauth_provider = NULL;
+
+/*----------------------------------------------------------------
+ * OAuth Authentication
+ *----------------------------------------------------------------
+ */
+static List *oauth_providers = NIL;
static void oauth_get_mechanisms(Port *port, StringInfo buf);
static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
static int oauth_exchange(void *opaq, const char *input, int inputlen,
- char **output, int *outputlen, char **logdetail);
+ char **output, int *outputlen, const char **logdetail);
/* Mechanism declaration */
const pg_be_sasl_mech pg_be_oauth_mech = {
@@ -43,7 +51,6 @@ const pg_be_sasl_mech pg_be_oauth_mech = {
PG_MAX_AUTH_TOKEN_LENGTH,
};
-
typedef enum
{
OAUTH_STATE_INIT = 0,
@@ -62,7 +69,7 @@ struct oauth_ctx
static char *sanitize_char(char c);
static char *parse_kvpairs_for_auth(char **input);
static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
-static bool validate(Port *port, const char *auth, char **logdetail);
+static bool validate(Port *port, const char *auth, const char **logdetail);
static bool run_validator_command(Port *port, const char *token);
static bool check_exit(FILE **fh, const char *command);
static bool unset_cloexec(int fd);
@@ -72,6 +79,86 @@ static bool username_ok_for_shell(const char *username);
#define AUTH_KEY "auth"
#define BEARER_SCHEME "Bearer "
+#include "utils/memutils.h"
+
+/*----------------------------------------------------------------
+ * OAuth Token Validator
+ *----------------------------------------------------------------
+ */
+
+/*
+ * RegistorOAuthProvider registers a OAuth Token Validator to be
+ * used for oauth token validation. It validates the token and adds the valiator
+ * name and it's hooks to a list of loaded token validator. The right validator's
+ * hooks can then be called based on the validator name specified in
+ * pg_hba.conf.
+ *
+ * This function should be called in _PG_init() by any extension looking to
+ * add a custom authentication method.
+ */
+void
+RegistorOAuthProvider(
+ const char *provider_name,
+ OAuthProviderCheck_hook_type OAuthProviderCheck_hook,
+ OAuthProviderError_hook_type OAuthProviderError_hook,
+ OAuthProviderOptions_hook_type OAuthProviderOptions_hook
+)
+{
+ if (!process_shared_preload_libraries_in_progress)
+ {
+ ereport(ERROR,
+ (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
+ errmsg("RegistorOAuthProvider can only be called by a shared_preload_library")));
+ return;
+ }
+
+ MemoryContext oldcxt;
+ if (oauth_provider == NULL)
+ {
+ oldcxt = MemoryContextSwitchTo(TopMemoryContext);
+ oauth_provider = palloc(sizeof(OAuthProvider));
+ oauth_provider->name = pstrdup(provider_name);
+ oauth_provider->oauth_provider_hook = OAuthProviderCheck_hook;
+ oauth_provider->oauth_error_hook = OAuthProviderError_hook;
+ oauth_provider->oauth_options_hook = OAuthProviderOptions_hook;
+ oauth_providers = lappend(oauth_providers, oauth_provider);
+ MemoryContextSwitchTo(oldcxt);
+ }
+ else
+ {
+ if (oauth_provider && oauth_provider->name)
+ {
+ ereport(ERROR,
+ (errmsg("OAuth provider \"%s\" is already loaded.",
+ oauth_provider->name)));
+ }
+ else
+ {
+ ereport(ERROR,
+ (errmsg("OAuth provider is already loaded.")));
+ }
+ }
+}
+
+/*
+ * Returns the oauth provider (which includes it's
+ * callback functions) based on name specified.
+ */
+OAuthProvider *get_provider_by_name(const char *name)
+{
+ ListCell *lc;
+ foreach(lc, oauth_providers)
+ {
+ OAuthProvider *provider = (OAuthProvider *) lfirst(lc);
+ if (strcmp(provider->name, name) == 0)
+ {
+ return provider;
+ }
+ }
+
+ return NULL;
+}
+
static void
oauth_get_mechanisms(Port *port, StringInfo buf)
{
@@ -102,9 +189,32 @@ oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
return ctx;
}
+static void process_oauth_flow_type(pg_oauth_flow_type flow_type, struct oauth_ctx *ctx, char **output, int *outputlen)
+{
+ StringInfoData buf;
+ initStringInfo(&buf);
+
+ OAuthProviderOptions *oauth_options = oauth_provider->oauth_options_hook(flow_type);
+ ctx->scope = oauth_options->scope;
+ ctx->issuer = oauth_options->issuer_url;
+ appendStringInfo(&buf,
+ "{ "
+ "\"status\": \"invalid_token\", "
+ "\"openid-configuration\": \"%s/.well-known/openid-configuration\","
+ "\"scope\": \"%s\""
+ "}",
+ oauth_options->issuer_url,
+ oauth_options->scope);
+
+ *output = buf.data;
+ *outputlen = buf.len;
+
+ pfree(oauth_options);
+}
+
static int
oauth_exchange(void *opaq, const char *input, int inputlen,
- char **output, int *outputlen, char **logdetail)
+ char **output, int *outputlen, const char **logdetail)
{
char *p;
char cbind_flag;
@@ -247,11 +357,17 @@ oauth_exchange(void *opaq, const char *input, int inputlen,
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("malformed OAUTHBEARER message"),
errdetail("Message contains additional data after the final terminator.")));
-
- if (!validate(ctx->port, auth, logdetail))
+
+ /* if not Bearer, process flow_type*/
+ if (strncasecmp(auth, BEARER_SCHEME, strlen(BEARER_SCHEME)))
+ {
+ process_oauth_flow_type(atoi(auth), ctx, output, outputlen);
+ ctx->state = OAUTH_STATE_ERROR;
+ return PG_SASL_EXCHANGE_CONTINUE;
+ }
+ else if(!validate(ctx->port, auth, logdetail))
{
generate_error_response(ctx, output, outputlen);
-
ctx->state = OAUTH_STATE_ERROR;
return PG_SASL_EXCHANGE_CONTINUE;
}
@@ -415,7 +531,7 @@ generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
}
static bool
-validate(Port *port, const char *auth, char **logdetail)
+validate(Port *port, const char *auth, const char **logdetail)
{
static const char * const b64_set = "abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
@@ -508,7 +624,7 @@ validate(Port *port, const char *auth, char **logdetail)
return true;
}
- /* Make sure the validator authenticated the user. */
+ /* Make sure the validator authenticated the user. */
if (!MyClientConnectionInfo.authn_id)
{
/* TODO: use logdetail; reduce message duplication */
@@ -518,199 +634,22 @@ validate(Port *port, const char *auth, char **logdetail)
return false;
}
- /* Finally, check the user map. */
- ret = check_usermap(port->hba->usermap, port->user_name,
- MyClientConnectionInfo.authn_id, false);
+ /* Finally, check the user map. */
+ ret = check_usermap(port->hba->usermap, port->user_name,
+ MyClientConnectionInfo.authn_id, false);
return (ret == STATUS_OK);
}
static bool
run_validator_command(Port *port, const char *token)
{
- bool success = false;
- int rc;
- int pipefd[2];
- int rfd = -1;
- int wfd = -1;
-
- StringInfoData command = { 0 };
- char *p;
- FILE *fh = NULL;
-
- ssize_t written;
- char *line = NULL;
- size_t size = 0;
- ssize_t len;
-
- Assert(oauth_validator_command);
-
- if (!oauth_validator_command[0])
- {
- ereport(COMMERROR,
- (errmsg("oauth_validator_command is not set"),
- errhint("To allow OAuth authenticated connections, set "
- "oauth_validator_command in postgresql.conf.")));
- return false;
- }
-
- /*
- * Since popen() is unidirectional, open up a pipe for the other direction.
- * Use CLOEXEC to ensure that our write end doesn't accidentally get copied
- * into child processes, which would prevent us from closing it cleanly.
- *
- * XXX this is ugly. We should just read from the child process's stdout,
- * but that's a lot more code.
- * XXX by bypassing the popen API, we open the potential of process
- * deadlock. Clearly document child process requirements (i.e. the child
- * MUST read all data off of the pipe before writing anything).
- * TODO: port to Windows using _pipe().
- */
- rc = pipe2(pipefd, O_CLOEXEC);
- if (rc < 0)
- {
- ereport(COMMERROR,
- (errcode_for_file_access(),
- errmsg("could not create child pipe: %m")));
- return false;
- }
-
- rfd = pipefd[0];
- wfd = pipefd[1];
-
- /* Allow the read pipe be passed to the child. */
- if (!unset_cloexec(rfd))
- {
- /* error message was already logged */
- goto cleanup;
- }
-
- /*
- * Construct the command, substituting any recognized %-specifiers:
- *
- * %f: the file descriptor of the input pipe
- * %r: the role that the client wants to assume (port->user_name)
- * %%: a literal '%'
- */
- initStringInfo(&command);
-
- for (p = oauth_validator_command; *p; p++)
- {
- if (p[0] == '%')
- {
- switch (p[1])
- {
- case 'f':
- appendStringInfo(&command, "%d", rfd);
- p++;
- break;
- case 'r':
- /*
- * TODO: decide how this string should be escaped. The role
- * is controlled by the client, so if we don't escape it,
- * command injections are inevitable.
- *
- * This is probably an indication that the role name needs
- * to be communicated to the validator process in some other
- * way. For this proof of concept, just be incredibly strict
- * about the characters that are allowed in user names.
- */
- if (!username_ok_for_shell(port->user_name))
- goto cleanup;
-
- appendStringInfoString(&command, port->user_name);
- p++;
- break;
- case '%':
- appendStringInfoChar(&command, '%');
- p++;
- break;
- default:
- appendStringInfoChar(&command, p[0]);
- }
- }
- else
- appendStringInfoChar(&command, p[0]);
- }
-
- /* Execute the command. */
- fh = OpenPipeStream(command.data, "re");
- /* TODO: handle failures */
-
- /* We don't need the read end of the pipe anymore. */
- close(rfd);
- rfd = -1;
-
- /* Give the command the token to validate. */
- written = write(wfd, token, strlen(token));
- if (written != strlen(token))
- {
- /* TODO must loop for short writes, EINTR et al */
- ereport(COMMERROR,
- (errcode_for_file_access(),
- errmsg("could not write token to child pipe: %m")));
- goto cleanup;
- }
-
- close(wfd);
- wfd = -1;
-
- /*
- * Read the command's response.
- *
- * TODO: getline() is probably too new to use, unfortunately.
- * TODO: loop over all lines
- */
- if ((len = getline(&line, &size, fh)) >= 0)
- {
- /* TODO: fail if the authn_id doesn't end with a newline */
- if (len > 0)
- line[len - 1] = '\0';
-
- set_authn_id(port, line);
- }
- else if (ferror(fh))
- {
- ereport(COMMERROR,
- (errcode_for_file_access(),
- errmsg("could not read from command \"%s\": %m",
- command.data)));
- goto cleanup;
- }
-
- /* Make sure the command exits cleanly. */
- if (!check_exit(&fh, command.data))
+ int result = oauth_provider->oauth_provider_hook(port, token);
+ if(result == STATUS_OK)
{
- /* error message already logged */
- goto cleanup;
- }
-
- /* Done. */
- success = true;
-
-cleanup:
- if (line)
- free(line);
-
- /*
- * In the successful case, the pipe fds are already closed. For the error
- * case, always close out the pipe before waiting for the command, to
- * prevent deadlock.
- */
- if (rfd >= 0)
- close(rfd);
- if (wfd >= 0)
- close(wfd);
-
- if (fh)
- {
- Assert(!success);
- check_exit(&fh, command.data);
+ set_authn_id(port, port->user_name);
+ return true;
}
-
- if (command.data)
- pfree(command.data);
-
- return success;
+ return false;
}
static bool
@@ -780,7 +719,7 @@ username_ok_for_shell(const char *username)
/* This set is borrowed from fe_utils' appendShellStringNoError(). */
static const char * const allowed = "abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- "0123456789-_./:";
+ "0123456789-_./@:";
size_t span;
Assert(username && username[0]); /* should have already been checked */
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index b62457d57c..7b7b6ff9aa 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -28,6 +28,41 @@ extern void set_authn_id(Port *port, const char *id);
/* Hook for plugins to get control in ClientAuthentication() */
typedef void (*ClientAuthentication_hook_type) (Port *, int);
extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook;
+/* Declarations for oAuth authentication providers */
+typedef int (*OAuthProviderCheck_hook_type) (Port *, const char*);
+
+/* Hook for plugins to report error messages in validation_failed() */
+typedef const char * (*OAuthProviderError_hook_type) (Port *);
+
+/* Hook for plugins to validate oauth provider options */
+typedef bool (*OAuthProviderValidateOptions_hook_type)
+ (char *, char *, HbaLine *, char **);
+
+typedef struct OAuthProviderOptions
+{
+ char *issuer_url;
+ char *scope;
+} OAuthProviderOptions;
+
+/* Hook for plugins to get oauth params */
+typedef OAuthProviderOptions *(*OAuthProviderOptions_hook_type) (pg_oauth_flow_type);
+
+typedef struct OAuthProvider
+{
+ const char *name;
+ OAuthProviderCheck_hook_type oauth_provider_hook;
+ OAuthProviderError_hook_type oauth_error_hook;
+ OAuthProviderOptions_hook_type oauth_options_hook;
+} OAuthProvider;
+
+extern void RegistorOAuthProvider
+ (const char *provider_name,
+ OAuthProviderCheck_hook_type OAuthProviderCheck_hook,
+ OAuthProviderError_hook_type OAuthProviderError_hook,
+ OAuthProviderOptions_hook_type OAuthProviderParams_hook
+ );
+
+extern OAuthProvider *get_provider_by_name(const char *name);
#define PG_MAX_AUTH_TOKEN_LENGTH 65535
#endif /* AUTH_H */
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 6d452ec6d9..f7bbb9dcf4 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -68,6 +68,17 @@ typedef enum CAC_state
CAC_TOOMANY
} CAC_state;
+/* OAuth flow types */
+typedef enum pg_oauth_flow_type
+{
+ OAUTH_DEVICE_CODE,
+ OAUTH_CLIENT_CREDENTIALS,
+ OAUTH_AUTH,
+ OAUTH_AUTH_PKCE,
+ OAUTH_REFRESH_TOKEN,
+ OAUTH_NONE
+} pg_oauth_flow_type;
+
/*
* GSSAPI specific state information
diff --git a/src/interfaces/libpq/fe-auth-oauth.c b/src/interfaces/libpq/fe-auth-oauth.c
index 91d2c69f16..1ba2e033c4 100644
--- a/src/interfaces/libpq/fe-auth-oauth.c
+++ b/src/interfaces/libpq/fe-auth-oauth.c
@@ -142,6 +142,43 @@ iddawc_request_error(PGconn *conn, struct _i_session *i, int err, const char *ms
appendPQExpBuffer(&conn->errorMessage, "(%s)\n", error_code);
}
+static pg_oauth_flow_type oauth_get_flow_type(const char *oauthflow)
+{
+ pg_oauth_flow_type flow_type;
+
+ if(!oauthflow)
+ {
+ return OAUTH_NONE;
+ }
+
+ /* client_secret, device_code, auth_code_pkce, refresh_token */
+ if(strcmp(oauthflow, "device_code") == 0)
+ {
+ flow_type = OAUTH_DEVICE_CODE;
+ }
+ else if(strcmp(oauthflow, "client_secret") == 0)
+ {
+ flow_type = OAUTH_CLIENT_CREDENTIALS;
+ }
+ else if(strcmp(oauthflow, "auth_code_pkce") == 0)
+ {
+ flow_type = OAUTH_AUTH_PKCE;
+ }
+ else if(strcmp(oauthflow, "refresh_token") == 0)
+ {
+ flow_type = OAUTH_REFRESH_TOKEN;
+ }
+ else if(strcmp(oauthflow, "auth_code"))
+ {
+ flow_type = OAUTH_AUTH_CODE;
+ }
+ else
+ {
+ flow_type = OAUTH_NONE;
+ }
+ return flow_type;
+}
+
static char *
get_auth_token(PGconn *conn)
{
@@ -150,29 +187,44 @@ get_auth_token(PGconn *conn)
int err;
int auth_method;
bool user_prompted = false;
- const char *verification_uri;
- const char *user_code;
- const char *access_token;
- const char *token_type;
- char *token = NULL;
-
+ char *verification_uri;
+ char *user_code;
+ char *access_token;
+ char *refresh_token;
+ char *token_type;
+ pg_oauth_flow_type flow_type;
+ char *token = NULL;
+ uint session_response_type;
+ PGOAuthMsgObj oauthMsgObj;
+
+ MemSet(&oauthMsgObj, 0x00, sizeof(PGOAuthMsgObj));
+
if (!conn->oauth_discovery_uri)
return strdup(""); /* ask the server for one */
- i_init_session(&session);
-
if (!conn->oauth_client_id)
{
/* We can't talk to a server without a client identifier. */
appendPQExpBufferStr(&conn->errorMessage,
libpq_gettext("no oauth_client_id is set for the connection"));
- goto cleanup;
+ return NULL;
}
- token_buf = createPQExpBuffer();
+ i_init_session(&session);
+ token_buf = createPQExpBuffer();
if (!token_buf)
goto cleanup;
+
+ if(conn->oauth_bearer_token)
+ {
+ appendPQExpBufferStr(token_buf, "Bearer ");
+ appendPQExpBufferStr(token_buf, conn->oauth_bearer_token);
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+ token = strdup(token_buf->data);
+ goto cleanup;
+ }
err = i_set_str_parameter(&session, I_OPT_OPENID_CONFIG_ENDPOINT, conn->oauth_discovery_uri);
if (err)
@@ -181,6 +233,8 @@ get_auth_token(PGconn *conn)
goto cleanup;
}
+ flow_type = oauth_get_flow_type(conn->oauth_flow_type);
+
err = i_get_openid_config(&session);
if (err)
{
@@ -201,18 +255,64 @@ get_auth_token(PGconn *conn)
libpq_gettext("issuer does not support device authorization"));
goto cleanup;
}
+ auth_method = I_TOKEN_AUTH_METHOD_NONE;
+
+ /* for refresh token flow, do not run auth request*/
+ if(flow_type == OAUTH_REFRESH_TOKEN && conn->oauth_refresh_token)
+ {
+ err = i_set_parameter_list(&session,
+ I_OPT_CLIENT_ID, conn->oauth_client_id,
+ I_OPT_REFRESH_TOKEN, conn->oauth_refresh_token,
+ I_OPT_RESPONSE_TYPE, I_RESPONSE_TYPE_REFRESH_TOKEN,
+ I_OPT_TOKEN_METHOD, auth_method,
+ I_OPT_CLIENT_SECRET, conn->oauth_client_secret,
+ I_OPT_SCOPE, conn->oauth_scope,
+ I_OPT_NONE
+ );
+
+ if (err)
+ {
+ iddawc_error(conn, err, "failed to set refresh token flow parameters");
+ goto cleanup;
+ }
- err = i_set_response_type(&session, I_RESPONSE_TYPE_DEVICE_CODE);
+ err = i_run_token_request(&session);
+ if (err)
+ {
+ iddawc_request_error(conn, &session, err,
+ "failed to obtain token authorization with refresh token flow");
+ goto cleanup;
+ }
+
+ access_token = i_get_str_parameter(&session, I_OPT_ACCESS_TOKEN);
+ token_type = i_get_str_parameter(&session, I_OPT_TOKEN_TYPE);
+
+ if (!access_token || !token_type || strcasecmp(token_type, "Bearer"))
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("issuer did not provide a bearer token"));
+ goto cleanup;
+ }
+
+ appendPQExpBufferStr(token_buf, "Bearer ");
+ appendPQExpBufferStr(token_buf, access_token);
+
+ if (PQExpBufferBroken(token_buf))
+ goto cleanup;
+
+ token = strdup(token_buf->data);
+ return token;
+ }
+
+ //default device flow
+ session_response_type = I_RESPONSE_TYPE_DEVICE_CODE;
+ err = i_set_response_type(&session, session_response_type);
if (err)
{
iddawc_error(conn, err, "failed to set device code response type");
goto cleanup;
}
- auth_method = I_TOKEN_AUTH_METHOD_NONE;
- if (conn->oauth_client_secret && *conn->oauth_client_secret)
- auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC;
-
err = i_set_parameter_list(&session,
I_OPT_CLIENT_ID, conn->oauth_client_id,
I_OPT_CLIENT_SECRET, conn->oauth_client_secret,
@@ -225,7 +325,7 @@ get_auth_token(PGconn *conn)
iddawc_error(conn, err, "failed to set client identifier");
goto cleanup;
}
-
+
err = i_run_device_auth_request(&session);
if (err)
{
@@ -278,14 +378,15 @@ get_auth_token(PGconn *conn)
if (!user_prompted)
{
+ oauthMsgObj.verification_uri = verification_uri;
+ oauthMsgObj.user_code = user_code;
+ conn->oauthNoticeHooks.noticeRecArg = (void*) &oauthMsgObj;
+
/*
* Now that we know the token endpoint isn't broken, give the user
* the login instructions.
- */
- pqInternalNotice(&conn->noticeHooks,
- "Visit %s and enter the code: %s",
- verification_uri, user_code);
-
+ */
+ pqInternalOAuthNotice(&conn->oauthNoticeHooks, "");
user_prompted = true;
}
@@ -300,7 +401,7 @@ get_auth_token(PGconn *conn)
* A slow_down error requires us to permanently increase our retry
* interval by five seconds. RFC 8628, Sec. 3.5.
*/
- if (!strcmp(error_code, "slow_down"))
+ //if (!strcmp(error_code, "slow_down"))
{
interval += 5;
i_set_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL, interval);
@@ -323,6 +424,14 @@ get_auth_token(PGconn *conn)
access_token = i_get_str_parameter(&session, I_OPT_ACCESS_TOKEN);
token_type = i_get_str_parameter(&session, I_OPT_TOKEN_TYPE);
+ refresh_token = i_get_str_parameter(&session, I_OPT_REFRESH_TOKEN);
+
+ if(refresh_token)
+ {
+ MemSet(&oauthMsgObj, 0x00, sizeof(PGOAuthMsgObj));
+ oauthMsgObj.refresh_token = refresh_token;
+ pqInternalOAuthNotice(&conn->oauthNoticeHooks, "");
+ }
if (!access_token || !token_type || strcasecmp(token_type, "Bearer"))
{
@@ -358,6 +467,8 @@ client_initial_response(PGconn *conn)
PQExpBuffer discovery_buf = NULL;
char *token = NULL;
char *response = NULL;
+ pg_oauth_flow_type flow_type;
+ char oauth_flow_str[3];
token_buf = createPQExpBuffer();
if (!token_buf)
@@ -385,8 +496,26 @@ client_initial_response(PGconn *conn)
token = get_auth_token(conn);
if (!token)
goto cleanup;
-
+
+ if(strcmp(token, "") == 0)
+ {
+ flow_type = oauth_get_flow_type(conn->oauth_flow_type);
+ if(flow_type == OAUTH_NONE)
+ {
+ appendPQExpBufferStr(&conn->errorMessage,
+ libpq_gettext("value passed in oauth_flow_type is not valid."\
+ "supported flows: client_secret, device_code, auth_code_pkce, refresh_token\n"));
+ goto cleanup;
+ }
+ else
+ {
+ sprintf(oauth_flow_str, "%d", flow_type);
+ token = strdup(oauth_flow_str);
+ }
+ }
appendPQExpBuffer(token_buf, resp_format, token);
+// elog(INFO, "fe-flowtype: %s", token);
+
if (PQExpBufferBroken(token_buf))
goto cleanup;
@@ -406,6 +535,9 @@ cleanup:
#define ERROR_STATUS_FIELD "status"
#define ERROR_SCOPE_FIELD "scope"
#define ERROR_OPENID_CONFIGURATION_FIELD "openid-configuration"
+#define ERROR_ISSUER_URL_FIELD "issuer"
+#define ERROR_AUTH_ENDPOINT_FIELD "authorization_endpoint"
+#define ERROR_TOKEN_ENDPOINT_FIELD "token_endpoint"
struct json_ctx
{
@@ -420,6 +552,9 @@ struct json_ctx
char *status;
char *scope;
char *discovery_uri;
+ char *issuer_url;
+ char *auth_endpoint;
+ char *token_endpoint;
};
#define oauth_json_has_error(ctx) \
@@ -491,6 +626,21 @@ oauth_json_object_field_start(void *state, char *name, bool isnull)
ctx->target_field_name = ERROR_OPENID_CONFIGURATION_FIELD;
ctx->target_field = &ctx->discovery_uri;
}
+ else if(!strcmp(name, ERROR_ISSUER_URL_FIELD))
+ {
+ ctx->target_field_name = ERROR_ISSUER_URL_FIELD;
+ ctx->target_field = &ctx->issuer_url;
+ }
+ else if(!strcmp(name, ERROR_AUTH_ENDPOINT_FIELD))
+ {
+ ctx->target_field_name = ERROR_AUTH_ENDPOINT_FIELD;
+ ctx->target_field = &ctx->auth_endpoint;
+ }
+ else if(!strcmp(name, ERROR_TOKEN_ENDPOINT_FIELD))
+ {
+ ctx->target_field_name = ERROR_TOKEN_ENDPOINT_FIELD;
+ ctx->target_field = &ctx->token_endpoint;
+ }
}
free(name);
@@ -627,6 +777,15 @@ handle_oauth_sasl_error(PGconn *conn, char *msg, int msglen)
conn->oauth_scope = ctx.scope;
}
+
+ if(ctx.issuer_url)
+ {
+ if(conn->oauth_issuer)
+ free(conn->oauth_issuer);
+
+ conn->oauth_issuer = ctx.issuer_url;
+ }
+
/* TODO: missing error scope should clear any existing connection scope */
if (!ctx.status)
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 64f27fee18..e6e8dc48e2 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -358,6 +358,18 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
"OAuth-Scope", "", 15,
offsetof(struct pg_conn, oauth_scope)},
+ {"oauth_bearer_token", NULL, NULL, NULL,
+ "OAuth-Bearer", "", 20,
+ offsetof(struct pg_conn, oauth_bearer_token)},
+
+ {"oauth_flow_type", NULL, NULL, NULL,
+ "OAuth-Flow-Type", "", 20,
+ offsetof(struct pg_conn, oauth_flow_type)},
+
+ {"oauth_refresh_token", NULL, NULL, NULL,
+ "OAuth-Refresh-Token", "", 40,
+ offsetof(struct pg_conn, oauth_refresh_token)},
+
/* Terminating entry --- MUST BE LAST */
{NULL, NULL, NULL, NULL,
NULL, NULL, 0}
@@ -427,6 +439,7 @@ static PQconninfoOption *conninfo_find(PQconninfoOption *connOptions,
const char *keyword);
static void defaultNoticeReceiver(void *arg, const PGresult *res);
static void defaultNoticeProcessor(void *arg, const char *message);
+static void OAuthMsgObjReceiver(void *arg, const PGresult *res);
static int parseServiceInfo(PQconninfoOption *options,
PQExpBuffer errorMessage);
static int parseServiceFile(const char *serviceFile,
@@ -3926,6 +3939,7 @@ makeEmptyPGconn(void)
/* install default notice hooks */
conn->noticeHooks.noticeRec = defaultNoticeReceiver;
conn->noticeHooks.noticeProc = defaultNoticeProcessor;
+ conn->oauthNoticeHooks.noticeRec = OAuthMsgObjReceiver;
conn->status = CONNECTION_BAD;
conn->asyncStatus = PGASYNC_IDLE;
@@ -4073,6 +4087,12 @@ freePGconn(PGconn *conn)
free(conn->oauth_client_secret);
if (conn->oauth_scope)
free(conn->oauth_scope);
+ if(conn->oauth_bearer_token)
+ free(conn->oauth_bearer_token);
+ if(conn->oauth_flow_type)
+ free(conn->oauth_flow_type);
+ if(conn->oauth_refresh_token)
+ free(conn->oauth_refresh_token);
termPQExpBuffer(&conn->errorMessage);
termPQExpBuffer(&conn->workBuffer);
@@ -6991,6 +7011,32 @@ defaultNoticeProcessor(void *arg, const char *message)
fprintf(stderr, "%s", message);
}
+static void
+OAuthMsgObjReceiver(void *arg, const PGresult *res)
+{
+ PGOAuthMsgObj *oauthMsg = (PGOAuthMsgObj *) arg;
+
+ if(oauthMsg->message)
+ {
+ fprintf(stderr, "%s\n", oauthMsg->message);
+ }
+
+ if(oauthMsg->verification_uri)
+ {
+ fprintf(stderr, "Visit: %s\n", oauthMsg->verification_uri);
+ }
+
+ if(oauthMsg->user_code)
+ {
+ fprintf(stderr, "Enter: %s\n", oauthMsg->user_code);
+ }
+
+ if(oauthMsg->refresh_token)
+ {
+ fprintf(stderr, "Refresh Token: %s\n", oauthMsg->refresh_token);
+ }
+}
+
/*
* returns a pointer to the next token or NULL if the current
* token doesn't match
diff --git a/src/interfaces/libpq/fe-exec.c b/src/interfaces/libpq/fe-exec.c
index da229d632a..4789c1a1fe 100644
--- a/src/interfaces/libpq/fe-exec.c
+++ b/src/interfaces/libpq/fe-exec.c
@@ -976,6 +976,58 @@ pqInternalNotice(const PGNoticeHooks *hooks, const char *fmt,...)
PQclear(res);
}
+/*
+ * pqInternalOAuthNotice - it is similar to pqInternalNotice
+ * except that OAuthNoticeHooks are invoked.
+ */
+void
+pqInternalOAuthNotice(const PGOAuthNoticeHooks *hooks, const char *fmt,...)
+{
+ char msgBuf[1024];
+ va_list args;
+ PGresult *res;
+
+ if (hooks->noticeRec == NULL)
+ return; /* nobody home to receive notice? */
+
+ /* Format the message */
+ va_start(args, fmt);
+ vsnprintf(msgBuf, sizeof(msgBuf), libpq_gettext(fmt), args);
+ va_end(args);
+ msgBuf[sizeof(msgBuf) - 1] = '\0'; /* make real sure it's terminated */
+
+ /* Make a PGresult to pass to the notice receiver */
+ res = PQmakeEmptyPGresult(NULL, PGRES_NONFATAL_ERROR);
+ if (!res)
+ return;
+ res->oauthNoticeHooks = *hooks;
+ res->oauthNoticeHooks.noticeRecArg = hooks->noticeRecArg;
+
+ /*
+ * Set up fields of notice.
+ */
+ pqSaveMessageField(res, PG_DIAG_MESSAGE_PRIMARY, msgBuf);
+ pqSaveMessageField(res, PG_DIAG_SEVERITY, libpq_gettext("NOTICE"));
+ pqSaveMessageField(res, PG_DIAG_SEVERITY_NONLOCALIZED, "NOTICE");
+ /* XXX should provide a SQLSTATE too? */
+
+ /*
+ * Result text is always just the primary message + newline. If we can't
+ * allocate it, substitute "out of memory", as in pqSetResultError.
+ */
+ res->errMsg = (char *) pqResultAlloc(res, strlen(msgBuf) + 2, false);
+ if (res->errMsg)
+ sprintf(res->errMsg, "%s\n", msgBuf);
+ else
+ res->errMsg = libpq_gettext("out of memory\n");
+
+ /*
+ * Pass to receiver, then free it.
+ */
+ res->oauthNoticeHooks.noticeRec(res->oauthNoticeHooks.noticeRecArg, res);
+ PQclear(res);
+}
+
/*
* pqAddTuple
* add a row pointer to the PGresult structure, growing it if necessary
diff --git a/src/interfaces/libpq/libpq-fe.h b/src/interfaces/libpq/libpq-fe.h
index b7df3224c0..ee5b2e2b59 100644
--- a/src/interfaces/libpq/libpq-fe.h
+++ b/src/interfaces/libpq/libpq-fe.h
@@ -197,6 +197,9 @@ typedef struct pgNotify
typedef void (*PQnoticeReceiver) (void *arg, const PGresult *res);
typedef void (*PQnoticeProcessor) (void *arg, const char *message);
+/* Function types for notice-handling callbacks */
+typedef void (*PQOAuthNoticeReceiver) (void *arg, const PGresult *res);
+
/* Print options for PQprint() */
typedef char pqbool;
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index ae76ae0e8f..3155d81e00 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -157,6 +157,24 @@ typedef struct
void *noticeProcArg;
} PGNoticeHooks;
+typedef struct
+{
+ char *verification_uri; /* URI the user should go to with the user_code in order to sign in */
+ char *user_code; /* used to identify the session on a secondary device */
+ char *refresh_token;
+ char *message; /* string with instructions for the user. */
+ char *response_error; /*JSON error response (400 Bad Request) */
+ uint expires_in; /* number of seconds before the device_code expire */
+ uint interval; /* number of seconds the client should wait between polling requests */
+} PGOAuthMsgObj;
+
+/* Fields needed for oauth callback handling */
+typedef struct
+{
+ PQOAuthNoticeReceiver noticeRec; /* OAuth notice message receiver */
+ void *noticeRecArg;
+} PGOAuthNoticeHooks;
+
typedef struct PGEvent
{
PGEventProc proc; /* the function to call on events */
@@ -186,6 +204,7 @@ struct pg_result
* on the PGresult don't have to reference the PGconn.
*/
PGNoticeHooks noticeHooks;
+ PGOAuthNoticeHooks oauthNoticeHooks;
PGEvent *events;
int nEvents;
int client_encoding; /* encoding id */
@@ -343,6 +362,17 @@ typedef struct pg_conn_host
* found in password file. */
} pg_conn_host;
+typedef enum pg_oauth_flow_type
+{
+ OAUTH_DEVICE_CODE,
+ OAUTH_CLIENT_CREDENTIALS,
+ OAUTH_AUTH,
+ OAUTH_AUTH_PKCE,
+ OAUTH_REFRESH_TOKEN,
+ OAUTH_AUTH_CODE,
+ OAUTH_NONE
+} pg_oauth_flow_type;
+
/*
* PGconn stores all the state data associated with a single connection
* to a backend.
@@ -403,6 +433,9 @@ struct pg_conn
char *oauth_client_id; /* client identifier */
char *oauth_client_secret; /* client secret */
char *oauth_scope; /* access token scope */
+ char *oauth_bearer_token; /* oauth token */
+ char *oauth_flow_type; /* oauth flow type */
+ char *oauth_refresh_token; /* refresh token */
bool oauth_want_retry; /* should we retry on failure? */
/* Optional file to write trace info to */
@@ -412,6 +445,9 @@ struct pg_conn
/* Callback procedures for notice message processing */
PGNoticeHooks noticeHooks;
+ /* Callback procedures for notifying messages during oauth flows*/
+ PGOAuthNoticeHooks oauthNoticeHooks;
+
/* Event procs registered via PQregisterEventProc */
PGEvent *events; /* expandable array of event data */
int nEvents; /* number of active events */
@@ -677,6 +713,7 @@ extern void pqClearAsyncResult(PGconn *conn);
extern void pqSaveErrorResult(PGconn *conn);
extern PGresult *pqPrepareAsyncResult(PGconn *conn);
extern void pqInternalNotice(const PGNoticeHooks *hooks, const char *fmt,...) pg_attribute_printf(2, 3);
+extern void pqInternalOAuthNotice(const PGOAuthNoticeHooks *hooks, const char *fmt,...);
extern void pqSaveMessageField(PGresult *res, char code,
const char *value);
extern void pqSaveParameterStatus(PGconn *conn, const char *name,
On 11/23/22 01:58, mahendrakar s wrote:
We validated on libpq handling OAuth natively with different flows
with different OIDC certified providers.Flows: Device Code, Client Credentials and Refresh Token.
Providers: Microsoft, Google and Okta.
Great, thank you!
Also validated with OAuth provider Github.
(How did you get discovery working? I tried this and had to give up
eventually.)
We propose using OpenID Connect (OIDC) as the protocol, instead of
OAuth, as it is:
- Discovery mechanism to bridge the differences and provide metadata.
- Stricter protocol and certification process to reliably identify
which providers can be supported.
- OIDC is designed for authentication, while the main purpose of OAUTH is to
authorize applications on behalf of the user.
How does this differ from the previous proposal? The OAUTHBEARER SASL
mechanism already relies on OIDC for discovery. (I think that decision
is confusing from an architectural and naming standpoint, but I don't
think they really had an alternative...)
Github is not OIDC certified, so won’t be supported with this proposal.
However, it may be supported in the future through the ability for the
extension to provide custom discovery document content.
Right.
OpenID configuration has a well-known discovery mechanism
for the provider configuration URI which is
defined in OpenID Connect. It allows libpq to fetch
metadata about provider (i.e endpoints, supported grants, response types, etc).
Sure, but this is already how the original PoC works. The test suite
implements an OIDC provider, for instance. Is there something different
to this that I'm missing?
In the attached patch (based on V2 patch in the thread and does not
contain Samay's changes):
- Provider can configure issuer url and scope through the options hook.)
- Server passes on an open discovery url and scope to libpq.
- Libpq handles OAuth flow based on the flow_type sent in the
connection string [1].
- Added callbacks to notify a structure to client tools if OAuth flow
requires user interaction.
- Pg backend uses hooks to validate bearer token.
Thank you for the sample!
Note that authentication code flow with PKCE for GUI clients is not
implemented yet.Proposed next steps:
- Broaden discussion to reach agreement on the approach.
High-level thoughts on this particular patch (I assume you're not
looking for low-level implementation comments yet):
0) The original hook proposal upthread, I thought, was about allowing
libpq's flow implementation to be switched out by the application. I
don't see that approach taken here. It's fine if that turned out to be a
bad idea, of course, but this patch doesn't seem to match what we were
talking about.
1) I'm really concerned about the sudden explosion of flows. We went
from one flow (Device Authorization) to six. It's going to be hard
enough to validate that *one* flow is useful and can be securely
deployed by end users; I don't think we're going to be able to maintain
six, especially in combination with my statement that iddawc is not an
appropriate dependency for us.
I'd much rather give applications the ability to use their own OAuth
code, and then maintain within libpq only the flows that are broadly
useful. This ties back to (0) above.
2) Breaking the refresh token into its own pseudoflow is, I think,
passing the buck onto the user for something that's incredibly security
sensitive. The refresh token is powerful; I don't really want it to be
printed anywhere, let alone copy-pasted by the user. Imagine the
phishing opportunities.
If we want to support refresh tokens, I believe we should be developing
a plan to cache and secure them within the client. They should be used
as an accelerator for other flows, not as their own flow.
3) I don't like the departure from the OAUTHBEARER mechanism that's
presented here. For one, since I can't see a sample plugin that makes
use of the "flow type" magic numbers that have been added, I don't
really understand why the extension to the mechanism is necessary.
For two, if we think OAUTHBEARER is insufficient, the people who wrote
it would probably like to hear about it. Claiming support for a spec,
and then implementing an extension without review from the people who
wrote the spec, is not something I'm personally interested in doing.
4) The test suite is still broken, so it's difficult to see these things
in practice for review purposes.
- Implement libpq changes without iddawc
This in particular will be much easier with a functioning test suite,
and with a smaller number of flows.
- Prototype GUI flow with pgAdmin
Cool!
Thanks,
--Jacob
How does this differ from the previous proposal? The OAUTHBEARER SASL
mechanism already relies on OIDC for discovery. (I think that decision
is confusing from an architectural and naming standpoint, but I don't
think they really had an alternative...)
Mostly terminology questions here. OAUTHBEARER SASL appears to be the
spec about using OAUTH2 tokens for Authentication.
While any OAUTH2 can generally work, we propose to specifically
highlight that only OIDC providers can be supported, as we need the
discovery document.
And we won't be able to support Github under that requirement.
Since the original patch used that too - no change on that, just
confirmation that we need OIDC compliance.
0) The original hook proposal upthread, I thought, was about allowing
libpq's flow implementation to be switched out by the application. I
don't see that approach taken here. It's fine if that turned out to be a
bad idea, of course, but this patch doesn't seem to match what we were
talking about.
We still plan to allow the client to pass the token. Which is a
generic way to implement its own OAUTH flows.
1) I'm really concerned about the sudden explosion of flows. We went
from one flow (Device Authorization) to six. It's going to be hard
enough to validate that *one* flow is useful and can be securely
deployed by end users; I don't think we're going to be able to maintain
six, especially in combination with my statement that iddawc is not an
appropriate dependency for us.
I'd much rather give applications the ability to use their own OAuth
code, and then maintain within libpq only the flows that are broadly
useful. This ties back to (0) above.
We consider the following set of flows to be minimum required:
- Client Credentials - For Service to Service scenarios.
- Authorization Code with PKCE - For rich clients,including pgAdmin.
- Device code - for psql (and possibly other non-GUI clients).
- Refresh code (separate discussion)
Which is pretty much the list described here:
https://oauth.net/2/grant-types/ and in OAUTH2 specs.
Client Credentials is very simple, so does Refresh Code.
If you prefer to pick one of the richer flows, Authorization code for
GUI scenarios is probably much more widely used.
Plus it's easier to implement too, as interaction goes through a
series of callbacks. No polling required.
2) Breaking the refresh token into its own pseudoflow is, I think,
passing the buck onto the user for something that's incredibly security
sensitive. The refresh token is powerful; I don't really want it to be
printed anywhere, let alone copy-pasted by the user. Imagine the
phishing opportunities.
If we want to support refresh tokens, I believe we should be developing
a plan to cache and secure them within the client. They should be used
as an accelerator for other flows, not as their own flow.
It's considered a separate "grant_type" in the specs / APIs.
https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
For the clients, it would be storing the token and using it to authenticate.
On the question of sensitivity, secure credentials stores are
different for each platform, with a lot of cloud offerings for this.
pgAdmin, for example, has its own way to secure credentials to avoid
asking users for passwords every time the app is opened.
I believe we should delegate the refresh token management to the clients.
3) I don't like the departure from the OAUTHBEARER mechanism that's
presented here. For one, since I can't see a sample plugin that makes
use of the "flow type" magic numbers that have been added, I don't
really understand why the extension to the mechanism is necessary.
I don't think it's much of a departure, but rather a separation of
responsibilities between libpq and upstream clients.
As libpq can be used in different apps, the client would need
different types of flows/grants.
I.e. those need to be provided to libpq at connection initialization
or some other point.
We will change to "grant_type" though and use string to be closer to the spec.
What do you think is the best way for the client to signal which OAUTH
flow should be used?
Show quoted text
On Wed, Nov 23, 2022 at 12:05 PM Jacob Champion <jchampion@timescale.com> wrote:
On 11/23/22 01:58, mahendrakar s wrote:
We validated on libpq handling OAuth natively with different flows
with different OIDC certified providers.Flows: Device Code, Client Credentials and Refresh Token.
Providers: Microsoft, Google and Okta.Great, thank you!
Also validated with OAuth provider Github.
(How did you get discovery working? I tried this and had to give up
eventually.)We propose using OpenID Connect (OIDC) as the protocol, instead of
OAuth, as it is:
- Discovery mechanism to bridge the differences and provide metadata.
- Stricter protocol and certification process to reliably identify
which providers can be supported.
- OIDC is designed for authentication, while the main purpose of OAUTH is to
authorize applications on behalf of the user.How does this differ from the previous proposal? The OAUTHBEARER SASL
mechanism already relies on OIDC for discovery. (I think that decision
is confusing from an architectural and naming standpoint, but I don't
think they really had an alternative...)Github is not OIDC certified, so won’t be supported with this proposal.
However, it may be supported in the future through the ability for the
extension to provide custom discovery document content.Right.
OpenID configuration has a well-known discovery mechanism
for the provider configuration URI which is
defined in OpenID Connect. It allows libpq to fetch
metadata about provider (i.e endpoints, supported grants, response types, etc).Sure, but this is already how the original PoC works. The test suite
implements an OIDC provider, for instance. Is there something different
to this that I'm missing?In the attached patch (based on V2 patch in the thread and does not
contain Samay's changes):
- Provider can configure issuer url and scope through the options hook.)
- Server passes on an open discovery url and scope to libpq.
- Libpq handles OAuth flow based on the flow_type sent in the
connection string [1].
- Added callbacks to notify a structure to client tools if OAuth flow
requires user interaction.
- Pg backend uses hooks to validate bearer token.Thank you for the sample!
Note that authentication code flow with PKCE for GUI clients is not
implemented yet.Proposed next steps:
- Broaden discussion to reach agreement on the approach.High-level thoughts on this particular patch (I assume you're not
looking for low-level implementation comments yet):0) The original hook proposal upthread, I thought, was about allowing
libpq's flow implementation to be switched out by the application. I
don't see that approach taken here. It's fine if that turned out to be a
bad idea, of course, but this patch doesn't seem to match what we were
talking about.1) I'm really concerned about the sudden explosion of flows. We went
from one flow (Device Authorization) to six. It's going to be hard
enough to validate that *one* flow is useful and can be securely
deployed by end users; I don't think we're going to be able to maintain
six, especially in combination with my statement that iddawc is not an
appropriate dependency for us.I'd much rather give applications the ability to use their own OAuth
code, and then maintain within libpq only the flows that are broadly
useful. This ties back to (0) above.2) Breaking the refresh token into its own pseudoflow is, I think,
passing the buck onto the user for something that's incredibly security
sensitive. The refresh token is powerful; I don't really want it to be
printed anywhere, let alone copy-pasted by the user. Imagine the
phishing opportunities.If we want to support refresh tokens, I believe we should be developing
a plan to cache and secure them within the client. They should be used
as an accelerator for other flows, not as their own flow.3) I don't like the departure from the OAUTHBEARER mechanism that's
presented here. For one, since I can't see a sample plugin that makes
use of the "flow type" magic numbers that have been added, I don't
really understand why the extension to the mechanism is necessary.For two, if we think OAUTHBEARER is insufficient, the people who wrote
it would probably like to hear about it. Claiming support for a spec,
and then implementing an extension without review from the people who
wrote the spec, is not something I'm personally interested in doing.4) The test suite is still broken, so it's difficult to see these things
in practice for review purposes.- Implement libpq changes without iddawc
This in particular will be much easier with a functioning test suite,
and with a smaller number of flows.- Prototype GUI flow with pgAdmin
Cool!
Thanks,
--Jacob
Hi Jacob,
I had validated Github by skipping the discovery mechanism and letting
the provider extension pass on the endpoints. This is just for
validation purposes.
If it needs to be supported, then need a way to send the discovery
document from extension.
Thanks,
Mahendrakar.
Show quoted text
On Thu, 24 Nov 2022 at 09:16, Andrey Chudnovsky <achudnovskij@gmail.com> wrote:
How does this differ from the previous proposal? The OAUTHBEARER SASL
mechanism already relies on OIDC for discovery. (I think that decision
is confusing from an architectural and naming standpoint, but I don't
think they really had an alternative...)Mostly terminology questions here. OAUTHBEARER SASL appears to be the
spec about using OAUTH2 tokens for Authentication.
While any OAUTH2 can generally work, we propose to specifically
highlight that only OIDC providers can be supported, as we need the
discovery document.
And we won't be able to support Github under that requirement.
Since the original patch used that too - no change on that, just
confirmation that we need OIDC compliance.0) The original hook proposal upthread, I thought, was about allowing
libpq's flow implementation to be switched out by the application. I
don't see that approach taken here. It's fine if that turned out to be a
bad idea, of course, but this patch doesn't seem to match what we were
talking about.We still plan to allow the client to pass the token. Which is a
generic way to implement its own OAUTH flows.1) I'm really concerned about the sudden explosion of flows. We went
from one flow (Device Authorization) to six. It's going to be hard
enough to validate that *one* flow is useful and can be securely
deployed by end users; I don't think we're going to be able to maintain
six, especially in combination with my statement that iddawc is not an
appropriate dependency for us.I'd much rather give applications the ability to use their own OAuth
code, and then maintain within libpq only the flows that are broadly
useful. This ties back to (0) above.We consider the following set of flows to be minimum required:
- Client Credentials - For Service to Service scenarios.
- Authorization Code with PKCE - For rich clients,including pgAdmin.
- Device code - for psql (and possibly other non-GUI clients).
- Refresh code (separate discussion)
Which is pretty much the list described here:
https://oauth.net/2/grant-types/ and in OAUTH2 specs.
Client Credentials is very simple, so does Refresh Code.
If you prefer to pick one of the richer flows, Authorization code for
GUI scenarios is probably much more widely used.
Plus it's easier to implement too, as interaction goes through a
series of callbacks. No polling required.2) Breaking the refresh token into its own pseudoflow is, I think,
passing the buck onto the user for something that's incredibly security
sensitive. The refresh token is powerful; I don't really want it to be
printed anywhere, let alone copy-pasted by the user. Imagine the
phishing opportunities.If we want to support refresh tokens, I believe we should be developing
a plan to cache and secure them within the client. They should be used
as an accelerator for other flows, not as their own flow.It's considered a separate "grant_type" in the specs / APIs.
https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokensFor the clients, it would be storing the token and using it to authenticate.
On the question of sensitivity, secure credentials stores are
different for each platform, with a lot of cloud offerings for this.
pgAdmin, for example, has its own way to secure credentials to avoid
asking users for passwords every time the app is opened.
I believe we should delegate the refresh token management to the clients.3) I don't like the departure from the OAUTHBEARER mechanism that's
presented here. For one, since I can't see a sample plugin that makes
use of the "flow type" magic numbers that have been added, I don't
really understand why the extension to the mechanism is necessary.I don't think it's much of a departure, but rather a separation of
responsibilities between libpq and upstream clients.
As libpq can be used in different apps, the client would need
different types of flows/grants.
I.e. those need to be provided to libpq at connection initialization
or some other point.
We will change to "grant_type" though and use string to be closer to the spec.
What do you think is the best way for the client to signal which OAUTH
flow should be used?On Wed, Nov 23, 2022 at 12:05 PM Jacob Champion <jchampion@timescale.com> wrote:
On 11/23/22 01:58, mahendrakar s wrote:
We validated on libpq handling OAuth natively with different flows
with different OIDC certified providers.Flows: Device Code, Client Credentials and Refresh Token.
Providers: Microsoft, Google and Okta.Great, thank you!
Also validated with OAuth provider Github.
(How did you get discovery working? I tried this and had to give up
eventually.)We propose using OpenID Connect (OIDC) as the protocol, instead of
OAuth, as it is:
- Discovery mechanism to bridge the differences and provide metadata.
- Stricter protocol and certification process to reliably identify
which providers can be supported.
- OIDC is designed for authentication, while the main purpose of OAUTH is to
authorize applications on behalf of the user.How does this differ from the previous proposal? The OAUTHBEARER SASL
mechanism already relies on OIDC for discovery. (I think that decision
is confusing from an architectural and naming standpoint, but I don't
think they really had an alternative...)Github is not OIDC certified, so won’t be supported with this proposal.
However, it may be supported in the future through the ability for the
extension to provide custom discovery document content.Right.
OpenID configuration has a well-known discovery mechanism
for the provider configuration URI which is
defined in OpenID Connect. It allows libpq to fetch
metadata about provider (i.e endpoints, supported grants, response types, etc).Sure, but this is already how the original PoC works. The test suite
implements an OIDC provider, for instance. Is there something different
to this that I'm missing?In the attached patch (based on V2 patch in the thread and does not
contain Samay's changes):
- Provider can configure issuer url and scope through the options hook.)
- Server passes on an open discovery url and scope to libpq.
- Libpq handles OAuth flow based on the flow_type sent in the
connection string [1].
- Added callbacks to notify a structure to client tools if OAuth flow
requires user interaction.
- Pg backend uses hooks to validate bearer token.Thank you for the sample!
Note that authentication code flow with PKCE for GUI clients is not
implemented yet.Proposed next steps:
- Broaden discussion to reach agreement on the approach.High-level thoughts on this particular patch (I assume you're not
looking for low-level implementation comments yet):0) The original hook proposal upthread, I thought, was about allowing
libpq's flow implementation to be switched out by the application. I
don't see that approach taken here. It's fine if that turned out to be a
bad idea, of course, but this patch doesn't seem to match what we were
talking about.1) I'm really concerned about the sudden explosion of flows. We went
from one flow (Device Authorization) to six. It's going to be hard
enough to validate that *one* flow is useful and can be securely
deployed by end users; I don't think we're going to be able to maintain
six, especially in combination with my statement that iddawc is not an
appropriate dependency for us.I'd much rather give applications the ability to use their own OAuth
code, and then maintain within libpq only the flows that are broadly
useful. This ties back to (0) above.2) Breaking the refresh token into its own pseudoflow is, I think,
passing the buck onto the user for something that's incredibly security
sensitive. The refresh token is powerful; I don't really want it to be
printed anywhere, let alone copy-pasted by the user. Imagine the
phishing opportunities.If we want to support refresh tokens, I believe we should be developing
a plan to cache and secure them within the client. They should be used
as an accelerator for other flows, not as their own flow.3) I don't like the departure from the OAUTHBEARER mechanism that's
presented here. For one, since I can't see a sample plugin that makes
use of the "flow type" magic numbers that have been added, I don't
really understand why the extension to the mechanism is necessary.For two, if we think OAUTHBEARER is insufficient, the people who wrote
it would probably like to hear about it. Claiming support for a spec,
and then implementing an extension without review from the people who
wrote the spec, is not something I'm personally interested in doing.4) The test suite is still broken, so it's difficult to see these things
in practice for review purposes.- Implement libpq changes without iddawc
This in particular will be much easier with a functioning test suite,
and with a smaller number of flows.- Prototype GUI flow with pgAdmin
Cool!
Thanks,
--Jacob
On 11/23/22 19:45, Andrey Chudnovsky wrote:
Mostly terminology questions here. OAUTHBEARER SASL appears to be the
spec about using OAUTH2 tokens for Authentication.
While any OAUTH2 can generally work, we propose to specifically
highlight that only OIDC providers can be supported, as we need the
discovery document.
*If* you're using in-band discovery, yes. But I thought your use case
was explicitly tailored to out-of-band token retrieval:
The client knows how to get a token for a particular principal
and doesn't need any additional information other than human readable
messages.
In that case, isn't OAuth sufficient? There's definitely a need to
document the distinction, but I don't think we have to require OIDC as
long as the client application makes up for the missing information.
(OAUTHBEARER makes the openid-configuration error member optional,
presumably for this reason.)
0) The original hook proposal upthread, I thought, was about allowing
libpq's flow implementation to be switched out by the application. I
don't see that approach taken here. It's fine if that turned out to be a
bad idea, of course, but this patch doesn't seem to match what we were
talking about.We still plan to allow the client to pass the token. Which is a
generic way to implement its own OAUTH flows.
Okay. But why push down the implementation into the server?
To illustrate what I mean, here's the architecture of my proposed patchset:
+-------+ +----------+
| | -------------- Empty Token ------------> | |
| libpq | <----- Error Result (w/ Discovery ) ---- | |
| | | |
| +--------+ +--------------+ | |
| | iddawc | <--- [ Flow ] ----> | Issuer/ | | Postgres |
| | | <-- Access Token -- | Authz Server | | |
| +--------+ +--------------+ | +-----------+
| | | | |
| | -------------- Access Token -----------> | > | Validator |
| | <---- Authorization Success/Failure ---- | < | |
| | | +-----------+
+-------+ +----------+
In this implementation, there's only one black box: the validator, which
is responsible for taking an access token from an untrusted client,
verifying that it was issued correctly for the Postgres service, and
either 1) determining whether the bearer is authorized to access the
database, or 2) determining the authenticated ID of the bearer so that
the HBA can decide whether they're authorized. (Or both.)
This approach is limited by the flows that we explicitly enable within
libpq and its OAuth implementation library. You mentioned that you
wanted to support other flows, including clients with out-of-band
knowledge, and I suggested:
If you wanted to override [iddawc's]
behavior as a client, you could replace the builtin flow with your
own, by registering a set of callbacks.
In other words, the hooks would replace iddawc in the above diagram.
In my mind, something like this:
+-------+ +----------+
+------+ | ----------- Empty Token ------------> | Postgres |
| | < | <---------- Error Result ------------ | |
| Hook | | | +-----------+
| | | | | |
+------+ > | ------------ Access Token ----------> | > | Validator |
| | <--- Authorization Success/Failure -- | < | |
| libpq | | +-----------+
+-------+ +----------+
Now there's a second black box -- the client hook -- which takes an
OAUTHBEARER error result (which may or may not have OIDC discovery
information) and returns the access token. How it does this is
unspecified -- it'll probably use some OAuth 2.0 flow, but maybe not.
Maybe it sends the user to a web browser; maybe it uses some of the
magic provider-specific libraries you mentioned upthread. It might have
a refresh token cached so it doesn't have to involve the user at all.
Crucially, though, the two black boxes remain independent of each other.
They have well-defined inputs and outputs (the client hook could be
roughly described as "implement get_auth_token()"). Their correctness
can be independently verified against published OAuth specs and/or
provider documentation. And the client application still makes a single
call to PQconnect*().
Compare this to the architecture proposed by your patch:
Client App
+----------------------+
| +-------+ +----------+
| | libpq | | Postgres |
| PQconnect > | | | +-------+
| +------+ | ------- Flow Type (!) -------> | > | |
| +- < | Hook | < | <------- Error Result -------- | < | |
| [ get +------+ | | | |
| token ] | | | | |
| | | | | | Hooks |
| v | | | | |
| PQconnect > | ----> | ------ Access Token ---------> | > | |
| | | <--- Authz Success/Failure --- | < | |
| +-------+ | +-------+
+----------------------+ +----------+
Rather than decouple things, I think this proposal drives a spike
through the client app, libpq, and the server. Please correct me if I've
misunderstood pieces of the patch, but the following is my view of it:
What used to be a validator hook on the server side now actively
participates in the client-side flow for some reason. (I still don't
understand what the server is supposed to do with that knowledge.
Changing your authz requirements based on the flow the client wants to
use seems like a good way to introduce bugs.)
The client-side hook is now coupled to the application logic: you have
to know to expect an error from the first PQconnect*() call, then check
whatever magic your hook has done for you to be able to set up the
second call to PQconnect*() with the correctly scoped bearer token. So
if you want to switch between the internal libpq OAuth implementation
and your own hook, you have to rewrite your app logic.
On top of all that, the "flow type code" being sent is a custom
extension to OAUTHBEARER that appears to be incompatible with the RFC's
discovery exchange (which is done by sending an empty auth token during
the first round trip).
We consider the following set of flows to be minimum required:
- Client Credentials - For Service to Service scenarios.
Okay, that's simple enough that I think it could probably be maintained
inside libpq with minimal cost. At the same time, is it complicated
enough that you need libpq to do it for you?
Maybe once we get the hooks ironed out, it'll be more obvious what the
tradeoff is...
If you prefer to pick one of the richer flows, Authorization code for
GUI scenarios is probably much more widely used.
Plus it's easier to implement too, as interaction goes through a
series of callbacks. No polling required.
I don't think flows requiring the invocation of web browsers and custom
URL handlers are a clear fit for libpq. For a first draft, at least, I
think that use case should be pushed upward into the client application
via a custom hook.
If we want to support refresh tokens, I believe we should be developing
a plan to cache and secure them within the client. They should be used
as an accelerator for other flows, not as their own flow.It's considered a separate "grant_type" in the specs / APIs.
https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
Yes, but that doesn't mean we have to expose it to users via a
connection option. You don't get a refresh token out of the blue; you
get it by going through some other flow, and then you use it in
preference to going through that flow again later.
For the clients, it would be storing the token and using it to authenticate.
On the question of sensitivity, secure credentials stores are
different for each platform, with a lot of cloud offerings for this.
pgAdmin, for example, has its own way to secure credentials to avoid
asking users for passwords every time the app is opened.
I believe we should delegate the refresh token management to the clients.
Delegating to client apps would be fine (and implicitly handled by a
token hook, because the client app would receive the refresh token
directly rather than going through libpq). Delegating to end users, not
so much. Printing a refresh token to stderr as proposed here is, I
think, making things unnecessarily difficult (and/or dangerous) for users.
3) I don't like the departure from the OAUTHBEARER mechanism that's
presented here. For one, since I can't see a sample plugin that makes
use of the "flow type" magic numbers that have been added, I don't
really understand why the extension to the mechanism is necessary.I don't think it's much of a departure, but rather a separation of
responsibilities between libpq and upstream clients.
Given the proposed architectures above, 1) I think this is further
coupling the components, not separating them; and 2) I can't agree that
an incompatible discovery mechanism is "not much of a departure". If
OAUTHBEARER's functionality isn't good enough for some reason, let's
talk about why.
As libpq can be used in different apps, the client would need
different types of flows/grants.
I.e. those need to be provided to libpq at connection initialization
or some other point.
Why do libpq (or the server!) need to know those things at all, if
they're not going to implement the flow?
We will change to "grant_type" though and use string to be closer to the spec.
What do you think is the best way for the client to signal which OAUTH
flow should be used?
libpq should not need to know the grant type in use if the client is
bypassing its internal implementation entirely.
Thanks,
--Jacob
On 11/24/22 00:20, mahendrakar s wrote:
I had validated Github by skipping the discovery mechanism and letting
the provider extension pass on the endpoints. This is just for
validation purposes.
If it needs to be supported, then need a way to send the discovery
document from extension.
Yeah. I had originally bounced around the idea that we could send a
data:// URL, but I think that opens up problems.
You're supposed to be able to link the issuer URI with the URI you got
the configuration from, and if they're different, you bail out. If a
server makes up its own OpenID configuration, we'd have to bypass that
safety check, and decide what the risks and mitigations are... Not sure
it's worth it.
Especially if you could just lobby GitHub to, say, provide an OpenID
config. (Maybe there's a security-related reason they don't.)
--Jacob
Jacob,
Thanks for your feedback.
I think we can focus on the roles and responsibilities of the components first.
Details of the patch can be elaborated. Like "flow type code" is a
mistake on our side, and we will use the term "grant_type" which is
defined by OIDC spec. As well as details of usage of refresh_token.
Rather than decouple things, I think this proposal drives a spike
through the client app, libpq, and the server. Please correct me if I've
misunderstood pieces of the patch, but the following is my view of it:
What used to be a validator hook on the server side now actively
participates in the client-side flow for some reason. (I still don't
understand what the server is supposed to do with that knowledge.
Changing your authz requirements based on the flow the client wants to
use seems like a good way to introduce bugs.)
The client-side hook is now coupled to the application logic: you have
to know to expect an error from the first PQconnect*() call, then check
whatever magic your hook has done for you to be able to set up the
second call to PQconnect*() with the correctly scoped bearer token. So
if you want to switch between the internal libpq OAuth implementation
and your own hook, you have to rewrite your app logic.
Basically Yes. We propose an increase of the server side hook responsibility.
From just validating the token, to also return the provider root URL
and required audience. And possibly provide more metadata in the
future.
Which is in our opinion aligned with SASL protocol, where the server
side is responsible for telling the client auth requirements based on
the requested role in the startup packet.
Our understanding is that in the original patch that information came
purely from hba, and we propose extension being able to control that
metadata.
As we see extension as being owned by the identity provider, compared
to HBA which is owned by the server administrator or cloud provider.
This change of the roles is based on the vision of 4 independent actor
types in the ecosystem:
1. Identity Providers (Okta, Google, Microsoft, other OIDC providers).
- Publish open source extensions for PostgreSQL.
- Don't have to own the server deployments, and must ensure their
extensions can work in any environment. This is where we think
additional hook responsibility helps.
2. Server Owners / PAAS providers (On premise admins, Cloud providers,
multi-cloud PAAS providers).
- Install extensions and configure HBA to allow clients to
authenticate with the identity providers of their choice.
3. Client Application Developers (Data Wis, integration tools,
PgAdmin, monitoring tools, e.t.c.)
- Independent from specific Identity providers or server providers.
Write one code for all identity providers.
- Rely on application deployment owners to configure which OIDC
provider to use across client and server setups.
4. Application Deployment Owners (End customers setting up applications)
- The only actor actually aware of which identity provider to use.
Configures the stack based on the Identity and PostgreSQL deployments
they have.
The critical piece of the vision is (3.) above is applications
agnostic of the identity providers. Those applications rely on
properly configured servers and rich driver logic (libpq,
com.postgresql, npgsql) to allow their application to popup auth
windows or do service-to-service authentication with any provider. In
our view that would significantly democratize the deployment of OAUTH
authentication in the community.
In order to allow this separation, we propose:
1. HBA + Extension is the single source of truth of Provider root URL
+ Required Audience for each role. If some backfill for missing OIDC
discovery is needed, the provider-specific extension would be
providing it.
2. Client Application knows which grant_type to use in which scenario.
But can be coded without knowledge of a specific provider. So can't
provide discovery details.
3. Driver (libpq, others) - coordinate the authentication flow based
on client grant_type and identity provider metadata to allow client
applications to use any flow with any provider in a unified way.
Yes, this would require a little more complicated flow between
components than in your original patch. And yes, more complexity comes
with more opportunity to make bugs.
However, I see PG Server and Libpq as the places which can have more
complexity. For the purpose of making work for the community
participants easier and simplify adoption.
Does this make sense to you?
Show quoted text
On Tue, Nov 29, 2022 at 1:20 PM Jacob Champion <jchampion@timescale.com> wrote:
On 11/24/22 00:20, mahendrakar s wrote:
I had validated Github by skipping the discovery mechanism and letting
the provider extension pass on the endpoints. This is just for
validation purposes.
If it needs to be supported, then need a way to send the discovery
document from extension.Yeah. I had originally bounced around the idea that we could send a
data:// URL, but I think that opens up problems.You're supposed to be able to link the issuer URI with the URI you got
the configuration from, and if they're different, you bail out. If a
server makes up its own OpenID configuration, we'd have to bypass that
safety check, and decide what the risks and mitigations are... Not sure
it's worth it.Especially if you could just lobby GitHub to, say, provide an OpenID
config. (Maybe there's a security-related reason they don't.)--Jacob
On Mon, Dec 5, 2022 at 4:15 PM Andrey Chudnovsky <achudnovskij@gmail.com> wrote:
I think we can focus on the roles and responsibilities of the components first.
Details of the patch can be elaborated. Like "flow type code" is a
mistake on our side, and we will use the term "grant_type" which is
defined by OIDC spec. As well as details of usage of refresh_token.
(For the record, whether we call it "flow type" or "grant type"
doesn't address my concern.)
Basically Yes. We propose an increase of the server side hook responsibility.
From just validating the token, to also return the provider root URL
and required audience. And possibly provide more metadata in the
future.
I think it's okay to have the extension and HBA collaborate to provide
discovery information. Your proposal goes further than that, though,
and makes the server aware of the chosen client flow. That appears to
be an architectural violation: why does an OAuth resource server need
to know the client flow at all?
Which is in our opinion aligned with SASL protocol, where the server
side is responsible for telling the client auth requirements based on
the requested role in the startup packet.
You've proposed an alternative SASL mechanism. There's nothing wrong
with that, per se, but I think it should be clear why we've chosen
something nonstandard.
Our understanding is that in the original patch that information came
purely from hba, and we propose extension being able to control that
metadata.
As we see extension as being owned by the identity provider, compared
to HBA which is owned by the server administrator or cloud provider.
That seems reasonable, considering how tightly coupled the Issuer and
the token validation process are.
2. Server Owners / PAAS providers (On premise admins, Cloud providers,
multi-cloud PAAS providers).
- Install extensions and configure HBA to allow clients to
authenticate with the identity providers of their choice.
(For a future conversation: they need to set up authorization, too,
with custom scopes or some other magic. It's not enough to check who
the token belongs to; even if Postgres is just using the verified
email from OpenID as an authenticator, you have to also know that the
user authorized the token -- and therefore the client -- to access
Postgres on their behalf.)
3. Client Application Developers (Data Wis, integration tools,
PgAdmin, monitoring tools, e.t.c.)
- Independent from specific Identity providers or server providers.
Write one code for all identity providers.
Ideally, yes, but that only works if all identity providers implement
the same flows in compatible ways. We're already seeing instances
where that's not the case and we'll necessarily have to deal with that
up front.
- Rely on application deployment owners to configure which OIDC
provider to use across client and server setups.
4. Application Deployment Owners (End customers setting up applications)
- The only actor actually aware of which identity provider to use.
Configures the stack based on the Identity and PostgreSQL deployments
they have.
(I have doubts that the roles will be as decoupled in practice as you
have described them, but I'd rather defer that for now.)
The critical piece of the vision is (3.) above is applications
agnostic of the identity providers. Those applications rely on
properly configured servers and rich driver logic (libpq,
com.postgresql, npgsql) to allow their application to popup auth
windows or do service-to-service authentication with any provider. In
our view that would significantly democratize the deployment of OAUTH
authentication in the community.
That seems to be restating the goal of OAuth and OIDC. Can you explain
how the incompatible change allows you to accomplish this better than
standard implementations?
In order to allow this separation, we propose:
1. HBA + Extension is the single source of truth of Provider root URL
+ Required Audience for each role. If some backfill for missing OIDC
discovery is needed, the provider-specific extension would be
providing it.
2. Client Application knows which grant_type to use in which scenario.
But can be coded without knowledge of a specific provider. So can't
provide discovery details.
3. Driver (libpq, others) - coordinate the authentication flow based
on client grant_type and identity provider metadata to allow client
applications to use any flow with any provider in a unified way.Yes, this would require a little more complicated flow between
components than in your original patch.
Why? I claim that standard OAUTHBEARER can handle all of that. What
does your proposed architecture (the third diagram) enable that my
proposed hook (the second diagram) doesn't?
And yes, more complexity comes
with more opportunity to make bugs.
However, I see PG Server and Libpq as the places which can have more
complexity. For the purpose of making work for the community
participants easier and simplify adoption.Does this make sense to you?
Some of it, but it hasn't really addressed the questions from my last mail.
Thanks,
--Jacob
I think it's okay to have the extension and HBA collaborate to provide
discovery information. Your proposal goes further than that, though,
and makes the server aware of the chosen client flow. That appears to
be an architectural violation: why does an OAuth resource server need
to know the client flow at all?
Ok. It may have left there from intermediate iterations. We did
consider making extension drive the flow for specific grant_type, but
decided against that idea. For the same reason you point to.
Is it correct that your main concern about use of grant_type was that
it's propagated to the server? Then yes, we will remove sending it to
the server.
Ideally, yes, but that only works if all identity providers implement
the same flows in compatible ways. We're already seeing instances
where that's not the case and we'll necessarily have to deal with that
up front.
Yes, based on our analysis OIDC spec is detailed enough, that
providers implementing that one, can be supported with generic code in
libpq / client.
Github specifically won't fit there though. Microsoft Azure AD,
Google, Okta (including Auth0) will.
Theoretically discovery documents can be returned from the extension
(server-side) which is provider specific. Though we didn't plan to
prioritize that.
That seems to be restating the goal of OAuth and OIDC. Can you explain
how the incompatible change allows you to accomplish this better than
standard implementations?
Do you refer to passing grant_type to the server? Which we will get
rid of in the next iteration. Or other incompatible changes as well?
Why? I claim that standard OAUTHBEARER can handle all of that. What
does your proposed architecture (the third diagram) enable that my
proposed hook (the second diagram) doesn't?
The hook proposed on the 2nd diagram effectively delegates all Oauth
flows implementations to the client.
We propose libpq takes care of pulling OpenId discovery and coordination.
Which is effectively Diagram 1 + more flows + server hook providing
root url/audience.
Created the diagrams with all components for 3 flows:
1. Authorization code grant (Clients with Browser access):
+----------------------+ +----------+
| +-------+ |
|
| PQconnect | | |
|
| [auth_code] | | |
+-----------+
| -> | | -------------- Empty Token ------------> | >
| |
| | libpq | <----- Error(w\ Root URL + Audience ) -- | <
| Pre-Auth |
| | | |
| Hook |
| | | |
+-----------+
| | | +--------------+ | |
| | | -------[GET]---------> | OIDC | | Postgres |
| +------+ | <--Provider Metadata-- | Discovery | | |
| +- < | Hook | < | +--------------+ |
|
| | +------+ | |
|
| v | | |
|
| [get auth | | |
|
| code] | | |
|
|<user action>| | |
|
| | | | |
|
| + | | |
|
| PQconnect > | +--------+ +--------------+ |
|
| | | iddawc | <-- [ Auth code ]-> | Issuer/ | | |
| | | | <-- Access Token -- | Authz Server | | |
| | +--------+ +--------------+ | |
| | | |
+-----------+
| | | -------------- Access Token -----------> | >
| Validator |
| | | <---- Authorization Success/Failure ---- | <
| Hook |
| +------+ | |
+-----------+
| +-< | Hook | | |
|
| v +------+ | |
|
|[store +-------+ |
|
| refresh_token] +----------+
+----------------------+
2. Device code grant
+----------------------+ +----------+
| +-------+ |
|
| PQconnect | | |
|
| [auth_code] | | |
+-----------+
| -> | | -------------- Empty Token ------------> | >
| |
| | libpq | <----- Error(w\ Root URL + Audience ) -- | <
| Pre-Auth |
| | | |
| Hook |
| | | |
+-----------+
| | | +--------------+ | |
| | | -------[GET]---------> | OIDC | | Postgres |
| +------+ | <--Provider Metadata-- | Discovery | | |
| +- < | Hook | < | +--------------+ |
|
| | +------+ | |
|
| v | | |
|
| [device | +---------+ +--------------+ |
|
| code] | | iddawc | | Issuer/ | |
|
|<user action>| | | --[ Device code ]-> | Authz Server | |
|
| | |<polling>| --[ Device code ]-> | | |
|
| | | | --[ Device code ]-> | | |
|
| | | | | | | |
| | | | <-- Access Token -- | | | |
| | +---------+ +--------------+ | |
| | | |
+-----------+
| | | -------------- Access Token -----------> | >
| Validator |
| | | <---- Authorization Success/Failure ---- | <
| Hook |
| +------+ | |
+-----------+
| +-< | Hook | | |
|
| v +------+ | |
|
|[store +-------+ |
|
| refresh_token] +----------+
+----------------------+
3. Non-interactive flows (Client Secret / Refresh_Token)
+----------------------+ +----------+
| +-------+ |
|
| PQconnect | | |
|
| [grant_type]| | | |
| -> | | |
+-----------+
| | | -------------- Empty Token ------------> | >
| |
| | libpq | <----- Error(w\ Root URL + Audience ) -- | <
| Pre-Auth |
| | | |
| Hook |
| | | |
+-----------+
| | | +--------------+ | |
| | | -------[GET]---------> | OIDC | | Postgres |
| | | <--Provider Metadata-- | Discovery | | |
| | | +--------------+ |
|
| | | |
|
| | +--------+ +--------------+ |
|
| | | iddawc | <-- [ Secret ]----> | Issuer/ | | |
| | | | <-- Access Token -- | Authz Server | | |
| | +--------+ +--------------+ | |
| | | |
+-----------+
| | | -------------- Access Token -----------> | >
| Validator |
| | | <---- Authorization Success/Failure ---- | <
| Hook |
| | | |
+-----------+
| +-------+ +----------+
+----------------------+
I think what was the most confusing in our latest patch is that
flow_type was passed to the server.
We are not proposing this going forward.
(For a future conversation: they need to set up authorization, too,
with custom scopes or some other magic. It's not enough to check who
the token belongs to; even if Postgres is just using the verified
email from OpenID as an authenticator, you have to also know that the
user authorized the token -- and therefore the client -- to access
Postgres on their behalf.)
My understanding is that metadata in the tokens is provider specific,
so server side hook would be the right place to handle that.
Plus I can envision for some providers it can make sense to make a
remote call to pull some information.
The way we implement Azure AD auth today in PAAS PostgreSQL offering:
- Server administrator uses special extension functions to create
Azure AD enabled PostgreSQL roles.
- PostgreSQL extension maps Roles to unique identity Ids (UID) in the Directory.
- Connection flow: If the token is valid and Role => UID mapping
matches, we authenticate as the Role.
- Then its native PostgreSQL role based access control takes care of privileges.
This is the same for both User- and System-to-system authorization.
Though I assume different providers may treat user- and system-
identities differently. So their extension would handle that.
Thanks!
Andrey.
Show quoted text
On Wed, Dec 7, 2022 at 11:06 AM Jacob Champion <jchampion@timescale.com> wrote:
On Mon, Dec 5, 2022 at 4:15 PM Andrey Chudnovsky <achudnovskij@gmail.com> wrote:
I think we can focus on the roles and responsibilities of the components first.
Details of the patch can be elaborated. Like "flow type code" is a
mistake on our side, and we will use the term "grant_type" which is
defined by OIDC spec. As well as details of usage of refresh_token.(For the record, whether we call it "flow type" or "grant type"
doesn't address my concern.)Basically Yes. We propose an increase of the server side hook responsibility.
From just validating the token, to also return the provider root URL
and required audience. And possibly provide more metadata in the
future.I think it's okay to have the extension and HBA collaborate to provide
discovery information. Your proposal goes further than that, though,
and makes the server aware of the chosen client flow. That appears to
be an architectural violation: why does an OAuth resource server need
to know the client flow at all?Which is in our opinion aligned with SASL protocol, where the server
side is responsible for telling the client auth requirements based on
the requested role in the startup packet.You've proposed an alternative SASL mechanism. There's nothing wrong
with that, per se, but I think it should be clear why we've chosen
something nonstandard.Our understanding is that in the original patch that information came
purely from hba, and we propose extension being able to control that
metadata.
As we see extension as being owned by the identity provider, compared
to HBA which is owned by the server administrator or cloud provider.That seems reasonable, considering how tightly coupled the Issuer and
the token validation process are.2. Server Owners / PAAS providers (On premise admins, Cloud providers,
multi-cloud PAAS providers).
- Install extensions and configure HBA to allow clients to
authenticate with the identity providers of their choice.(For a future conversation: they need to set up authorization, too,
with custom scopes or some other magic. It's not enough to check who
the token belongs to; even if Postgres is just using the verified
email from OpenID as an authenticator, you have to also know that the
user authorized the token -- and therefore the client -- to access
Postgres on their behalf.)3. Client Application Developers (Data Wis, integration tools,
PgAdmin, monitoring tools, e.t.c.)
- Independent from specific Identity providers or server providers.
Write one code for all identity providers.Ideally, yes, but that only works if all identity providers implement
the same flows in compatible ways. We're already seeing instances
where that's not the case and we'll necessarily have to deal with that
up front.- Rely on application deployment owners to configure which OIDC
provider to use across client and server setups.
4. Application Deployment Owners (End customers setting up applications)
- The only actor actually aware of which identity provider to use.
Configures the stack based on the Identity and PostgreSQL deployments
they have.(I have doubts that the roles will be as decoupled in practice as you
have described them, but I'd rather defer that for now.)The critical piece of the vision is (3.) above is applications
agnostic of the identity providers. Those applications rely on
properly configured servers and rich driver logic (libpq,
com.postgresql, npgsql) to allow their application to popup auth
windows or do service-to-service authentication with any provider. In
our view that would significantly democratize the deployment of OAUTH
authentication in the community.That seems to be restating the goal of OAuth and OIDC. Can you explain
how the incompatible change allows you to accomplish this better than
standard implementations?In order to allow this separation, we propose:
1. HBA + Extension is the single source of truth of Provider root URL
+ Required Audience for each role. If some backfill for missing OIDC
discovery is needed, the provider-specific extension would be
providing it.
2. Client Application knows which grant_type to use in which scenario.
But can be coded without knowledge of a specific provider. So can't
provide discovery details.
3. Driver (libpq, others) - coordinate the authentication flow based
on client grant_type and identity provider metadata to allow client
applications to use any flow with any provider in a unified way.Yes, this would require a little more complicated flow between
components than in your original patch.Why? I claim that standard OAUTHBEARER can handle all of that. What
does your proposed architecture (the third diagram) enable that my
proposed hook (the second diagram) doesn't?And yes, more complexity comes
with more opportunity to make bugs.
However, I see PG Server and Libpq as the places which can have more
complexity. For the purpose of making work for the community
participants easier and simplify adoption.Does this make sense to you?
Some of it, but it hasn't really addressed the questions from my last mail.
Thanks,
--Jacob
That being said, the Diagram 2 would look like this with our proposal:
+----------------------+ +----------+
| +-------+ | Postgres |
| PQconnect ->| | |
|
| | | |
+-----------+
| | | -------------- Empty Token ------------> | >
| |
| | libpq | <----- Error(w\ Root URL + Audience ) -- | <
| Pre-Auth |
| +------+ | |
| Hook |
| +- < | Hook | | |
+-----------+
| | +------+ | | |
| v | | |
|
| [get token]| | |
|
| | | | |
|
| + | | |
+-----------+
| PQconnect > | | -------------- Access Token -----------> | >
| Validator |
| | | <---- Authorization Success/Failure ---- | <
| Hook |
| | | |
+-----------+
| +-------+ |
| +----------------------+
+----------+
With the application taking care of all Token acquisition logic. While
the server-side hook is participating in the pre-authentication reply.
That is definitely a required scenario for the long term and the
easiest to implement in the client core.
And if we can do at least that flow in PG16 it will be a strong
foundation to provide more support for specific grants in libpq going
forward.
Does the diagram above look good to you? We can then start cleaning up
the patch to get that in first.
Thanks!
Andrey.
Show quoted text
On Wed, Dec 7, 2022 at 3:22 PM Andrey Chudnovsky <achudnovskij@gmail.com> wrote:
I think it's okay to have the extension and HBA collaborate to provide
discovery information. Your proposal goes further than that, though,
and makes the server aware of the chosen client flow. That appears to
be an architectural violation: why does an OAuth resource server need
to know the client flow at all?Ok. It may have left there from intermediate iterations. We did
consider making extension drive the flow for specific grant_type, but
decided against that idea. For the same reason you point to.
Is it correct that your main concern about use of grant_type was that
it's propagated to the server? Then yes, we will remove sending it to
the server.Ideally, yes, but that only works if all identity providers implement
the same flows in compatible ways. We're already seeing instances
where that's not the case and we'll necessarily have to deal with that
up front.Yes, based on our analysis OIDC spec is detailed enough, that
providers implementing that one, can be supported with generic code in
libpq / client.
Github specifically won't fit there though. Microsoft Azure AD,
Google, Okta (including Auth0) will.
Theoretically discovery documents can be returned from the extension
(server-side) which is provider specific. Though we didn't plan to
prioritize that.That seems to be restating the goal of OAuth and OIDC. Can you explain
how the incompatible change allows you to accomplish this better than
standard implementations?Do you refer to passing grant_type to the server? Which we will get
rid of in the next iteration. Or other incompatible changes as well?Why? I claim that standard OAUTHBEARER can handle all of that. What
does your proposed architecture (the third diagram) enable that my
proposed hook (the second diagram) doesn't?The hook proposed on the 2nd diagram effectively delegates all Oauth
flows implementations to the client.
We propose libpq takes care of pulling OpenId discovery and coordination.
Which is effectively Diagram 1 + more flows + server hook providing
root url/audience.Created the diagrams with all components for 3 flows: 1. Authorization code grant (Clients with Browser access): +----------------------+ +----------+ | +-------+ | | | PQconnect | | | | | [auth_code] | | | +-----------+ | -> | | -------------- Empty Token ------------> | > | | | | libpq | <----- Error(w\ Root URL + Audience ) -- | < | Pre-Auth | | | | | | Hook | | | | | +-----------+ | | | +--------------+ | | | | | -------[GET]---------> | OIDC | | Postgres | | +------+ | <--Provider Metadata-- | Discovery | | | | +- < | Hook | < | +--------------+ | | | | +------+ | | | | v | | | | | [get auth | | | | | code] | | | | |<user action>| | | | | | | | | | | + | | | | | PQconnect > | +--------+ +--------------+ | | | | | iddawc | <-- [ Auth code ]-> | Issuer/ | | | | | | | <-- Access Token -- | Authz Server | | | | | +--------+ +--------------+ | | | | | | +-----------+ | | | -------------- Access Token -----------> | > | Validator | | | | <---- Authorization Success/Failure ---- | < | Hook | | +------+ | | +-----------+ | +-< | Hook | | | | | v +------+ | | | |[store +-------+ | | | refresh_token] +----------+ +----------------------+2. Device code grant +----------------------+ +----------+ | +-------+ | | | PQconnect | | | | | [auth_code] | | | +-----------+ | -> | | -------------- Empty Token ------------> | > | | | | libpq | <----- Error(w\ Root URL + Audience ) -- | < | Pre-Auth | | | | | | Hook | | | | | +-----------+ | | | +--------------+ | | | | | -------[GET]---------> | OIDC | | Postgres | | +------+ | <--Provider Metadata-- | Discovery | | | | +- < | Hook | < | +--------------+ | | | | +------+ | | | | v | | | | | [device | +---------+ +--------------+ | | | code] | | iddawc | | Issuer/ | | | |<user action>| | | --[ Device code ]-> | Authz Server | | | | | |<polling>| --[ Device code ]-> | | | | | | | | --[ Device code ]-> | | | | | | | | | | | | | | | | <-- Access Token -- | | | | | | +---------+ +--------------+ | | | | | | +-----------+ | | | -------------- Access Token -----------> | > | Validator | | | | <---- Authorization Success/Failure ---- | < | Hook | | +------+ | | +-----------+ | +-< | Hook | | | | | v +------+ | | | |[store +-------+ | | | refresh_token] +----------+ +----------------------+3. Non-interactive flows (Client Secret / Refresh_Token) +----------------------+ +----------+ | +-------+ | | | PQconnect | | | | | [grant_type]| | | | | -> | | | +-----------+ | | | -------------- Empty Token ------------> | > | | | | libpq | <----- Error(w\ Root URL + Audience ) -- | < | Pre-Auth | | | | | | Hook | | | | | +-----------+ | | | +--------------+ | | | | | -------[GET]---------> | OIDC | | Postgres | | | | <--Provider Metadata-- | Discovery | | | | | | +--------------+ | | | | | | | | | +--------+ +--------------+ | | | | | iddawc | <-- [ Secret ]----> | Issuer/ | | | | | | | <-- Access Token -- | Authz Server | | | | | +--------+ +--------------+ | | | | | | +-----------+ | | | -------------- Access Token -----------> | > | Validator | | | | <---- Authorization Success/Failure ---- | < | Hook | | | | | +-----------+ | +-------+ +----------+ +----------------------+I think what was the most confusing in our latest patch is that
flow_type was passed to the server.
We are not proposing this going forward.(For a future conversation: they need to set up authorization, too,
with custom scopes or some other magic. It's not enough to check who
the token belongs to; even if Postgres is just using the verified
email from OpenID as an authenticator, you have to also know that the
user authorized the token -- and therefore the client -- to access
Postgres on their behalf.)My understanding is that metadata in the tokens is provider specific,
so server side hook would be the right place to handle that.
Plus I can envision for some providers it can make sense to make a
remote call to pull some information.The way we implement Azure AD auth today in PAAS PostgreSQL offering:
- Server administrator uses special extension functions to create
Azure AD enabled PostgreSQL roles.
- PostgreSQL extension maps Roles to unique identity Ids (UID) in the Directory.
- Connection flow: If the token is valid and Role => UID mapping
matches, we authenticate as the Role.
- Then its native PostgreSQL role based access control takes care of privileges.This is the same for both User- and System-to-system authorization.
Though I assume different providers may treat user- and system-
identities differently. So their extension would handle that.Thanks!
Andrey.On Wed, Dec 7, 2022 at 11:06 AM Jacob Champion <jchampion@timescale.com> wrote:
On Mon, Dec 5, 2022 at 4:15 PM Andrey Chudnovsky <achudnovskij@gmail.com> wrote:
I think we can focus on the roles and responsibilities of the components first.
Details of the patch can be elaborated. Like "flow type code" is a
mistake on our side, and we will use the term "grant_type" which is
defined by OIDC spec. As well as details of usage of refresh_token.(For the record, whether we call it "flow type" or "grant type"
doesn't address my concern.)Basically Yes. We propose an increase of the server side hook responsibility.
From just validating the token, to also return the provider root URL
and required audience. And possibly provide more metadata in the
future.I think it's okay to have the extension and HBA collaborate to provide
discovery information. Your proposal goes further than that, though,
and makes the server aware of the chosen client flow. That appears to
be an architectural violation: why does an OAuth resource server need
to know the client flow at all?Which is in our opinion aligned with SASL protocol, where the server
side is responsible for telling the client auth requirements based on
the requested role in the startup packet.You've proposed an alternative SASL mechanism. There's nothing wrong
with that, per se, but I think it should be clear why we've chosen
something nonstandard.Our understanding is that in the original patch that information came
purely from hba, and we propose extension being able to control that
metadata.
As we see extension as being owned by the identity provider, compared
to HBA which is owned by the server administrator or cloud provider.That seems reasonable, considering how tightly coupled the Issuer and
the token validation process are.2. Server Owners / PAAS providers (On premise admins, Cloud providers,
multi-cloud PAAS providers).
- Install extensions and configure HBA to allow clients to
authenticate with the identity providers of their choice.(For a future conversation: they need to set up authorization, too,
with custom scopes or some other magic. It's not enough to check who
the token belongs to; even if Postgres is just using the verified
email from OpenID as an authenticator, you have to also know that the
user authorized the token -- and therefore the client -- to access
Postgres on their behalf.)3. Client Application Developers (Data Wis, integration tools,
PgAdmin, monitoring tools, e.t.c.)
- Independent from specific Identity providers or server providers.
Write one code for all identity providers.Ideally, yes, but that only works if all identity providers implement
the same flows in compatible ways. We're already seeing instances
where that's not the case and we'll necessarily have to deal with that
up front.- Rely on application deployment owners to configure which OIDC
provider to use across client and server setups.
4. Application Deployment Owners (End customers setting up applications)
- The only actor actually aware of which identity provider to use.
Configures the stack based on the Identity and PostgreSQL deployments
they have.(I have doubts that the roles will be as decoupled in practice as you
have described them, but I'd rather defer that for now.)The critical piece of the vision is (3.) above is applications
agnostic of the identity providers. Those applications rely on
properly configured servers and rich driver logic (libpq,
com.postgresql, npgsql) to allow their application to popup auth
windows or do service-to-service authentication with any provider. In
our view that would significantly democratize the deployment of OAUTH
authentication in the community.That seems to be restating the goal of OAuth and OIDC. Can you explain
how the incompatible change allows you to accomplish this better than
standard implementations?In order to allow this separation, we propose:
1. HBA + Extension is the single source of truth of Provider root URL
+ Required Audience for each role. If some backfill for missing OIDC
discovery is needed, the provider-specific extension would be
providing it.
2. Client Application knows which grant_type to use in which scenario.
But can be coded without knowledge of a specific provider. So can't
provide discovery details.
3. Driver (libpq, others) - coordinate the authentication flow based
on client grant_type and identity provider metadata to allow client
applications to use any flow with any provider in a unified way.Yes, this would require a little more complicated flow between
components than in your original patch.Why? I claim that standard OAUTHBEARER can handle all of that. What
does your proposed architecture (the third diagram) enable that my
proposed hook (the second diagram) doesn't?And yes, more complexity comes
with more opportunity to make bugs.
However, I see PG Server and Libpq as the places which can have more
complexity. For the purpose of making work for the community
participants easier and simplify adoption.Does this make sense to you?
Some of it, but it hasn't really addressed the questions from my last mail.
Thanks,
--Jacob
On Wed, Dec 7, 2022 at 3:22 PM Andrey Chudnovsky
<achudnovskij@gmail.com> wrote:
I think it's okay to have the extension and HBA collaborate to
provide discovery information. Your proposal goes further than
that, though, and makes the server aware of the chosen client flow.
That appears to be an architectural violation: why does an OAuth
resource server need to know the client flow at all?Ok. It may have left there from intermediate iterations. We did
consider making extension drive the flow for specific grant_type,
but decided against that idea. For the same reason you point to. Is
it correct that your main concern about use of grant_type was that
it's propagated to the server? Then yes, we will remove sending it
to the server.
Okay. Yes, that was my primary concern.
Ideally, yes, but that only works if all identity providers
implement the same flows in compatible ways. We're already seeing
instances where that's not the case and we'll necessarily have to
deal with that up front.Yes, based on our analysis OIDC spec is detailed enough, that
providers implementing that one, can be supported with generic code
in libpq / client. Github specifically won't fit there though.
Microsoft Azure AD, Google, Okta (including Auth0) will.
Theoretically discovery documents can be returned from the extension
(server-side) which is provider specific. Though we didn't plan to
prioritize that.
As another example, Google's device authorization grant is incompatible
with the spec (which they co-authored). I want to say I had problems
with Azure AD not following that spec either, but I don't remember
exactly what they were. I wouldn't be surprised to find more tiny
departures once we get deeper into implementation.
That seems to be restating the goal of OAuth and OIDC. Can you
explain how the incompatible change allows you to accomplish this
better than standard implementations?Do you refer to passing grant_type to the server? Which we will get
rid of in the next iteration. Or other incompatible changes as well?
Just the grant type, yeah.
Why? I claim that standard OAUTHBEARER can handle all of that.
What does your proposed architecture (the third diagram) enable
that my proposed hook (the second diagram) doesn't?The hook proposed on the 2nd diagram effectively delegates all Oauth
flows implementations to the client. We propose libpq takes care of
pulling OpenId discovery and coordination. Which is effectively
Diagram 1 + more flows + server hook providing root url/audience.Created the diagrams with all components for 3 flows: [snip]
(I'll skip ahead to your later mail on this.)
(For a future conversation: they need to set up authorization,
too, with custom scopes or some other magic. It's not enough to
check who the token belongs to; even if Postgres is just using the
verified email from OpenID as an authenticator, you have to also
know that the user authorized the token -- and therefore the client
-- to access Postgres on their behalf.)My understanding is that metadata in the tokens is provider
specific, so server side hook would be the right place to handle
that. Plus I can envision for some providers it can make sense to
make a remote call to pull some information.
The server hook is the right place to check the scopes, yes, but I think
the DBA should be able to specify what those scopes are to begin with.
The provider of the extension shouldn't be expected by the architecture
to hardcode those decisions, even if Azure AD chooses to short-circuit
that choice and provide magic instead.
On 12/7/22 20:25, Andrey Chudnovsky wrote:
That being said, the Diagram 2 would look like this with our proposal:
[snip]With the application taking care of all Token acquisition logic. While
the server-side hook is participating in the pre-authentication reply.That is definitely a required scenario for the long term and the
easiest to implement in the client core.> And if we can do at least that flow in PG16 it will be a strong
foundation to provide more support for specific grants in libpq going
forward.
Agreed.
Does the diagram above look good to you? We can then start cleaning up
the patch to get that in first.
I maintain that the hook doesn't need to hand back artifacts to the
client for a second PQconnect call. It can just use those artifacts to
obtain the access token and hand that right back to libpq. (I think any
requirement that clients be rewritten to call PQconnect twice will
probably be a sticking point for adoption of an OAuth patch.)
That said, now that your proposal is also compatible with OAUTHBEARER, I
can pony up some code to hopefully prove my point. (I don't know if I'll
be able to do that by the holidays though.)
Thanks!
--Jacob
The server hook is the right place to check the scopes, yes, but I think
the DBA should be able to specify what those scopes are to begin with.
The provider of the extension shouldn't be expected by the architecture
to hardcode those decisions, even if Azure AD chooses to short-circuit
that choice and provide magic instead.
Hardcode is definitely not expected, but customization for identity
provider specific, I think, should be allowed.
I can provide a couple of advanced use cases which happen in the cloud
deployments world, and require per-role management:
- Multi-tenant deployments, when root provider URL would be different
for different roles, based on which tenant they come from.
- Federation to multiple providers. Solutions like Amazon Cognito
which offer a layer of abstraction with several providers
transparently supported.
If your concern is extension not honoring the DBA configured values:
Would a server-side logic to prefer HBA value over extension-provided
resolve this concern?
We are definitely biased towards the cloud deployment scenarios, where
direct access to .hba files is usually not offered at all.
Let's find the middle ground here.
A separate reason for creating this pre-authentication hook is further
extensibility to support more metadata.
Specifically when we add support for OAUTH flows to libpq, server-side
extensions can help bridge the gap between the identity provider
implementation and OAUTH/OIDC specs.
For example, that could allow the Github extension to provide an OIDC
discovery document.
I definitely see identity providers as institutional actors here which
can be given some power through the extension hooks to customize the
behavior within the framework.
I maintain that the hook doesn't need to hand back artifacts to the
client for a second PQconnect call. It can just use those artifacts to
obtain the access token and hand that right back to libpq. (I think any
requirement that clients be rewritten to call PQconnect twice will
probably be a sticking point for adoption of an OAuth patch.)
Obtaining a token is an asynchronous process with a human in the loop.
Not sure if expecting a hook function to return a token synchronously
is the best option here.
Can that be an optional return value of the hook in cases when a token
can be obtained synchronously?
Show quoted text
On Thu, Dec 8, 2022 at 4:41 PM Jacob Champion <jchampion@timescale.com> wrote:
On Wed, Dec 7, 2022 at 3:22 PM Andrey Chudnovsky
<achudnovskij@gmail.com> wrote:I think it's okay to have the extension and HBA collaborate to
provide discovery information. Your proposal goes further than
that, though, and makes the server aware of the chosen client flow.
That appears to be an architectural violation: why does an OAuth
resource server need to know the client flow at all?Ok. It may have left there from intermediate iterations. We did
consider making extension drive the flow for specific grant_type,
but decided against that idea. For the same reason you point to. Is
it correct that your main concern about use of grant_type was that
it's propagated to the server? Then yes, we will remove sending it
to the server.Okay. Yes, that was my primary concern.
Ideally, yes, but that only works if all identity providers
implement the same flows in compatible ways. We're already seeing
instances where that's not the case and we'll necessarily have to
deal with that up front.Yes, based on our analysis OIDC spec is detailed enough, that
providers implementing that one, can be supported with generic code
in libpq / client. Github specifically won't fit there though.
Microsoft Azure AD, Google, Okta (including Auth0) will.
Theoretically discovery documents can be returned from the extension
(server-side) which is provider specific. Though we didn't plan to
prioritize that.As another example, Google's device authorization grant is incompatible
with the spec (which they co-authored). I want to say I had problems
with Azure AD not following that spec either, but I don't remember
exactly what they were. I wouldn't be surprised to find more tiny
departures once we get deeper into implementation.That seems to be restating the goal of OAuth and OIDC. Can you
explain how the incompatible change allows you to accomplish this
better than standard implementations?Do you refer to passing grant_type to the server? Which we will get
rid of in the next iteration. Or other incompatible changes as well?Just the grant type, yeah.
Why? I claim that standard OAUTHBEARER can handle all of that.
What does your proposed architecture (the third diagram) enable
that my proposed hook (the second diagram) doesn't?The hook proposed on the 2nd diagram effectively delegates all Oauth
flows implementations to the client. We propose libpq takes care of
pulling OpenId discovery and coordination. Which is effectively
Diagram 1 + more flows + server hook providing root url/audience.Created the diagrams with all components for 3 flows: [snip]
(I'll skip ahead to your later mail on this.)
(For a future conversation: they need to set up authorization,
too, with custom scopes or some other magic. It's not enough to
check who the token belongs to; even if Postgres is just using the
verified email from OpenID as an authenticator, you have to also
know that the user authorized the token -- and therefore the client
-- to access Postgres on their behalf.)My understanding is that metadata in the tokens is provider
specific, so server side hook would be the right place to handle
that. Plus I can envision for some providers it can make sense to
make a remote call to pull some information.The server hook is the right place to check the scopes, yes, but I think
the DBA should be able to specify what those scopes are to begin with.
The provider of the extension shouldn't be expected by the architecture
to hardcode those decisions, even if Azure AD chooses to short-circuit
that choice and provide magic instead.On 12/7/22 20:25, Andrey Chudnovsky wrote:
That being said, the Diagram 2 would look like this with our proposal:
[snip]With the application taking care of all Token acquisition logic. While
the server-side hook is participating in the pre-authentication reply.That is definitely a required scenario for the long term and the
easiest to implement in the client core.> And if we can do at least that flow in PG16 it will be a strong
foundation to provide more support for specific grants in libpq going
forward.Agreed.
Does the diagram above look good to you? We can then start cleaning up
the patch to get that in first.I maintain that the hook doesn't need to hand back artifacts to the
client for a second PQconnect call. It can just use those artifacts to
obtain the access token and hand that right back to libpq. (I think any
requirement that clients be rewritten to call PQconnect twice will
probably be a sticking point for adoption of an OAuth patch.)That said, now that your proposal is also compatible with OAUTHBEARER, I
can pony up some code to hopefully prove my point. (I don't know if I'll
be able to do that by the holidays though.)Thanks!
--Jacob
On Mon, Dec 12, 2022 at 9:06 PM Andrey Chudnovsky
<achudnovskij@gmail.com> wrote:
If your concern is extension not honoring the DBA configured values:
Would a server-side logic to prefer HBA value over extension-provided
resolve this concern?
Yeah. It also seals the role of the extension here as "optional".
We are definitely biased towards the cloud deployment scenarios, where
direct access to .hba files is usually not offered at all.
Let's find the middle ground here.
Sure. I don't want to make this difficult in cloud scenarios --
obviously I'd like for Timescale Cloud to be able to make use of this
too. But if we make this easy for a lone DBA (who doesn't have any
institutional power with the providers) to use correctly and securely,
then it should follow that the providers who _do_ have power and
resources will have an easy time of it as well. The reverse isn't
necessarily true. So I'm definitely planning to focus on the DBA case
first.
A separate reason for creating this pre-authentication hook is further
extensibility to support more metadata.
Specifically when we add support for OAUTH flows to libpq, server-side
extensions can help bridge the gap between the identity provider
implementation and OAUTH/OIDC specs.
For example, that could allow the Github extension to provide an OIDC
discovery document.I definitely see identity providers as institutional actors here which
can be given some power through the extension hooks to customize the
behavior within the framework.
We'll probably have to make some compromises in this area, but I think
they should be carefully considered exceptions and not a core feature
of the mechanism. The gaps you point out are just fragmentation, and
adding custom extensions to deal with it leads to further
fragmentation instead of providing pressure on providers to just
implement the specs. Worst case, we open up new exciting security
flaws, and then no one can analyze them independently because no one
other than the provider knows how the two sides work together anymore.
Don't get me wrong; it would be naive to proceed as if the OAUTHBEARER
spec were perfect, because it's clearly not. But if we need to make
extensions to it, we can participate in IETF discussions and make our
case publicly for review, rather than enshrining MS/GitHub/Google/etc.
versions of the RFC and enabling that proliferation as a Postgres core
feature.
Obtaining a token is an asynchronous process with a human in the loop.
Not sure if expecting a hook function to return a token synchronously
is the best option here.
Can that be an optional return value of the hook in cases when a token
can be obtained synchronously?
I don't think the hook is generally going to be able to return a token
synchronously, and I expect the final design to be async-first. As far
as I know, this will need to be solved for the builtin flows as well
(you don't want a synchronous HTTP call to block your PQconnectPoll
architecture), so the hook should be able to make use of whatever
solution we land on for that.
This is hand-wavy, and I don't expect it to be easy to solve. I just
don't think we have to solve it twice.
Have a good end to the year!
--Jacob
Hi All,
Changes added to Jacob's patch(v2) as per the discussion in the thread.
The changes allow the customer to send the OAUTH BEARER token through psql
connection string.
Example:
psql -U user@example.com -d 'dbname=postgres oauth_bearer_token=abc'
To configure OAUTH, the pg_hba.conf line look like:
local all all oauth
provider=oauth_provider issuer="https://example.com" scope="openid email"
We also added hook to libpq to pass on the metadata about the issuer.
Thanks,
Mahendrakar.
On Sat, 17 Dec 2022 at 04:48, Jacob Champion <jchampion@timescale.com>
wrote:
Show quoted text
On Mon, Dec 12, 2022 at 9:06 PM Andrey Chudnovsky
<achudnovskij@gmail.com> wrote:If your concern is extension not honoring the DBA configured values:
Would a server-side logic to prefer HBA value over extension-provided
resolve this concern?Yeah. It also seals the role of the extension here as "optional".
We are definitely biased towards the cloud deployment scenarios, where
direct access to .hba files is usually not offered at all.
Let's find the middle ground here.Sure. I don't want to make this difficult in cloud scenarios --
obviously I'd like for Timescale Cloud to be able to make use of this
too. But if we make this easy for a lone DBA (who doesn't have any
institutional power with the providers) to use correctly and securely,
then it should follow that the providers who _do_ have power and
resources will have an easy time of it as well. The reverse isn't
necessarily true. So I'm definitely planning to focus on the DBA case
first.A separate reason for creating this pre-authentication hook is further
extensibility to support more metadata.
Specifically when we add support for OAUTH flows to libpq, server-side
extensions can help bridge the gap between the identity provider
implementation and OAUTH/OIDC specs.
For example, that could allow the Github extension to provide an OIDC
discovery document.I definitely see identity providers as institutional actors here which
can be given some power through the extension hooks to customize the
behavior within the framework.We'll probably have to make some compromises in this area, but I think
they should be carefully considered exceptions and not a core feature
of the mechanism. The gaps you point out are just fragmentation, and
adding custom extensions to deal with it leads to further
fragmentation instead of providing pressure on providers to just
implement the specs. Worst case, we open up new exciting security
flaws, and then no one can analyze them independently because no one
other than the provider knows how the two sides work together anymore.Don't get me wrong; it would be naive to proceed as if the OAUTHBEARER
spec were perfect, because it's clearly not. But if we need to make
extensions to it, we can participate in IETF discussions and make our
case publicly for review, rather than enshrining MS/GitHub/Google/etc.
versions of the RFC and enabling that proliferation as a Postgres core
feature.Obtaining a token is an asynchronous process with a human in the loop.
Not sure if expecting a hook function to return a token synchronously
is the best option here.
Can that be an optional return value of the hook in cases when a token
can be obtained synchronously?I don't think the hook is generally going to be able to return a token
synchronously, and I expect the final design to be async-first. As far
as I know, this will need to be solved for the builtin flows as well
(you don't want a synchronous HTTP call to block your PQconnectPoll
architecture), so the hook should be able to make use of whatever
solution we land on for that.This is hand-wavy, and I don't expect it to be easy to solve. I just
don't think we have to solve it twice.Have a good end to the year!
--Jacob
Attachments:
v3-0002-backend-add-OAUTHBEARER-SASL-mechanishm.patch.gzapplication/gzip; name=v3-0002-backend-add-OAUTHBEARER-SASL-mechanishm.patch.gzDownload
�+N�c v3-0002-backend-add-OAUTHBEARER-SASL-mechanishm.patch �<�w�F�?�_1e��`�v��Mv�7�q
n���=!
����d�v����H p�6�ksN0������w�p��L��s7VWEvwj�7�w��;
?t��9s=)�tV]�����y����9�:�gO�t�z��/�z����j��~p�j��|x���]���E>�xt}�K)���d�����Q����:s/�Z^U�ELe{�T�L�p�X�H�Y���� rm���T���#�4�V0����S��TEU_}?/��j�b%���g'��
�R^������@_��Qo@�7"�P��� B-@��9�����tfr�~�/���D��#o�~�y[�_X)�k�� ��O���n�k��6�.��5��"�HF�2j+�.CO.�[��"��x!��xt&�������W�+�����]��d�GR��Y�y������I�x` 382�\O=��m�{r����X�q��w�qx���x� �yw/�n4�qP��e�����*r��X����{�������k��#9��L�!���-�Y�go� ?��%��F������Qn������p*���m}��z�W�e���K*����:�Z�it�5�Ng��l��~�Q� ���������@��p���#������N}������=�:�����������)u,]e[��-�8����K��/;
Y�n��/���>���I��A�\\�����/��#���`-���}_�+43�S���n�0
Os����n���sO2qVu��gK����@m1Mf���|����qa�
b �]%=i���J]� �&����)Mw&�qWsY�����"�X��WO�-0����I�
�8� g������LAU�� ����'��>
H���5h{V������d*'���(��X���W8�&���v$D��"��[�\����<A
>�
.&g��7�����6Ym��P:r&��,�M\��d4>&��c�E�Uj\]
���'0z�v��kZ"R�f�(�c�l����1���� ��>1���D�>�R���}(� t�j�Dn��B����8��_���f,��B+RrrsZn�&`
&���&>1bYT��H ���(����C����q!v��%d���(�'E����qp��AI�]_�w?���w��������5[�;��OF��������V�������)~d����&q���xP�2*�pQ�U k�VD�P�DI��B� �
�.@�;�4N�"?��X�#,���[��������x|� �Q��%<T��{x��������#� ����qGHb[>.������������6x*��;s��%��-8��,w�����Y�����"H<�X����������`+!�c�+����I�vK�����K�Qs+ �f���&!
T���~����1�S��2F���� �&������y��-����y��<�m3��j�j�������R0���0��p�"+r�$����`bj�)� �LCW*uP?6�0���L ���������{�����t8i���iK5��6
2W�S8�1P[l����h��T"'��_?��:�'������s�{�����w��YvT�� ct���b��AX����<���u�?��gll��~I�A�;IX/����g�5 ��y@�����������2 �7�@o���b�Y �/60
6��I�8�0�+�H�����}'6!9��lbw�wb\�����o���),/������SK����i.�H����
�oX���?�J 64�����PG;���A!�p4S�������O�jQ�w����t�"����U#O�!G_z���l<�Di���g��!�M��!�������8`�J� q�2�8&P��@��38�M�7���{u���$U�_$1R���J4s&@x��8 �NX����[���������HyB�T������p4�_/P��mt[q{[L^��k�d��Z�����.�r
E�-�o��]^
����l�����x|:����"R��=sF�����������5/E��g�E 5�]���qeh��p��p1�:J�iN�$b2Dj���l�"��]�pJ!r�6�^�b�|`� /)��xb���3�R3��L� ���1��*��������G�p�4�n#�#v�<B�%OU��Ac{���V���>���f1�u��8����N1�<��W:b��[
3
��� 6T�����'���0^!f�3�| Fl928zw��Yz��CL#�B���OVaO� `�D�c�f��;����L�_s���� ��8�������U���,��gx�D�d�s�������?��^\r�3%
�_|V<>oQ�����S��!E���t1��pj�Vj����ZZ�%��bz���fg���0V��h��=2���o^��7}]���?����v���bi���t��*X��)�����h�����`��B������3����;m�+�)897�!� nRH���?�R���>��"�� !�*���p�D�9X0
�xs�d�����2������+>R����D�/����H�~������L�zm�����4�\�����f
=��V[�l��<2��5Uqr|zv}5��;rf%^����k��Z��gOM��������2���a�;�uY�j������� ������mf�d'Q@�Su��B����h"�-�*��D�k���g������p��m���d�G��������ah� ��fR� �*��l�:��(���4��1���I@��8 �'0�89�b�}�T�(L��J�v�e`������w�9y���� ���o���d:*Y�Q���[��/�/����`L�����5AzQ ���(6u[fy���<=�<"^���$�E� ,^DA2_hA�~_#���f�(��kd@��������`��2~b��$
]S#�����!//����p�
3&��j����k?,��v�f(TED���*����wDSP���NE��4��K���>�0Lq���:� �aO��-�{��~��������x��2�������x�z[��(����e�qc��q��0�������[18�p ����7���P;��9|�0l6��--0��j���2�
.��Q�]���(��R��kt��������\��aI�������qL�LL�l�N�n
\���`��$�� \o�Ha��[��@m��d:�Do�� �J��F���Z�U�����&%�� ���OK��W�9&��YYE��j ���s7���������P�~���m^?!���Bl7�����Q1��|��b�0A1^��F���X�����"�%���N<~J!�_������2���������-����1�'`"@���U�<�����/��xP��0 �������6!��4�n��7�Y���|$��������/_����������d@R�|�����:���x�v��wj-���.n������S�������G�� 4�a�:IU�i1�M�F�6����t �[��{MI���5
���[p�9�)��h`J����O�.���<yk���8i�Y,�H��������X���1��B|K�&Y7�W��{��������|���[���Jt��>���_��>6�����8b�����^�4H�[:��h���i�����V,V}����v@P���1������l+_���b�!v�#e�5��&P�-�4�R]��y��u�S����T�u!x��%P��������?�s��Aie>,��[`�sY#�E5[��:2�a+�g�o�����2��������tn�q�A>E���B�5��vjO������dF''0oi/�z�NQ�H\�.�"gli���M
F4���g���"�
4 ������,n������<����H��k�O�cu��wg+*�1 $�)���t"�K)�{����:�h���% $��j
�������%�����������H9����� ��bu�:��(1E����1(�R�g�8��Q���S��*��B=_�!�i�3S�`rTY3L��,���- ,Oz����m�'������R�a��<�?>����e��n�-��U!���^��6|�i�1���A�p�����@�D Jh��d�J�#N�D�C�F���<[1�
�>N�V'���!5�-*�+�� 0cI; ��|���{i���J&��<�%�c]+h�~
���9�g[��e������,"{(1�?0���d����C���|�R^���r��EU������{���
|�
�\|?^�l+��Y�\(B� i�_�b�W{_��%�}��x�^_L���=8��� @�N�?�y���<��Z94,>�;�N@�u�>���������w�\�Q�t��u0]�x��V*J���/��fr!
���F��
��.+�
YO�� a��h2|��YJ������s������-��|'���M3PFr����������=��fMmP�����o���3�����/�TK��������������;;�^�p5_����������������O�>kw&�ivk���L#�O��2D�����g��bXr|���O����������wAtC7D6�@@�=��
�uu%<����|���
�H���}��h��O�Z����z��8�r�42A<�����Xt����j���u�sB����&}vk�!v!2��`
0y<�����`e���������95�T������<8�to����]F�� 7r��=�����i9���)z;_�����IE��f���� m���������opx��Ny�Y=������������A`{�VQ�"x�\�kAX�L_��cm&�������t����:�4L��\��-���*��W�����qt�X�/q��
�Bk����6a��L��
c����4��7����/���y��������1�� �6�i0
N5`�~�UPC���a���T,��/�w�7-L�PJj��u�O�����@�6����\�_�Y��c��Q��0�Q��\|��*<