From 58926a4546b3918af8f6e6691956731d8c902701 Mon Sep 17 00:00:00 2001
From: Antonin Houska <ah@cybertec.at>
Date: Wed, 8 Apr 2020 15:03:20 +0200
Subject: [PATCH 2/4] Changed ri_GenerateQual() so it generates the whole
 qualifier.

This way we can use the function to reduce the amount of copy&pasted code a
bit.
---
 src/backend/utils/adt/ri_triggers.c | 288 +++++++++++++++-------------
 1 file changed, 159 insertions(+), 129 deletions(-)

diff --git a/src/backend/utils/adt/ri_triggers.c b/src/backend/utils/adt/ri_triggers.c
index 6220872126..3bedb75846 100644
--- a/src/backend/utils/adt/ri_triggers.c
+++ b/src/backend/utils/adt/ri_triggers.c
@@ -180,11 +180,17 @@ static Datum ri_restrict(TriggerData *trigdata, bool is_no_action);
 static Datum ri_set(TriggerData *trigdata, bool is_set_null);
 static void quoteOneName(char *buffer, const char *name);
 static void quoteRelationName(char *buffer, Relation rel);
-static void ri_GenerateQual(StringInfo buf,
-							const char *sep,
-							const char *leftop, Oid leftoptype,
-							Oid opoid,
-							const char *rightop, Oid rightoptype);
+static char *ri_ColNameQuoted(const char *tabname, const char *attname);
+static void ri_GenerateQual(StringInfo buf, char *sep, int nkeys,
+							const char *ltabname, Relation lrel,
+							const int16 *lattnums,
+							const char *rtabname, Relation rrel,
+							const int16 *rattnums, const Oid *eq_oprs);
+static void ri_GenerateQualComponent(StringInfo buf,
+									 const char *sep,
+									 const char *leftop, Oid leftoptype,
+									 Oid opoid,
+									 const char *rightop, Oid rightoptype);
 static void ri_GenerateQualCollation(StringInfo buf, Oid collation);
 static int	ri_NullCheck(TupleDesc tupdesc, TupleTableSlot *slot,
 						 const RI_ConstraintInfo *riinfo, bool rel_is_pk);
@@ -372,10 +378,10 @@ RI_FKey_check(TriggerData *trigdata)
 			quoteOneName(attname,
 						 RIAttName(pk_rel, riinfo->pk_attnums[i]));
 			sprintf(paramname, "$%d", i + 1);
-			ri_GenerateQual(&querybuf, querysep,
-							attname, pk_type,
-							riinfo->pf_eq_oprs[i],
-							paramname, fk_type);
+			ri_GenerateQualComponent(&querybuf, querysep,
+									 attname, pk_type,
+									 riinfo->pf_eq_oprs[i],
+									 paramname, fk_type);
 			querysep = "AND";
 			queryoids[i] = fk_type;
 		}
@@ -504,10 +510,10 @@ ri_Check_Pk_Match(Relation pk_rel, Relation fk_rel,
 			quoteOneName(attname,
 						 RIAttName(pk_rel, riinfo->pk_attnums[i]));
 			sprintf(paramname, "$%d", i + 1);
-			ri_GenerateQual(&querybuf, querysep,
-							attname, pk_type,
-							riinfo->pp_eq_oprs[i],
-							paramname, pk_type);
+			ri_GenerateQualComponent(&querybuf, querysep,
+									 attname, pk_type,
+									 riinfo->pp_eq_oprs[i],
+									 paramname, pk_type);
 			querysep = "AND";
 			queryoids[i] = pk_type;
 		}
@@ -694,10 +700,10 @@ ri_restrict(TriggerData *trigdata, bool is_no_action)
 			quoteOneName(attname,
 						 RIAttName(fk_rel, riinfo->fk_attnums[i]));
 			sprintf(paramname, "$%d", i + 1);
-			ri_GenerateQual(&querybuf, querysep,
-							paramname, pk_type,
-							riinfo->pf_eq_oprs[i],
-							attname, fk_type);
+			ri_GenerateQualComponent(&querybuf, querysep,
+									 paramname, pk_type,
+									 riinfo->pf_eq_oprs[i],
+									 attname, fk_type);
 			if (pk_coll != fk_coll && !get_collation_isdeterministic(pk_coll))
 				ri_GenerateQualCollation(&querybuf, pk_coll);
 			querysep = "AND";
@@ -805,10 +811,10 @@ RI_FKey_cascade_del(PG_FUNCTION_ARGS)
 			quoteOneName(attname,
 						 RIAttName(fk_rel, riinfo->fk_attnums[i]));
 			sprintf(paramname, "$%d", i + 1);
