From 0281635f35a44e0fdfd4369423f98ebe5b467ce3 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchampion@vmware.com>
Date: Fri, 4 Jun 2021 09:06:38 -0700
Subject: [PATCH v2 5/5] Add pytest suite for OAuth

Requires Python 3; on the first run of `make installcheck` the
dependencies will be installed into ./venv for you. See the README for
more details.
---
 src/test/python/.gitignore                 |    2 +
 src/test/python/Makefile                   |   38 +
 src/test/python/README                     |   54 ++
 src/test/python/client/__init__.py         |    0
 src/test/python/client/conftest.py         |  126 +++
 src/test/python/client/test_client.py      |  180 ++++
 src/test/python/client/test_oauth.py       |  936 ++++++++++++++++++
 src/test/python/pq3.py                     |  727 ++++++++++++++
 src/test/python/pytest.ini                 |    4 +
 src/test/python/requirements.txt           |    7 +
 src/test/python/server/__init__.py         |    0
 src/test/python/server/conftest.py         |   45 +
 src/test/python/server/test_oauth.py       | 1012 ++++++++++++++++++++
 src/test/python/server/test_server.py      |   21 +
 src/test/python/server/validate_bearer.py  |  101 ++
 src/test/python/server/validate_reflect.py |   34 +
 src/test/python/test_internals.py          |  138 +++
 src/test/python/test_pq3.py                |  558 +++++++++++
 src/test/python/tls.py                     |  195 ++++
 19 files changed, 4178 insertions(+)
 create mode 100644 src/test/python/.gitignore
 create mode 100644 src/test/python/Makefile
 create mode 100644 src/test/python/README
 create mode 100644 src/test/python/client/__init__.py
 create mode 100644 src/test/python/client/conftest.py
 create mode 100644 src/test/python/client/test_client.py
 create mode 100644 src/test/python/client/test_oauth.py
 create mode 100644 src/test/python/pq3.py
 create mode 100644 src/test/python/pytest.ini
 create mode 100644 src/test/python/requirements.txt
 create mode 100644 src/test/python/server/__init__.py
 create mode 100644 src/test/python/server/conftest.py
 create mode 100644 src/test/python/server/test_oauth.py
 create mode 100644 src/test/python/server/test_server.py
 create mode 100755 src/test/python/server/validate_bearer.py
 create mode 100755 src/test/python/server/validate_reflect.py
 create mode 100644 src/test/python/test_internals.py
 create mode 100644 src/test/python/test_pq3.py
 create mode 100644 src/test/python/tls.py

