From 300b43180b287424b595c237b452d55f4d582bc5 Mon Sep 17 00:00:00 2001
From: Nikita Glukhov <n.gluhov@postgrespro.ru>
Date: Sat, 12 Nov 2016 18:59:43 +0300
Subject: [PATCH 1/6] Add pg_operator.oprstat for derived operator statistics
 estimation

---
 src/backend/catalog/pg_operator.c      | 11 +++++
 src/backend/commands/operatorcmds.c    | 61 ++++++++++++++++++++++++++
 src/backend/utils/adt/selfuncs.c       | 38 ++++++++++++++++
 src/backend/utils/cache/lsyscache.c    | 24 ++++++++++
 src/include/catalog/pg_operator.h      |  4 ++
 src/include/utils/lsyscache.h          |  1 +
 src/test/regress/expected/oidjoins.out |  1 +
 7 files changed, 140 insertions(+)

diff --git a/src/backend/catalog/pg_operator.c b/src/backend/catalog/pg_operator.c
index 630bf3e56cc..9205e62c0eb 100644
--- a/src/backend/catalog/pg_operator.c
+++ b/src/backend/catalog/pg_operator.c
@@ -256,6 +256,7 @@ OperatorShellMake(const char *operatorName,
 	values[Anum_pg_operator_oprcode - 1] = ObjectIdGetDatum(InvalidOid);
 	values[Anum_pg_operator_oprrest - 1] = ObjectIdGetDatum(InvalidOid);
 	values[Anum_pg_operator_oprjoin - 1] = ObjectIdGetDatum(InvalidOid);
+	values[Anum_pg_operator_oprstat - 1] = ObjectIdGetDatum(InvalidOid);
 
 	/*
 	 * create a new operator tuple
@@ -301,6 +302,7 @@ OperatorShellMake(const char *operatorName,
  *		negatorName				X negator operator
  *		restrictionId			X restriction selectivity procedure ID
  *		joinId					X join selectivity procedure ID
+ *		statsId					X statistics derivation procedure ID
  *		canMerge				merge join can be used with this operator
  *		canHash					hash join can be used with this operator
  *
@@ -333,6 +335,7 @@ OperatorCreate(const char *operatorName,
 			   List *negatorName,
 			   Oid restrictionId,
 			   Oid joinId,
+			   Oid statsId,
 			   bool canMerge,
 			   bool canHash)
 {
@@ -505,6 +508,7 @@ OperatorCreate(const char *operatorName,
 	values[Anum_pg_operator_oprcode - 1] = ObjectIdGetDatum(procedureId);
 	values[Anum_pg_operator_oprrest - 1] = ObjectIdGetDatum(restrictionId);
 	values[Anum_pg_operator_oprjoin - 1] = ObjectIdGetDatum(joinId);
+	values[Anum_pg_operator_oprstat - 1] = ObjectIdGetDatum(statsId);
 
 	pg_operator_desc = table_open(OperatorRelationId, RowExclusiveLock);
 
@@ -855,6 +859,13 @@ makeOperatorDependencies(HeapTuple tuple,
 		add_exact_object_address(&referenced, addrs);
 	}
 
+	/* Dependency on statistics derivation function */
+	if (OidIsValid(oper->oprstat))
+	{
+		ObjectAddressSet(referenced, ProcedureRelationId, oper->oprstat);
+		add_exact_object_address(&referenced, addrs);
+	}
+
 	record_object_address_dependencies(&myself, addrs, DEPENDENCY_NORMAL);
 	free_object_addresses(addrs);
 
diff --git a/src/backend/commands/operatorcmds.c b/src/backend/commands/operatorcmds.c
index a5924d7d564..adf13e648a6 100644
--- a/src/backend/commands/operatorcmds.c
+++ b/src/backend/commands/operatorcmds.c
@@ -52,6 +52,7 @@
 
 static Oid	ValidateRestrictionEstimator(List *restrictionName);
 static Oid	ValidateJoinEstimator(List *joinName);
+static Oid	ValidateStatisticsDerivator(List *joinName);
 
 /*
  * DefineOperator
@@ -78,10 +79,12 @@ DefineOperator(List *names, List *parameters)
 	List	   *commutatorName = NIL;	/* optional commutator operator name */
 	List	   *negatorName = NIL;	/* optional negator operator name */
 	List	   *restrictionName = NIL;	/* optional restrict. sel. function */
+	List	   *statisticsName = NIL;	/* optional stats estimat. procedure */
 	List	   *joinName = NIL; /* optional join sel. function */
 	Oid			functionOid;	/* functions converted to OID */
 	Oid			restrictionOid;
 	Oid			joinOid;
+	Oid			statisticsOid;
 	Oid			typeId[2];		/* to hold left and right arg */
 	int			nargs;
 	ListCell   *pl;
@@ -131,6 +134,8 @@ DefineOperator(List *names, List *parameters)
 			restrictionName = defGetQualifiedName(defel);
 		else if (strcmp(defel->defname, "join") == 0)
 			joinName = defGetQualifiedName(defel);
+		else if (strcmp(defel->defname, "statistics") == 0)
+			statisticsName = defGetQualifiedName(defel);
 		else if (strcmp(defel->defname, "hashes") == 0)
 			canHash = defGetBoolean(defel);
 		else if (strcmp(defel->defname, "merges") == 0)
@@ -246,6 +251,10 @@ DefineOperator(List *names, List *parameters)
 		joinOid = ValidateJoinEstimator(joinName);
 	else
 		joinOid = InvalidOid;
+	if (statisticsName)
+		statisticsOid = ValidateStatisticsDerivator(statisticsName);
+	else
+		statisticsOid = InvalidOid;
 
 	/*
 	 * now have OperatorCreate do all the work..
@@ -260,6 +269,7 @@ DefineOperator(List *names, List *parameters)
 					   negatorName, /* optional negator operator name */
 					   restrictionOid,	/* optional restrict. sel. function */
 					   joinOid, /* optional join sel. function name */
+					   statisticsOid, /* optional stats estimation procedure */
 					   canMerge,	/* operator merges */
 					   canHash);	/* operator hashes */
 }