-			ri_GenerateQual(&querybuf, querysep,
-							paramname, pk_type,
-							riinfo->pf_eq_oprs[i],
-							attname, fk_type);
+			ri_GenerateQualComponent(&querybuf, querysep,
+									 paramname, pk_type,
+									 riinfo->pf_eq_oprs[i],
+									 attname, fk_type);
 			if (pk_coll != fk_coll && !get_collation_isdeterministic(pk_coll))
 				ri_GenerateQualCollation(&querybuf, pk_coll);
 			querysep = "AND";
@@ -924,10 +930,10 @@ RI_FKey_cascade_upd(PG_FUNCTION_ARGS)
 							 "%s %s = $%d",
 							 querysep, attname, i + 1);
 			sprintf(paramname, "$%d", j + 1);
-			ri_GenerateQual(&qualbuf, qualsep,
-							paramname, pk_type,
-							riinfo->pf_eq_oprs[i],
-							attname, fk_type);
+			ri_GenerateQualComponent(&qualbuf, qualsep,
+									 paramname, pk_type,
+									 riinfo->pf_eq_oprs[i],
+									 attname, fk_type);
 			if (pk_coll != fk_coll && !get_collation_isdeterministic(pk_coll))
 				ri_GenerateQualCollation(&querybuf, pk_coll);
 			querysep = ",";
@@ -1104,10 +1110,10 @@ ri_set(TriggerData *trigdata, bool is_set_null)
 							 querysep, attname,
 							 is_set_null ? "NULL" : "DEFAULT");
 			sprintf(paramname, "$%d", i + 1);
-			ri_GenerateQual(&qualbuf, qualsep,
-							paramname, pk_type,
-							riinfo->pf_eq_oprs[i],
-							attname, fk_type);
+			ri_GenerateQualComponent(&qualbuf, qualsep,
+									 paramname, pk_type,
+									 riinfo->pf_eq_oprs[i],
+									 attname, fk_type);
 			if (pk_coll != fk_coll && !get_collation_isdeterministic(pk_coll))
 				ri_GenerateQualCollation(&querybuf, pk_coll);
 			querysep = ",";
@@ -1402,31 +1408,13 @@ RI_Initial_Check(Trigger *trigger, Relation fk_rel, Relation pk_rel)
 	pk_only = pk_rel->rd_rel->relkind == RELKIND_PARTITIONED_TABLE ?
 		"" : "ONLY ";
 	appendStringInfo(&querybuf,
-					 " FROM %s%s fk LEFT OUTER JOIN %s%s pk ON",
+					 " FROM %s%s fk LEFT OUTER JOIN %s%s pk ON (",
 					 fk_only, fkrelname, pk_only, pkrelname);
 
