From 2a3061a95988c39f4654accc06205099713af6cc Mon Sep 17 00:00:00 2001
From: Ankit Kumar Pandey <itsankitkp@gmail.com>
Date: Wed, 23 Nov 2022 00:38:01 +0530
Subject: [PATCH] Implement distinct in Window Aggregates.

---
 src/backend/executor/nodeWindowAgg.c | 229 +++++++++++++++++++++++----
 src/backend/optimizer/util/clauses.c |   2 +
 src/backend/parser/parse_agg.c       |  45 ++++++
 src/backend/parser/parse_func.c      |  19 +--
 src/include/nodes/execnodes.h        |   1 +
 src/include/nodes/primnodes.h        |   2 +
 6 files changed, 258 insertions(+), 40 deletions(-)

diff --git a/src/backend/executor/nodeWindowAgg.c b/src/backend/executor/nodeWindowAgg.c
index 4f4aeb2883..1d67ba2c39 100644
--- a/src/backend/executor/nodeWindowAgg.c
+++ b/src/backend/executor/nodeWindowAgg.c
@@ -154,6 +154,14 @@ typedef struct WindowStatePerAggData
 
 	int64		transValueCount;	/* number of currently-aggregated rows */
 
+	/* For DISTINCT in Aggregates */
+	Datum		lastdatum;		/* used for single-column DISTINCT */
+	FmgrInfo	equalfnOne; /* single-column comparisons*/
+
+	Oid			*eq_ops;  /* used for equality check in DISTINCT */
+	Oid			*sort_ops; /* used for sorting distinct columns */
+	bool 		sort_in; /* FLAG set true if data is stored in tuplesort */
+
 	/* Data local to eval_windowaggregates() */
 	bool		restart;		/* need to restart this agg in this cycle? */
 } WindowStatePerAggData;
@@ -163,7 +171,7 @@ static void initialize_windowaggregate(WindowAggState *winstate,
 									   WindowStatePerAgg peraggstate);
 static void advance_windowaggregate(WindowAggState *winstate,
 									WindowStatePerFunc perfuncstate,
-									WindowStatePerAgg peraggstate);
+									WindowStatePerAgg peraggstate, Datum value, bool isNull);
 static bool advance_windowaggregate_base(WindowAggState *winstate,
 										 WindowStatePerFunc perfuncstate,
 										 WindowStatePerAgg peraggstate);
@@ -173,6 +181,9 @@ static void finalize_windowaggregate(WindowAggState *winstate,
 									 Datum *result, bool *isnull);
 
 static void eval_windowaggregates(WindowAggState *winstate);
+static void process_ordered_windowaggregate_single(WindowAggState *winstate, 
+											 WindowStatePerFunc perfuncstate,
+								 			 WindowStatePerAgg peraggstate);
 static void eval_windowfunction(WindowAggState *winstate,
 								WindowStatePerFunc perfuncstate,
 								Datum *result, bool *isnull);
@@ -230,6 +241,7 @@ initialize_windowaggregate(WindowAggState *winstate,
 	peraggstate->transValueIsNull = peraggstate->initValueIsNull;
 	peraggstate->transValueCount = 0;
 	peraggstate->resultValue = (Datum) 0;
+	peraggstate->lastdatum = (Datum) 0;
 	peraggstate->resultValueIsNull = true;
 }
 