@@ -357,6 +367,40 @@ ValidateJoinEstimator(List *joinName)
 	return joinOid;
 }
 
+/*
+ * Look up a statistics estimator function by name, and verify that it has the
+ * correct signature and we have the permissions to attach it to an operator.
+ */
+static Oid
+ValidateStatisticsDerivator(List *statName)
+{
+	Oid			typeId[4];
+	Oid			statOid;
+	AclResult	aclresult;
+
+	typeId[0] = INTERNALOID;	/* PlannerInfo */
+	typeId[1] = INTERNALOID;	/* OpExpr */
+	typeId[2] = INT4OID;		/* varRelid */
+	typeId[3] = INTERNALOID;	/* VariableStatData */
+
+	statOid = LookupFuncName(statName, 4, typeId, false);
+
+	/* statistics estimators must return void */
+	if (get_func_rettype(statOid) != VOIDOID)
+		ereport(ERROR,
+				(errcode(ERRCODE_INVALID_OBJECT_DEFINITION),
+				 errmsg("statistics estimator function %s must return type %s",
+						NameListToString(statName), "void")));
+
+	/* Require EXECUTE rights for the estimator */
+	aclresult = pg_proc_aclcheck(statOid, GetUserId(), ACL_EXECUTE);
+	if (aclresult != ACLCHECK_OK)
+		aclcheck_error(aclresult, OBJECT_FUNCTION,
+					   NameListToString(statName));
+
+	return statOid;
+}
+
 /*
  * Guts of operator deletion.
  */
@@ -424,6 +468,9 @@ AlterOperator(AlterOperatorStmt *stmt)
 	List	   *joinName = NIL; /* optional join sel. function */
 	bool		updateJoin = false;
 	Oid			joinOid;
