From 76a9d3dde944a41d61bf7e448d1eaf7db5d55f39 Mon Sep 17 00:00:00 2001
From: Sami Imseih <simseih@amazon.com>
Date: Mon, 26 May 2025 22:11:46 -0500
Subject: [PATCH v9 2/2] Fix Normalization for squashed query texts

62d712ec added the ability to squash constants from an
IN list/ArrayExpr for queryId computation purposes. However,
in certain cases, this broke normalization. For example,
"IN (1, 2, int4(1))" is normalized to "IN ($2 /*, ... */))",
which leaves an extra parenthesis at the end of the normalized string.

To correct this, the start and end boundaries of an expr_list are
now tracked by the various nodes used during parsing and are made
available to the ArrayExpr node for query jumbling. Having these
boundaries allows normalization to precisely identify the locations
in the query text that should be squashed.
---
 .../pg_stat_statements/expected/select.out    |  30 ++++
 .../pg_stat_statements/expected/squashing.out |  22 +--
 .../pg_stat_statements/pg_stat_statements.c   |  83 +++-------
 contrib/pg_stat_statements/sql/select.sql     |   8 +
 contrib/pg_stat_statements/sql/squashing.sql  |   2 +-
 src/backend/nodes/gen_node_support.pl         |   2 +-
 src/backend/nodes/queryjumblefuncs.c          | 153 ++++++++++--------
 src/backend/parser/gram.y                     | 102 ++++++------
 src/backend/parser/parse_expr.c               |   4 +
 src/include/nodes/parsenodes.h                |  10 ++
 src/include/nodes/primnodes.h                 |   4 +
 11 files changed, 222 insertions(+), 198 deletions(-)

diff --git a/contrib/pg_stat_statements/expected/select.out b/contrib/pg_stat_statements/expected/select.out
index 038ae110364..a57e11ef5b1 100644
--- a/contrib/pg_stat_statements/expected/select.out
+++ b/contrib/pg_stat_statements/expected/select.out
@@ -267,6 +267,36 @@ SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
  SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C" |     0
 (4 rows)
 
+-- with the last element being an explicit function call with an argument, ensure
+-- the normalization of the squashing interval is correct.
+SELECT pg_stat_statements_reset() IS NOT NULL AS t;
+ t 
+---
+ t
+(1 row)
+
+SELECT pg_stat_statements_reset() IS NOT NULL AS t;
+ t 
+---
+ t
+(1 row)
+
+SELECT WHERE 1 IN (1, int4(1), int4(2));
+--
+(1 row)
+
+SELECT WHERE 1 = ANY (ARRAY[1, int4(1), int4(2)]);
+--
+(1 row)
+
+SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
+                                 query                                  | calls 
+------------------------------------------------------------------------+-------
+ SELECT WHERE $1 IN ($2 /*, ... */)                                     |     2
+ SELECT pg_stat_statements_reset() IS NOT NULL AS t                     |     1
+ SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C" |     0
+(3 rows)
+
 --
 -- queries with locking clauses
 --
diff --git a/contrib/pg_stat_statements/expected/squashing.out b/contrib/pg_stat_statements/expected/squashing.out
index b8724e3356c..9cbd308ec43 100644
--- a/contrib/pg_stat_statements/expected/squashing.out
+++ b/contrib/pg_stat_statements/expected/squashing.out
@@ -453,7 +453,7 @@ SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
                        query                        | calls 
 ----------------------------------------------------+-------
  SELECT * FROM test_squash_bigint WHERE data IN    +|     2
-         ($1 /*, ... */::bigint)                    | 
+         ($1 /*, ... */)                            | 
  SELECT pg_stat_statements_reset() IS NOT NULL AS t |     1
 (2 rows)
 
@@ -503,10 +503,10 @@ SELECT WHERE 1 = ANY(ARRAY[1::int::bigint::int, 2::int::bigint::int]);
 (1 row)
 
 SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
-                              query                              | calls 
------------------------------------------------------------------+-------
- SELECT WHERE $1 IN ($2::int::bigint::int, $3::int::bigint::int) |     2
- SELECT pg_stat_statements_reset() IS NOT NULL AS t              |     1
+                       query                        | calls 
+----------------------------------------------------+-------
+ SELECT WHERE $1 IN ($2 /*, ... */)                 |     2
+ SELECT pg_stat_statements_reset() IS NOT NULL AS t |     1
 (2 rows)
 
 --