-	strcpy(pkattname, "pk.");
-	strcpy(fkattname, "fk.");
-	sep = "(";
-	for (int i = 0; i < riinfo->nkeys; i++)
-	{
-		Oid			pk_type = RIAttType(pk_rel, riinfo->pk_attnums[i]);
-		Oid			fk_type = RIAttType(fk_rel, riinfo->fk_attnums[i]);
-		Oid			pk_coll = RIAttCollation(pk_rel, riinfo->pk_attnums[i]);
-		Oid			fk_coll = RIAttCollation(fk_rel, riinfo->fk_attnums[i]);
-
-		quoteOneName(pkattname + 3,
-					 RIAttName(pk_rel, riinfo->pk_attnums[i]));
-		quoteOneName(fkattname + 3,
-					 RIAttName(fk_rel, riinfo->fk_attnums[i]));
-		ri_GenerateQual(&querybuf, sep,
-						pkattname, pk_type,
-						riinfo->pf_eq_oprs[i],
-						fkattname, fk_type);
-		if (pk_coll != fk_coll)
-			ri_GenerateQualCollation(&querybuf, pk_coll);
-		sep = "AND";
-	}
+	ri_GenerateQual(&querybuf, "AND", riinfo->nkeys,
+					"pk", pk_rel, riinfo->pk_attnums,
+					"fk", fk_rel, riinfo->fk_attnums,
+					riinfo->pf_eq_oprs);
 
 	/*
 	 * It's sufficient to test any one pk attribute for null to detect a join
@@ -1584,7 +1572,6 @@ RI_PartitionRemove_Check(Trigger *trigger, Relation fk_rel, Relation pk_rel)
 	char	   *constraintDef;
 	char		pkrelname[MAX_QUOTED_REL_NAME_LEN];
 	char		fkrelname[MAX_QUOTED_REL_NAME_LEN];
-	char		pkattname[MAX_QUOTED_NAME_LEN + 3];
 	char		fkattname[MAX_QUOTED_NAME_LEN + 3];
 	const char *sep;
 	const char *fk_only;
@@ -1633,30 +1620,13 @@ RI_PartitionRemove_Check(Trigger *trigger, Relation fk_rel, Relation pk_rel)
 	fk_only = fk_rel->rd_rel->relkind == RELKIND_PARTITIONED_TABLE ?
 		"" : "ONLY ";
 	appendStringInfo(&querybuf,
-					 " FROM %s%s fk JOIN %s pk ON",
+					 " FROM %s%s fk JOIN %s pk ON (",
 					 fk_only, fkrelname, pkrelname);
-	strcpy(pkattname, "pk.");
-	strcpy(fkattname, "fk.");
-	sep = "(";
-	for (i = 0; i < riinfo->nkeys; i++)
-	{
-		Oid			pk_type = RIAttType(pk_rel, riinfo->pk_attnums[i]);
-		Oid			fk_type = RIAttType(fk_rel, riinfo->fk_attnums[i]);
-		Oid			pk_coll = RIAttCollation(pk_rel, riinfo->pk_attnums[i]);
-		Oid			fk_coll = RIAttCollation(fk_rel, riinfo->fk_attnums[i]);
-
-		quoteOneName(pkattname + 3,
-					 RIAttName(pk_rel, riinfo->pk_attnums[i]));
-		quoteOneName(fkattname + 3,
-					 RIAttName(fk_rel, riinfo->fk_attnums[i]));
-		ri_GenerateQual(&querybuf, sep,
-						pkattname, pk_type,
-						riinfo->pf_eq_oprs[i],
-						fkattname, fk_type);
-		if (pk_coll != fk_coll)
-			ri_GenerateQualCollation(&querybuf, pk_coll);
-		sep = "AND";
-	}
+
+	ri_GenerateQual(&querybuf, "AND", riinfo->nkeys,
+					"pk", pk_rel, riinfo->pk_attnums,
+					"fk", fk_rel, riinfo->fk_attnums,
+					riinfo->pf_eq_oprs);
 
 	/*
 	 * Start the WHERE clause with the partition constraint (except if this is
@@ -1820,7 +1790,44 @@ quoteRelationName(char *buffer, Relation rel)
 }
 
 /*
- * ri_GenerateQual --- generate a WHERE clause equating two variables
+ * ri_GenerateQual --- generate WHERE/ON clause.
+ *
+ * Note: to avoid unnecessary explicit casts, make sure that the left and
+ * right operands match eq_oprs expect (ie don't swap the left and right
+ * operands accidentally).
+ */
+static void
+ri_GenerateQual(StringInfo buf, char *sep, int nkeys,
+				const char *ltabname, Relation lrel,
+				const int16 *lattnums,
+				const char *rtabname, Relation rrel,
+				const int16 *rattnums,
+				const Oid *eq_oprs)
+{
+	for (int i = 0; i < nkeys; i++)
+	{
+		Oid			ltype = RIAttType(lrel, lattnums[i]);
+		Oid			rtype = RIAttType(rrel, rattnums[i]);
+		Oid			lcoll = RIAttCollation(lrel, lattnums[i]);
+		Oid			rcoll = RIAttCollation(rrel, rattnums[i]);
+		char	   *latt,
+				   *ratt;
+		char	   *sep_current = i > 0 ? sep : NULL;
+
+		latt = ri_ColNameQuoted(ltabname, RIAttName(lrel, lattnums[i]));
+		ratt = ri_ColNameQuoted(rtabname, RIAttName(rrel, rattnums[i]));
+
+		ri_GenerateQualComponent(buf, sep_current, latt, ltype, eq_oprs[i],
+								 ratt, rtype);
+
+		if (lcoll != rcoll)
+			ri_GenerateQualCollation(buf, lcoll);
+	}
+}
+
+/*
+ * ri_GenerateQual --- generate a component of WHERE/ON clause equating two
+ * variables, to be AND-ed to the other components.
  *
  * This basically appends " sep leftop op rightop" to buf, adding casts
  * and schema qualification as needed to ensure that the parser will select
@@ -1828,17 +1835,86 @@ quoteRelationName(char *buffer, Relation rel)
  * if they aren't variables or parameters.
  */
 static void
-ri_GenerateQual(StringInfo buf,
-				const char *sep,
-				const char *leftop, Oid leftoptype,
-				Oid opoid,
-				const char *rightop, Oid rightoptype)
+ri_GenerateQualComponent(StringInfo buf,
+						 const char *sep,
+						 const char *leftop, Oid leftoptype,
+						 Oid opoid,
+						 const char *rightop, Oid rightoptype)
 {
-	appendStringInfo(buf, " %s ", sep);
+	if (sep)
+		appendStringInfo(buf, " %s ", sep);
 	generate_operator_clause(buf, leftop, leftoptype, opoid,
 							 rightop, rightoptype);
 }
 