+	List	   *statName = NIL; /* optional statistics estimation procedure */
+	bool		updateStat = false;
+	Oid			statOid;
 
 	/* Look up the operator */
 	oprId = LookupOperWithArgs(stmt->opername, false);
@@ -454,6 +501,11 @@ AlterOperator(AlterOperatorStmt *stmt)
 			joinName = param;
 			updateJoin = true;
 		}
+		else if (pg_strcasecmp(defel->defname, "stats") == 0)
+		{
+			statName = param;
+			updateStat = true;
+		}
 
 		/*
 		 * The rest of the options that CREATE accepts cannot be changed.
@@ -496,6 +548,10 @@ AlterOperator(AlterOperatorStmt *stmt)
 		joinOid = ValidateJoinEstimator(joinName);
 	else
 		joinOid = InvalidOid;
+	if (statName)
+		statOid = ValidateStatisticsDerivator(statName);
+	else
+		statOid = InvalidOid;
 
 	/* Perform additional checks, like OperatorCreate does */
 	if (!(OidIsValid(oprForm->oprleft) && OidIsValid(oprForm->oprright)))
@@ -536,6 +592,11 @@ AlterOperator(AlterOperatorStmt *stmt)
 		replaces[Anum_pg_operator_oprjoin - 1] = true;
 		values[Anum_pg_operator_oprjoin - 1] = joinOid;
 	}
+	if (updateStat)
+	{
+		replaces[Anum_pg_operator_oprstat - 1] = true;
+		values[Anum_pg_operator_oprstat - 1] = statOid;
+	}
 
 	tup = heap_modify_tuple(tup, RelationGetDescr(catalog),
 							values, nulls, replaces);
diff --git a/src/backend/utils/adt/selfuncs.c b/src/backend/utils/adt/selfuncs.c
index 1fbb0b28c3b..d1dd049f1ae 100644
--- a/src/backend/utils/adt/selfuncs.c
+++ b/src/backend/utils/adt/selfuncs.c
@@ -4917,6 +4917,30 @@ ReleaseDummy(HeapTuple tuple)
 	pfree(tuple);
 }
 
+/*
+ * examine_operator_expression
+ *		Try to derive optimizer statistics for the operator expression using
+ *		operator's oprstat function.
+ *
+ * There can be another OpExpr in one of the arguments, and it will be called
+ * recursively from the oprstat procedure through the following chain:
+ * get_restriction_variable() => examine_variable() =>
+ * examine_operator_expression().
+ */
+static void
+examine_operator_expression(PlannerInfo *root, OpExpr *opexpr, int varRelid,
+							VariableStatData *vardata)
+{
+	RegProcedure oprstat = get_oprstat(opexpr->opno);
+
+	if (OidIsValid(oprstat))
+		OidFunctionCall4(oprstat,
+						 PointerGetDatum(root),
+						 PointerGetDatum(opexpr),
+						 Int32GetDatum(varRelid),
+						 PointerGetDatum(vardata));
+}
+
 /*
  * examine_variable
  *		Try to look up statistical data about an expression.
@@ -5332,6 +5356,20 @@ examine_variable(PlannerInfo *root, Node *node, int varRelid,
 				pos++;
 			}
 		}
+
+		/*
+		 * If there's no index or extended statistics matching the expression,
+		 * try deriving the statistics from statistics on arguments of the
+		 * operator expression (OpExpr). We do this last because it may be quite
+		 * expensive, and it's unclear how accurate the statistics will be.
+		 *
+		 * More restrictions on the OpExpr (e.g. that one of the arguments
+		 * has to be a Const or something) can be put by the operator itself
+		 * in its oprstat function.
+		 */
+		if (!vardata->statsTuple && IsA(basenode, OpExpr))
+			examine_operator_expression(root, (OpExpr *) basenode, varRelid,
+										vardata);
 	}
 }
 