@@ -572,7 +572,7 @@ SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
                        query                        | calls 
 ----------------------------------------------------+-------
  SELECT * FROM test_squash_cast WHERE data IN      +|     2
-         ($1 /*, ... */::int4::casttesttype)        | 
+         ($1 /*, ... */)                            | 
  SELECT pg_stat_statements_reset() IS NOT NULL AS t |     1
 (2 rows)
 
@@ -604,7 +604,7 @@ SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
                        query                        | calls 
 ----------------------------------------------------+-------
  SELECT * FROM test_squash_jsonb WHERE data IN     +|     2
-         (($1 /*, ... */)::jsonb)                   | 
+         ($1 /*, ... */)                            | 
  SELECT pg_stat_statements_reset() IS NOT NULL AS t |     1
 (2 rows)
 
@@ -704,9 +704,9 @@ SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
                                query                                | calls 
 --------------------------------------------------------------------+-------
  SELECT * FROM test_squash WHERE id IN                             +|     1
-         ($1 /*, ... */::oid)                                       | 
+         ($1 /*, ... */)                                            | 
  SELECT * FROM test_squash WHERE id IN ($1::oid, $2::oid::int::oid) |     2
- SELECT ARRAY[$1 /*, ... */::oid]                                   |     1
+ SELECT ARRAY[$1 /*, ... */]                                        |     1
  SELECT pg_stat_statements_reset() IS NOT NULL AS t                 |     1
 (4 rows)
 
@@ -773,7 +773,7 @@ SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
                        query                        | calls 
 ----------------------------------------------------+-------
  SELECT pg_stat_statements_reset() IS NOT NULL AS t |     1
- select where $1 IN ($2 /*, ... */::int)            |     1
+ select where $1 IN ($2 /*, ... */)                 |     1
  select where $1 IN ($2::int, $3::int::text)        |     1
 (3 rows)
 
@@ -797,7 +797,7 @@ SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
                        query                        | calls 
 ----------------------------------------------------+-------
  SELECT pg_stat_statements_reset() IS NOT NULL AS t |     1
- select where $1 IN ($2 /*, ... */::int::text)      |     2
+ select where $1 IN ($2 /*, ... */)                 |     2
 (2 rows)
 
 --
