From 83c8a41c7d9d41fb63a82fc14c2dd66f30753a48 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=C3=81lvaro=20Herrera?= <alvherre@alvh.no-ip.org>
Date: Thu, 6 Mar 2025 17:01:22 +0100
Subject: [PATCH v2 1/2] Improve processCASbits API with a 'seen' struct

This allows ALTER TABLE .. ALTER CONSTRAINT to be more precise about
operations that are supported or not, as well as the reports from CREATE
CONSTRAINT TRIGGER error messages making more sense.
---
 src/backend/parser/gram.y                 | 152 +++++++++++++++++-----
 src/test/regress/expected/constraints.out |   4 +-
 src/test/regress/expected/foreign_key.out |   4 +-
 src/test/regress/expected/triggers.out    |  17 +++
 src/test/regress/sql/triggers.sql         |   6 +
 5 files changed, 148 insertions(+), 35 deletions(-)

diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index 271ae26cbaf..50024aabbca 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -146,6 +146,17 @@ typedef struct KeyActions
 #define CAS_NOT_ENFORCED			0x40
 #define CAS_ENFORCED				0x80
 
+/*
+ * We represent whether each set of flags is seen on a command with CAS_flags.
+ */
+typedef struct CAS_flags
+{
+	bool	seen_deferrability;
+	bool	seen_enforced;
+	bool	seen_valid;
+	bool	seen_inherit;
+} CAS_flags;
+
 
 #define parser_yyerror(msg)  scanner_yyerror(msg, yyscanner)
 #define parser_errposition(pos)  scanner_errposition(pos, yyscanner)
@@ -198,8 +209,9 @@ static void SplitColQualList(List *qualList,
 							 List **constraintList, CollateClause **collClause,
 							 core_yyscan_t yyscanner);
 static void processCASbits(int cas_bits, int location, const char *constrType,
-			   bool *deferrable, bool *initdeferred, bool *is_enforced,
-			   bool *not_valid, bool *no_inherit, core_yyscan_t yyscanner);
+						   bool *deferrable, bool *initdeferred, bool *is_enforced,
+						   bool *not_valid, bool *no_inherit, CAS_flags *seen,
+						   core_yyscan_t yyscanner);
 static PartitionStrategy parsePartitionStrategy(char *strategy, int location,
 												core_yyscan_t yyscanner);
 static void preprocess_pubobj_list(List *pubobjspec_list,
@@ -2658,15 +2670,35 @@ alter_table_cmd:
 				{
 					AlterTableCmd *n = makeNode(AlterTableCmd);
 					ATAlterConstraint *c = makeNode(ATAlterConstraint);
+					CAS_flags	seen;
 
 					n->subtype = AT_AlterConstraint;
 					n->def = (Node *) c;
 					c->conname = $3;
-					c->alterDeferrability = true;
-					processCASbits($4, @4, "FOREIGN KEY",
-									&c->deferrable,
-									&c->initdeferred,
-									NULL, NULL, NULL, yyscanner);
+					processCASbits($4, @4, NULL,
+								   &c->deferrable,
+								   &c->initdeferred,
+								   NULL, NULL, NULL,
+								   &seen,
+								   yyscanner);
+					if (seen.seen_deferrability)
+						c->alterDeferrability = true;
+					/* cannot (currently) be changed by this syntax: */
+					if (seen.seen_enforced)
+						ereport(ERROR,
+								errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+								errmsg("cannot alter constraint enforceability"),
+								parser_errposition(@4));
+					if (seen.seen_valid)
+						ereport(ERROR,
+								errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+								errmsg("cannot alter constraint validity"),
+								parser_errposition(@4));
+					if (seen.seen_inherit)
+						ereport(ERROR,
+								errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+								errmsg("cannot alter constraint inheritability"),
+								parser_errposition(@4));
 					$$ = (Node *) n;
 				}
 			/* ALTER TABLE <name> ALTER CONSTRAINT SET INHERIT */
@@ -4211,7 +4243,7 @@ ConstraintElem:
 					n->cooked_expr = NULL;
 					processCASbits($5, @5, "CHECK",
 								   NULL, NULL, &n->is_enforced, &n->skip_validation,
-								   &n->is_no_inherit, yyscanner);
+								   &n->is_no_inherit, NULL, yyscanner);
 					n->initially_valid = !n->skip_validation;
 					$$ = (Node *) n;
 				}
