From 4067bb1fc2973ce95ff4b82ace0f2d400b45a657 Mon Sep 17 00:00:00 2001
From: Jelte Fennema-Nio <postgres@jeltef.nl>
Date: Thu, 23 Oct 2025 14:31:52 +0200
Subject: [PATCH v4 2/2] Add pytest based tests for GoAway message

I used this patchset as a trial for the new pytest suite that Jacob is
trying to introduce. Feel free to look at it, but I'd say don't review
this test in detail until we have the pytest changes merged or at least
in a more agreed upon state. This patch is built on top of that v3
patchset. This test is not applied by cfbot.

Testing this any other way is actually quite difficult with the
infrastructure we currently have (of course I could change that, but I'd
much rather spend that energy/time on making the pytest test suite a
thing):
- pgregress and perl tests don't work because we need to call a new
  libpq function that is not exposed in psql (I guess I could expose it
  with some \goawayreceived command, but it doesn't seem very useful).
- libpq_pipeline cannot test this because it would need to restart
  the Postgres server and all it has
---
 src/interfaces/libpq/meson.build        |  1 +
 src/interfaces/libpq/pyt/test_goaway.py | 54 +++++++++++++++++++++++++
 src/test/pytest/libpq/_core.py          | 18 +++++++++
 src/test/pytest/pypg/fixtures.py        | 24 +++++++++++
 4 files changed, 97 insertions(+)
 create mode 100644 src/interfaces/libpq/pyt/test_goaway.py

diff --git a/src/interfaces/libpq/meson.build b/src/interfaces/libpq/meson.build
index 56790dd92a9..983af1d5bea 100644
--- a/src/interfaces/libpq/meson.build
+++ b/src/interfaces/libpq/meson.build
@@ -163,6 +163,7 @@ tests += {
   'pytest': {
     'tests': [
       'pyt/test_load_balance.py',
+      'pyt/test_goaway.py',
     ],
   },
 }
diff --git a/src/interfaces/libpq/pyt/test_goaway.py b/src/interfaces/libpq/pyt/test_goaway.py
new file mode 100644
index 00000000000..3941e5666af
--- /dev/null
+++ b/src/interfaces/libpq/pyt/test_goaway.py
@@ -0,0 +1,54 @@
+# Copyright (c) 2025, PostgreSQL Global Development Group
+
+"""
+Tests for the GoAway protocol message during smart shutdown.
+
+The GoAway message is sent by the server during smart shutdown to politely
+request that clients disconnect when convenient. The connection remains
+functional after receiving the message.
+"""
+
+
+def test_goaway_smart_shutdown(pg, wait_until):
+    """
+    Test that GoAway message is sent during smart shutdown.
+
+    This test:
+    1. Connects to a running PostgreSQL server via Unix socket
+    2. Verifies GoAway is not received initially
+    3. Initiates a smart shutdown
+    4. Verifies that GoAway is received
+    5. Verifies that queries still work after GoAway
+    6. Disconnects and verifies the server shuts down
+    """
+
+    # Connect to the server via Unix socket, libpq will request the
+    # _pq_.goaway protocol extension
+    conn = pg.connect(max_protocol_version="latest")
+
+    # Initially, GoAway should not be received
+    assert not conn.goaway_received(), "GoAway should not be received initially"
+
+    # Execute a simple query to ensure connection is working
+    conn.sql("SELECT 1")
+
+    pg.pg_ctl("stop", "--mode", "smart", "--no-wait")
+
+    for _ in wait_until("Did not receive GoAway after smart shutdown"):
+        # Consume any data the backend may have sent (like GoAway)
+        assert conn.consume_input()
+        if conn.goaway_received():
+            break
+
+    # Execute a query - this will trigger the backend to send GoAway before
+    # processing the query, and the client will parse it
+    conn.sql("SELECT 2")
+
+    # Check that GoAway was received
+    assert conn.goaway_received(), "GoAway should be received after smart shutdown"
+
+    # Connection should still be functional - try one more query
+    conn.sql("SELECT 3")
+
+    # Verify GoAway is still flagged
+    assert conn.goaway_received(), "GoAway flag should remain set"
diff --git a/src/test/pytest/libpq/_core.py b/src/test/pytest/libpq/_core.py
index 1c059b9b446..c137688a1aa 100644
--- a/src/test/pytest/libpq/_core.py
+++ b/src/test/pytest/libpq/_core.py
@@ -147,6 +147,12 @@ def load_libpq_handle(libdir, bindir):
     lib.PQresultErrorField.restype = ctypes.c_char_p
     lib.PQresultErrorField.argtypes = [_PGresult_p, ctypes.c_int]
 
+    lib.PQgoAwayReceived.restype = ctypes.c_int
+    lib.PQgoAwayReceived.argtypes = [_PGconn_p]
+
+    lib.PQconsumeInput.restype = ctypes.c_int
+    lib.PQconsumeInput.argtypes = [_PGconn_p]
+
     return lib
 
 
@@ -419,6 +425,18 @@ class PGconn(contextlib.AbstractContextManager):
         else:
             res.raise_error()
 
+    def consume_input(self) -> bool:
+        """
+        Consumes any available input from the server. Returns True on success.
+        """
+        return bool(self._lib.PQconsumeInput(self._handle))
+
+    def goaway_received(self) -> bool:
+        """
+        Returns True if a GoAway message was received from the server.
+        """
+        return bool(self._lib.PQgoAwayReceived(self._handle))
+
 
 def connstr(opts: Dict[str, Any]) -> str:
     """
diff --git a/src/test/pytest/pypg/fixtures.py b/src/test/pytest/pypg/fixtures.py
index 8c0cb60daa5..4aa6c73349d 100644
--- a/src/test/pytest/pypg/fixtures.py
+++ b/src/test/pytest/pypg/fixtures.py
@@ -57,6 +57,30 @@ def remaining_timeout_module():
     return lambda: max(deadline - time.monotonic(), 0)
 
 
+@pytest.fixture
+def wait_until(remaining_timeout):
+    def wait_until(error_message="Did not complete in time", timeout=None, interval=1):
+        """
+        Loop until the timeout is reached. If the timeout is reached, raise an
+        exception with the given error message.
+        """
+        if timeout is None:
+            timeout = remaining_timeout()
+
+        end = time.time() + timeout
+        print_progress = timeout / 10 > 4
+        last_printed_progress = 0
+        while time.time() < end:
+            if print_progress and time.time() - last_printed_progress > 4:
+                last_printed_progress = time.time()
+                print(f"{error_message} - will retry")
+            yield
+            time.sleep(interval)
+        raise TimeoutError(error_message)
+
+    return wait_until
+
+
 @pytest.fixture(scope="session")
 def libpq_handle(libdir, bindir):
     """
-- 
2.52.0