+/*
+ * ri_ColNameQuoted() --- return column name, with both table and column name
+ * quoted.
+ */
+static char *
+ri_ColNameQuoted(const char *tabname, const char *attname)
+{
+	char		quoted[MAX_QUOTED_NAME_LEN];
+	StringInfo	result = makeStringInfo();
+
+	if (tabname && strlen(tabname) > 0)
+	{
+		quoteOneName(quoted, tabname);
+		appendStringInfo(result, "%s.", quoted);
+	}
+
+	quoteOneName(quoted, attname);
+	appendStringInfoString(result, quoted);
+
+	return result->data;
+}
+
+/*
+ * Check that RI trigger function was called in expected context
+ */
+static void
+ri_CheckTrigger(FunctionCallInfo fcinfo, const char *funcname, int tgkind)
+{
+	TriggerData *trigdata = (TriggerData *) fcinfo->context;
+
+	if (!CALLED_AS_TRIGGER(fcinfo))
+		ereport(ERROR,
+				(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
+				 errmsg("function \"%s\" was not called by trigger manager", funcname)));
+
+	/*
+	 * Check proper event
+	 */
+	if (!TRIGGER_FIRED_AFTER(trigdata->tg_event) ||
+		!TRIGGER_FIRED_FOR_ROW(trigdata->tg_event))
+		ereport(ERROR,
+				(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
+				 errmsg("function \"%s\" must be fired AFTER ROW", funcname)));
+
+	switch (tgkind)
+	{
+		case RI_TRIGTYPE_INSERT:
+			if (!TRIGGER_FIRED_BY_INSERT(trigdata->tg_event))
+				ereport(ERROR,
+						(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
+						 errmsg("function \"%s\" must be fired for INSERT", funcname)));
+			break;
+		case RI_TRIGTYPE_UPDATE:
+			if (!TRIGGER_FIRED_BY_UPDATE(trigdata->tg_event))
+				ereport(ERROR,
+						(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
+						 errmsg("function \"%s\" must be fired for UPDATE", funcname)));
+			break;
+
+		case RI_TRIGTYPE_DELETE:
+			if (!TRIGGER_FIRED_BY_DELETE(trigdata->tg_event))
+				ereport(ERROR,
+						(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
+						 errmsg("function \"%s\" must be fired for DELETE", funcname)));
+			break;
+	}
+}
+
 /*
  * ri_GenerateQualCollation --- add a COLLATE spec to a WHERE clause
  *
@@ -1909,52 +1985,6 @@ ri_BuildQueryKey(RI_QueryKey *key, const RI_ConstraintInfo *riinfo,
 	key->constr_queryno = constr_queryno;
 }
 
-/*
- * Check that RI trigger function was called in expected context
- */
-static void
-ri_CheckTrigger(FunctionCallInfo fcinfo, const char *funcname, int tgkind)
-{
-	TriggerData *trigdata = (TriggerData *) fcinfo->context;
-
-	if (!CALLED_AS_TRIGGER(fcinfo))
-		ereport(ERROR,
-				(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
-				 errmsg("function \"%s\" was not called by trigger manager", funcname)));
-
-	/*
-	 * Check proper event
-	 */
-	if (!TRIGGER_FIRED_AFTER(trigdata->tg_event) ||
-		!TRIGGER_FIRED_FOR_ROW(trigdata->tg_event))
-		ereport(ERROR,
-				(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
-				 errmsg("function \"%s\" must be fired AFTER ROW", funcname)));
-
-	switch (tgkind)
-	{
-		case RI_TRIGTYPE_INSERT:
-			if (!TRIGGER_FIRED_BY_INSERT(trigdata->tg_event))
-				ereport(ERROR,
-						(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
-						 errmsg("function \"%s\" must be fired for INSERT", funcname)));
-			break;
-		case RI_TRIGTYPE_UPDATE:
-			if (!TRIGGER_FIRED_BY_UPDATE(trigdata->tg_event))
-				ereport(ERROR,
-						(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
-						 errmsg("function \"%s\" must be fired for UPDATE", funcname)));
-			break;
-		case RI_TRIGTYPE_DELETE:
-			if (!TRIGGER_FIRED_BY_DELETE(trigdata->tg_event))
-				ereport(ERROR,
-						(errcode(ERRCODE_E_R_I_E_TRIGGER_PROTOCOL_VIOLATED),
-						 errmsg("function \"%s\" must be fired for DELETE", funcname)));
-			break;
-	}
-}
-
-
 /*
  * Fetch the RI_ConstraintInfo struct for the trigger's FK constraint.
  */
-- 
2.20.1