diff --git a/src/backend/utils/cache/lsyscache.c b/src/backend/utils/cache/lsyscache.c
index feef9998634..b5440b596cd 100644
--- a/src/backend/utils/cache/lsyscache.c
+++ b/src/backend/utils/cache/lsyscache.c
@@ -1567,6 +1567,30 @@ get_oprjoin(Oid opno)
 		return (RegProcedure) InvalidOid;
 }
 
+/*
+ * get_oprstat
+ *
+ *		Returns procedure id for estimating statistics for an operator.
+ */
+RegProcedure
+get_oprstat(Oid opno)
+{
+	HeapTuple	tp;
+
+	tp = SearchSysCache1(OPEROID, ObjectIdGetDatum(opno));
+	if (HeapTupleIsValid(tp))
+	{
+		Form_pg_operator optup = (Form_pg_operator) GETSTRUCT(tp);
+		RegProcedure result;
+
+		result = optup->oprstat;
+		ReleaseSysCache(tp);
+		return result;
+	}
+	else
+		return (RegProcedure) InvalidOid;
+}
+
 /*				---------- FUNCTION CACHE ----------					 */
 
 /*
diff --git a/src/include/catalog/pg_operator.h b/src/include/catalog/pg_operator.h
index 51263f550fe..ff1bb339f75 100644
--- a/src/include/catalog/pg_operator.h
+++ b/src/include/catalog/pg_operator.h
@@ -73,6 +73,9 @@ CATALOG(pg_operator,2617,OperatorRelationId)
 
 	/* OID of join estimator, or 0 */
 	regproc		oprjoin BKI_DEFAULT(-) BKI_LOOKUP_OPT(pg_proc);
+
+	/* OID of statistics estimator, or 0 */
+	regproc		oprstat BKI_DEFAULT(-) BKI_LOOKUP_OPT(pg_proc);
 } FormData_pg_operator;
 
 /* ----------------
@@ -95,6 +98,7 @@ extern ObjectAddress OperatorCreate(const char *operatorName,
 									List *negatorName,
 									Oid restrictionId,
 									Oid joinId,
+									Oid statisticsId,
 									bool canMerge,
 									bool canHash);
 
diff --git a/src/include/utils/lsyscache.h b/src/include/utils/lsyscache.h
index b8dd27d4a96..cc08fc50a50 100644
--- a/src/include/utils/lsyscache.h
+++ b/src/include/utils/lsyscache.h
@@ -118,6 +118,7 @@ extern Oid	get_commutator(Oid opno);
 extern Oid	get_negator(Oid opno);
 extern RegProcedure get_oprrest(Oid opno);
 extern RegProcedure get_oprjoin(Oid opno);
+extern RegProcedure get_oprstat(Oid opno);
 extern char *get_func_name(Oid funcid);
 extern Oid	get_func_namespace(Oid funcid);
 extern Oid	get_func_rettype(Oid funcid);
diff --git a/src/test/regress/expected/oidjoins.out b/src/test/regress/expected/oidjoins.out
index 215eb899be3..111ea99cdae 100644
--- a/src/test/regress/expected/oidjoins.out
+++ b/src/test/regress/expected/oidjoins.out
@@ -113,6 +113,7 @@ NOTICE:  checking pg_operator {oprnegate} => pg_operator {oid}
 NOTICE:  checking pg_operator {oprcode} => pg_proc {oid}
 NOTICE:  checking pg_operator {oprrest} => pg_proc {oid}
 NOTICE:  checking pg_operator {oprjoin} => pg_proc {oid}
+NOTICE:  checking pg_operator {oprstat} => pg_proc {oid}
 NOTICE:  checking pg_opfamily {opfmethod} => pg_am {oid}
 NOTICE:  checking pg_opfamily {opfnamespace} => pg_namespace {oid}
 NOTICE:  checking pg_opfamily {opfowner} => pg_authid {oid}
-- 
2.25.1