diff --git a/contrib/pg_stat_statements/pg_stat_statements.c b/contrib/pg_stat_statements/pg_stat_statements.c
index 129001c70c8..ecc7f2fb266 100644
--- a/contrib/pg_stat_statements/pg_stat_statements.c
+++ b/contrib/pg_stat_statements/pg_stat_statements.c
@@ -2810,14 +2810,12 @@ generate_normalized_query(JumbleState *jstate, const char *query,
 {
 	char	   *norm_query;
 	int			query_len = *query_len_p;
-	int			i,
-				norm_query_buflen,	/* Space allowed for norm_query */
+	int			norm_query_buflen,	/* Space allowed for norm_query */
 				len_to_wrt,		/* Length (in bytes) to write */
 				quer_loc = 0,	/* Source query byte location */
 				n_quer_loc = 0, /* Normalized query byte location */
 				last_off = 0,	/* Offset from start for previous tok */
 				last_tok_len = 0;	/* Length (in bytes) of that tok */
-	bool		in_squashed = false;	/* in a run of squashed consts? */
 	int			num_constants_replaced = 0;
 
 	/*
@@ -2832,16 +2830,13 @@ generate_normalized_query(JumbleState *jstate, const char *query,
 	 * certainly isn't more than 11 bytes, even if n reaches INT_MAX.  We
 	 * could refine that limit based on the max value of n for the current
 	 * query, but it hardly seems worth any extra effort to do so.
-	 *
-	 * Note this also gives enough room for the commented-out ", ..." list
-	 * syntax used by constant squashing.
 	 */
 	norm_query_buflen = query_len + jstate->clocations_count * 10;
 
 	/* Allocate result buffer */
 	norm_query = palloc(norm_query_buflen + 1);
 
-	for (i = 0; i < jstate->clocations_count; i++)
+	for (int i = 0; i < jstate->clocations_count; i++)
 	{
 		int			off,		/* Offset from start for cur tok */
 					tok_len;	/* Length (in bytes) of that tok */
@@ -2856,65 +2851,24 @@ generate_normalized_query(JumbleState *jstate, const char *query,
 		if (tok_len < 0)
 			continue;			/* ignore any duplicates */
 
+		/* Copy next chunk (what precedes the next constant) */
+		len_to_wrt = off - last_off;
+		len_to_wrt -= last_tok_len;
+		Assert(len_to_wrt >= 0);
+		memcpy(norm_query + n_quer_loc, query + quer_loc, len_to_wrt);
+		n_quer_loc += len_to_wrt;
+
 		/*
-		 * What to do next depends on whether we're squashing constant lists,
-		 * and whether we're already in a run of such constants.
+		 * And insert a param symbol in place of the constant token; and, if
+		 * we have a squashable list, insert a placeholder comment starting
+		 * from the list's second value.
 		 */
-		if (!jstate->clocations[i].squashed)
-		{
-			/*
-			 * This location corresponds to a constant not to be squashed.
-			 * Print what comes before the constant ...
-			 */
-			len_to_wrt = off - last_off;
-			len_to_wrt -= last_tok_len;
+		n_quer_loc += sprintf(norm_query + n_quer_loc, "$%d%s",
+							  num_constants_replaced + 1 + jstate->highest_extern_param_id,
+							  jstate->clocations[i].squashed ? " /*, ... */" : "");
+		num_constants_replaced++;
 
-			Assert(len_to_wrt >= 0);
-
-			memcpy(norm_query + n_quer_loc, query + quer_loc, len_to_wrt);
-			n_quer_loc += len_to_wrt;
-
-			/* ... and then a param symbol replacing the constant itself */
-			n_quer_loc += sprintf(norm_query + n_quer_loc, "$%d",
-								  num_constants_replaced++ + 1 + jstate->highest_extern_param_id);
-
-			/* In case previous constants were merged away, stop doing that */
-			in_squashed = false;
-		}
-		else if (!in_squashed)
-		{
-			/*
-			 * This location is the start position of a run of constants to be
-			 * squashed, so we need to print the representation of starting a
-			 * group of stashed constants.
-			 *
-			 * Print what comes before the constant ...
-			 */
-			len_to_wrt = off - last_off;
-			len_to_wrt -= last_tok_len;
-			Assert(len_to_wrt >= 0);
-			Assert(i + 1 < jstate->clocations_count);
-			Assert(jstate->clocations[i + 1].squashed);
-			memcpy(norm_query + n_quer_loc, query + quer_loc, len_to_wrt);
-			n_quer_loc += len_to_wrt;
-
-			/* ... and then start a run of squashed constants */
-			n_quer_loc += sprintf(norm_query + n_quer_loc, "$%d /*, ... */",
-								  num_constants_replaced++ + 1 + jstate->highest_extern_param_id);
-
-			/* The next location will match the block below, to end the run */
-			in_squashed = true;
-		}
-		else
-		{
-			/*
-			 * The second location of a run of squashable elements; this
-			 * indicates its end.
-			 */
-			in_squashed = false;
-		}
-
-		/* Otherwise the constant is squashed away -- move forward */
+		/* move forward */
 		quer_loc = off + tok_len;
 		last_off = off;
 		last_tok_len = tok_len;
@@ -3005,6 +2959,9 @@ fill_in_constant_lengths(JumbleState *jstate, const char *query,
 
 		Assert(loc >= 0);
 
+		if (locs[i].squashed)
+			continue;			/* squashable list, ignore */
+
 		if (loc <= last_loc)
 			continue;			/* Duplicate constant, ignore */
 
diff --git a/contrib/pg_stat_statements/sql/select.sql b/contrib/pg_stat_statements/sql/select.sql
index 189d405512f..11662cde08c 100644
--- a/contrib/pg_stat_statements/sql/select.sql
+++ b/contrib/pg_stat_statements/sql/select.sql
@@ -87,6 +87,14 @@ SELECT WHERE (1, 2) IN ((1, 2), (2, 3));
 SELECT WHERE (3, 4) IN ((5, 6), (8, 7));
 SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
 
+-- with the last element being an explicit function call with an argument, ensure
+-- the normalization of the squashing interval is correct.
+SELECT pg_stat_statements_reset() IS NOT NULL AS t;
+SELECT pg_stat_statements_reset() IS NOT NULL AS t;
+SELECT WHERE 1 IN (1, int4(1), int4(2));
+SELECT WHERE 1 = ANY (ARRAY[1, int4(1), int4(2)]);
+SELECT query, calls FROM pg_stat_statements ORDER BY query COLLATE "C";
+
 --
 -- queries with locking clauses
 --
diff --git a/contrib/pg_stat_statements/sql/squashing.sql b/contrib/pg_stat_statements/sql/squashing.sql
index 85aae152da8..ba0a72b6529 100644
--- a/contrib/pg_stat_statements/sql/squashing.sql
+++ b/contrib/pg_stat_statements/sql/squashing.sql
@@ -298,4 +298,4 @@ DROP TABLE test_float;
 DROP TABLE test_squash_numeric;
 DROP TABLE test_squash_bigint;
 DROP TABLE test_squash_cast CASCADE;
-DROP TABLE test_squash_jsonb;
\ No newline at end of file
+DROP TABLE test_squash_jsonb;
diff --git a/src/backend/nodes/gen_node_support.pl b/src/backend/nodes/gen_node_support.pl
index c8595109b0e..9ecddb14231 100644
--- a/src/backend/nodes/gen_node_support.pl
+++ b/src/backend/nodes/gen_node_support.pl
@@ -1329,7 +1329,7 @@ _jumble${n}(JumbleState *jstate, Node *node)
 			# Node type.  Squash constants if requested.
 			if ($query_jumble_squash)
 			{
-				print $jff "\tJUMBLE_ELEMENTS($f);\n"
+				print $jff "\tJUMBLE_ELEMENTS($f, node);\n"
 				  unless $query_jumble_ignore;
 			}
 			else
diff --git a/src/backend/nodes/queryjumblefuncs.c b/src/backend/nodes/queryjumblefuncs.c
index ac3cb3d9caf..fb33e6931ad 100644
--- a/src/backend/nodes/queryjumblefuncs.c
+++ b/src/backend/nodes/queryjumblefuncs.c
@@ -61,9 +61,9 @@ static void AppendJumble(JumbleState *jstate,
 						 const unsigned char *value, Size size);
 static void FlushPendingNulls(JumbleState *jstate);
 static void RecordConstLocation(JumbleState *jstate,
-								int location, bool squashed);
+								int location, int len);
 static void _jumbleNode(JumbleState *jstate, Node *node);
-static void _jumbleElements(JumbleState *jstate, List *elements);
+static void _jumbleElements(JumbleState *jstate, List *elements, Node *node);
 static void _jumbleA_Const(JumbleState *jstate, Node *node);
 static void _jumbleList(JumbleState *jstate, Node *node);
 static void _jumbleVariableSetStmt(JumbleState *jstate, Node *node);
@@ -373,15 +373,17 @@ FlushPendingNulls(JumbleState *jstate)
 
 
 /*
- * Record location of constant within query string of query tree that is
- * currently being walked.
+ * Record the location of some kind of constant within a query string.
+ * These are not only bare constants but also expressions that ultimately
+ * constitute a constant, such as those inside casts and simple function
+ * calls.
  *
- * 'squashed' signals that the constant represents the first or the last
- * element in a series of merged constants, and everything but the first/last
- * element contributes nothing to the jumble hash.
+ * If length is -1, it indicates a single such constant element.  If
+ * it's a positive integer, it indicates the length of a squashable
+ * list of them.
  */
 static void
-RecordConstLocation(JumbleState *jstate, int location, bool squashed)
+RecordConstLocation(JumbleState *jstate, int location, int len)
 {
 	/* -1 indicates unknown or undefined location */
 	if (location >= 0)
@@ -396,9 +398,14 @@ RecordConstLocation(JumbleState *jstate, int location, bool squashed)
 						 sizeof(LocationLen));
 		}
 		jstate->clocations[jstate->clocations_count].location = location;
-		/* initialize lengths to -1 to simplify third-party module usage */
-		jstate->clocations[jstate->clocations_count].squashed = squashed;
-		jstate->clocations[jstate->clocations_count].length = -1;
+
+		/*
+		 * Lengths are either positive integers (indicating a squashable
+		 * list), or -1.
+		 */
+		Assert(len > -1 || len == -1);
+		jstate->clocations[jstate->clocations_count].length = len;
+		jstate->clocations[jstate->clocations_count].squashed = (len > -1);
 		jstate->clocations_count++;
 	}
 }
@@ -408,12 +415,12 @@ RecordConstLocation(JumbleState *jstate, int location, bool squashed)
  * deduce that the expression is a constant:
  *
  * - Ignore a possible wrapping RelabelType and CoerceViaIO.
- * - If it's a FuncExpr, check that the function is an implicit
+ * - If it's a FuncExpr, check that the function is a builtin
  *   cast and its arguments are Const.
  * - Otherwise test if the expression is a simple Const.
  */
 static bool
-IsSquashableConst(Node *element)
+IsSquashableConstant(Node *element)
 {
 	if (IsA(element, RelabelType))
 		element = (Node *) ((RelabelType *) element)->arg;
@@ -421,32 +428,50 @@ IsSquashableConst(Node *element)
 	if (IsA(element, CoerceViaIO))
 		element = (Node *) ((CoerceViaIO *) element)->arg;
 
-	if (IsA(element, FuncExpr))
+	switch (nodeTag(element))
 	{
-		FuncExpr   *func = (FuncExpr *) element;
-		ListCell   *temp;
+		case T_FuncExpr:
+			{
+				FuncExpr   *func = (FuncExpr *) element;
+				ListCell   *temp;
 
-		if (func->funcformat != COERCE_IMPLICIT_CAST &&
-			func->funcformat != COERCE_EXPLICIT_CAST)
-			return false;
+				if (func->funcformat != COERCE_IMPLICIT_CAST &&
+					func->funcformat != COERCE_EXPLICIT_CAST)
+					return false;
 
-		if (func->funcid > FirstGenbkiObjectId)
-			return false;
+				if (func->funcid > FirstGenbkiObjectId)
+					return false;
 
-		foreach(temp, func->args)
-		{
-			Node	   *arg = lfirst(temp);
+				/*
+				 * We can check function arguments recursively, being careful
+				 * about recursing too deep.  At each recursion level it's
+				 * enough to test the stack on the first element.  (Note that
+				 * I wasn't able to hit this without bloating the stack
+				 * artificially in this function: the parser errors out before
+				 * stack size becomes a problem here.)
+				 */
+				foreach(temp, func->args)
+				{
+					Node	   *arg = lfirst(temp);
 
-			if (!IsA(arg, Const))	/* XXX we could recurse here instead */
+					if (!IsA(arg, Const))
+					{
+						if (foreach_current_index(temp) == 0 &&
+							stack_is_too_deep())
+							return false;
+						else if (!IsSquashableConstant(arg))
+							return false;
+					}
+				}
+
+				return true;
+			}
+
+		default:
+			if (!IsA(element, Const))
 				return false;
-		}
-
-		return true;
 	}
 
-	if (!IsA(element, Const))
-		return false;
-
 	return true;
 }
 
@@ -461,35 +486,29 @@ IsSquashableConst(Node *element)
  * expressions.
  */
 static bool
-IsSquashableConstList(List *elements, Node **firstExpr, Node **lastExpr)
+IsSquashableConstantList(List *elements)
 {
 	ListCell   *temp;
 
-	/*
-	 * If squashing is disabled, or the list is too short, we don't try to
-	 * squash it.
-	 */
+	/* If the list is too short, we don't try to squash it. */
 	if (list_length(elements) < 2)
 		return false;
 
 	foreach(temp, elements)
 	{
-		if (!IsSquashableConst(lfirst(temp)))
+		if (!IsSquashableConstant(lfirst(temp)))
 			return false;
 	}
 
-	*firstExpr = linitial(elements);
-	*lastExpr = llast(elements);
-
 	return true;
 }
 
 #define JUMBLE_NODE(item) \
 	_jumbleNode(jstate, (Node *) expr->item)
-#define JUMBLE_ELEMENTS(list) \
-	_jumbleElements(jstate, (List *) expr->list)
+#define JUMBLE_ELEMENTS(list, node) \
+	_jumbleElements(jstate, (List *) expr->list, node)
 #define JUMBLE_LOCATION(location) \
-	RecordConstLocation(jstate, expr->location, false)
+	RecordConstLocation(jstate, expr->location, -1)
 #define JUMBLE_FIELD(item) \
 do { \
 	if (sizeof(expr->item) == 8) \
@@ -517,36 +536,36 @@ do { \
 #include "queryjumblefuncs.funcs.c"
 
 /*
- * We jumble lists of constant elements as one individual item regardless
- * of how many elements are in the list.  This means different queries
- * jumble to the same query_id, if the only difference is the number of
- * elements in the list.
+ * We try to jumble lists of expressions as one individual item regardless
+ * of how many elements are in the list. This is know as squashing, which
+ * results in different queries jumbling to the same query_id, if the only
+ * difference is the number of elements in the list.
+ *
+ * We allow constants to be squashed. To normalize such queries, we use
+ * the start and end locations of the list of elements in a list.
  */
 static void
-_jumbleElements(JumbleState *jstate, List *elements)
+_jumbleElements(JumbleState *jstate, List *elements, Node *node)
 {
-	Node	   *first,
-			   *last;
+	bool		normalize_list = false;
 
-	if (IsSquashableConstList(elements, &first, &last))
+	if (IsSquashableConstantList(elements))
 	{
-		/*
-		 * If this list of elements is squashable, keep track of the location
-		 * of its first and last elements.  When reading back the locations
-		 * array, we'll see two consecutive locations with ->squashed set to
-		 * true, indicating the location of initial and final elements of this
-		 * list.
-		 *
-		 * For the limited set of cases we support now (implicit coerce via
-		 * FuncExpr, Const) it's fine to use exprLocation of the 'last'
-		 * expression, but if more complex composite expressions are to be
-		 * supported (e.g., OpExpr or FuncExpr as an explicit call), more
-		 * sophisticated tracking will be needed.
-		 */
-		RecordConstLocation(jstate, exprLocation(first), true);
-		RecordConstLocation(jstate, exprLocation(last), true);
+		if (IsA(node, ArrayExpr))
+		{
+			ArrayExpr  *aexpr = (ArrayExpr *) node;
+
+			if (aexpr->list_start > 0 && aexpr->list_end > 0)
+			{
+				RecordConstLocation(jstate,
+									aexpr->list_start + 1,
+									(aexpr->list_end - aexpr->list_start) - 1);
+				normalize_list = true;
+			}
+		}
 	}
-	else
+
+	if (!normalize_list)
 	{
 		_jumbleNode(jstate, (Node *) elements);
 	}
diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index 0b5652071d1..41f03e77149 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -184,7 +184,7 @@ static void doNegateFloat(Float *v);
 static Node *makeAndExpr(Node *lexpr, Node *rexpr, int location);
 static Node *makeOrExpr(Node *lexpr, Node *rexpr, int location);
 static Node *makeNotExpr(Node *expr, int location);
-static Node *makeAArrayExpr(List *elements, int location);
+static Node *makeAArrayExpr(List *elements, int location, int end_location);
 static Node *makeSQLValueFunction(SQLValueFunctionOp op, int32 typmod,
 								  int location);
 static Node *makeXmlExpr(XmlExprOp op, char *name, List *named_args,
@@ -523,7 +523,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query);
 %type <defelt>	def_elem reloption_elem old_aggr_elem operator_def_elem
 %type <node>	def_arg columnElem where_clause where_or_current_clause
 				a_expr b_expr c_expr AexprConst indirection_el opt_slice_bound
-				columnref in_expr having_clause func_table xmltable array_expr
+				columnref having_clause func_table xmltable array_expr
 				OptWhereClause operator_def_arg
 %type <list>	opt_column_and_period_list
 %type <list>	rowsfrom_item rowsfrom_list opt_col_def_list
@@ -15287,49 +15287,50 @@ a_expr:		c_expr									{ $$ = $1; }
 												   (Node *) list_make2($5, $7),
 												   @2);
 				}
-			| a_expr IN_P in_expr
+			| a_expr IN_P select_with_parens
 				{
-					/* in_expr returns a SubLink or a list of a_exprs */
-					if (IsA($3, SubLink))
-					{
-						/* generate foo = ANY (subquery) */
-						SubLink	   *n = (SubLink *) $3;
+					/* generate foo = ANY (subquery) */
+					SubLink	   *n = makeNode(SubLink);
 
-						n->subLinkType = ANY_SUBLINK;
-						n->subLinkId = 0;
-						n->testexpr = $1;
-						n->operName = NIL;		/* show it's IN not = ANY */
-						n->location = @2;
-						$$ = (Node *) n;
-					}
-					else
-					{
-						/* generate scalar IN expression */
-						$$ = (Node *) makeSimpleA_Expr(AEXPR_IN, "=", $1, $3, @2);
-					}
+					n->subselect = $3;
+					n->subLinkType = ANY_SUBLINK;
+					n->subLinkId = 0;
+					n->testexpr = $1;
+					n->operName = NIL;		/* show it's IN not = ANY */
+					n->location = @2;
+					$$ = (Node *) n;
 				}
-			| a_expr NOT_LA IN_P in_expr						%prec NOT_LA
+			| a_expr IN_P '(' expr_list ')'
 				{
-					/* in_expr returns a SubLink or a list of a_exprs */
-					if (IsA($4, SubLink))
-					{
-						/* generate NOT (foo = ANY (subquery)) */
-						/* Make an = ANY node */
-						SubLink	   *n = (SubLink *) $4;
+					/* generate scalar IN expression */
+					A_Expr *n = makeSimpleA_Expr(AEXPR_IN, "=", $1, (Node *) $4, @2);
 
-						n->subLinkType = ANY_SUBLINK;
-						n->subLinkId = 0;
-						n->testexpr = $1;
-						n->operName = NIL;		/* show it's IN not = ANY */
-						n->location = @2;
-						/* Stick a NOT on top; must have same parse location */
-						$$ = makeNotExpr((Node *) n, @2);
-					}
-					else
-					{
-						/* generate scalar NOT IN expression */
-						$$ = (Node *) makeSimpleA_Expr(AEXPR_IN, "<>", $1, $4, @2);
-					}
+					n->rexpr_list_start = @3;
+					n->rexpr_list_end = @5;
+					$$ = (Node *) n;
+				}
+			| a_expr NOT_LA IN_P select_with_parens			%prec NOT_LA
+				{
+					/* generate NOT (foo = ANY (subquery)) */
+					SubLink	   *n = makeNode(SubLink);
+
+					n->subselect = $4;
+					n->subLinkType = ANY_SUBLINK;
+					n->subLinkId = 0;
+					n->testexpr = $1;
+					n->operName = NIL;		/* show it's IN not = ANY */
+					n->location = @2;
+					/* Stick a NOT on top; must have same parse location */
+					$$ = makeNotExpr((Node *) n, @2);
+				}
+			| a_expr NOT_LA IN_P '(' expr_list ')'
+				{
+					/* generate scalar NOT IN expression */
+					A_Expr *n = makeSimpleA_Expr(AEXPR_IN, "<>", $1, (Node *) $5, @2);
+
+					n->rexpr_list_start = @4;
+					n->rexpr_list_end = @6;
+					$$ = (Node *) n;
 				}
 			| a_expr subquery_Op sub_type select_with_parens	%prec Op
 				{
@@ -16764,15 +16765,15 @@ type_list:	Typename								{ $$ = list_make1($1); }
 
 array_expr: '[' expr_list ']'
 				{
-					$$ = makeAArrayExpr($2, @1);
+					$$ = makeAArrayExpr($2, @1, @3);
 				}
 			| '[' array_expr_list ']'
 				{
-					$$ = makeAArrayExpr($2, @1);
+					$$ = makeAArrayExpr($2, @1, @3);
 				}
 			| '[' ']'
 				{
-					$$ = makeAArrayExpr(NIL, @1);
+					$$ = makeAArrayExpr(NIL, @1, @2);
 				}
 		;
 
@@ -16894,17 +16895,6 @@ trim_list:	a_expr FROM expr_list					{ $$ = lappend($3, $1); }
 			| expr_list								{ $$ = $1; }
 		;
 
-in_expr:	select_with_parens
-				{
-					SubLink	   *n = makeNode(SubLink);
-
-					n->subselect = $1;
-					/* other fields will be filled later */
-					$$ = (Node *) n;
-				}
-			| '(' expr_list ')'						{ $$ = (Node *) $2; }
-		;
-
 /*
  * Define SQL-style CASE clause.
  * - Full specification
@@ -19300,12 +19290,14 @@ makeNotExpr(Node *expr, int location)
 }
 
 static Node *
-makeAArrayExpr(List *elements, int location)
+makeAArrayExpr(List *elements, int location, int location_end)
 {
 	A_ArrayExpr *n = makeNode(A_ArrayExpr);
 
 	n->elements = elements;
 	n->location = location;
+	n->list_start = location;
+	n->list_end = location_end;
 	return (Node *) n;
 }
 
diff --git a/src/backend/parser/parse_expr.c b/src/backend/parser/parse_expr.c
index 1f8e2d54673..d66276801c6 100644
--- a/src/backend/parser/parse_expr.c
+++ b/src/backend/parser/parse_expr.c
@@ -1223,6 +1223,8 @@ transformAExprIn(ParseState *pstate, A_Expr *a)
 			newa->element_typeid = scalar_type;
 			newa->elements = aexprs;
 			newa->multidims = false;
+			newa->list_start = a->rexpr_list_start;
+			newa->list_end = a->rexpr_list_end;
 			newa->location = -1;
 
 			result = (Node *) make_scalar_array_op(pstate,
@@ -2165,6 +2167,8 @@ transformArrayExpr(ParseState *pstate, A_ArrayExpr *a,
 	/* array_collid will be set by parse_collate.c */
 	newa->element_typeid = element_type;
 	newa->elements = newcoercedelems;
+	newa->list_start = a->list_start;
+	newa->list_end = a->list_end;
 	newa->location = a->location;
 
 	return (Node *) newa;
diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h
index dd00ab420b8..71a9768fe2f 100644
--- a/src/include/nodes/parsenodes.h
+++ b/src/include/nodes/parsenodes.h
@@ -351,6 +351,14 @@ typedef struct A_Expr
 	List	   *name;			/* possibly-qualified name of operator */
 	Node	   *lexpr;			/* left argument, or NULL if none */
 	Node	   *rexpr;			/* right argument, or NULL if none */
+
+	/*
+	 * If rexpr is a list of some kind, we separately track its starting and
+	 * ending location; it's not the same as the starting and ending location
+	 * of the token itself.
+	 */
+	ParseLoc	rexpr_list_start;
+	ParseLoc	rexpr_list_end;
 	ParseLoc	location;		/* token location, or -1 if unknown */
 } A_Expr;
 
@@ -506,6 +514,8 @@ typedef struct A_ArrayExpr
 {
 	NodeTag		type;
 	List	   *elements;		/* array element expressions */
+	ParseLoc	list_start;		/* start of the element list */
+	ParseLoc	list_end;		/* end of the elements list */
 	ParseLoc	location;		/* token location, or -1 if unknown */
 } A_ArrayExpr;
 
diff --git a/src/include/nodes/primnodes.h b/src/include/nodes/primnodes.h
index 7d3b4198f26..01510b01b64 100644
--- a/src/include/nodes/primnodes.h
+++ b/src/include/nodes/primnodes.h
@@ -1397,6 +1397,10 @@ typedef struct ArrayExpr
 	List	   *elements pg_node_attr(query_jumble_squash);
 	/* true if elements are sub-arrays */
 	bool		multidims pg_node_attr(query_jumble_ignore);
+	/* location of the start of the elements list */
+	ParseLoc	list_start;
+	/* location of the end of the elements list */
+	ParseLoc	list_end;
 	/* token location, or -1 if unknown */
 	ParseLoc	location;
 } ArrayExpr;
-- 
2.39.5