@@ -4225,7 +4257,7 @@ ConstraintElem:
 					/* no NOT VALID support yet */
 					processCASbits($4, @4, "NOT NULL",
 								   NULL, NULL, NULL, NULL,
-								   &n->is_no_inherit, yyscanner);
+								   &n->is_no_inherit, NULL, yyscanner);
 					n->initially_valid = true;
 					$$ = (Node *) n;
 				}
@@ -4245,7 +4277,7 @@ ConstraintElem:
 					n->indexspace = $9;
 					processCASbits($10, @10, "UNIQUE",
 								   &n->deferrable, &n->initdeferred, NULL,
-								   NULL, NULL, yyscanner);
+								   NULL, NULL, NULL, yyscanner);
 					$$ = (Node *) n;
 				}
 			| UNIQUE ExistingIndex ConstraintAttributeSpec
@@ -4261,7 +4293,7 @@ ConstraintElem:
 					n->indexspace = NULL;
 					processCASbits($3, @3, "UNIQUE",
 								   &n->deferrable, &n->initdeferred, NULL,
-								   NULL, NULL, yyscanner);
+								   NULL, NULL, NULL, yyscanner);
 					$$ = (Node *) n;
 				}
 			| PRIMARY KEY '(' columnList opt_without_overlaps ')' opt_c_include opt_definition OptConsTableSpace
@@ -4279,7 +4311,7 @@ ConstraintElem:
 					n->indexspace = $9;
 					processCASbits($10, @10, "PRIMARY KEY",
 								   &n->deferrable, &n->initdeferred, NULL,
-								   NULL, NULL, yyscanner);
+								   NULL, NULL, NULL, yyscanner);
 					$$ = (Node *) n;
 				}
 			| PRIMARY KEY ExistingIndex ConstraintAttributeSpec
@@ -4295,7 +4327,7 @@ ConstraintElem:
 					n->indexspace = NULL;
 					processCASbits($4, @4, "PRIMARY KEY",
 								   &n->deferrable, &n->initdeferred, NULL,
-								   NULL, NULL, yyscanner);
+								   NULL, NULL, NULL, yyscanner);
 					$$ = (Node *) n;
 				}
 			| EXCLUDE access_method_clause '(' ExclusionConstraintList ')'
@@ -4315,7 +4347,7 @@ ConstraintElem:
 					n->where_clause = $9;
 					processCASbits($10, @10, "EXCLUDE",
 								   &n->deferrable, &n->initdeferred, NULL,
-								   NULL, NULL, yyscanner);
+								   NULL, NULL, NULL, yyscanner);
 					$$ = (Node *) n;
 				}
 			| FOREIGN KEY '(' columnList optionalPeriodName ')' REFERENCES qualified_name
@@ -4345,7 +4377,7 @@ ConstraintElem:
 					processCASbits($12, @12, "FOREIGN KEY",
 								   &n->deferrable, &n->initdeferred,
 								   NULL, &n->skip_validation, NULL,
-								   yyscanner);
+								   NULL, yyscanner);
 					n->initially_valid = !n->skip_validation;
 					$$ = (Node *) n;
 				}
@@ -4383,9 +4415,9 @@ DomainConstraintElem:
 					n->location = @1;
 					n->raw_expr = $3;
 					n->cooked_expr = NULL;
-					processCASbits($5, @5, "CHECK",
+					processCASbits($5, @5, "CHECK", /* FIXME */
 								   NULL, NULL, NULL, &n->skip_validation,
-								   &n->is_no_inherit, yyscanner);
+								   &n->is_no_inherit, NULL, yyscanner);
 					n->is_enforced = true;
 					n->initially_valid = !n->skip_validation;
 					$$ = (Node *) n;
@@ -4398,9 +4430,9 @@ DomainConstraintElem:
 					n->location = @1;
 					n->keys = list_make1(makeString("value"));
 					/* no NOT VALID, NO INHERIT support */