@@ -240,43 +252,21 @@ initialize_windowaggregate(WindowAggState *winstate,
 static void
 advance_windowaggregate(WindowAggState *winstate,
 						WindowStatePerFunc perfuncstate,
-						WindowStatePerAgg peraggstate)
+						WindowStatePerAgg peraggstate, Datum value, bool isNull)
 {
 	LOCAL_FCINFO(fcinfo, FUNC_MAX_ARGS);
-	WindowFuncExprState *wfuncstate = perfuncstate->wfuncstate;
 	int			numArguments = perfuncstate->numArguments;
 	Datum		newVal;
-	ListCell   *arg;
 	int			i;
 	MemoryContext oldContext;
 	ExprContext *econtext = winstate->tmpcontext;
-	ExprState  *filter = wfuncstate->aggfilter;
 
 	oldContext = MemoryContextSwitchTo(econtext->ecxt_per_tuple_memory);
 
-	/* Skip anything FILTERed out */
-	if (filter)
-	{
-		bool		isnull;
-		Datum		res = ExecEvalExpr(filter, econtext, &isnull);
-
-		if (isnull || !DatumGetBool(res))
-		{
-			MemoryContextSwitchTo(oldContext);
-			return;
-		}
-	}
-
 	/* We start from 1, since the 0th arg will be the transition value */
-	i = 1;
-	foreach(arg, wfuncstate->args)
-	{
-		ExprState  *argstate = (ExprState *) lfirst(arg);
 
-		fcinfo->args[i].value = ExecEvalExpr(argstate, econtext,
-											 &fcinfo->args[i].isnull);
-		i++;
-	}
+	fcinfo->args[1].value = value;
+	fcinfo->args[1].isnull = isNull;
 
 	if (peraggstate->transfn.fn_strict)
 	{
@@ -585,6 +575,10 @@ finalize_windowaggregate(WindowAggState *winstate,
 
 	oldContext = MemoryContextSwitchTo(winstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory);
 
+	/* Run transition function for distinct agg */
+	if (perfuncstate->wfunc->aggdistinct)
+		process_ordered_windowaggregate_single(winstate,  perfuncstate,  peraggstate);
+
 	/*
 	 * Apply the agg's finalfn if one is provided, else return transValue.
 	 */
@@ -666,6 +660,16 @@ eval_windowaggregates(WindowAggState *winstate)
 	TupleTableSlot *agg_row_slot;
 	TupleTableSlot *temp_slot;
 
+	ExprState 	*filter;
+	bool		isnull;
+	WindowFuncExprState *wfuncstate;
+	ListCell 	*arg;
+	Datum 		tuple;
+	ExprContext *aggecontext;
+	ListCell 	*lc;
+	Oid			inputTypes[FUNC_MAX_ARGS];
+	WindowStatePerFunc perfuncstate;
+
 	numaggs = winstate->numaggs;
 	if (numaggs == 0)
 		return;					/* nothing to do */
@@ -893,6 +897,23 @@ eval_windowaggregates(WindowAggState *winstate)
 		}
 	}
 
+	perfuncstate = &winstate->perfunc[wfuncno];
+	/* Initialize tuplesort for new partition */
+	if (perfuncstate->wfunc->aggdistinct)
+	{
+		i = 0;
+		foreach(lc, perfuncstate->wfunc->args)
+		{
+			inputTypes[i++] = exprType((Node *) lfirst(lc));
+		}
+		winstate->sortstates =
+					tuplesort_begin_datum(inputTypes[0],
+										peraggstate->sort_ops[0],
+										perfuncstate->wfunc->inputcollid,
+										true,
+										work_mem, NULL, TUPLESORT_NONE);
+	}
+
 	/*
 	 * Non-restarted aggregates now contain the rows between aggregatedbase
 	 * (i.e., frameheadpos) and aggregatedupto, while restarted aggregates
@@ -927,7 +948,8 @@ eval_windowaggregates(WindowAggState *winstate)
 		{
 			if (!window_gettupleslot(agg_winobj, winstate->aggregatedupto,
 									 agg_row_slot))
-				break;			/* must be end of partition */
+			break;			/* must be end of partition */
+				
 		}
 
 		/*
@@ -935,14 +957,16 @@ eval_windowaggregates(WindowAggState *winstate)
 		 * current row is not in frame but there might be more in the frame.
 		 */
 		ret = row_is_in_frame(winstate, winstate->aggregatedupto, agg_row_slot);
+
 		if (ret < 0)
 			break;
+
 		if (ret == 0)
 			goto next_tuple;
 
 		/* Set tuple context for evaluation of aggregate arguments */
 		winstate->tmpcontext->ecxt_outertuple = agg_row_slot;
-
+		
 		/* Accumulate row into the aggregates */
 		for (i = 0; i < numaggs; i++)
 		{
@@ -954,9 +978,51 @@ eval_windowaggregates(WindowAggState *winstate)
 				continue;
 
 			wfuncno = peraggstate->wfuncno;
-			advance_windowaggregate(winstate,
-									&winstate->perfunc[wfuncno],
-									peraggstate);
+			perfuncstate = &winstate->perfunc[wfuncno];
+
+			aggecontext = winstate->tmpcontext;
+
+			wfuncstate = perfuncstate->wfuncstate;
+			filter = wfuncstate->aggfilter;
+
+			oldContext = MemoryContextSwitchTo(aggecontext->ecxt_per_tuple_memory);
+
+			/* Skip anything FILTERed out for aggregates */
+			if (perfuncstate->plain_agg && wfuncstate->aggfilter)
+			{
+				Datum	res = ExecEvalExpr(filter, aggecontext, &isnull);
+
+				if (isnull || !DatumGetBool(res))
+				{
+					MemoryContextSwitchTo(oldContext);
+					continue;
+				}
+			}
+
+			/* Fetch tuple and either put them in tuplesort for removal
+			 * of duplicates and running partition later or run transition
+			 * function right away
+			 */
+			foreach(arg, wfuncstate->args)
+			{
+				
+				ExprState  *argstate = (ExprState *) lfirst(arg);
+				tuple = ExecEvalExpr(argstate, aggecontext, &isnull);
+				
+				/* Store in tuplestore */
+				if (perfuncstate->wfunc->aggdistinct)
+				{
+					tuplesort_putdatum(winstate->sortstates, tuple, isnull);
+					peraggstate->sort_in = true;
+				}
+				else
+				{
+					advance_windowaggregate(winstate, &winstate->perfunc[wfuncno], 
+											peraggstate, tuple, isnull);
+				}
+				
+			}
+			MemoryContextSwitchTo(oldContext);
 		}
 
 next_tuple:
@@ -1012,6 +1078,75 @@ next_tuple:
 	}
 }
 
+/*
+ * process_ordered_windowaggregate_single
+ * parallel to process_ordered_aggregate_single in nodeAgg.c
+ */
+static void
+process_ordered_windowaggregate_single(WindowAggState *winstate, WindowStatePerFunc perfuncstate,
+								 WindowStatePerAgg peraggstate)
+{
+	Datum	   newVal;
+	bool	   isNull;
+	MemoryContext workcontext = winstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory;
+	MemoryContext oldContext;
+	Datum		oldVal = (Datum) 0;
+	bool		oldIsNull = true;
+	bool		haveOldVal = false;
+
+	if (peraggstate->sort_in){								
+		tuplesort_performsort(winstate->sortstates);
+
+		while (tuplesort_getdatum(winstate->sortstates,
+							  true, false, &newVal, &isNull, NULL))
+		{
+			MemoryContextReset(workcontext);
+			oldContext = MemoryContextSwitchTo(workcontext);
+
+			/* 
+			 * Loop over all tuples in current partition
+			 * and remove duplicates
+			 */
+			if (haveOldVal && DatumGetBool(FunctionCall2Coll(&peraggstate->equalfnOne,
+												perfuncstate->winCollation,
+												oldVal, newVal)))
+			{
+				MemoryContextSwitchTo(oldContext);
+				continue;
+			} 
+			else
+			{
+			/* Run transition function over each unique tuple */
+			advance_windowaggregate(winstate, perfuncstate,
+											peraggstate, newVal, isNull);
+			}
+			MemoryContextSwitchTo(oldContext);
+
+			if (!peraggstate->resulttypeByVal)
+			{
+				if (!oldIsNull && false)
+					pfree(DatumGetPointer(oldVal));
+				if (!isNull)
+					oldVal = datumCopy(newVal, true,
+									   peraggstate->resulttypeLen);
+			}
+			else
+				oldVal = newVal;
+
+			oldIsNull = isNull;
+			haveOldVal = true;
+			oldVal = newVal;
+			}
+
+		}
+		// clear up tuplesort, next partition will
+		// use a new one
+		tuplesort_end(winstate->sortstates);
+		peraggstate->sort_in = false;
+
+}
+
+
 /*
  * eval_windowfunction
  *
@@ -2947,6 +3082,9 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
 	get_typlenbyval(aggtranstype,
 					&peraggstate->transtypeLen,
 					&peraggstate->transtypeByVal);
+	get_typlenbyval(wfunc->wintype,
+					&peraggstate->inputtypeLen,
+					&peraggstate->inputtypeByVal);
 
 	/*
 	 * initval is potentially null, so don't try to access it as a struct
@@ -3014,6 +3152,35 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
 	else
 		peraggstate->aggcontext = winstate->aggcontext;
 
+	/* Handle distinct operation in agg */
+	if (wfunc->aggdistinct)
+	{
+		int		numDistinctCols = list_length(wfunc->distinctargs);
+		peraggstate->eq_ops = palloc(numDistinctCols * sizeof(Oid));
+		peraggstate->sort_ops =  palloc(numDistinctCols * sizeof(Oid));
+		/* Use single tuplesort for all partitions by rinsing it again and again */
+		winstate->sortstates = (Tuplesortstate *)
+								palloc0(sizeof(Tuplesortstate *) * 1);
+
+		/* Initialize tuplesort operators namely sort operator to sort tuples 
+		 * before running equality op to remove/skip duplicates
+		 */
+
+		i=0;
+		foreach(lc, wfunc->distinctargs)
+		{
+			peraggstate->eq_ops[i] = ((SortGroupClause *) lfirst(lc))->eqop;
+			peraggstate->sort_ops[i] = ((SortGroupClause *) lfirst(lc))->sortop;
+			i++;
+		}
+		fmgr_info(get_opcode(peraggstate->eq_ops[0]), &peraggstate->equalfnOne);
+		winstate->sortstates = tuplesort_begin_datum(inputTypes[0],
+											peraggstate->sort_ops[0],
+											wfunc->inputcollid,
+											true,
+											work_mem, NULL, TUPLESORT_NONE);
+	}
+	
 	ReleaseSysCache(aggTuple);
 
 	return peraggstate;
diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c
index bffc8112aa..1e2aa897df 100644
--- a/src/backend/optimizer/util/clauses.c
+++ b/src/backend/optimizer/util/clauses.c
@@ -2443,6 +2443,8 @@ eval_const_expressions_mutator(Node *node,
 				newexpr->winref = expr->winref;
 				newexpr->winstar = expr->winstar;
 				newexpr->winagg = expr->winagg;
+				newexpr->aggdistinct = expr->aggdistinct;
+				newexpr->distinctargs = expr->distinctargs;
 				newexpr->location = expr->location;
 
 				return (Node *) newexpr;
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index 8eec2088aa..03f21b37bf 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -1047,6 +1047,51 @@ transformWindowFuncCall(ParseState *pstate, WindowFunc *wfunc,
 		}
 	}
 
+	if (wfunc->aggdistinct){
+		List	   *argtypes = NIL;
+		List	   *tlist = NIL;
+		List	   *torder = NIL;
+		List	   *tdistinct = NIL;
+		AttrNumber	attno = 1;
+		ListCell   *lc;
+
+		foreach(lc, wfunc->args)
+		{
+			Expr	   *arg = (Expr *) lfirst(lc);
+			TargetEntry *tle;
+
+			/* We don't bother to assign column names to the entries */
+			tle = makeTargetEntry(arg, attno++, NULL, false);
+			tlist = lappend(tlist, tle);
+		}
+		torder = transformSortClause(pstate,
+									 NIL,
+									 &tlist,
+									 EXPR_KIND_ORDER_BY,
+									 true /* force SQL99 rules */ );
+
+		tdistinct = transformDistinctClause(pstate, &tlist, torder, true);
+
+		foreach(lc, tdistinct)
+		{
+			SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
+
+			if (!OidIsValid(sortcl->sortop))
+			{
+				Node	   *expr = get_sortgroupclause_expr(sortcl, tlist);
+
+				ereport(ERROR,
+						(errcode(ERRCODE_UNDEFINED_FUNCTION),
+							errmsg("could not identify an ordering operator for type %s",
+								format_type_be(exprType(expr))),
+							errdetail("Aggregates with DISTINCT must be able to sort their inputs."),
+							parser_errposition(pstate, exprLocation(expr))));
+			}
+		}
+		wfunc->distinctargs = tdistinct;
+	}
+	
+
 	pstate->p_hasWindowFuncs = true;
 }
 
diff --git a/src/backend/parser/parse_func.c b/src/backend/parser/parse_func.c
index 827989f379..f536a1411a 100644
--- a/src/backend/parser/parse_func.c
+++ b/src/backend/parser/parse_func.c
@@ -835,15 +835,7 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
 		wfunc->winagg = (fdresult == FUNCDETAIL_AGGREGATE);
 		wfunc->aggfilter = agg_filter;
 		wfunc->location = location;
-
-		/*
-		 * agg_star is allowed for aggregate functions but distinct isn't
-		 */
-		if (agg_distinct)
-			ereport(ERROR,
-					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
-					 errmsg("DISTINCT is not implemented for window functions"),
-					 parser_errposition(pstate, location)));
+		wfunc->aggdistinct = agg_distinct;
 
 		/*
 		 * Reject attempt to call a parameterless aggregate without (*)
@@ -856,6 +848,15 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
 							NameListToString(funcname)),
 					 parser_errposition(pstate, location)));
 
+		/*
+		 * Distinct is not implemented for aggregates with filter
+		 */
+		if (agg_distinct && over->orderClause)
+			ereport(ERROR,
+					(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+					 errmsg("DISTINCT is not implemented for aggregate functions with ORDER BY"),
+					 parser_errposition(pstate, location)));
+
 		/*
 		 * ordered aggs not allowed in windows yet
 		 */
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index 9a64a830a2..63188a9565 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -2523,6 +2523,7 @@ typedef struct WindowAggState
 									 * date for current row */
 	bool		grouptail_valid;	/* true if grouptailpos is known up to
 									 * date for current row */
+	Tuplesortstate *sortstates;
 
 	TupleTableSlot *first_part_slot;	/* first tuple of current or next
 										 * partition */
diff --git a/src/include/nodes/primnodes.h b/src/include/nodes/primnodes.h
index 74f228d959..d7f84a40fd 100644
--- a/src/include/nodes/primnodes.h
+++ b/src/include/nodes/primnodes.h
@@ -495,6 +495,8 @@ typedef struct WindowFunc
 	Index		winref;			/* index of associated WindowClause */
 	bool		winstar;		/* true if argument list was really '*' */
 	bool		winagg;			/* is function a simple aggregate? */
+	bool		aggdistinct;    /* do we need distinct values for aggregation? */
+	List		*distinctargs;
 	int			location;		/* token location, or -1 if unknown */
 } WindowFunc;
 
-- 
2.37.2

