From edba1bcefa85e995ec3e9df6c9e8d30adcd940b9 Mon Sep 17 00:00:00 2001
From: Ilia Evdokimov <ilya.evdokimov@tantorlabs.com>
Date: Tue, 26 Nov 2024 00:58:24 +0300
Subject: [PATCH] Add time-based sampling to pg_stat_statements

New configuration parameter pg_stat_statements.sample_exectime_threshold,
which allows tracking only queries that meet a specified execution time threshold.
Queries with execution times above the threshold are always tracked,
while shorter queries are sampled probabilistically.
This helps reduce the overhead of tracking frequent, fast queries
while preserving data for longer-running ones.
---
 .../pg_stat_statements/pg_stat_statements.c   | 57 ++++++++++++++-----
 doc/src/sgml/pgstatstatements.sgml            | 20 +++++++
 2 files changed, 64 insertions(+), 13 deletions(-)

diff --git a/contrib/pg_stat_statements/pg_stat_statements.c b/contrib/pg_stat_statements/pg_stat_statements.c
index 49c657b3e0..6ec1e3e7ed 100644
--- a/contrib/pg_stat_statements/pg_stat_statements.c
+++ b/contrib/pg_stat_statements/pg_stat_statements.c
@@ -49,6 +49,7 @@
 
 #include "access/parallel.h"
 #include "catalog/pg_authid.h"
+#include "common/pg_prng.h"
 #include "common/int.h"
 #include "executor/instrument.h"
 #include "funcapi.h"
@@ -289,6 +290,7 @@ static const struct config_enum_entry track_options[] =
 };
 
 static int	pgss_max = 5000;	/* max # statements to track */
+static int	pgss_sample_exectime_threshold = 0; /* Threshold for query execution sampling (msec) */
 static int	pgss_track = PGSS_TRACK_TOP;	/* tracking level */
 static bool pgss_track_utility = true;	/* whether to track utility commands */
 static bool pgss_track_planning = false;	/* whether to track planning
@@ -414,6 +416,19 @@ _PG_init(void)
 							NULL,
 							NULL);
 
+	DefineCustomIntVariable("pg_stat_statements.sample_exectime_threshold",
+							"Sets the threshold (in msec) for query execution time sampling.",
+							NULL,
+							&pgss_sample_exectime_threshold,
+							0,
+							0,
+							INT_MAX / 2,
+							PGC_SUSET,
+							0,
+							NULL,
+							NULL,
+							NULL);
+
 	DefineCustomEnumVariable("pg_stat_statements.track",
 							 "Selects which statements are tracked by pg_stat_statements.",
 							 NULL,
@@ -1071,25 +1086,41 @@ pgss_ExecutorEnd(QueryDesc *queryDesc)
 	if (queryId != UINT64CONST(0) && queryDesc->totaltime &&
 		pgss_enabled(nesting_level))
 	{
+		double total_time = 0.0;
 		/*
 		 * Make sure stats accumulation is done.  (Note: it's okay if several
 		 * levels of hook all do this.)
 		 */
 		InstrEndLoop(queryDesc->totaltime);
 
-		pgss_store(queryDesc->sourceText,
-				   queryId,
-				   queryDesc->plannedstmt->stmt_location,
-				   queryDesc->plannedstmt->stmt_len,
-				   PGSS_EXEC,
-				   queryDesc->totaltime->total * 1000.0,	/* convert to msec */
-				   queryDesc->estate->es_total_processed,
-				   &queryDesc->totaltime->bufusage,
-				   &queryDesc->totaltime->walusage,
-				   queryDesc->estate->es_jit ? &queryDesc->estate->es_jit->instr : NULL,
-				   NULL,
-				   queryDesc->estate->es_parallel_workers_to_launch,
-				   queryDesc->estate->es_parallel_workers_launched);
+		/* convert to msec */
+		total_time = queryDesc->totaltime->total * 1000.0;
+
+		/*
+		* Sampling is implemented by comparing the execution time of the query
+		* to a random threshold derived from the configured sample_execution_time.
+		* This method ensures proportional sampling: queries with execution times
+		* closer to the threshold are more likely to be trackeded, while very short
+		* queries are more likely to be sampled.
+		*/
+		if (total_time >= pgss_sample_exectime_threshold ||
+		   (total_time < pgss_sample_exectime_threshold &&
+		   (pgss_sample_exectime_threshold * pg_prng_double(&pg_global_prng_state) < total_time)))
+		{
+			pgss_store(queryDesc->sourceText,
+				   queryId,
+				   queryDesc->plannedstmt->stmt_location,
+				   queryDesc->plannedstmt->stmt_len,
+				   PGSS_EXEC,
+				   total_time,
+				   queryDesc->estate->es_total_processed,
+				   &queryDesc->totaltime->bufusage,
+				   &queryDesc->totaltime->walusage,
+				   queryDesc->estate->es_jit ? &queryDesc->estate->es_jit->instr : NULL,
+				   NULL,
+				   queryDesc->estate->es_parallel_workers_to_launch,
+				   queryDesc->estate->es_parallel_workers_launched);
+		}
 	}
 
 	if (prev_ExecutorEnd)
diff --git a/doc/src/sgml/pgstatstatements.sgml b/doc/src/sgml/pgstatstatements.sgml
index 501b468e9a..7b998a77f3 100644
--- a/doc/src/sgml/pgstatstatements.sgml
+++ b/doc/src/sgml/pgstatstatements.sgml
@@ -872,6 +872,26 @@
     </listitem>
    </varlistentry>
 
+  <variablelist>
+   <varlistentry>
+    <term>
+     <varname>pg_stat_statements.sample_exectime_threshold</varname> (<type>integer</type>)
+     <indexterm>
+      <primary><varname>pg_stat_statements.sample_exectime_threshold</varname> configuration parameter</primary>
+     </indexterm>
+    </term>
+
+    <listitem>
+     <para>
+      <varname>pg_stat_statements.sample_exectime_threshold</varname> is the threshold (in msec)
+      for query execution time sampling. Queries with execution times above this threshold are
+      are never sampled. Execution time below threshold are sampled to reduce overhead from frequent short queries.
+      The default value is 0.
+      Only superusers can change this setting.
+     </para>
+    </listitem>
+   </varlistentry>
+
    <varlistentry>
     <term>
      <varname>pg_stat_statements.track</varname> (<type>enum</type>)
-- 
2.34.1