-					processCASbits($3, @3, "NOT NULL",
+					processCASbits($3, @3, "NOT NULL", /* FIXME */
 								   NULL, NULL, NULL,
-								   NULL, NULL, yyscanner);
+								   NULL, NULL, NULL, yyscanner);
 					n->initially_valid = true;
 					$$ = (Node *) n;
 				}
@@ -6043,6 +6075,7 @@ CreateTrigStmt:
 			EXECUTE FUNCTION_or_PROCEDURE func_name '(' TriggerFuncArgs ')'
 				{
 					CreateTrigStmt *n = makeNode(CreateTrigStmt);
+					CAS_flags	seen;
 
 					n->replace = $2;
 					if (n->replace) /* not supported, see CreateTrigger */
@@ -6061,9 +6094,27 @@ CreateTrigStmt:
 					n->columns = (List *) lsecond($7);
 					n->whenClause = $15;
 					n->transitionRels = NIL;
-					processCASbits($11, @11, "TRIGGER",
+					processCASbits($11, @11, NULL,
 								   &n->deferrable, &n->initdeferred, NULL,
-								   NULL, NULL, yyscanner);
+								   NULL, NULL, &seen, yyscanner);
+					if (seen.seen_valid)
+						ereport(ERROR,
+								errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+								errmsg("constraint triggers cannot be marked %s",
+									   "NOT VALID"),
+								parser_errposition(@11));
+					if (seen.seen_inherit)
+						ereport(ERROR,
+								errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+								errmsg("constraint triggers cannot be marked %s",
+									   "INHERIT/NO INHERIT"),
+								parser_errposition(@11));
+					if (seen.seen_enforced)
+						ereport(ERROR,
+								errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+								errmsg("constraint triggers cannot be marked %s",
+									   "ENFORCED/NOT ENFORCED"),
+								parser_errposition(@11));
 					n->constrrel = $10;
 					$$ = (Node *) n;
 				}
@@ -19505,14 +19556,31 @@ SplitColQualList(List *qualList,
 }
 
 /*
- * Process result of ConstraintAttributeSpec, and set appropriate bool flags
- * in the output command node.  Pass NULL for any flags the particular
- * command doesn't support.
+ * Process result of ConstraintAttributeSpec, and set appropriate bool flags.
+ * Any of those flags can be given as NULL pointers, for options that are
+ * unsupported by the particular production being parsed.  If 'seen' is given
+ * as a non NULL pointer, the corresponding boolean there is set for every
+ * option in the command being parsed.
+ *
+ * Unsupported flags for a particular command can be handled in one of two
+ * ways.  Productions that require ad-hoc error reporting (those that don't
+ * know which type of constraint is being parsed or simply require a
+ * different phrasing than what this routine provides) can pass a valid
+ * 'seen' pointer; when that is given, each flag in that struct is set when
+ * a particular type of option appears in the command.  The production-
+ * specific case can inspect the 'seen' flags and complain appropriately if
+ * one option was seen that the command doesn't support.  In this case,
+ * 'constrType' must be given as NULL.
+ *
+ * The other option is to give a NULL 'seen' pointer.  In this case, an
+ * unsuported flag will give rise to an error report using the 'constrType',
+ * which must be given as not NULL.
  */
 static void
 processCASbits(int cas_bits, int location, const char *constrType,
 			   bool *deferrable, bool *initdeferred, bool *is_enforced,
-			   bool *not_valid, bool *no_inherit, core_yyscan_t yyscanner)
+			   bool *not_valid, bool *no_inherit,
+			   CAS_flags *seen, core_yyscan_t yyscanner)
 {
 	/* defaults */
 	if (deferrable)
@@ -19523,70 +19591,90 @@ processCASbits(int cas_bits, int location, const char *constrType,
 		*not_valid = false;
 	if (is_enforced)
 		*is_enforced = true;
+	if (no_inherit)
+		*no_inherit = false;
+
+	Assert((constrType == NULL) ^ (seen == NULL));
+	if (seen)
+		memset(seen, 0, sizeof(CAS_flags));
 
 	if (cas_bits & (CAS_DEFERRABLE | CAS_INITIALLY_DEFERRED))
 	{
 		if (deferrable)
 			*deferrable = true;
-		else
+		else if (!seen)
 			ereport(ERROR,
 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
 			/* translator: %s is CHECK, UNIQUE, or similar */
 					 errmsg("%s constraints cannot be marked DEFERRABLE",
 							constrType),
 					 parser_errposition(location)));
+		if (seen)
+			seen->seen_deferrability = true;
 	}
 
 	if (cas_bits & CAS_INITIALLY_DEFERRED)
 	{
 		if (initdeferred)
 			*initdeferred = true;
-		else
+		else if (!seen)
 			ereport(ERROR,
 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
 			/* translator: %s is CHECK, UNIQUE, or similar */
 					 errmsg("%s constraints cannot be marked DEFERRABLE",
 							constrType),
 					 parser_errposition(location)));
+		if (seen)
+			seen->seen_deferrability = true;
 	}
 