diff --git a/src/test/python/.gitignore b/src/test/python/.gitignore
new file mode 100644
index 0000000000..0e8f027b2e
--- /dev/null
+++ b/src/test/python/.gitignore
@@ -0,0 +1,2 @@
+__pycache__/
+/venv/
diff --git a/src/test/python/Makefile b/src/test/python/Makefile
new file mode 100644
index 0000000000..b0695b6287
--- /dev/null
+++ b/src/test/python/Makefile
@@ -0,0 +1,38 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+# Only Python 3 is supported, but if it's named something different on your
+# system you can override it with the PYTHON3 variable.
+PYTHON3 := python3
+
+# All dependencies are placed into this directory. The default is .gitignored
+# for you, but you can override it if you'd like.
+VENV := ./venv
+
+override VBIN   := $(VENV)/bin
+override PIP    := $(VBIN)/pip
+override PYTEST := $(VBIN)/py.test
+override ISORT  := $(VBIN)/isort
+override BLACK  := $(VBIN)/black
+
+.PHONY: installcheck indent
+
+installcheck: $(PYTEST)
+	$(PYTEST) -v -rs
+
+indent: $(ISORT) $(BLACK)
+	$(ISORT) --profile black *.py client/*.py server/*.py
+	$(BLACK) *.py client/*.py server/*.py
+
+$(PYTEST) $(ISORT) $(BLACK) &: requirements.txt | $(PIP)
+	$(PIP) install --force-reinstall -r $<
+
+$(PIP):
+	$(PYTHON3) -m venv $(VENV)
+
+# A convenience recipe to rebuild psycopg2 against the local libpq.
+.PHONY: rebuild-psycopg2
+rebuild-psycopg2: | $(PIP)
+	$(PIP) install --force-reinstall --no-binary :all: $(shell grep psycopg2 requirements.txt)
diff --git a/src/test/python/README b/src/test/python/README
new file mode 100644
index 0000000000..0bda582c4b
--- /dev/null
+++ b/src/test/python/README
@@ -0,0 +1,54 @@
+A test suite for exercising both the libpq client and the server backend at the
+protocol level, based on pytest and Construct.
+
+The test suite currently assumes that the standard PG* environment variables
+point to the database under test and are sufficient to log in a superuser on
+that system. In other words, a bare `psql` needs to Just Work before the test
+suite can do its thing. For a newly built dev cluster, typically all that I need
+to do is a
+
+    export PGDATABASE=postgres
+
+but you can adjust as needed for your setup.
+
+## Requirements
+
+A supported version (3.6+) of Python.
+
+The first run of
+
+    make installcheck
+
+will install a local virtual environment and all needed dependencies. During
+development, if libpq changes incompatibly, you can issue
+
+    $ make rebuild-psycopg2
+
+to force a rebuild of the client library.
+
+## Hacking
+
+The code style is enforced by a _very_ opinionated autoformatter. Running the
+
+    make indent
+
+recipe will invoke it for you automatically. Don't fight the tool; part of the
+zen is in knowing that if the formatter makes your code ugly, there's probably a
+cleaner way to write your code.
+
+## Advanced Usage
+
+The Makefile is there for convenience, but you don't have to use it. Activate
+the virtualenv to be able to use pytest directly:
+
+    $ source venv/bin/activate
+    $ py.test -k oauth
+    ...
+    $ py.test ./server/test_server.py
+    ...
+    $ deactivate  # puts the PATH et al back the way it was before
+
+To make quick smoke tests possible, slow tests have been marked explicitly. You
+can skip them by saying e.g.
+
+    $ py.test -m 'not slow'
diff --git a/src/test/python/client/__init__.py b/src/test/python/client/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/client/conftest.py b/src/test/python/client/conftest.py
new file mode 100644
index 0000000000..f38da7a138
--- /dev/null
+++ b/src/test/python/client/conftest.py
@@ -0,0 +1,126 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import socket
+import sys
+import threading
+
+import psycopg2
+import pytest
+
+import pq3
+
+BLOCKING_TIMEOUT = 2  # the number of seconds to wait for blocking calls
+
+
+@pytest.fixture
+def server_socket(unused_tcp_port_factory):
+    """
+    Returns a listening socket bound to an ephemeral port.
+    """
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+        s.bind(("127.0.0.1", unused_tcp_port_factory()))
+        s.listen(1)
+        s.settimeout(BLOCKING_TIMEOUT)
+        yield s
+
+
+class ClientHandshake(threading.Thread):
+    """
+    A thread that connects to a local Postgres server using psycopg2. Once the
+    opening handshake completes, the connection will be immediately closed.
+    """
+
+    def __init__(self, *, port, **kwargs):
+        super().__init__()
+
+        kwargs["port"] = port
+        self._kwargs = kwargs
+
+        self.exception = None
+
+    def run(self):
+        try:
+            conn = psycopg2.connect(host="127.0.0.1", **self._kwargs)
+            conn.close()
+        except Exception as e:
+            self.exception = e
+
+    def check_completed(self, timeout=BLOCKING_TIMEOUT):
+        """
+        Joins the client thread. Raises an exception if the thread could not be
+        joined, or if it threw an exception itself. (The exception will be
+        cleared, so future calls to check_completed will succeed.)
+        """
+        self.join(timeout)
+
+        if self.is_alive():
+            raise TimeoutError("client thread did not handshake within the timeout")
+        elif self.exception:
+            e = self.exception
+            self.exception = None
+            raise e
+
+
+@pytest.fixture
+def accept(server_socket):
+    """
+    Returns a factory function that, when called, returns a pair (sock, client)
+    where sock is a server socket that has accepted a connection from client,
+    and client is an instance of ClientHandshake. Clients will complete their
+    handshakes and cleanly disconnect.
+
+    The default connstring options may be extended or overridden by passing
+    arbitrary keyword arguments. Keep in mind that you generally should not
+    override the host or port, since they point to the local test server.
+
+    For situations where a client needs to connect more than once to complete a
+    handshake, the accept function may be called more than once. (The client
+    returned for subsequent calls will always be the same client that was
+    returned for the first call.)
+
+    Tests must either complete the handshake so that the client thread can be
+    automatically joined during teardown, or else call client.check_completed()
+    and manually handle any expected errors.
+    """
+    _, port = server_socket.getsockname()
+
+    client = None
+    default_opts = dict(
+        port=port,
+        user=pq3.pguser(),
+        sslmode="disable",
+    )
+
+    def factory(**kwargs):
+        nonlocal client
+
+        if client is None:
+            opts = dict(default_opts)
+            opts.update(kwargs)
+
+            # The server_socket is already listening, so the client thread can
+            # be safely started; it'll block on the connection until we accept.
+            client = ClientHandshake(**opts)
+            client.start()
+
+        sock, _ = server_socket.accept()
+        return sock, client
+
+    yield factory
+    client.check_completed()
+
+
+@pytest.fixture
+def conn(accept):
+    """
+    Returns an accepted, wrapped pq3 connection to a psycopg2 client. The socket
+    will be closed when the test finishes, and the client will be checked for a
+    cleanly completed handshake.
+    """
+    sock, client = accept()
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            yield conn
diff --git a/src/test/python/client/test_client.py b/src/test/python/client/test_client.py
new file mode 100644
index 0000000000..c4c946fda4
--- /dev/null
+++ b/src/test/python/client/test_client.py
@@ -0,0 +1,180 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import sys
+
+import psycopg2
+import pytest
+from cryptography.hazmat.primitives import hashes, hmac
+
+import pq3
+
+
+def finish_handshake(conn):
+    """
+    Sends the AuthenticationOK message and the standard opening salvo of server
+    messages, then asserts that the client immediately sends a Terminate message
+    to close the connection cleanly.
+    """
+    pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+    pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+    pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+    pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+    pkt = pq3.recv1(conn)
+    assert pkt.type == pq3.types.Terminate
+
+
+def test_handshake(conn):
+    startup = pq3.recv1(conn, cls=pq3.Startup)
+    assert startup.proto == pq3.protocol(3, 0)
+
+    finish_handshake(conn)
+
+
+def test_aborted_connection(accept):
+    """
+    Make sure the client correctly reports an early close during handshakes.
+    """
+    sock, client = accept()
+    sock.close()
+
+    expected = "server closed the connection unexpectedly"
+    with pytest.raises(psycopg2.OperationalError, match=expected):
+        client.check_completed()
+
+
+#
+# SCRAM-SHA-256 (see RFC 5802: https://tools.ietf.org/html/rfc5802)
+#
+
+
+@pytest.fixture
+def password():
+    """
+    Returns a password for use by both client and server.
+    """
+    # TODO: parameterize this with passwords that require SASLprep.
+    return "secret"
+
+
+@pytest.fixture
+def pwconn(accept, password):
+    """
+    Like the conn fixture, but uses a password in the connection.
+    """
+    sock, client = accept(password=password)
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            yield conn
+
+
+def sha256(data):
+    """The H(str) function from Section 2.2."""
+    digest = hashes.Hash(hashes.SHA256())
+    digest.update(data)
+    return digest.finalize()
+
+
+def hmac_256(key, data):
+    """The HMAC(key, str) function from Section 2.2."""
+    h = hmac.HMAC(key, hashes.SHA256())
+    h.update(data)
+    return h.finalize()
+
+
+def xor(a, b):
+    """The XOR operation from Section 2.2."""
+    res = bytearray(a)
+    for i, byte in enumerate(b):
+        res[i] ^= byte
+    return bytes(res)
+
+
+def h_i(data, salt, i):
+    """The Hi(str, salt, i) function from Section 2.2."""
+    assert i > 0
+
+    acc = hmac_256(data, salt + b"\x00\x00\x00\x01")
+    last = acc
+    i -= 1
+
+    while i:
+        u = hmac_256(data, last)
+        acc = xor(acc, u)
+
+        last = u
+        i -= 1
+
+    return acc
+
+
+def test_scram(pwconn, password):
+    startup = pq3.recv1(pwconn, cls=pq3.Startup)
+    assert startup.proto == pq3.protocol(3, 0)
+
+    pq3.send(
+        pwconn,
+        pq3.types.AuthnRequest,
+        type=pq3.authn.SASL,
+        body=[b"SCRAM-SHA-256", b""],
+    )
+
+    # Get the client-first-message.
+    pkt = pq3.recv1(pwconn)
+    assert pkt.type == pq3.types.PasswordMessage
+
+    initial = pq3.SASLInitialResponse.parse(pkt.payload)
+    assert initial.name == b"SCRAM-SHA-256"
+
+    c_bind, authzid, c_name, c_nonce = initial.data.split(b",")
+    assert c_bind == b"n"  # no channel bindings on a plaintext connection
+    assert authzid == b""  # we don't support authzid currently
+    assert c_name == b"n="  # libpq doesn't honor the GS2 username
+    assert c_nonce.startswith(b"r=")
+
+    # Send the server-first-message.
+    salt = b"12345"
+    iterations = 2
+
+    s_nonce = c_nonce + b"somenonce"
+    s_salt = b"s=" + base64.b64encode(salt)
+    s_iterations = b"i=%d" % iterations
+
+    msg = b",".join([s_nonce, s_salt, s_iterations])
+    pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=msg)
+
+    # Get the client-final-message.
+    pkt = pq3.recv1(pwconn)
+    assert pkt.type == pq3.types.PasswordMessage
+
+    c_bind_final, c_nonce_final, c_proof = pkt.payload.split(b",")
+    assert c_bind_final == b"c=" + base64.b64encode(c_bind + b"," + authzid + b",")
+    assert c_nonce_final == s_nonce
+
+    # Calculate what the client proof should be.
+    salted_password = h_i(password.encode("ascii"), salt, iterations)
+    client_key = hmac_256(salted_password, b"Client Key")
+    stored_key = sha256(client_key)
+
+    auth_message = b",".join(
+        [c_name, c_nonce, s_nonce, s_salt, s_iterations, c_bind_final, c_nonce_final]
+    )
+    client_signature = hmac_256(stored_key, auth_message)
+    client_proof = xor(client_key, client_signature)
+
+    expected = b"p=" + base64.b64encode(client_proof)
+    assert c_proof == expected
+
+    # Send the correct server signature.
+    server_key = hmac_256(salted_password, b"Server Key")
+    server_signature = hmac_256(server_key, auth_message)
+
+    s_verify = b"v=" + base64.b64encode(server_signature)
+    pq3.send(pwconn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal, body=s_verify)
+
+    # Done!
+    finish_handshake(pwconn)
diff --git a/src/test/python/client/test_oauth.py b/src/test/python/client/test_oauth.py
new file mode 100644
index 0000000000..a754a9c0b6
--- /dev/null
+++ b/src/test/python/client/test_oauth.py
@@ -0,0 +1,936 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import http.server
+import json
+import secrets
+import sys
+import threading
+import time
+import urllib.parse
+
+import psycopg2
+import pytest
+
+import pq3
+
+from .conftest import BLOCKING_TIMEOUT
+
+
+def finish_handshake(conn):
+    """
+    Sends the AuthenticationOK message and the standard opening salvo of server
+    messages, then asserts that the client immediately sends a Terminate message
+    to close the connection cleanly.
+    """
+    pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.OK)
+    pq3.send(conn, pq3.types.ParameterStatus, name=b"client_encoding", value=b"UTF-8")
+    pq3.send(conn, pq3.types.ParameterStatus, name=b"DateStyle", value=b"ISO, MDY")
+    pq3.send(conn, pq3.types.ReadyForQuery, status=b"I")
+
+    pkt = pq3.recv1(conn)
+    assert pkt.type == pq3.types.Terminate
+
+
+#
+# OAUTHBEARER (see RFC 7628: https://tools.ietf.org/html/rfc7628)
+#
+
+
+def start_oauth_handshake(conn):
+    """
+    Negotiates an OAUTHBEARER SASL challenge. Returns the client's initial
+    response data.
+    """
+    startup = pq3.recv1(conn, cls=pq3.Startup)
+    assert startup.proto == pq3.protocol(3, 0)
+
+    pq3.send(
+        conn, pq3.types.AuthnRequest, type=pq3.authn.SASL, body=[b"OAUTHBEARER", b""]
+    )
+
+    pkt = pq3.recv1(conn)
+    assert pkt.type == pq3.types.PasswordMessage
+
+    initial = pq3.SASLInitialResponse.parse(pkt.payload)
+    assert initial.name == b"OAUTHBEARER"
+
+    return initial.data
+
+
+def get_auth_value(initial):
+    """
+    Finds the auth value (e.g. "Bearer somedata..." in the client's initial SASL
+    response.
+    """
+    kvpairs = initial.split(b"\x01")
+    assert kvpairs[0] == b"n,,"  # no channel binding or authzid
+    assert kvpairs[2] == b""  # ends with an empty kvpair
+    assert kvpairs[3] == b""  # ...and there's nothing after it
+    assert len(kvpairs) == 4
+
+    key, value = kvpairs[1].split(b"=", 2)
+    assert key == b"auth"
+
+    return value
+
+
+def xtest_oauth_success(conn):  # TODO
+    initial = start_oauth_handshake(conn)
+
+    auth = get_auth_value(initial)
+    assert auth.startswith(b"Bearer ")
+
+    # Accept the token. TODO actually validate
+    pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+    finish_handshake(conn)
+
+
+class OpenIDProvider(threading.Thread):
+    """
+    A thread that runs a mock OpenID provider server.
+    """
+
+    def __init__(self, *, port):
+        super().__init__()
+
+        self.exception = None
+
+        addr = ("", port)
+        self.server = self._Server(addr, self._Handler)
+
+        # TODO: allow HTTPS only, somehow
+        oauth = self._OAuthState()
+        oauth.host = f"localhost:{port}"
+        oauth.issuer = f"http://localhost:{port}"
+
+        # The following endpoints are required to be advertised by providers,
+        # even though our chosen client implementation does not actually make
+        # use of them.
+        oauth.register_endpoint(
+            "authorization_endpoint", "POST", "/authorize", self._authorization_handler
+        )
+        oauth.register_endpoint("jwks_uri", "GET", "/keys", self._jwks_handler)
+
+        self.server.oauth = oauth
+
+    def run(self):
+        try:
+            self.server.serve_forever()
+        except Exception as e:
+            self.exception = e
+
+    def stop(self, timeout=BLOCKING_TIMEOUT):
+        """
+        Shuts down the server and joins its thread. Raises an exception if the
+        thread could not be joined, or if it threw an exception itself. Must
+        only be called once, after start().
+        """
+        self.server.shutdown()
+        self.join(timeout)
+
+        if self.is_alive():
+            raise TimeoutError("client thread did not handshake within the timeout")
+        elif self.exception:
+            e = self.exception
+            raise e
+
+    class _OAuthState(object):
+        def __init__(self):
+            self.endpoint_paths = {}
+            self._endpoints = {}
+
+        def register_endpoint(self, name, method, path, func):
+            if method not in self._endpoints:
+                self._endpoints[method] = {}
+
+            self._endpoints[method][path] = func
+            self.endpoint_paths[name] = path
+
+        def endpoint(self, method, path):
+            if method not in self._endpoints:
+                return None
+
+            return self._endpoints[method].get(path)
+
+    class _Server(http.server.HTTPServer):
+        def handle_error(self, request, addr):
+            self.shutdown_request(request)
+            raise
+
+    @staticmethod
+    def _jwks_handler(headers, params):
+        return 200, {"keys": []}
+
+    @staticmethod
+    def _authorization_handler(headers, params):
+        # We don't actually want this to be called during these tests -- we
+        # should be using the device authorization endpoint instead.
+        assert (
+            False
+        ), "authorization handler called instead of device authorization handler"
+
+    class _Handler(http.server.BaseHTTPRequestHandler):
+        timeout = BLOCKING_TIMEOUT
+
+        def _discovery_handler(self, headers, params):
+            oauth = self.server.oauth
+
+            doc = {
+                "issuer": oauth.issuer,
+                "response_types_supported": ["token"],
+                "subject_types_supported": ["public"],
+                "id_token_signing_alg_values_supported": ["RS256"],
+            }
+
+            for name, path in oauth.endpoint_paths.items():
+                doc[name] = oauth.issuer + path
+
+            return 200, doc
+
+        def _handle(self, *, params=None, handler=None):
+            oauth = self.server.oauth
+            assert self.headers["Host"] == oauth.host
+
+            if handler is None:
+                handler = oauth.endpoint(self.command, self.path)
+                assert (
+                    handler is not None
+                ), f"no registered endpoint for {self.command} {self.path}"
+
+            code, resp = handler(self.headers, params)
+
+            self.send_response(code)
+            self.send_header("Content-Type", "application/json")
+            self.end_headers()
+
+            resp = json.dumps(resp)
+            resp = resp.encode("utf-8")
+            self.wfile.write(resp)
+
+            self.close_connection = True
+
+        def do_GET(self):
+            if self.path == "/.well-known/openid-configuration":
+                self._handle(handler=self._discovery_handler)
+                return
+
+            self._handle()
+
+        def _request_body(self):
+            length = self.headers["Content-Length"]
+
+            # Handle only an explicit content-length.
+            assert length is not None
+            length = int(length)
+
+            return self.rfile.read(length).decode("utf-8")
+
+        def do_POST(self):
+            assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
+
+            body = self._request_body()
+            params = urllib.parse.parse_qs(body)
+
+            self._handle(params=params)
+
+
+@pytest.fixture
+def openid_provider(unused_tcp_port_factory):
+    """
+    A fixture that returns the OAuth state of a running OpenID provider server. The
+    server will be stopped when the fixture is torn down.
+    """
+    thread = OpenIDProvider(port=unused_tcp_port_factory())
+    thread.start()
+
+    try:
+        yield thread.server.oauth
+    finally:
+        thread.stop()
+
+
+@pytest.mark.parametrize("secret", [None, "", "hunter2"])
+@pytest.mark.parametrize("scope", [None, "", "openid email"])
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_with_explicit_issuer(
+    capfd, accept, openid_provider, retries, scope, secret
+):
+    client_id = secrets.token_hex()
+
+    sock, client = accept(
+        oauth_issuer=openid_provider.issuer,
+        oauth_client_id=client_id,
+        oauth_client_secret=secret,
+        oauth_scope=scope,
+    )
+
+    device_code = secrets.token_hex()
+    user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+    verification_url = "https://example.com/device"
+
+    access_token = secrets.token_urlsafe()
+
+    def check_client_authn(headers, params):
+        if not secret:
+            assert params["client_id"] == [client_id]
+            return
+
+        # Require the client to use Basic authn; request-body credentials are
+        # NOT RECOMMENDED (RFC 6749, Sec. 2.3.1).
+        assert "Authorization" in headers
+
+        method, creds = headers["Authorization"].split()
+        assert method == "Basic"
+
+        expected = f"{client_id}:{secret}"
+        assert base64.b64decode(creds) == expected.encode("ascii")
+
+    # Set up our provider callbacks.
+    # NOTE that these callbacks will be called on a background thread. Don't do
+    # any unprotected state mutation here.
+
+    def authorization_endpoint(headers, params):
+        check_client_authn(headers, params)
+
+        if scope:
+            assert params["scope"] == [scope]
+        else:
+            assert "scope" not in params
+
+        resp = {
+            "device_code": device_code,
+            "user_code": user_code,
+            "interval": 0,
+            "verification_uri": verification_url,
+            "expires_in": 5,
+        }
+
+        return 200, resp
+
+    openid_provider.register_endpoint(
+        "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+    )
+
+    attempts = 0
+    retry_lock = threading.Lock()
+
+    def token_endpoint(headers, params):
+        check_client_authn(headers, params)
+
+        assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+        assert params["device_code"] == [device_code]
+
+        now = time.monotonic()
+
+        with retry_lock:
+            nonlocal attempts
+
+            # If the test wants to force the client to retry, return an
+            # authorization_pending response and decrement the retry count.
+            if attempts < retries:
+                attempts += 1
+                return 400, {"error": "authorization_pending"}
+
+        # Successfully finish the request by sending the access bearer token.
+        resp = {
+            "access_token": access_token,
+            "token_type": "bearer",
+        }
+
+        return 200, resp
+
+    openid_provider.register_endpoint(
+        "token_endpoint", "POST", "/token", token_endpoint
+    )
+
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            # Initiate a handshake, which should result in the above endpoints
+            # being called.
+            initial = start_oauth_handshake(conn)
+
+            # Validate and accept the token.
+            auth = get_auth_value(initial)
+            assert auth == f"Bearer {access_token}".encode("ascii")
+
+            pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+            finish_handshake(conn)
+
+    if retries:
+        # Finally, make sure that the client prompted the user with the expected
+        # authorization URL and user code.
+        expected = f"Visit {verification_url} and enter the code: {user_code}"
+        _, stderr = capfd.readouterr()
+        assert expected in stderr
+
+
+def test_oauth_requires_client_id(accept, openid_provider):
+    sock, client = accept(
+        oauth_issuer=openid_provider.issuer,
+        # Do not set a client ID; this should cause a client error after the
+        # server asks for OAUTHBEARER and the client tries to contact the
+        # issuer.
+    )
+
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            # Initiate a handshake.
+            startup = pq3.recv1(conn, cls=pq3.Startup)
+            assert startup.proto == pq3.protocol(3, 0)
+
+            pq3.send(
+                conn,
+                pq3.types.AuthnRequest,
+                type=pq3.authn.SASL,
+                body=[b"OAUTHBEARER", b""],
+            )
+
+            # The client should disconnect at this point.
+            assert not conn.read()
+
+    expected_error = "no oauth_client_id is set"
+    with pytest.raises(psycopg2.OperationalError, match=expected_error):
+        client.check_completed()
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("error_code", ["authorization_pending", "slow_down"])
+@pytest.mark.parametrize("retries", [1, 2])
+def test_oauth_retry_interval(accept, openid_provider, retries, error_code):
+    sock, client = accept(
+        oauth_issuer=openid_provider.issuer,
+        oauth_client_id="some-id",
+    )
+
+    expected_retry_interval = 1
+    access_token = secrets.token_urlsafe()
+
+    # Set up our provider callbacks.
+    # NOTE that these callbacks will be called on a background thread. Don't do
+    # any unprotected state mutation here.
+
+    def authorization_endpoint(headers, params):
+        resp = {
+            "device_code": "my-device-code",
+            "user_code": "my-user-code",
+            "interval": expected_retry_interval,
+            "verification_uri": "https://example.com",
+            "expires_in": 5,
+        }
+
+        return 200, resp
+
+    openid_provider.register_endpoint(
+        "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+    )
+
+    attempts = 0
+    last_retry = None
+    retry_lock = threading.Lock()
+
+    def token_endpoint(headers, params):
+        now = time.monotonic()
+
+        with retry_lock:
+            nonlocal attempts, last_retry, expected_retry_interval
+
+            # Make sure the retry interval is being respected by the client.
+            if last_retry is not None:
+                interval = now - last_retry
+                assert interval >= expected_retry_interval
+
+            last_retry = now
+
+            # If the test wants to force the client to retry, return the desired
+            # error response and decrement the retry count.
+            if attempts < retries:
+                attempts += 1
+
+                # A slow_down code requires the client to additionally increase
+                # its interval by five seconds.
+                if error_code == "slow_down":
+                    expected_retry_interval += 5
+
+                return 400, {"error": error_code}
+
+        # Successfully finish the request by sending the access bearer token.
+        resp = {
+            "access_token": access_token,
+            "token_type": "bearer",
+        }
+
+        return 200, resp
+
+    openid_provider.register_endpoint(
+        "token_endpoint", "POST", "/token", token_endpoint
+    )
+
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            # Initiate a handshake, which should result in the above endpoints
+            # being called.
+            initial = start_oauth_handshake(conn)
+
+            # Validate and accept the token.
+            auth = get_auth_value(initial)
+            assert auth == f"Bearer {access_token}".encode("ascii")
+
+            pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+            finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+    "failure_mode, error_pattern",
+    [
+        pytest.param(
+            {
+                "error": "invalid_client",
+                "error_description": "client authentication failed",
+            },
+            r"client authentication failed \(invalid_client\)",
+            id="authentication failure with description",
+        ),
+        pytest.param(
+            {"error": "invalid_request"},
+            r"\(invalid_request\)",
+            id="invalid request without description",
+        ),
+        pytest.param(
+            {},
+            r"failed to obtain device authorization",
+            id="broken error response",
+        ),
+    ],
+)
+def test_oauth_device_authorization_failures(
+    accept, openid_provider, failure_mode, error_pattern
+):
+    client_id = secrets.token_hex()
+
+    sock, client = accept(
+        oauth_issuer=openid_provider.issuer,
+        oauth_client_id=client_id,
+    )
+
+    # Set up our provider callbacks.
+    # NOTE that these callbacks will be called on a background thread. Don't do
+    # any unprotected state mutation here.
+
+    def authorization_endpoint(headers, params):
+        return 400, failure_mode
+
+    openid_provider.register_endpoint(
+        "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+    )
+
+    def token_endpoint(headers, params):
+        assert False, "token endpoint was invoked unexpectedly"
+
+    openid_provider.register_endpoint(
+        "token_endpoint", "POST", "/token", token_endpoint
+    )
+
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            # Initiate a handshake, which should result in the above endpoints
+            # being called.
+            startup = pq3.recv1(conn, cls=pq3.Startup)
+            assert startup.proto == pq3.protocol(3, 0)
+
+            pq3.send(
+                conn,
+                pq3.types.AuthnRequest,
+                type=pq3.authn.SASL,
+                body=[b"OAUTHBEARER", b""],
+            )
+
+            # The client should not continue the connection due to the hardcoded
+            # provider failure; we disconnect here.
+
+    # Now make sure the client correctly failed.
+    with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+        client.check_completed()
+
+
+@pytest.mark.parametrize(
+    "failure_mode, error_pattern",
+    [
+        pytest.param(
+            {
+                "error": "expired_token",
+                "error_description": "the device code has expired",
+            },
+            r"the device code has expired \(expired_token\)",
+            id="expired token with description",
+        ),
+        pytest.param(
+            {"error": "access_denied"},
+            r"\(access_denied\)",
+            id="access denied without description",
+        ),
+        pytest.param(
+            {},
+            r"OAuth token retrieval failed",
+            id="broken error response",
+        ),
+    ],
+)
+@pytest.mark.parametrize("retries", [0, 1])
+def test_oauth_token_failures(
+    accept, openid_provider, retries, failure_mode, error_pattern
+):
+    client_id = secrets.token_hex()
+
+    sock, client = accept(
+        oauth_issuer=openid_provider.issuer,
+        oauth_client_id=client_id,
+    )
+
+    device_code = secrets.token_hex()
+    user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+
+    # Set up our provider callbacks.
+    # NOTE that these callbacks will be called on a background thread. Don't do
+    # any unprotected state mutation here.
+
+    def authorization_endpoint(headers, params):
+        assert params["client_id"] == [client_id]
+
+        resp = {
+            "device_code": device_code,
+            "user_code": user_code,
+            "interval": 0,
+            "verification_uri": "https://example.com/device",
+            "expires_in": 5,
+        }
+
+        return 200, resp
+
+    openid_provider.register_endpoint(
+        "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+    )
+
+    retry_lock = threading.Lock()
+
+    def token_endpoint(headers, params):
+        with retry_lock:
+            nonlocal retries
+
+            # If the test wants to force the client to retry, return an
+            # authorization_pending response and decrement the retry count.
+            if retries > 0:
+                retries -= 1
+                return 400, {"error": "authorization_pending"}
+
+        return 400, failure_mode
+
+    openid_provider.register_endpoint(
+        "token_endpoint", "POST", "/token", token_endpoint
+    )
+
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            # Initiate a handshake, which should result in the above endpoints
+            # being called.
+            startup = pq3.recv1(conn, cls=pq3.Startup)
+            assert startup.proto == pq3.protocol(3, 0)
+
+            pq3.send(
+                conn,
+                pq3.types.AuthnRequest,
+                type=pq3.authn.SASL,
+                body=[b"OAUTHBEARER", b""],
+            )
+
+            # The client should not continue the connection due to the hardcoded
+            # provider failure; we disconnect here.
+
+    # Now make sure the client correctly failed.
+    with pytest.raises(psycopg2.OperationalError, match=error_pattern):
+        client.check_completed()
+
+
+@pytest.mark.parametrize("scope", [None, "openid email"])
+@pytest.mark.parametrize(
+    "base_response",
+    [
+        {"status": "invalid_token"},
+        {"extra_object": {"key": "value"}, "status": "invalid_token"},
+        {"extra_object": {"status": 1}, "status": "invalid_token"},
+    ],
+)
+def test_oauth_discovery(accept, openid_provider, base_response, scope):
+    sock, client = accept(oauth_client_id=secrets.token_hex())
+
+    device_code = secrets.token_hex()
+    user_code = f"{secrets.token_hex(2)}-{secrets.token_hex(2)}"
+    verification_url = "https://example.com/device"
+
+    access_token = secrets.token_urlsafe()
+
+    # Set up our provider callbacks.
+    # NOTE that these callbacks will be called on a background thread. Don't do
+    # any unprotected state mutation here.
+
+    def authorization_endpoint(headers, params):
+        if scope:
+            assert params["scope"] == [scope]
+        else:
+            assert "scope" not in params
+
+        resp = {
+            "device_code": device_code,
+            "user_code": user_code,
+            "interval": 0,
+            "verification_uri": verification_url,
+            "expires_in": 5,
+        }
+
+        return 200, resp
+
+    openid_provider.register_endpoint(
+        "device_authorization_endpoint", "POST", "/device", authorization_endpoint
+    )
+
+    def token_endpoint(headers, params):
+        assert params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"]
+        assert params["device_code"] == [device_code]
+
+        # Successfully finish the request by sending the access bearer token.
+        resp = {
+            "access_token": access_token,
+            "token_type": "bearer",
+        }
+
+        return 200, resp
+
+    openid_provider.register_endpoint(
+        "token_endpoint", "POST", "/token", token_endpoint
+    )
+
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            initial = start_oauth_handshake(conn)
+
+            # For discovery, the client should send an empty auth header. See
+            # RFC 7628, Sec. 4.3.
+            auth = get_auth_value(initial)
+            assert auth == b""
+
+            # We will fail the first SASL exchange. First return a link to the
+            # discovery document, pointing to the test provider server.
+            resp = dict(base_response)
+
+            discovery_uri = f"{openid_provider.issuer}/.well-known/openid-configuration"
+            resp["openid-configuration"] = discovery_uri
+
+            if scope:
+                resp["scope"] = scope
+
+            resp = json.dumps(resp)
+
+            pq3.send(
+                conn,
+                pq3.types.AuthnRequest,
+                type=pq3.authn.SASLContinue,
+                body=resp.encode("ascii"),
+            )
+
+            # Per RFC, the client is required to send a dummy ^A response.
+            pkt = pq3.recv1(conn)
+            assert pkt.type == pq3.types.PasswordMessage
+            assert pkt.payload == b"\x01"
+
+            # Now fail the SASL exchange.
+            pq3.send(
+                conn,
+                pq3.types.ErrorResponse,
+                fields=[
+                    b"SFATAL",
+                    b"C28000",
+                    b"Mdoesn't matter",
+                    b"",
+                ],
+            )
+
+    # The client will connect to us a second time, using the parameters we sent
+    # it.
+    sock, _ = accept()
+
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            initial = start_oauth_handshake(conn)
+
+            # Validate and accept the token.
+            auth = get_auth_value(initial)
+            assert auth == f"Bearer {access_token}".encode("ascii")
+
+            pq3.send(conn, pq3.types.AuthnRequest, type=pq3.authn.SASLFinal)
+            finish_handshake(conn)
+
+
+@pytest.mark.parametrize(
+    "response,expected_error",
+    [
+        pytest.param(
+            "abcde",
+            'Token "abcde" is invalid',
+            id="bad JSON: invalid syntax",
+        ),
+        pytest.param(
+            '"abcde"',
+            "top-level element must be an object",
+            id="bad JSON: top-level element is a string",
+        ),
+        pytest.param(
+            "[]",
+            "top-level element must be an object",
+            id="bad JSON: top-level element is an array",
+        ),
+        pytest.param(
+            "{}",
+            "server sent error response without a status",
+            id="bad JSON: no status member",
+        ),
+        pytest.param(
+            '{ "status": null }',
+            'field "status" must be a string',
+            id="bad JSON: null status member",
+        ),
+        pytest.param(
+            '{ "status": 0 }',
+            'field "status" must be a string',
+            id="bad JSON: int status member",
+        ),
+        pytest.param(
+            '{ "status": [ "bad" ] }',
+            'field "status" must be a string',
+            id="bad JSON: array status member",
+        ),
+        pytest.param(
+            '{ "status": { "bad": "bad" } }',
+            'field "status" must be a string',
+            id="bad JSON: object status member",
+        ),
+        pytest.param(
+            '{ "nested": { "status": "bad" } }',
+            "server sent error response without a status",
+            id="bad JSON: nested status",
+        ),
+        pytest.param(
+            '{ "status": "invalid_token" ',
+            "The input string ended unexpectedly",
+            id="bad JSON: unterminated object",
+        ),
+        pytest.param(
+            '{ "status": "invalid_token" } { }',
+            'Expected end of input, but found "{"',
+            id="bad JSON: trailing data",
+        ),
+        pytest.param(
+            '{ "status": "invalid_token", "openid-configuration": 1 }',
+            'field "openid-configuration" must be a string',
+            id="bad JSON: int openid-configuration member",
+        ),
+        pytest.param(
+            '{ "status": "invalid_token", "openid-configuration": 1 }',
+            'field "openid-configuration" must be a string',
+            id="bad JSON: int openid-configuration member",
+        ),
+        pytest.param(
+            '{ "status": "invalid_token", "scope": 1 }',
+            'field "scope" must be a string',
+            id="bad JSON: int scope member",
+        ),
+    ],
+)
+def test_oauth_discovery_server_error(accept, response, expected_error):
+    sock, client = accept(oauth_client_id=secrets.token_hex())
+
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            initial = start_oauth_handshake(conn)
+
+            # Fail the SASL exchange with an invalid JSON response.
+            pq3.send(
+                conn,
+                pq3.types.AuthnRequest,
+                type=pq3.authn.SASLContinue,
+                body=response.encode("utf-8"),
+            )
+
+            # The client should disconnect, so the socket is closed here. (If
+            # the client doesn't disconnect, it will report a different error
+            # below and the test will fail.)
+
+    with pytest.raises(psycopg2.OperationalError, match=expected_error):
+        client.check_completed()
+
+
+@pytest.mark.parametrize(
+    "sasl_err,resp_type,resp_payload,expected_error",
+    [
+        pytest.param(
+            {"status": "invalid_request"},
+            pq3.types.ErrorResponse,
+            dict(
+                fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+            ),
+            "expected error message",
+            id="standard server error: invalid_request",
+        ),
+        pytest.param(
+            {"status": "invalid_token"},
+            pq3.types.ErrorResponse,
+            dict(
+                fields=[b"SFATAL", b"C28000", b"Mexpected error message", b""],
+            ),
+            "expected error message",
+            id="standard server error: invalid_token without discovery URI",
+        ),
+        pytest.param(
+            {"status": "invalid_request"},
+            pq3.types.AuthnRequest,
+            dict(type=pq3.authn.SASLContinue, body=b""),
+            "server sent additional OAuth data",
+            id="broken server: additional challenge after error",
+        ),
+        pytest.param(
+            {"status": "invalid_request"},
+            pq3.types.AuthnRequest,
+            dict(type=pq3.authn.SASLFinal),
+            "server sent additional OAuth data",
+            id="broken server: SASL success after error",
+        ),
+    ],
+)
+def test_oauth_server_error(accept, sasl_err, resp_type, resp_payload, expected_error):
+    sock, client = accept()
+
+    with sock:
+        with pq3.wrap(sock, debug_stream=sys.stdout) as conn:
+            start_oauth_handshake(conn)
+
+            # Ignore the client data. Return an error "challenge".
+            resp = json.dumps(sasl_err)
+            resp = resp.encode("utf-8")
+
+            pq3.send(
+                conn, pq3.types.AuthnRequest, type=pq3.authn.SASLContinue, body=resp
+            )
+
+            # Per RFC, the client is required to send a dummy ^A response.
+            pkt = pq3.recv1(conn)
+            assert pkt.type == pq3.types.PasswordMessage
+            assert pkt.payload == b"\x01"
+
+            # Now fail the SASL exchange (in either a valid way, or an invalid
+            # one, depending on the test).
+            pq3.send(conn, resp_type, **resp_payload)
+
+    with pytest.raises(psycopg2.OperationalError, match=expected_error):
+        client.check_completed()
diff --git a/src/test/python/pq3.py b/src/test/python/pq3.py
new file mode 100644
index 0000000000..3a22dad0b6
--- /dev/null
+++ b/src/test/python/pq3.py
@@ -0,0 +1,727 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import os
+import ssl
+import sys
+import textwrap
+
+from construct import *
+
+import tls
+
+
+def protocol(major, minor):
+    """
+    Returns the protocol version, in integer format, corresponding to the given
+    major and minor version numbers.
+    """
+    return (major << 16) | minor
+
+
+# Startup
+
+StringList = GreedyRange(NullTerminated(GreedyBytes))
+
+
+class KeyValueAdapter(Adapter):
+    """
+    Turns a key-value store into a null-terminated list of null-terminated
+    strings, as presented on the wire in the startup packet.
+    """
+
+    def _encode(self, obj, context, path):
+        if isinstance(obj, list):
+            return obj
+
+        l = []
+
+        for k, v in obj.items():
+            if isinstance(k, str):
+                k = k.encode("utf-8")
+            l.append(k)
+
+            if isinstance(v, str):
+                v = v.encode("utf-8")
+            l.append(v)
+
+        l.append(b"")
+        return l
+
+    def _decode(self, obj, context, path):
+        # TODO: turn a list back into a dict
+        return obj
+
+
+KeyValues = KeyValueAdapter(StringList)
+
+_startup_payload = Switch(
+    this.proto,
+    {
+        protocol(3, 0): KeyValues,
+    },
+    default=GreedyBytes,
+)
+
+
+def _default_protocol(this):
+    try:
+        if isinstance(this.payload, (list, dict)):
+            return protocol(3, 0)
+    except AttributeError:
+        pass  # no payload passed during build
+
+    return 0
+
+
+def _startup_payload_len(this):
+    """
+    The payload field has a fixed size based on the length of the packet. But
+    if the caller hasn't supplied an explicit length at build time, we have to
+    build the payload to figure out how long it is, which requires us to know
+    the length first... This function exists solely to break the cycle.
+    """
+    assert this._building, "_startup_payload_len() cannot be called during parsing"
+
+    try:
+        payload = this.payload
+    except AttributeError:
+        return 0  # no payload
+
+    if isinstance(payload, bytes):
+        # already serialized; just use the given length
+        return len(payload)
+
+    try:
+        proto = this.proto
+    except AttributeError:
+        proto = _default_protocol(this)
+
+    data = _startup_payload.build(payload, proto=proto)
+    return len(data)
+
+
+Startup = Struct(
+    "len" / Default(Int32sb, lambda this: _startup_payload_len(this) + 8),
+    "proto" / Default(Hex(Int32sb), _default_protocol),
+    "payload" / FixedSized(this.len - 8, Default(_startup_payload, b"")),
+)
+
+# Pq3
+
+# Adapted from construct.core.EnumIntegerString
+class EnumNamedByte:
+    def __init__(self, val, name):
+        self._val = val
+        self._name = name
+
+    def __int__(self):
+        return ord(self._val)
+
+    def __str__(self):
+        return "(enum) %s %r" % (self._name, self._val)
+
+    def __repr__(self):
+        return "EnumNamedByte(%r)" % self._val
+
+    def __eq__(self, other):
+        if isinstance(other, EnumNamedByte):
+            other = other._val
+        if not isinstance(other, bytes):
+            return NotImplemented
+
+        return self._val == other
+
+    def __hash__(self):
+        return hash(self._val)
+
+
+# Adapted from construct.core.Enum
+class ByteEnum(Adapter):
+    def __init__(self, **mapping):
+        super(ByteEnum, self).__init__(Byte)
+        self.namemapping = {k: EnumNamedByte(v, k) for k, v in mapping.items()}
+        self.decmapping = {v: EnumNamedByte(v, k) for k, v in mapping.items()}
+
+    def __getattr__(self, name):
+        if name in self.namemapping:
+            return self.decmapping[self.namemapping[name]]
+        raise AttributeError
+
+    def _decode(self, obj, context, path):
+        b = bytes([obj])
+        try:
+            return self.decmapping[b]
+        except KeyError:
+            return EnumNamedByte(b, "(unknown)")
+
+    def _encode(self, obj, context, path):
+        if isinstance(obj, int):
+            return obj
+        elif isinstance(obj, bytes):
+            return ord(obj)
+        return int(obj)
+
+
+types = ByteEnum(
+    ErrorResponse=b"E",
+    ReadyForQuery=b"Z",
+    Query=b"Q",
+    EmptyQueryResponse=b"I",
+    AuthnRequest=b"R",
+    PasswordMessage=b"p",
+    BackendKeyData=b"K",
+    CommandComplete=b"C",
+    ParameterStatus=b"S",
+    DataRow=b"D",
+    Terminate=b"X",
+)
+
+
+authn = Enum(
+    Int32ub,
+    OK=0,
+    SASL=10,
+    SASLContinue=11,
+    SASLFinal=12,
+)
+
+
+_authn_body = Switch(
+    this.type,
+    {
+        authn.OK: Terminated,
+        authn.SASL: StringList,
+    },
+    default=GreedyBytes,
+)
+
+
+def _data_len(this):
+    assert this._building, "_data_len() cannot be called during parsing"
+
+    if not hasattr(this, "data") or this.data is None:
+        return -1
+
+    return len(this.data)
+
+
+# The protocol reuses the PasswordMessage for several authentication response
+# types, and there's no good way to figure out which is which without keeping
+# state for the entire stream. So this is a separate Construct that can be
+# explicitly parsed/built by code that knows it's needed.
+SASLInitialResponse = Struct(
+    "name" / NullTerminated(GreedyBytes),
+    "len" / Default(Int32sb, lambda this: _data_len(this)),
+    "data"
+    / IfThenElse(
+        # Allow tests to explicitly pass an incorrect length during testing, by
+        # not enforcing a FixedSized during build. (The len calculation above
+        # defaults to the correct size.)
+        this._building,
+        Optional(GreedyBytes),
+        If(this.len != -1, Default(FixedSized(this.len, GreedyBytes), b"")),
+    ),
+    Terminated,  # make sure the entire response is consumed
+)
+
+
+_column = FocusedSeq(
+    "data",
+    "len" / Default(Int32sb, lambda this: _data_len(this)),
+    "data" / If(this.len != -1, FixedSized(this.len, GreedyBytes)),
+)
+
+
+_payload_map = {
+    types.ErrorResponse: Struct("fields" / StringList),
+    types.ReadyForQuery: Struct("status" / Bytes(1)),
+    types.Query: Struct("query" / NullTerminated(GreedyBytes)),
+    types.EmptyQueryResponse: Terminated,
+    types.AuthnRequest: Struct("type" / authn, "body" / Default(_authn_body, b"")),
+    types.BackendKeyData: Struct("pid" / Int32ub, "key" / Hex(Int32ub)),
+    types.CommandComplete: Struct("tag" / NullTerminated(GreedyBytes)),
+    types.ParameterStatus: Struct(
+        "name" / NullTerminated(GreedyBytes), "value" / NullTerminated(GreedyBytes)
+    ),
+    types.DataRow: Struct("columns" / Default(PrefixedArray(Int16sb, _column), b"")),
+    types.Terminate: Terminated,
+}
+
+
+_payload = FocusedSeq(
+    "_payload",
+    "_payload"
+    / Switch(
+        this._.type,
+        _payload_map,
+        default=GreedyBytes,
+    ),
+    Terminated,  # make sure every payload consumes the entire packet
+)
+
+
+def _payload_len(this):
+    """
+    See _startup_payload_len() for an explanation.
+    """
+    assert this._building, "_payload_len() cannot be called during parsing"
+
+    try:
+        payload = this.payload
+    except AttributeError:
+        return 0  # no payload
+
+    if isinstance(payload, bytes):
+        # already serialized; just use the given length
+        return len(payload)
+
+    data = _payload.build(payload, type=this.type)
+    return len(data)
+
+
+Pq3 = Struct(
+    "type" / types,
+    "len" / Default(Int32ub, lambda this: _payload_len(this) + 4),
+    "payload" / FixedSized(this.len - 4, Default(_payload, b"")),
+)
+
+
+# Environment
+
+
+def pghost():
+    return os.environ.get("PGHOST", default="localhost")
+
+
+def pgport():
+    return int(os.environ.get("PGPORT", default=5432))
+
+
+def pguser():
+    try:
+        return os.environ["PGUSER"]
+    except KeyError:
+        return getpass.getuser()
+
+
+def pgdatabase():
+    return os.environ.get("PGDATABASE", default="postgres")
+
+
+# Connections
+
+
+def _hexdump_translation_map():
+    """
+    For hexdumps. Translates any unprintable or non-ASCII bytes into '.'.
+    """
+    input = bytearray()
+
+    for i in range(128):
+        c = chr(i)
+
+        if not c.isprintable():
+            input += bytes([i])
+
+    input += bytes(range(128, 256))
+
+    return bytes.maketrans(input, b"." * len(input))
+
+
+class _DebugStream(object):
+    """
+    Wraps a file-like object and adds hexdumps of the read and write data. Call
+    end_packet() on a _DebugStream to write the accumulated hexdumps to the
+    output stream, along with the packet that was sent.
+    """
+
+    _translation_map = _hexdump_translation_map()
+
+    def __init__(self, stream, out=sys.stdout):
+        """
+        Creates a new _DebugStream wrapping the given stream (which must have
+        been created by wrap()). All attributes not provided by the _DebugStream
+        are delegated to the wrapped stream. out is the text stream to which
+        hexdumps are written.
+        """
+        self.raw = stream
+        self._out = out
+        self._rbuf = io.BytesIO()
+        self._wbuf = io.BytesIO()
+
+    def __getattr__(self, name):
+        return getattr(self.raw, name)
+
+    def __setattr__(self, name, value):
+        if name in ("raw", "_out", "_rbuf", "_wbuf"):
+            return object.__setattr__(self, name, value)
+
+        setattr(self.raw, name, value)
+
+    def read(self, *args, **kwargs):
+        buf = self.raw.read(*args, **kwargs)
+
+        self._rbuf.write(buf)
+        return buf
+
+    def write(self, b):
+        self._wbuf.write(b)
+        return self.raw.write(b)
+
+    def recv(self, *args):
+        buf = self.raw.recv(*args)
+
+        self._rbuf.write(buf)
+        return buf
+
+    def _flush(self, buf, prefix):
+        width = 16
+        hexwidth = width * 3 - 1
+
+        count = 0
+        buf.seek(0)
+
+        while True:
+            line = buf.read(16)
+
+            if not line:
+                if count:
+                    self._out.write("\n")  # separate the output block with a newline
+                return
+
+            self._out.write("%s %04X:\t" % (prefix, count))
+            self._out.write("%*s\t" % (-hexwidth, line.hex(" ")))
+            self._out.write(line.translate(self._translation_map).decode("ascii"))
+            self._out.write("\n")
+
+            count += 16
+
+    def print_debug(self, obj, *, prefix=""):
+        contents = ""
+        if obj is not None:
+            contents = str(obj)
+
+        for line in contents.splitlines():
+            self._out.write("%s%s\n" % (prefix, line))
+
+        self._out.write("\n")
+
+    def flush_debug(self, *, prefix=""):
+        self._flush(self._rbuf, prefix + "<")
+        self._rbuf = io.BytesIO()
+
+        self._flush(self._wbuf, prefix + ">")
+        self._wbuf = io.BytesIO()
+
+    def end_packet(self, pkt, *, read=False, prefix="", indent="  "):
+        """
+        Marks the end of a logical "packet" of data. A string representation of
+        pkt will be printed, and the debug buffers will be flushed with an
+        indent. All lines can be optionally prefixed.
+
+        If read is True, the packet representation is written after the debug
+        buffers; otherwise the default of False (meaning write) causes the
+        packet representation to be dumped first. This is meant to capture the
+        logical flow of layer translation.
+        """
+        write = not read
+
+        if write:
+            self.print_debug(pkt, prefix=prefix + "> ")
+
+        self.flush_debug(prefix=prefix + indent)
+
+        if read:
+            self.print_debug(pkt, prefix=prefix + "< ")
+
+
+@contextlib.contextmanager
+def wrap(socket, *, debug_stream=None):
+    """
+    Transforms a raw socket into a connection that can be used for Construct
+    building and parsing. The return value is a context manager and can be used
+    in a with statement.
+    """
+    # It is critical that buffering be disabled here, so that we can still
+    # manipulate the raw socket without desyncing the stream.
+    with socket.makefile("rwb", buffering=0) as sfile:
+        # Expose the original socket's recv() on the SocketIO object we return.
+        def recv(self, *args):
+            return socket.recv(*args)
+
+        sfile.recv = recv.__get__(sfile)
+
+        conn = sfile
+        if debug_stream:
+            conn = _DebugStream(conn, debug_stream)
+
+        try:
+            yield conn
+        finally:
+            if debug_stream:
+                conn.flush_debug(prefix="? ")
+
+
+def _send(stream, cls, obj):
+    debugging = hasattr(stream, "flush_debug")
+    out = io.BytesIO()
+
+    # Ideally we would build directly to the passed stream, but because we need
+    # to reparse the generated output for the debugging case, build to an
+    # intermediate BytesIO and send it instead.
+    cls.build_stream(obj, out)
+    buf = out.getvalue()
+
+    stream.write(buf)
+    if debugging:
+        pkt = cls.parse(buf)
+        stream.end_packet(pkt)
+
+    stream.flush()
+
+
+def send(stream, packet_type, payload_data=None, **payloadkw):
+    """
+    Sends a packet on the given pq3 connection. type is the pq3.types member
+    that should be assigned to the packet. If payload_data is given, it will be
+    used as the packet payload; otherwise the key/value pairs in payloadkw will
+    be the payload contents.
+    """
+    data = payloadkw
+
+    if payload_data is not None:
+        if payloadkw:
+            raise ValueError(
+                "payload_data and payload keywords may not be used simultaneously"
+            )
+
+        data = payload_data
+
+    _send(stream, Pq3, dict(type=packet_type, payload=data))
+
+
+def send_startup(stream, proto=None, **kwargs):
+    """
+    Sends a startup packet on the given pq3 connection. In most cases you should
+    use the handshake functions instead, which will do this for you.
+
+    By default, a protocol version 3 packet will be sent. This can be overridden
+    with the proto parameter.
+    """
+    pkt = {}
+
+    if proto is not None:
+        pkt["proto"] = proto
+    if kwargs:
+        pkt["payload"] = kwargs
+
+    _send(stream, Startup, pkt)
+
+
+def recv1(stream, *, cls=Pq3):
+    """
+    Receives a single pq3 packet from the given stream and returns it.
+    """
+    resp = cls.parse_stream(stream)
+
+    debugging = hasattr(stream, "flush_debug")
+    if debugging:
+        stream.end_packet(resp, read=True)
+
+    return resp
+
+
+def handshake(stream, **kwargs):
+    """
+    Performs a libpq v3 startup handshake. kwargs should contain the key/value
+    parameters to send to the server in the startup packet.
+    """
+    # Send our startup parameters.
+    send_startup(stream, **kwargs)
+
+    # Receive and dump packets until the server indicates it's ready for our
+    # first query.
+    while True:
+        resp = recv1(stream)
+        if resp is None:
+            raise RuntimeError("server closed connection during handshake")
+
+        if resp.type == types.ReadyForQuery:
+            return
+        elif resp.type == types.ErrorResponse:
+            raise RuntimeError(
+                f"received error response from peer: {resp.payload.fields!r}"
+            )
+
+
+# TLS
+
+
+class _TLSStream(object):
+    """
+    A file-like object that performs TLS encryption/decryption on a wrapped
+    stream. Differs from ssl.SSLSocket in that we have full visibility and
+    control over the TLS layer.
+    """
+
+    def __init__(self, stream, context):
+        self._stream = stream
+        self._debugging = hasattr(stream, "flush_debug")
+
+        self._in = ssl.MemoryBIO()
+        self._out = ssl.MemoryBIO()
+        self._ssl = context.wrap_bio(self._in, self._out)
+
+    def handshake(self):
+        try:
+            self._pump(lambda: self._ssl.do_handshake())
+        finally:
+            self._flush_debug(prefix="? ")
+
+    def read(self, *args):
+        return self._pump(lambda: self._ssl.read(*args))
+
+    def write(self, *args):
+        return self._pump(lambda: self._ssl.write(*args))
+
+    def _decode(self, buf):
+        """
+        Attempts to decode a buffer of TLS data into a packet representation
+        that can be printed.
+
+        TODO: handle buffers (and record fragments) that don't align with packet
+        boundaries.
+        """
+        end = len(buf)
+        bio = io.BytesIO(buf)
+
+        ret = io.StringIO()
+
+        while bio.tell() < end:
+            record = tls.Plaintext.parse_stream(bio)
+
+            if ret.tell() > 0:
+                ret.write("\n")
+            ret.write("[Record] ")
+            ret.write(str(record))
+            ret.write("\n")
+
+            if record.type == tls.ContentType.handshake:
+                record_cls = tls.Handshake
+            else:
+                continue
+
+            innerlen = len(record.fragment)
+            inner = io.BytesIO(record.fragment)
+
+            while inner.tell() < innerlen:
+                msg = record_cls.parse_stream(inner)
+
+                indented = "[Message] " + str(msg)
+                indented = textwrap.indent(indented, "    ")
+
+                ret.write("\n")
+                ret.write(indented)
+                ret.write("\n")
+
+        return ret.getvalue()
+
+    def flush(self):
+        if not self._out.pending:
+            self._stream.flush()
+            return
+
+        buf = self._out.read()
+        self._stream.write(buf)
+
+        if self._debugging:
+            pkt = self._decode(buf)
+            self._stream.end_packet(pkt, prefix="  ")
+
+        self._stream.flush()
+
+    def _pump(self, operation):
+        while True:
+            try:
+                return operation()
+            except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
+                want = e
+            self._read_write(want)
+
+    def _recv(self, maxsize):
+        buf = self._stream.recv(4096)
+        if not buf:
+            self._in.write_eof()
+            return
+
+        self._in.write(buf)
+
+        if not self._debugging:
+            return
+
+        pkt = self._decode(buf)
+        self._stream.end_packet(pkt, read=True, prefix="  ")
+
+    def _read_write(self, want):
+        # XXX This needs work. So many corner cases yet to handle. For one,
+        # doing blocking writes in flush may lead to distributed deadlock if the
+        # peer is already blocking on its writes.
+
+        if isinstance(want, ssl.SSLWantWriteError):
+            assert self._out.pending, "SSL backend wants write without data"
+
+        self.flush()
+
+        if isinstance(want, ssl.SSLWantReadError):
+            self._recv(4096)
+
+    def _flush_debug(self, prefix):
+        if not self._debugging:
+            return
+
+        self._stream.flush_debug(prefix=prefix)
+
+
+@contextlib.contextmanager
+def tls_handshake(stream, context):
+    """
+    Performs a TLS handshake over the given stream (which must have been created
+    via a call to wrap()), and returns a new stream which transparently tunnels
+    data over the TLS connection.
+
+    If the passed stream has debugging enabled, the returned stream will also
+    have debugging, using the same output IO.
+    """
+    debugging = hasattr(stream, "flush_debug")
+
+    # Send our startup parameters.
+    send_startup(stream, proto=protocol(1234, 5679))
+
+    # Look at the SSL response.
+    resp = stream.read(1)
+    if debugging:
+        stream.flush_debug(prefix="  ")
+
+    if resp == b"N":
+        raise RuntimeError("server does not support SSLRequest")
+    if resp != b"S":
+        raise RuntimeError(f"unexpected response of type {resp!r} during TLS startup")
+
+    tls = _TLSStream(stream, context)
+    tls.handshake()
+
+    if debugging:
+        tls = _DebugStream(tls, stream._out)
+
+    try:
+        yield tls
+        # TODO: teardown/unwrap the connection?
+    finally:
+        if debugging:
+            tls.flush_debug(prefix="? ")
diff --git a/src/test/python/pytest.ini b/src/test/python/pytest.ini
new file mode 100644
index 0000000000..ab7a6e7fb9
--- /dev/null
+++ b/src/test/python/pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+
+markers =
+    slow: mark test as slow
diff --git a/src/test/python/requirements.txt b/src/test/python/requirements.txt
new file mode 100644
index 0000000000..32f105ea84
--- /dev/null
+++ b/src/test/python/requirements.txt
@@ -0,0 +1,7 @@
+black
+cryptography~=3.4.6
+construct~=2.10.61
+isort~=5.6
+psycopg2~=2.8.6
+pytest~=6.1
+pytest-asyncio~=0.14.0
diff --git a/src/test/python/server/__init__.py b/src/test/python/server/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/python/server/conftest.py b/src/test/python/server/conftest.py
new file mode 100644
index 0000000000..ba7342a453
--- /dev/null
+++ b/src/test/python/server/conftest.py
@@ -0,0 +1,45 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import socket
+import sys
+
+import pytest
+
+import pq3
+
+
+@pytest.fixture
+def connect():
+    """
+    A factory fixture that, when called, returns a socket connected to a
+    Postgres server, wrapped in a pq3 connection. The calling test will be
+    skipped automatically if a server is not running at PGHOST:PGPORT, so it's
+    best to connect as soon as possible after the test case begins, to avoid
+    doing unnecessary work.
+    """
+    # Set up an ExitStack to handle safe cleanup of all of the moving pieces.
+    with contextlib.ExitStack() as stack:
+
+        def conn_factory():
+            addr = (pq3.pghost(), pq3.pgport())
+
+            try:
+                sock = socket.create_connection(addr, timeout=2)
+            except ConnectionError as e:
+                pytest.skip(f"unable to connect to {addr}: {e}")
+
+            # Have ExitStack close our socket.
+            stack.enter_context(sock)
+
+            # Wrap the connection in a pq3 layer and have ExitStack clean it up
+            # too.
+            wrap_ctx = pq3.wrap(sock, debug_stream=sys.stdout)
+            conn = stack.enter_context(wrap_ctx)
+
+            return conn
+
+        yield conn_factory
diff --git a/src/test/python/server/test_oauth.py b/src/test/python/server/test_oauth.py
new file mode 100644
index 0000000000..cb5ca7fa23
--- /dev/null
+++ b/src/test/python/server/test_oauth.py
@@ -0,0 +1,1012 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import base64
+import contextlib
+import json
+import os
+import pathlib
+import secrets
+import shlex
+import shutil
+import socket
+import struct
+from multiprocessing import shared_memory
+
+import psycopg2
+import pytest
+from psycopg2 import sql
+
+import pq3
+
+MAX_SASL_MESSAGE_LENGTH = 65535
+
+INVALID_AUTHORIZATION_ERRCODE = b"28000"
+PROTOCOL_VIOLATION_ERRCODE = b"08P01"
+FEATURE_NOT_SUPPORTED_ERRCODE = b"0A000"
+
+SHARED_MEM_NAME = "oauth-pytest"
+MAX_TOKEN_SIZE = 4096
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def skip_if_no_postgres():
+    """
+    Used by the oauth_ctx fixture to skip this test module if no Postgres server
+    is running.
+
+    This logic is nearly duplicated with the conn fixture. Ideally oauth_ctx
+    would depend on that, but a module-scope fixture can't depend on a
+    test-scope fixture, and we haven't reached the rule of three yet.
+    """
+    addr = (pq3.pghost(), pq3.pgport())
+
+    try:
+        with socket.create_connection(addr, timeout=2):
+            pass
+    except ConnectionError as e:
+        pytest.skip(f"unable to connect to {addr}: {e}")
+
+
+@contextlib.contextmanager
+def prepend_file(path, lines):
+    """
+    A context manager that prepends a file on disk with the desired lines of
+    text. When the context manager is exited, the file will be restored to its
+    original contents.
+    """
+    # First make a backup of the original file.
+    bak = path + ".bak"
+    shutil.copy2(path, bak)
+
+    try:
+        # Write the new lines, followed by the original file content.
+        with open(path, "w") as new, open(bak, "r") as orig:
+            new.writelines(lines)
+            shutil.copyfileobj(orig, new)
+
+        # Return control to the calling code.
+        yield
+
+    finally:
+        # Put the backup back into place.
+        os.replace(bak, path)
+
+
+@pytest.fixture(scope="module")
+def oauth_ctx():
+    """
+    Creates a database and user that use the oauth auth method. The context
+    object contains the dbname and user attributes as strings to be used during
+    connection, as well as the issuer and scope that have been set in the HBA
+    configuration.
+
+    This fixture assumes that the standard PG* environment variables point to a
+    server running on a local machine, and that the PGUSER has rights to create
+    databases and roles.
+    """
+    skip_if_no_postgres()  # don't bother running these tests without a server
+
+    id = secrets.token_hex(4)
+
+    class Context:
+        dbname = "oauth_test_" + id
+
+        user = "oauth_user_" + id
+        map_user = "oauth_map_user_" + id
+        authz_user = "oauth_authz_user_" + id
+
+        issuer = "https://example.com/" + id
+        scope = "openid " + id
+
+    ctx = Context()
+    hba_lines = (
+        f'host {ctx.dbname} {ctx.map_user}   samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" map=oauth\n',
+        f'host {ctx.dbname} {ctx.authz_user} samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}" trust_validator_authz=1\n',
+        f'host {ctx.dbname} all              samehost oauth issuer="{ctx.issuer}" scope="{ctx.scope}"\n',
+    )
+    ident_lines = (r"oauth /^(.*)@example\.com$ \1",)
+
+    conn = psycopg2.connect("")
+    conn.autocommit = True
+
+    with contextlib.closing(conn):
+        c = conn.cursor()
+
+        # Create our roles and database.
+        user = sql.Identifier(ctx.user)
+        map_user = sql.Identifier(ctx.map_user)
+        authz_user = sql.Identifier(ctx.authz_user)
+        dbname = sql.Identifier(ctx.dbname)
+
+        c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(user))
+        c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(map_user))
+        c.execute(sql.SQL("CREATE ROLE {} LOGIN;").format(authz_user))
+        c.execute(sql.SQL("CREATE DATABASE {};").format(dbname))
+
+        # Make this test script the server's oauth_validator.
+        path = pathlib.Path(__file__).parent / "validate_bearer.py"
+        path = str(path.absolute())
+
+        cmd = f"{shlex.quote(path)} {SHARED_MEM_NAME} <&%f"
+        c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+
+        # Replace pg_hba and pg_ident.
+        c.execute("SHOW hba_file;")
+        hba = c.fetchone()[0]
+
+        c.execute("SHOW ident_file;")
+        ident = c.fetchone()[0]
+
+        with prepend_file(hba, hba_lines), prepend_file(ident, ident_lines):
+            c.execute("SELECT pg_reload_conf();")
+
+            # Use the new database and user.
+            yield ctx
+
+        # Put things back the way they were.
+        c.execute("SELECT pg_reload_conf();")
+
+        c.execute("ALTER SYSTEM RESET oauth_validator_command;")
+        c.execute(sql.SQL("DROP DATABASE {};").format(dbname))
+        c.execute(sql.SQL("DROP ROLE {};").format(authz_user))
+        c.execute(sql.SQL("DROP ROLE {};").format(map_user))
+        c.execute(sql.SQL("DROP ROLE {};").format(user))
+
+
+@pytest.fixture()
+def conn(oauth_ctx, connect):
+    """
+    A convenience wrapper for connect(). The main purpose of this fixture is to
+    make sure oauth_ctx runs its setup code before the connection is made.
+    """
+    return connect()
+
+
+@pytest.fixture(scope="module", autouse=True)
+def authn_id_extension(oauth_ctx):
+    """
+    Performs a `CREATE EXTENSION authn_id` in the test database. This fixture is
+    autoused, so tests don't need to rely on it.
+    """
+    conn = psycopg2.connect(database=oauth_ctx.dbname)
+    conn.autocommit = True
+
+    with contextlib.closing(conn):
+        c = conn.cursor()
+        c.execute("CREATE EXTENSION authn_id;")
+
+
+@pytest.fixture(scope="session")
+def shared_mem():
+    """
+    Yields a shared memory segment that can be used for communication between
+    the bearer_token fixture and ./validate_bearer.py.
+    """
+    size = MAX_TOKEN_SIZE + 2  # two byte length prefix
+    mem = shared_memory.SharedMemory(SHARED_MEM_NAME, create=True, size=size)
+
+    try:
+        with contextlib.closing(mem):
+            yield mem
+    finally:
+        mem.unlink()
+
+
+@pytest.fixture()
+def bearer_token(shared_mem):
+    """
+    Returns a factory function that, when called, will store a Bearer token in
+    shared_mem. If token is None (the default), a new token will be generated
+    using secrets.token_urlsafe() and returned; otherwise the passed token will
+    be used as-is.
+
+    When token is None, the generated token size in bytes may be specified as an
+    argument; if unset, a small 16-byte token will be generated. The token size
+    may not exceed MAX_TOKEN_SIZE in any case.
+
+    The return value is the token, converted to a bytes object.
+
+    As a special case for testing failure modes, accept_any may be set to True.
+    This signals to the validator command that any bearer token should be
+    accepted. The returned token in this case may be used or discarded as needed
+    by the test.
+    """
+
+    def set_token(token=None, *, size=16, accept_any=False):
+        if token is not None:
+            size = len(token)
+
+        if size > MAX_TOKEN_SIZE:
+            raise ValueError(f"token size {size} exceeds maximum size {MAX_TOKEN_SIZE}")
+
+        if token is None:
+            if size % 4:
+                raise ValueError(f"requested token size {size} is not a multiple of 4")
+
+            token = secrets.token_urlsafe(size // 4 * 3)
+            assert len(token) == size
+
+        try:
+            token = token.encode("ascii")
+        except AttributeError:
+            pass  # already encoded
+
+        if accept_any:
+            # Two-byte magic value.
+            shared_mem.buf[:2] = struct.pack("H", MAX_UINT16)
+        else:
+            # Two-byte length prefix, then the token data.
+            shared_mem.buf[:2] = struct.pack("H", len(token))
+            shared_mem.buf[2 : size + 2] = token
+
+        return token
+
+    return set_token
+
+
+def begin_oauth_handshake(conn, oauth_ctx, *, user=None):
+    if user is None:
+        user = oauth_ctx.authz_user
+
+    pq3.send_startup(conn, user=user, database=oauth_ctx.dbname)
+
+    resp = pq3.recv1(conn)
+    assert resp.type == pq3.types.AuthnRequest
+
+    # The server should advertise exactly one mechanism.
+    assert resp.payload.type == pq3.authn.SASL
+    assert resp.payload.body == [b"OAUTHBEARER", b""]
+
+
+def send_initial_response(conn, *, auth=None, bearer=None):
+    """
+    Sends the OAUTHBEARER initial response on the connection, using the given
+    bearer token. Alternatively to a bearer token, the initial response's auth
+    field may be explicitly specified to test corner cases.
+    """
+    if bearer is not None and auth is not None:
+        raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+    if bearer is not None:
+        auth = b"Bearer " + bearer
+
+    if auth is None:
+        raise ValueError("exactly one of the auth and bearer kwargs must be set")
+
+    initial = pq3.SASLInitialResponse.build(
+        dict(
+            name=b"OAUTHBEARER",
+            data=b"n,,\x01auth=" + auth + b"\x01\x01",
+        )
+    )
+    pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+
+def expect_handshake_success(conn):
+    """
+    Validates that the server responds with an AuthnOK message, and then drains
+    the connection until a ReadyForQuery message is received.
+    """
+    resp = pq3.recv1(conn)
+
+    assert resp.type == pq3.types.AuthnRequest
+    assert resp.payload.type == pq3.authn.OK
+    assert not resp.payload.body
+
+    receive_until(conn, pq3.types.ReadyForQuery)
+
+
+def expect_handshake_failure(conn, oauth_ctx):
+    """
+    Performs the OAUTHBEARER SASL failure "handshake" and validates the server's
+    side of the conversation, including the final ErrorResponse.
+    """
+
+    # We expect a discovery "challenge" back from the server before the authn
+    # failure message.
+    resp = pq3.recv1(conn)
+    assert resp.type == pq3.types.AuthnRequest
+
+    req = resp.payload
+    assert req.type == pq3.authn.SASLContinue
+
+    body = json.loads(req.body)
+    assert body["status"] == "invalid_token"
+    assert body["scope"] == oauth_ctx.scope
+
+    expected_config = oauth_ctx.issuer + "/.well-known/openid-configuration"
+    assert body["openid-configuration"] == expected_config
+
+    # Send the dummy response to complete the failed handshake.
+    pq3.send(conn, pq3.types.PasswordMessage, b"\x01")
+    resp = pq3.recv1(conn)
+
+    err = ExpectedError(INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed")
+    err.match(resp)
+
+
+def receive_until(conn, type):
+    """
+    receive_until pulls packets off the pq3 connection until a packet with the
+    desired type is found, or an error response is received.
+    """
+    while True:
+        pkt = pq3.recv1(conn)
+
+        if pkt.type == type:
+            return pkt
+        elif pkt.type == pq3.types.ErrorResponse:
+            raise RuntimeError(
+                f"received error response from peer: {pkt.payload.fields!r}"
+            )
+
+
+@pytest.mark.parametrize("token_len", [16, 1024, 4096])
+@pytest.mark.parametrize(
+    "auth_prefix",
+    [
+        b"Bearer ",
+        b"bearer ",
+        b"Bearer    ",
+    ],
+)
+def test_oauth(conn, oauth_ctx, bearer_token, auth_prefix, token_len):
+    begin_oauth_handshake(conn, oauth_ctx)
+
+    # Generate our bearer token with the desired length.
+    token = bearer_token(size=token_len)
+    auth = auth_prefix + token
+
+    send_initial_response(conn, auth=auth)
+    expect_handshake_success(conn)
+
+    # Make sure that the server has not set an authenticated ID.
+    pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+    resp = receive_until(conn, pq3.types.DataRow)
+
+    row = resp.payload
+    assert row.columns == [None]
+
+
+@pytest.mark.parametrize(
+    "token_value",
+    [
+        "abcdzA==",
+        "123456M=",
+        "x-._~+/x",
+    ],
+)
+def test_oauth_bearer_corner_cases(conn, oauth_ctx, bearer_token, token_value):
+    begin_oauth_handshake(conn, oauth_ctx)
+
+    send_initial_response(conn, bearer=bearer_token(token_value))
+
+    expect_handshake_success(conn)
+
+
+@pytest.mark.parametrize(
+    "user,authn_id,should_succeed",
+    [
+        pytest.param(
+            lambda ctx: ctx.user,
+            lambda ctx: ctx.user,
+            True,
+            id="validator authn: succeeds when authn_id == username",
+        ),
+        pytest.param(
+            lambda ctx: ctx.user,
+            lambda ctx: None,
+            False,
+            id="validator authn: fails when authn_id is not set",
+        ),
+        pytest.param(
+            lambda ctx: ctx.user,
+            lambda ctx: ctx.authz_user,
+            False,
+            id="validator authn: fails when authn_id != username",
+        ),
+        pytest.param(
+            lambda ctx: ctx.map_user,
+            lambda ctx: ctx.map_user + "@example.com",
+            True,
+            id="validator with map: succeeds when authn_id matches map",
+        ),
+        pytest.param(
+            lambda ctx: ctx.map_user,
+            lambda ctx: None,
+            False,
+            id="validator with map: fails when authn_id is not set",
+        ),
+        pytest.param(
+            lambda ctx: ctx.map_user,
+            lambda ctx: ctx.map_user + "@example.net",
+            False,
+            id="validator with map: fails when authn_id doesn't match map",
+        ),
+        pytest.param(
+            lambda ctx: ctx.authz_user,
+            lambda ctx: None,
+            True,
+            id="validator authz: succeeds with no authn_id",
+        ),
+        pytest.param(
+            lambda ctx: ctx.authz_user,
+            lambda ctx: "",
+            True,
+            id="validator authz: succeeds with empty authn_id",
+        ),
+        pytest.param(
+            lambda ctx: ctx.authz_user,
+            lambda ctx: "postgres",
+            True,
+            id="validator authz: succeeds with basic username",
+        ),
+        pytest.param(
+            lambda ctx: ctx.authz_user,
+            lambda ctx: "me@example.com",
+            True,
+            id="validator authz: succeeds with email address",
+        ),
+    ],
+)
+def test_oauth_authn_id(conn, oauth_ctx, bearer_token, user, authn_id, should_succeed):
+    token = None
+
+    authn_id = authn_id(oauth_ctx)
+    if authn_id is not None:
+        authn_id = authn_id.encode("ascii")
+
+        # As a hack to get the validator to reflect arbitrary output from this
+        # test, encode the desired output as a base64 token. The validator will
+        # key on the leading "output=" to differentiate this from the random
+        # tokens generated by secrets.token_urlsafe().
+        output = b"output=" + authn_id + b"\n"
+        token = base64.urlsafe_b64encode(output)
+
+    token = bearer_token(token)
+    username = user(oauth_ctx)
+
+    begin_oauth_handshake(conn, oauth_ctx, user=username)
+    send_initial_response(conn, bearer=token)
+
+    if not should_succeed:
+        expect_handshake_failure(conn, oauth_ctx)
+        return
+
+    expect_handshake_success(conn)
+
+    # Check the reported authn_id.
+    pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+    resp = receive_until(conn, pq3.types.DataRow)
+
+    row = resp.payload
+    assert row.columns == [authn_id]
+
+
+class ExpectedError(object):
+    def __init__(self, code, msg=None, detail=None):
+        self.code = code
+        self.msg = msg
+        self.detail = detail
+
+        # Protect against the footgun of an accidental empty string, which will
+        # "match" anything. If you don't want to match message or detail, just
+        # don't pass them.
+        if self.msg == "":
+            raise ValueError("msg must be non-empty or None")
+        if self.detail == "":
+            raise ValueError("detail must be non-empty or None")
+
+    def _getfield(self, resp, type):
+        """
+        Searches an ErrorResponse for a single field of the given type (e.g.
+        "M", "C", "D") and returns its value. Asserts if it doesn't find exactly
+        one field.
+        """
+        prefix = type.encode("ascii")
+        fields = [f for f in resp.payload.fields if f.startswith(prefix)]
+
+        assert len(fields) == 1
+        return fields[0][1:]  # strip off the type byte
+
+    def match(self, resp):
+        """
+        Checks that the given response matches the expected code, message, and
+        detail (if given). The error code must match exactly. The expected
+        message and detail must be contained within the actual strings.
+        """
+        assert resp.type == pq3.types.ErrorResponse
+
+        code = self._getfield(resp, "C")
+        assert code == self.code
+
+        if self.msg:
+            msg = self._getfield(resp, "M")
+            expected = self.msg.encode("utf-8")
+            assert expected in msg
+
+        if self.detail:
+            detail = self._getfield(resp, "D")
+            expected = self.detail.encode("utf-8")
+            assert expected in detail
+
+
+def test_oauth_rejected_bearer(conn, oauth_ctx, bearer_token):
+    # Generate a new bearer token, which we will proceed not to use.
+    _ = bearer_token()
+
+    begin_oauth_handshake(conn, oauth_ctx)
+
+    # Send a bearer token that doesn't match what the validator expects. It
+    # should fail the connection.
+    send_initial_response(conn, bearer=b"xxxxxx")
+
+    expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.parametrize(
+    "bad_bearer",
+    [
+        b"Bearer    ",
+        b"Bearer a===b",
+        b"Bearer hello!",
+        b"Bearer me@example.com",
+        b'OAuth realm="Example"',
+        b"",
+    ],
+)
+def test_oauth_invalid_bearer(conn, oauth_ctx, bearer_token, bad_bearer):
+    # Tell the validator to accept any token. This ensures that the invalid
+    # bearer tokens are rejected before the validation step.
+    _ = bearer_token(accept_any=True)
+
+    begin_oauth_handshake(conn, oauth_ctx)
+    send_initial_response(conn, auth=bad_bearer)
+
+    expect_handshake_failure(conn, oauth_ctx)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+    "resp_type,resp,err",
+    [
+        pytest.param(
+            None,
+            None,
+            None,
+            marks=pytest.mark.slow,
+            id="no response (expect timeout)",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            b"hello",
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "did not send a kvsep response",
+            ),
+            id="bad dummy response",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            b"\x01\x01",
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "did not send a kvsep response",
+            ),
+            id="multiple kvseps",
+        ),
+        pytest.param(
+            pq3.types.Query,
+            dict(query=b""),
+            ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+            id="bad response message type",
+        ),
+    ],
+)
+def test_oauth_bad_response_to_error_challenge(conn, oauth_ctx, resp_type, resp, err):
+    begin_oauth_handshake(conn, oauth_ctx)
+
+    # Send an empty auth initial response, which will force an authn failure.
+    send_initial_response(conn, auth=b"")
+
+    # We expect a discovery "challenge" back from the server before the authn
+    # failure message.
+    pkt = pq3.recv1(conn)
+    assert pkt.type == pq3.types.AuthnRequest
+
+    req = pkt.payload
+    assert req.type == pq3.authn.SASLContinue
+
+    body = json.loads(req.body)
+    assert body["status"] == "invalid_token"
+
+    if resp_type is None:
+        # Do not send the dummy response. We should time out and not get a
+        # response from the server.
+        with pytest.raises(socket.timeout):
+            conn.read(1)
+
+        # Done with the test.
+        return
+
+    # Send the bad response.
+    pq3.send(conn, resp_type, resp)
+
+    # Make sure the server fails the connection correctly.
+    pkt = pq3.recv1(conn)
+    err.match(pkt)
+
+
+@pytest.mark.parametrize(
+    "type,payload,err",
+    [
+        pytest.param(
+            pq3.types.ErrorResponse,
+            dict(fields=[b""]),
+            ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "expected SASL response"),
+            id="error response in initial message",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            b"x" * (MAX_SASL_MESSAGE_LENGTH + 1),
+            ExpectedError(
+                INVALID_AUTHORIZATION_ERRCODE, "bearer authentication failed"
+            ),
+            id="overlong initial response data",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"SCRAM-SHA-256")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE, "invalid SASL authentication mechanism"
+            ),
+            id="bad SASL mechanism selection",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=2, data=b"x")),
+            ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "insufficient data"),
+            id="SASL data underflow",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", len=0, data=b"x")),
+            ExpectedError(PROTOCOL_VIOLATION_ERRCODE, "invalid message format"),
+            id="SASL data overflow",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "message is empty",
+            ),
+            id="empty",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(
+                dict(name=b"OAUTHBEARER", data=b"n,,\x01auth=\x01\x01\0")
+            ),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "length does not match input length",
+            ),
+            id="contains null byte",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"\x01")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "Unexpected channel-binding flag",  # XXX this is a bit strange
+            ),
+            id="initial error response",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(
+                dict(name=b"OAUTHBEARER", data=b"p=tls-server-end-point,,\x01")
+            ),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "server does not support channel binding",
+            ),
+            id="uses channel binding",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"x,,\x01")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "Unexpected channel-binding flag",
+            ),
+            id="invalid channel binding specifier",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "Comma expected",
+            ),
+            id="bad GS2 header: missing channel binding terminator",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,a")),
+            ExpectedError(
+                FEATURE_NOT_SUPPORTED_ERRCODE,
+                "client uses authorization identity",
+            ),
+            id="bad GS2 header: authzid in use",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,b,")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "Unexpected attribute",
+            ),
+            id="bad GS2 header: extra attribute",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "Unexpected attribute 0x00",  # XXX this is a bit strange
+            ),
+            id="bad GS2 header: missing authzid terminator",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "Key-value separator expected",
+            ),
+            id="missing initial kvsep",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "Key-value separator expected",
+            ),
+            id="missing initial kvsep",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(
+                dict(name=b"OAUTHBEARER", data=b"y,,\x01\x01")
+            ),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "does not contain an auth value",
+            ),
+            id="missing auth value: empty key-value list",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(
+                dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com\x01\x01")
+            ),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "does not contain an auth value",
+            ),
+            id="missing auth value: other keys present",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(
+                dict(name=b"OAUTHBEARER", data=b"y,,\x01host=example.com")
+            ),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "unterminated key/value pair",
+            ),
+            id="missing value terminator",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER", data=b"y,,\x01")),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "did not contain a final terminator",
+            ),
+            id="missing list terminator: empty list",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(
+                dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01")
+            ),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "did not contain a final terminator",
+            ),
+            id="missing list terminator: with auth value",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(
+                dict(name=b"OAUTHBEARER", data=b"y,,\x01auth=Bearer 0\x01\x01blah")
+            ),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "additional data after the final terminator",
+            ),
+            id="additional key after terminator",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(
+                dict(name=b"OAUTHBEARER", data=b"y,,\x01key\x01\x01")
+            ),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "key without a value",
+            ),
+            id="key without value",
+        ),
+        pytest.param(
+            pq3.types.PasswordMessage,
+            pq3.SASLInitialResponse.build(
+                dict(
+                    name=b"OAUTHBEARER",
+                    data=b"y,,\x01auth=Bearer 0\x01auth=Bearer 1\x01\x01",
+                )
+            ),
+            ExpectedError(
+                PROTOCOL_VIOLATION_ERRCODE,
+                "malformed OAUTHBEARER message",
+                "contains multiple auth values",
+            ),
+            id="multiple auth values",
+        ),
+    ],
+)
+def test_oauth_bad_initial_response(conn, oauth_ctx, type, payload, err):
+    begin_oauth_handshake(conn, oauth_ctx)
+
+    # The server expects a SASL response; give it something else instead.
+    if not isinstance(payload, dict):
+        payload = dict(payload_data=payload)
+    pq3.send(conn, type, **payload)
+
+    resp = pq3.recv1(conn)
+    err.match(resp)
+
+
+def test_oauth_empty_initial_response(conn, oauth_ctx, bearer_token):
+    begin_oauth_handshake(conn, oauth_ctx)
+
+    # Send an initial response without data.
+    initial = pq3.SASLInitialResponse.build(dict(name=b"OAUTHBEARER"))
+    pq3.send(conn, pq3.types.PasswordMessage, initial)
+
+    # The server should respond with an empty challenge so we can send the data
+    # it wants.
+    pkt = pq3.recv1(conn)
+
+    assert pkt.type == pq3.types.AuthnRequest
+    assert pkt.payload.type == pq3.authn.SASLContinue
+    assert not pkt.payload.body
+
+    # Now send the initial data.
+    data = b"n,,\x01auth=Bearer " + bearer_token() + b"\x01\x01"
+    pq3.send(conn, pq3.types.PasswordMessage, data)
+
+    # Server should now complete the handshake.
+    expect_handshake_success(conn)
+
+
+@pytest.fixture()
+def set_validator():
+    """
+    A per-test fixture that allows a test to override the setting of
+    oauth_validator_command for the cluster. The setting will be reverted during
+    teardown.
+
+    Passing None will perform an ALTER SYSTEM RESET.
+    """
+    conn = psycopg2.connect("")
+    conn.autocommit = True
+
+    with contextlib.closing(conn):
+        c = conn.cursor()
+
+        # Save the previous value.
+        c.execute("SHOW oauth_validator_command;")
+        prev_cmd = c.fetchone()[0]
+
+        def setter(cmd):
+            c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (cmd,))
+            c.execute("SELECT pg_reload_conf();")
+
+        yield setter
+
+        # Restore the previous value.
+        c.execute("ALTER SYSTEM SET oauth_validator_command TO %s;", (prev_cmd,))
+        c.execute("SELECT pg_reload_conf();")
+
+
+def test_oauth_no_validator(oauth_ctx, set_validator, connect, bearer_token):
+    # Clear out our validator command, then establish a new connection.
+    set_validator("")
+    conn = connect()
+
+    begin_oauth_handshake(conn, oauth_ctx)
+    send_initial_response(conn, bearer=bearer_token())
+
+    # The server should fail the connection.
+    expect_handshake_failure(conn, oauth_ctx)
+
+
+def test_oauth_validator_role(oauth_ctx, set_validator, connect):
+    # Switch the validator implementation. This validator will reflect the
+    # PGUSER as the authenticated identity.
+    path = pathlib.Path(__file__).parent / "validate_reflect.py"
+    path = str(path.absolute())
+
+    set_validator(f"{shlex.quote(path)} '%r' <&%f")
+    conn = connect()
+
+    # Log in. Note that the reflection validator ignores the bearer token.
+    begin_oauth_handshake(conn, oauth_ctx, user=oauth_ctx.user)
+    send_initial_response(conn, bearer=b"dontcare")
+    expect_handshake_success(conn)
+
+    # Check the user identity.
+    pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();")
+    resp = receive_until(conn, pq3.types.DataRow)
+
+    row = resp.payload
+    expected = oauth_ctx.user.encode("utf-8")
+    assert row.columns == [expected]
+
+
+def test_oauth_role_with_shell_unsafe_characters(oauth_ctx, set_validator, connect):
+    """
+    XXX This test pins undesirable behavior. We should be able to handle any
+    valid Postgres role name.
+    """
+    # Switch the validator implementation. This validator will reflect the
+    # PGUSER as the authenticated identity.
+    path = pathlib.Path(__file__).parent / "validate_reflect.py"
+    path = str(path.absolute())
+
+    set_validator(f"{shlex.quote(path)} '%r' <&%f")
+    conn = connect()
+
+    unsafe_username = "hello'there"
+    begin_oauth_handshake(conn, oauth_ctx, user=unsafe_username)
+
+    # The server should reject the handshake.
+    send_initial_response(conn, bearer=b"dontcare")
+    expect_handshake_failure(conn, oauth_ctx)
diff --git a/src/test/python/server/test_server.py b/src/test/python/server/test_server.py
new file mode 100644
index 0000000000..02126dba79
--- /dev/null
+++ b/src/test/python/server/test_server.py
@@ -0,0 +1,21 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import pq3
+
+
+def test_handshake(connect):
+    """Basic sanity check."""
+    conn = connect()
+
+    pq3.handshake(conn, user=pq3.pguser(), database=pq3.pgdatabase())
+
+    pq3.send(conn, pq3.types.Query, query=b"")
+
+    resp = pq3.recv1(conn)
+    assert resp.type == pq3.types.EmptyQueryResponse
+
+    resp = pq3.recv1(conn)
+    assert resp.type == pq3.types.ReadyForQuery
diff --git a/src/test/python/server/validate_bearer.py b/src/test/python/server/validate_bearer.py
new file mode 100755
index 0000000000..2cc73ff154
--- /dev/null
+++ b/src/test/python/server/validate_bearer.py
@@ -0,0 +1,101 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It doesn't actually validate
+# anything, and it logs the bearer token data, which is sensitive.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. Memory is shared and communicated from that test module's
+# bearer_token() fixture.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import base64
+import binascii
+import contextlib
+import struct
+import sys
+from multiprocessing import shared_memory
+
+MAX_UINT16 = 2 ** 16 - 1
+
+
+def remove_shm_from_resource_tracker():
+    """
+    Monkey-patch multiprocessing.resource_tracker so SharedMemory won't be
+    tracked. Pulled from this thread, where there are more details:
+
+        https://bugs.python.org/issue38119
+
+    TL;DR: all clients of shared memory segments automatically destroy them on
+    process exit, which makes shared memory segments much less useful. This
+    monkeypatch removes that behavior so that we can defer to the test to manage
+    the segment lifetime.
+
+    Ideally a future Python patch will pull in this fix and then the entire
+    function can go away.
+    """
+    from multiprocessing import resource_tracker
+
+    def fix_register(name, rtype):
+        if rtype == "shared_memory":
+            return
+        return resource_tracker._resource_tracker.register(self, name, rtype)
+
+    resource_tracker.register = fix_register
+
+    def fix_unregister(name, rtype):
+        if rtype == "shared_memory":
+            return
+        return resource_tracker._resource_tracker.unregister(self, name, rtype)
+
+    resource_tracker.unregister = fix_unregister
+
+    if "shared_memory" in resource_tracker._CLEANUP_FUNCS:
+        del resource_tracker._CLEANUP_FUNCS["shared_memory"]
+
+
+def main(args):
+    remove_shm_from_resource_tracker()  # XXX remove some day
+
+    # Get the expected token from the currently running test.
+    shared_mem_name = args[0]
+
+    mem = shared_memory.SharedMemory(shared_mem_name)
+    with contextlib.closing(mem):
+        # First two bytes are the token length.
+        size = struct.unpack("H", mem.buf[:2])[0]
+
+        if size == MAX_UINT16:
+            # Special case: the test wants us to accept any token.
+            sys.stderr.write("accepting token without validation\n")
+            return
+
+        # The remainder of the buffer contains the expected token.
+        assert size <= (mem.size - 2)
+        expected_token = mem.buf[2 : size + 2].tobytes()
+
+        mem.buf[:] = b"\0" * mem.size  # scribble over the token
+
+    token = sys.stdin.buffer.read()
+    if token != expected_token:
+        sys.exit(f"failed to match Bearer token ({token!r} != {expected_token!r})")
+
+    # See if the test wants us to print anything. If so, it will have encoded
+    # the desired output in the token with an "output=" prefix.
+    try:
+        # altchars="-_" corresponds to the urlsafe alphabet.
+        data = base64.b64decode(token, altchars="-_", validate=True)
+
+        if data.startswith(b"output="):
+            sys.stdout.buffer.write(data[7:])
+
+    except binascii.Error:
+        pass
+
+
+if __name__ == "__main__":
+    main(sys.argv[1:])
diff --git a/src/test/python/server/validate_reflect.py b/src/test/python/server/validate_reflect.py
new file mode 100755
index 0000000000..24c3a7e715
--- /dev/null
+++ b/src/test/python/server/validate_reflect.py
@@ -0,0 +1,34 @@
+#! /usr/bin/env python3
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+# DO NOT USE THIS OAUTH VALIDATOR IN PRODUCTION. It ignores the bearer token
+# entirely and automatically logs the user in.
+#
+# This executable is used as an oauth_validator_command in concert with
+# test_oauth.py. It expects the user's desired role name as an argument; the
+# actual token will be discarded and the user will be logged in with the role
+# name as the authenticated identity.
+#
+# This script must run under the Postgres server environment; keep the
+# dependency list fairly standard.
+
+import sys
+
+
+def main(args):
+    # We have to read the entire token as our first action to unblock the
+    # server, but we won't actually use it.
+    _ = sys.stdin.buffer.read()
+
+    if len(args) != 1:
+        sys.exit("usage: ./validate_reflect.py ROLE")
+
+    # Log the user in as the provided role.
+    role = args[0]
+    print(role)
+
+
+if __name__ == "__main__":
+    main(sys.argv[1:])
diff --git a/src/test/python/test_internals.py b/src/test/python/test_internals.py
new file mode 100644
index 0000000000..dee4855fc0
--- /dev/null
+++ b/src/test/python/test_internals.py
@@ -0,0 +1,138 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import io
+
+from pq3 import _DebugStream
+
+
+def test_DebugStream_read():
+    under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+    out = io.StringIO()
+
+    stream = _DebugStream(under, out)
+
+    res = stream.read(5)
+    assert res == b"abcde"
+
+    res = stream.read(16)
+    assert res == b"fghijklmnopqrstu"
+
+    stream.flush_debug()
+
+    res = stream.read()
+    assert res == b"vwxyz"
+
+    stream.flush_debug()
+
+    expected = (
+        "< 0000:\t61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70\tabcdefghijklmnop\n"
+        "< 0010:\t71 72 73 74 75                                 \tqrstu\n"
+        "\n"
+        "< 0000:\t76 77 78 79 7a                                 \tvwxyz\n"
+        "\n"
+    )
+    assert out.getvalue() == expected
+
+
+def test_DebugStream_write():
+    under = io.BytesIO()
+    out = io.StringIO()
+
+    stream = _DebugStream(under, out)
+
+    stream.write(b"\x00\x01\x02")
+    stream.flush()
+
+    assert under.getvalue() == b"\x00\x01\x02"
+
+    stream.write(b"\xc0\xc1\xc2")
+    stream.flush()
+
+    assert under.getvalue() == b"\x00\x01\x02\xc0\xc1\xc2"
+
+    stream.flush_debug()
+
+    expected = "> 0000:\t00 01 02 c0 c1 c2                              \t......\n\n"
+    assert out.getvalue() == expected
+
+
+def test_DebugStream_read_write():
+    under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+    out = io.StringIO()
+    stream = _DebugStream(under, out)
+
+    res = stream.read(5)
+    assert res == b"abcde"
+
+    stream.write(b"xxxxx")
+    stream.flush()
+
+    assert under.getvalue() == b"abcdexxxxxklmnopqrstuvwxyz"
+
+    res = stream.read(5)
+    assert res == b"klmno"
+
+    stream.write(b"xxxxx")
+    stream.flush()
+
+    assert under.getvalue() == b"abcdexxxxxklmnoxxxxxuvwxyz"
+
+    stream.flush_debug()
+
+    expected = (
+        "< 0000:\t61 62 63 64 65 6b 6c 6d 6e 6f                  \tabcdeklmno\n"
+        "\n"
+        "> 0000:\t78 78 78 78 78 78 78 78 78 78                  \txxxxxxxxxx\n"
+        "\n"
+    )
+    assert out.getvalue() == expected
+
+
+def test_DebugStream_end_packet():
+    under = io.BytesIO(b"abcdefghijklmnopqrstuvwxyz")
+    out = io.StringIO()
+    stream = _DebugStream(under, out)
+
+    stream.read(5)
+    stream.end_packet("read description", read=True, indent=" ")
+
+    stream.write(b"xxxxx")
+    stream.flush()
+    stream.end_packet("write description", indent=" ")
+
+    stream.read(5)
+    stream.write(b"xxxxx")
+    stream.flush()
+    stream.end_packet("read/write combo for read", read=True, indent=" ")
+
+    stream.read(5)
+    stream.write(b"xxxxx")
+    stream.flush()
+    stream.end_packet("read/write combo for write", indent=" ")
+
+    expected = (
+        " < 0000:\t61 62 63 64 65                                 \tabcde\n"
+        "\n"
+        "< read description\n"
+        "\n"
+        "> write description\n"
+        "\n"
+        " > 0000:\t78 78 78 78 78                                 \txxxxx\n"
+        "\n"
+        " < 0000:\t6b 6c 6d 6e 6f                                 \tklmno\n"
+        "\n"
+        " > 0000:\t78 78 78 78 78                                 \txxxxx\n"
+        "\n"
+        "< read/write combo for read\n"
+        "\n"
+        "> read/write combo for write\n"
+        "\n"
+        " < 0000:\t75 76 77 78 79                                 \tuvwxy\n"
+        "\n"
+        " > 0000:\t78 78 78 78 78                                 \txxxxx\n"
+        "\n"
+    )
+    assert out.getvalue() == expected
diff --git a/src/test/python/test_pq3.py b/src/test/python/test_pq3.py
new file mode 100644
index 0000000000..e0c0e0568d
--- /dev/null
+++ b/src/test/python/test_pq3.py
@@ -0,0 +1,558 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+import contextlib
+import getpass
+import io
+import struct
+import sys
+
+import pytest
+from construct import Container, PaddingError, StreamError, TerminatedError
+
+import pq3
+
+
+@pytest.mark.parametrize(
+    "raw,expected,extra",
+    [
+        pytest.param(
+            b"\x00\x00\x00\x10\x00\x04\x00\x00abcdefgh",
+            Container(len=16, proto=0x40000, payload=b"abcdefgh"),
+            b"",
+            id="8-byte payload",
+        ),
+        pytest.param(
+            b"\x00\x00\x00\x08\x00\x04\x00\x00",
+            Container(len=8, proto=0x40000, payload=b""),
+            b"",
+            id="no payload",
+        ),
+        pytest.param(
+            b"\x00\x00\x00\x09\x00\x04\x00\x00abcde",
+            Container(len=9, proto=0x40000, payload=b"a"),
+            b"bcde",
+            id="1-byte payload and extra padding",
+        ),
+        pytest.param(
+            b"\x00\x00\x00\x0B\x00\x03\x00\x00hi\x00",
+            Container(len=11, proto=pq3.protocol(3, 0), payload=[b"hi"]),
+            b"",
+            id="implied parameter list when using proto version 3.0",
+        ),
+    ],
+)
+def test_Startup_parse(raw, expected, extra):
+    with io.BytesIO(raw) as stream:
+        actual = pq3.Startup.parse_stream(stream)
+
+        assert actual == expected
+        assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+    "packet,expected_bytes",
+    [
+        pytest.param(
+            dict(),
+            b"\x00\x00\x00\x08\x00\x00\x00\x00",
+            id="nothing set",
+        ),
+        pytest.param(
+            dict(len=10, proto=0x12345678),
+            b"\x00\x00\x00\x0A\x12\x34\x56\x78\x00\x00",
+            id="len and proto set explicitly",
+        ),
+        pytest.param(
+            dict(proto=0x12345678),
+            b"\x00\x00\x00\x08\x12\x34\x56\x78",
+            id="implied len with no payload",
+        ),
+        pytest.param(
+            dict(proto=0x12345678, payload=b"abcd"),
+            b"\x00\x00\x00\x0C\x12\x34\x56\x78abcd",
+            id="implied len with payload",
+        ),
+        pytest.param(
+            dict(payload=[b""]),
+            b"\x00\x00\x00\x09\x00\x03\x00\x00\x00",
+            id="implied proto version 3 when sending parameters",
+        ),
+        pytest.param(
+            dict(payload=[b"hi", b""]),
+            b"\x00\x00\x00\x0C\x00\x03\x00\x00hi\x00\x00",
+            id="implied proto version 3 and len when sending more than one parameter",
+        ),
+        pytest.param(
+            dict(payload=dict(user="jsmith", database="postgres")),
+            b"\x00\x00\x00\x27\x00\x03\x00\x00user\x00jsmith\x00database\x00postgres\x00\x00",
+            id="auto-serialization of dict parameters",
+        ),
+    ],
+)
+def test_Startup_build(packet, expected_bytes):
+    actual = pq3.Startup.build(packet)
+    assert actual == expected_bytes
+
+
+@pytest.mark.parametrize(
+    "raw,expected,extra",
+    [
+        pytest.param(
+            b"*\x00\x00\x00\x08abcd",
+            dict(type=b"*", len=8, payload=b"abcd"),
+            b"",
+            id="4-byte payload",
+        ),
+        pytest.param(
+            b"*\x00\x00\x00\x04",
+            dict(type=b"*", len=4, payload=b""),
+            b"",
+            id="no payload",
+        ),
+        pytest.param(
+            b"*\x00\x00\x00\x05xabcd",
+            dict(type=b"*", len=5, payload=b"x"),
+            b"abcd",
+            id="1-byte payload with extra padding",
+        ),
+        pytest.param(
+            b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+            dict(
+                type=pq3.types.AuthnRequest,
+                len=8,
+                payload=dict(type=pq3.authn.OK, body=None),
+            ),
+            b"",
+            id="AuthenticationOk",
+        ),
+        pytest.param(
+            b"R\x00\x00\x00\x12\x00\x00\x00\x0AEXTERNAL\x00\x00",
+            dict(
+                type=pq3.types.AuthnRequest,
+                len=18,
+                payload=dict(type=pq3.authn.SASL, body=[b"EXTERNAL", b""]),
+            ),
+            b"",
+            id="AuthenticationSASL",
+        ),
+        pytest.param(
+            b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+            dict(
+                type=pq3.types.AuthnRequest,
+                len=13,
+                payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+            ),
+            b"",
+            id="AuthenticationSASLContinue",
+        ),
+        pytest.param(
+            b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+            dict(
+                type=pq3.types.AuthnRequest,
+                len=13,
+                payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+            ),
+            b"",
+            id="AuthenticationSASLFinal",
+        ),
+        pytest.param(
+            b"p\x00\x00\x00\x0Bhunter2",
+            dict(
+                type=pq3.types.PasswordMessage,
+                len=11,
+                payload=b"hunter2",
+            ),
+            b"",
+            id="PasswordMessage",
+        ),
+        pytest.param(
+            b"K\x00\x00\x00\x0C\x00\x00\x00\x00\x12\x34\x56\x78",
+            dict(
+                type=pq3.types.BackendKeyData,
+                len=12,
+                payload=dict(pid=0, key=0x12345678),
+            ),
+            b"",
+            id="BackendKeyData",
+        ),
+        pytest.param(
+            b"C\x00\x00\x00\x08SET\x00",
+            dict(
+                type=pq3.types.CommandComplete,
+                len=8,
+                payload=dict(tag=b"SET"),
+            ),
+            b"",
+            id="CommandComplete",
+        ),
+        pytest.param(
+            b"E\x00\x00\x00\x11Mbad!\x00Mdog!\x00\x00",
+            dict(type=b"E", len=17, payload=dict(fields=[b"Mbad!", b"Mdog!", b""])),
+            b"",
+            id="ErrorResponse",
+        ),
+        pytest.param(
+            b"S\x00\x00\x00\x08a\x00b\x00",
+            dict(
+                type=pq3.types.ParameterStatus,
+                len=8,
+                payload=dict(name=b"a", value=b"b"),
+            ),
+            b"",
+            id="ParameterStatus",
+        ),
+        pytest.param(
+            b"Z\x00\x00\x00\x05x",
+            dict(type=b"Z", len=5, payload=dict(status=b"x")),
+            b"",
+            id="ReadyForQuery",
+        ),
+        pytest.param(
+            b"Q\x00\x00\x00\x06!\x00",
+            dict(type=pq3.types.Query, len=6, payload=dict(query=b"!")),
+            b"",
+            id="Query",
+        ),
+        pytest.param(
+            b"D\x00\x00\x00\x0B\x00\x01\x00\x00\x00\x01!",
+            dict(type=pq3.types.DataRow, len=11, payload=dict(columns=[b"!"])),
+            b"",
+            id="DataRow",
+        ),
+        pytest.param(
+            b"D\x00\x00\x00\x06\x00\x00extra",
+            dict(type=pq3.types.DataRow, len=6, payload=dict(columns=[])),
+            b"extra",
+            id="DataRow with extra data",
+        ),
+        pytest.param(
+            b"I\x00\x00\x00\x04",
+            dict(type=pq3.types.EmptyQueryResponse, len=4, payload=None),
+            b"",
+            id="EmptyQueryResponse",
+        ),
+        pytest.param(
+            b"I\x00\x00\x00\x04\xFF",
+            dict(type=b"I", len=4, payload=None),
+            b"\xFF",
+            id="EmptyQueryResponse with extra bytes",
+        ),
+        pytest.param(
+            b"X\x00\x00\x00\x04",
+            dict(type=pq3.types.Terminate, len=4, payload=None),
+            b"",
+            id="Terminate",
+        ),
+    ],
+)
+def test_Pq3_parse(raw, expected, extra):
+    with io.BytesIO(raw) as stream:
+        actual = pq3.Pq3.parse_stream(stream)
+
+        assert actual == expected
+        assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+    "fields,expected",
+    [
+        pytest.param(
+            dict(type=b"*", len=5),
+            b"*\x00\x00\x00\x05\x00",
+            id="type and len set explicitly",
+        ),
+        pytest.param(
+            dict(type=b"*"),
+            b"*\x00\x00\x00\x04",
+            id="implied len with no payload",
+        ),
+        pytest.param(
+            dict(type=b"*", payload=b"1234"),
+            b"*\x00\x00\x00\x081234",
+            id="implied len with payload",
+        ),
+        pytest.param(
+            dict(type=pq3.types.AuthnRequest, payload=dict(type=pq3.authn.OK)),
+            b"R\x00\x00\x00\x08\x00\x00\x00\x00",
+            id="implied len/type for AuthenticationOK",
+        ),
+        pytest.param(
+            dict(
+                type=pq3.types.AuthnRequest,
+                payload=dict(
+                    type=pq3.authn.SASL,
+                    body=[b"SCRAM-SHA-256-PLUS", b"SCRAM-SHA-256", b""],
+                ),
+            ),
+            b"R\x00\x00\x00\x2A\x00\x00\x00\x0ASCRAM-SHA-256-PLUS\x00SCRAM-SHA-256\x00\x00",
+            id="implied len/type for AuthenticationSASL",
+        ),
+        pytest.param(
+            dict(
+                type=pq3.types.AuthnRequest,
+                payload=dict(type=pq3.authn.SASLContinue, body=b"12345"),
+            ),
+            b"R\x00\x00\x00\x0D\x00\x00\x00\x0B12345",
+            id="implied len/type for AuthenticationSASLContinue",
+        ),
+        pytest.param(
+            dict(
+                type=pq3.types.AuthnRequest,
+                payload=dict(type=pq3.authn.SASLFinal, body=b"12345"),
+            ),
+            b"R\x00\x00\x00\x0D\x00\x00\x00\x0C12345",
+            id="implied len/type for AuthenticationSASLFinal",
+        ),
+        pytest.param(
+            dict(
+                type=pq3.types.PasswordMessage,
+                payload=b"hunter2",
+            ),
+            b"p\x00\x00\x00\x0Bhunter2",
+            id="implied len/type for PasswordMessage",
+        ),
+        pytest.param(
+            dict(type=pq3.types.BackendKeyData, payload=dict(pid=1, key=7)),
+            b"K\x00\x00\x00\x0C\x00\x00\x00\x01\x00\x00\x00\x07",
+            id="implied len/type for BackendKeyData",
+        ),
+        pytest.param(
+            dict(type=pq3.types.CommandComplete, payload=dict(tag=b"SET")),
+            b"C\x00\x00\x00\x08SET\x00",
+            id="implied len/type for CommandComplete",
+        ),
+        pytest.param(
+            dict(type=pq3.types.ErrorResponse, payload=dict(fields=[b"error", b""])),
+            b"E\x00\x00\x00\x0Berror\x00\x00",
+            id="implied len/type for ErrorResponse",
+        ),
+        pytest.param(
+            dict(type=pq3.types.ParameterStatus, payload=dict(name=b"a", value=b"b")),
+            b"S\x00\x00\x00\x08a\x00b\x00",
+            id="implied len/type for ParameterStatus",
+        ),
+        pytest.param(
+            dict(type=pq3.types.ReadyForQuery, payload=dict(status=b"I")),
+            b"Z\x00\x00\x00\x05I",
+            id="implied len/type for ReadyForQuery",
+        ),
+        pytest.param(
+            dict(type=pq3.types.Query, payload=dict(query=b"SELECT 1;")),
+            b"Q\x00\x00\x00\x0eSELECT 1;\x00",
+            id="implied len/type for Query",
+        ),
+        pytest.param(
+            dict(type=pq3.types.DataRow, payload=dict(columns=[b"abcd"])),
+            b"D\x00\x00\x00\x0E\x00\x01\x00\x00\x00\x04abcd",
+            id="implied len/type for DataRow",
+        ),
+        pytest.param(
+            dict(type=pq3.types.EmptyQueryResponse),
+            b"I\x00\x00\x00\x04",
+            id="implied len for EmptyQueryResponse",
+        ),
+        pytest.param(
+            dict(type=pq3.types.Terminate),
+            b"X\x00\x00\x00\x04",
+            id="implied len for Terminate",
+        ),
+    ],
+)
+def test_Pq3_build(fields, expected):
+    actual = pq3.Pq3.build(fields)
+    assert actual == expected
+
+
+@pytest.mark.parametrize(
+    "raw,expected,extra",
+    [
+        pytest.param(
+            b"\x00\x00",
+            dict(columns=[]),
+            b"",
+            id="no columns",
+        ),
+        pytest.param(
+            b"\x00\x01\x00\x00\x00\x04abcd",
+            dict(columns=[b"abcd"]),
+            b"",
+            id="one column",
+        ),
+        pytest.param(
+            b"\x00\x02\x00\x00\x00\x04abcd\x00\x00\x00\x01x",
+            dict(columns=[b"abcd", b"x"]),
+            b"",
+            id="multiple columns",
+        ),
+        pytest.param(
+            b"\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01x",
+            dict(columns=[b"", b"x"]),
+            b"",
+            id="empty column value",
+        ),
+        pytest.param(
+            b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+            dict(columns=[None, None]),
+            b"",
+            id="null columns",
+        ),
+    ],
+)
+def test_DataRow_parse(raw, expected, extra):
+    pkt = b"D" + struct.pack("!i", len(raw) + 4) + raw
+    with io.BytesIO(pkt) as stream:
+        actual = pq3.Pq3.parse_stream(stream)
+
+        assert actual.type == pq3.types.DataRow
+        assert actual.payload == expected
+        assert stream.read() == extra
+
+
+@pytest.mark.parametrize(
+    "fields,expected",
+    [
+        pytest.param(
+            dict(),
+            b"\x00\x00",
+            id="no columns",
+        ),
+        pytest.param(
+            dict(columns=[None, None]),
+            b"\x00\x02\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
+            id="null columns",
+        ),
+    ],
+)
+def test_DataRow_build(fields, expected):
+    actual = pq3.Pq3.build(dict(type=pq3.types.DataRow, payload=fields))
+
+    expected = b"D" + struct.pack("!i", len(expected) + 4) + expected
+    assert actual == expected
+
+
+@pytest.mark.parametrize(
+    "raw,expected,exception",
+    [
+        pytest.param(
+            b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+            dict(name=b"EXTERNAL", len=-1, data=None),
+            None,
+            id="no initial response",
+        ),
+        pytest.param(
+            b"EXTERNAL\x00\x00\x00\x00\x02me",
+            dict(name=b"EXTERNAL", len=2, data=b"me"),
+            None,
+            id="initial response",
+        ),
+        pytest.param(
+            b"EXTERNAL\x00\x00\x00\x00\x02meextra",
+            None,
+            TerminatedError,
+            id="extra data",
+        ),
+        pytest.param(
+            b"EXTERNAL\x00\x00\x00\x00\xFFme",
+            None,
+            StreamError,
+            id="underflow",
+        ),
+    ],
+)
+def test_SASLInitialResponse_parse(raw, expected, exception):
+    ctx = contextlib.nullcontext()
+    if exception:
+        ctx = pytest.raises(exception)
+
+    with ctx:
+        actual = pq3.SASLInitialResponse.parse(raw)
+        assert actual == expected
+
+
+@pytest.mark.parametrize(
+    "fields,expected",
+    [
+        pytest.param(
+            dict(name=b"EXTERNAL"),
+            b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+            id="no initial response",
+        ),
+        pytest.param(
+            dict(name=b"EXTERNAL", data=None),
+            b"EXTERNAL\x00\xFF\xFF\xFF\xFF",
+            id="no initial response (explicit None)",
+        ),
+        pytest.param(
+            dict(name=b"EXTERNAL", data=b""),
+            b"EXTERNAL\x00\x00\x00\x00\x00",
+            id="empty response",
+        ),
+        pytest.param(
+            dict(name=b"EXTERNAL", data=b"me@example.com"),
+            b"EXTERNAL\x00\x00\x00\x00\x0Eme@example.com",
+            id="initial response",
+        ),
+        pytest.param(
+            dict(name=b"EXTERNAL", len=2, data=b"me@example.com"),
+            b"EXTERNAL\x00\x00\x00\x00\x02me@example.com",
+            id="data overflow",
+        ),
+        pytest.param(
+            dict(name=b"EXTERNAL", len=14, data=b"me"),
+            b"EXTERNAL\x00\x00\x00\x00\x0Eme",
+            id="data underflow",
+        ),
+    ],
+)
+def test_SASLInitialResponse_build(fields, expected):
+    actual = pq3.SASLInitialResponse.build(fields)
+    assert actual == expected
+
+
+@pytest.mark.parametrize(
+    "version,expected_bytes",
+    [
+        pytest.param((3, 0), b"\x00\x03\x00\x00", id="version 3"),
+        pytest.param((1234, 5679), b"\x04\xd2\x16\x2f", id="SSLRequest"),
+    ],
+)
+def test_protocol(version, expected_bytes):
+    # Make sure the integer returned by protocol is correctly serialized on the
+    # wire.
+    assert struct.pack("!i", pq3.protocol(*version)) == expected_bytes
+
+
+@pytest.mark.parametrize(
+    "envvar,func,expected",
+    [
+        ("PGHOST", pq3.pghost, "localhost"),
+        ("PGPORT", pq3.pgport, 5432),
+        ("PGUSER", pq3.pguser, getpass.getuser()),
+        ("PGDATABASE", pq3.pgdatabase, "postgres"),
+    ],
+)
+def test_env_defaults(monkeypatch, envvar, func, expected):
+    monkeypatch.delenv(envvar, raising=False)
+
+    actual = func()
+    assert actual == expected
+
+
+@pytest.mark.parametrize(
+    "envvars,func,expected",
+    [
+        (dict(PGHOST="otherhost"), pq3.pghost, "otherhost"),
+        (dict(PGPORT="6789"), pq3.pgport, 6789),
+        (dict(PGUSER="postgres"), pq3.pguser, "postgres"),
+        (dict(PGDATABASE="template1"), pq3.pgdatabase, "template1"),
+    ],
+)
+def test_env(monkeypatch, envvars, func, expected):
+    for k, v in envvars.items():
+        monkeypatch.setenv(k, v)
+
+    actual = func()
+    assert actual == expected
diff --git a/src/test/python/tls.py b/src/test/python/tls.py
new file mode 100644
index 0000000000..075c02c1ca
--- /dev/null
+++ b/src/test/python/tls.py
@@ -0,0 +1,195 @@
+#
+# Copyright 2021 VMware, Inc.
+# SPDX-License-Identifier: PostgreSQL
+#
+
+from construct import *
+
+#
+# TLS 1.3
+#
+# Most of the types below are transcribed from RFC 8446:
+#
+#     https://tools.ietf.org/html/rfc8446
+#
+
+
+def _Vector(size_field, element):
+    return Prefixed(size_field, GreedyRange(element))
+
+
+# Alerts
+
+AlertLevel = Enum(
+    Byte,
+    warning=1,
+    fatal=2,
+)
+
+AlertDescription = Enum(
+    Byte,
+    close_notify=0,
+    unexpected_message=10,
+    bad_record_mac=20,
+    decryption_failed_RESERVED=21,
+    record_overflow=22,
+    decompression_failure=30,
+    handshake_failure=40,
+    no_certificate_RESERVED=41,
+    bad_certificate=42,
+    unsupported_certificate=43,
+    certificate_revoked=44,
+    certificate_expired=45,
+    certificate_unknown=46,
+    illegal_parameter=47,
+    unknown_ca=48,
+    access_denied=49,
+    decode_error=50,
+    decrypt_error=51,
+    export_restriction_RESERVED=60,
+    protocol_version=70,
+    insufficient_security=71,
+    internal_error=80,
+    user_canceled=90,
+    no_renegotiation=100,
+    unsupported_extension=110,
+)
+
+Alert = Struct(
+    "level" / AlertLevel,
+    "description" / AlertDescription,
+)
+
+
+# Extensions
+
+ExtensionType = Enum(
+    Int16ub,
+    server_name=0,
+    max_fragment_length=1,
+    status_request=5,
+    supported_groups=10,
+    signature_algorithms=13,
+    use_srtp=14,
+    heartbeat=15,
+    application_layer_protocol_negotiation=16,
+    signed_certificate_timestamp=18,
+    client_certificate_type=19,
+    server_certificate_type=20,
+    padding=21,
+    pre_shared_key=41,
+    early_data=42,
+    supported_versions=43,
+    cookie=44,
+    psk_key_exchange_modes=45,
+    certificate_authorities=47,
+    oid_filters=48,
+    post_handshake_auth=49,
+    signature_algorithms_cert=50,
+    key_share=51,
+)
+
+Extension = Struct(
+    "extension_type" / ExtensionType,
+    "extension_data" / Prefixed(Int16ub, GreedyBytes),
+)
+
+
+# ClientHello
+
+
+class _CipherSuiteAdapter(Adapter):
+    class _hextuple(tuple):
+        def __repr__(self):
+            return f"(0x{self[0]:02X}, 0x{self[1]:02X})"
+
+    def _encode(self, obj, context, path):
+        return bytes(obj)
+
+    def _decode(self, obj, context, path):
+        assert len(obj) == 2
+        return self._hextuple(obj)
+
+
+ProtocolVersion = Hex(Int16ub)
+
+Random = Hex(Bytes(32))
+
+CipherSuite = _CipherSuiteAdapter(Byte[2])
+
+ClientHello = Struct(
+    "legacy_version" / ProtocolVersion,
+    "random" / Random,
+    "legacy_session_id" / Prefixed(Byte, Hex(GreedyBytes)),
+    "cipher_suites" / _Vector(Int16ub, CipherSuite),
+    "legacy_compression_methods" / Prefixed(Byte, GreedyBytes),
+    "extensions" / _Vector(Int16ub, Extension),
+)
+
+# ServerHello
+
+ServerHello = Struct(
+    "legacy_version" / ProtocolVersion,
+    "random" / Random,
+    "legacy_session_id_echo" / Prefixed(Byte, Hex(GreedyBytes)),
+    "cipher_suite" / CipherSuite,
+    "legacy_compression_method" / Hex(Byte),
+    "extensions" / _Vector(Int16ub, Extension),
+)
+
+# Handshake
+
+HandshakeType = Enum(
+    Byte,
+    client_hello=1,
+    server_hello=2,
+    new_session_ticket=4,
+    end_of_early_data=5,
+    encrypted_extensions=8,
+    certificate=11,
+    certificate_request=13,
+    certificate_verify=15,
+    finished=20,
+    key_update=24,
+    message_hash=254,
+)
+
+Handshake = Struct(
+    "msg_type" / HandshakeType,
+    "length" / Int24ub,
+    "payload"
+    / Switch(
+        this.msg_type,
+        {
+            HandshakeType.client_hello: ClientHello,
+            HandshakeType.server_hello: ServerHello,
+            # HandshakeType.end_of_early_data: EndOfEarlyData,
+            # HandshakeType.encrypted_extensions: EncryptedExtensions,
+            # HandshakeType.certificate_request: CertificateRequest,
+            # HandshakeType.certificate: Certificate,
+            # HandshakeType.certificate_verify: CertificateVerify,
+            # HandshakeType.finished: Finished,
+            # HandshakeType.new_session_ticket: NewSessionTicket,
+            # HandshakeType.key_update: KeyUpdate,
+        },
+        default=FixedSized(this.length, GreedyBytes),
+    ),
+)
+
+# Records
+
+ContentType = Enum(
+    Byte,
+    invalid=0,
+    change_cipher_spec=20,
+    alert=21,
+    handshake=22,
+    application_data=23,
+)
+
+Plaintext = Struct(
+    "type" / ContentType,
+    "legacy_record_version" / ProtocolVersion,
+    "length" / Int16ub,
+    "fragment" / FixedSized(this.length, GreedyBytes),
+)
-- 
2.25.1