+	/* not deferrable is the default; just report that we saw it */
+	if (cas_bits & (CAS_NOT_DEFERRABLE) && seen)
+		seen->seen_deferrability = true;
+
 	if (cas_bits & CAS_NOT_VALID)
 	{
 		if (not_valid)
 			*not_valid = true;
-		else
+		else if (!seen)
 			ereport(ERROR,
 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
 			/* translator: %s is CHECK, UNIQUE, or similar */
 					 errmsg("%s constraints cannot be marked NOT VALID",
 							constrType),
 					 parser_errposition(location)));
+		if (seen)
+			seen->seen_valid = true;
 	}
 
 	if (cas_bits & CAS_NO_INHERIT)
 	{
 		if (no_inherit)
 			*no_inherit = true;
-		else
+		else if (!seen)
 			ereport(ERROR,
 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
 			/* translator: %s is CHECK, UNIQUE, or similar */
 					 errmsg("%s constraints cannot be marked NO INHERIT",
 							constrType),
 					 parser_errposition(location)));
+		if (seen)
+			seen->seen_inherit = true;
 	}
 
 	if (cas_bits & CAS_NOT_ENFORCED)
 	{
 		if (is_enforced)
 			*is_enforced = false;
-		else
+		else if (!seen)
 			ereport(ERROR,
 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
 					 /* translator: %s is CHECK, UNIQUE, or similar */
 					 errmsg("%s constraints cannot be marked NOT ENFORCED",
 							constrType),
 					 parser_errposition(location)));
+		if (seen)
+			seen->seen_enforced = true;
 
 		/*
 		 * NB: The validated status is irrelevant when the constraint is set to
@@ -19602,13 +19690,15 @@ processCASbits(int cas_bits, int location, const char *constrType,
 	{
 		if (is_enforced)
 			*is_enforced = true;
-		else
+		else if (!seen)
 			ereport(ERROR,
 					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
 					 /* translator: %s is CHECK, UNIQUE, or similar */
 					 errmsg("%s constraints cannot be marked ENFORCED",
 							constrType),
 					 parser_errposition(location)));
+		if (seen)
+			seen->seen_enforced = true;
 	}
 }
 
diff --git a/src/test/regress/expected/constraints.out b/src/test/regress/expected/constraints.out
index 4f39100fcdf..df2c27dd7e7 100644
--- a/src/test/regress/expected/constraints.out
+++ b/src/test/regress/expected/constraints.out
@@ -745,11 +745,11 @@ ERROR:  misplaced NOT ENFORCED clause
 LINE 1: CREATE TABLE UNIQUE_NOTEN_TBL(i int UNIQUE NOT ENFORCED);
                                                    ^
 ALTER TABLE unique_tbl ALTER CONSTRAINT unique_tbl_i_key ENFORCED;
-ERROR:  FOREIGN KEY constraints cannot be marked ENFORCED
+ERROR:  cannot alter constraint enforceability
 LINE 1: ...TABLE unique_tbl ALTER CONSTRAINT unique_tbl_i_key ENFORCED;
                                                               ^
 ALTER TABLE unique_tbl ALTER CONSTRAINT unique_tbl_i_key NOT ENFORCED;
-ERROR:  FOREIGN KEY constraints cannot be marked NOT ENFORCED
+ERROR:  cannot alter constraint enforceability
 LINE 1: ...ABLE unique_tbl ALTER CONSTRAINT unique_tbl_i_key NOT ENFORC...
                                                              ^
 DROP TABLE unique_tbl;
diff --git a/src/test/regress/expected/foreign_key.out b/src/test/regress/expected/foreign_key.out
index 6a3374d5152..9d0f91a9039 100644
--- a/src/test/regress/expected/foreign_key.out
+++ b/src/test/regress/expected/foreign_key.out
@@ -1284,11 +1284,11 @@ ERROR:  constraint declared INITIALLY DEFERRED must be DEFERRABLE
 LINE 1: ...e ALTER CONSTRAINT fktable_fk_fkey NOT DEFERRABLE INITIALLY ...
                                                              ^
 ALTER TABLE fktable ALTER CONSTRAINT fktable_fk_fkey NO INHERIT;
-ERROR:  FOREIGN KEY constraints cannot be marked NO INHERIT
+ERROR:  cannot alter constraint inheritability
 LINE 1: ...ER TABLE fktable ALTER CONSTRAINT fktable_fk_fkey NO INHERIT...
                                                              ^
 ALTER TABLE fktable ALTER CONSTRAINT fktable_fk_fkey NOT VALID;
-ERROR:  FOREIGN KEY constraints cannot be marked NOT VALID
+ERROR:  cannot alter constraint validity
 LINE 1: ...ER TABLE fktable ALTER CONSTRAINT fktable_fk_fkey NOT VALID;
                                                              ^
 -- test order of firing of FK triggers when several RI-induced changes need to
diff --git a/src/test/regress/expected/triggers.out b/src/test/regress/expected/triggers.out
index 247c67c32ae..15423749506 100644
--- a/src/test/regress/expected/triggers.out
+++ b/src/test/regress/expected/triggers.out
@@ -2547,6 +2547,23 @@ select * from parted;
 
 drop table parted;
 drop function parted_trigfunc();
+-- constraint triggers
+create constraint trigger foo after insert on pg_class not valid for each row execute procedure test();
+ERROR:  constraint triggers cannot be marked NOT VALID
+LINE 1: ...e constraint trigger foo after insert on pg_class not valid ...
+                                                             ^
+create constraint trigger foo after insert on pg_class no inherit for each row execute procedure test();
+ERROR:  constraint triggers cannot be marked INHERIT/NO INHERIT
+LINE 1: ...e constraint trigger foo after insert on pg_class no inherit...
+                                                             ^
+create constraint trigger foo after insert on pg_class enforced for each row execute procedure test();
+ERROR:  constraint triggers cannot be marked ENFORCED/NOT ENFORCED
+LINE 1: ...e constraint trigger foo after insert on pg_class enforced f...
+                                                             ^
+create constraint trigger foo after insert on pg_class not enforced for each row execute procedure test();
+ERROR:  constraint triggers cannot be marked ENFORCED/NOT ENFORCED
+LINE 1: ...e constraint trigger foo after insert on pg_class not enforc...
+                                                             ^
 --
 -- Constraint triggers and partitioned tables
 create table parted_constr_ancestor (a int, b text)
diff --git a/src/test/regress/sql/triggers.sql b/src/test/regress/sql/triggers.sql
index 659972f1135..c10f5eac74e 100644
--- a/src/test/regress/sql/triggers.sql
+++ b/src/test/regress/sql/triggers.sql
@@ -1764,6 +1764,12 @@ select * from parted;
 drop table parted;
 drop function parted_trigfunc();
 
+-- constraint triggers
+create constraint trigger foo after insert on pg_class not valid for each row execute procedure test();
+create constraint trigger foo after insert on pg_class no inherit for each row execute procedure test();
+create constraint trigger foo after insert on pg_class enforced for each row execute procedure test();
+create constraint trigger foo after insert on pg_class not enforced for each row execute procedure test();
+
 --
 -- Constraint triggers and partitioned tables
 create table parted_constr_ancestor (a int, b text)
-- 
2.39.5

