From 3735861eb120587105676c4df5818d16bac2ad2e Mon Sep 17 00:00:00 2001
From: David Geier <david.geier@servicenow.com>
Date: Mon, 27 Jun 2022 12:28:29 +0200
Subject: [PATCH] Lazily JIT

---
 src/backend/commands/explain.c          |  1 +
 src/backend/jit/jit.c                   |  1 +
 src/backend/jit/llvm/llvmjit.c          | 16 ++---
 src/backend/jit/llvm/llvmjit_expr.c     | 95 +++++++++++++++----------
 src/backend/jit/llvm/llvmjit_inline.cpp | 14 ++--
 src/include/jit/jit.h                   |  3 +
 6 files changed, 74 insertions(+), 56 deletions(-)

diff --git a/src/backend/commands/explain.c b/src/backend/commands/explain.c
index 81a227d8b8..2cb201c093 100644
--- a/src/backend/commands/explain.c
+++ b/src/backend/commands/explain.c
@@ -890,6 +890,7 @@ ExplainPrintJIT(ExplainState *es, int jit_flags, JitInstrumentation *ji)
 		appendStringInfoString(es->str, "JIT:\n");
 		es->indent++;
 
+		ExplainPropertyInteger("Modules", NULL, ji->created_modules, es);
 		ExplainPropertyInteger("Functions", NULL, ji->created_functions, es);
 
 		ExplainIndentText(es);
diff --git a/src/backend/jit/jit.c b/src/backend/jit/jit.c
index 2da300e000..b574357125 100644
--- a/src/backend/jit/jit.c
+++ b/src/backend/jit/jit.c
@@ -183,6 +183,7 @@ jit_compile_expr(struct ExprState *state)
 void
 InstrJitAgg(JitInstrumentation *dst, JitInstrumentation *add)
 {
+	dst->created_modules += add->created_modules;
 	dst->created_functions += add->created_functions;
 	INSTR_TIME_ADD(dst->generation_counter, add->generation_counter);
 	INSTR_TIME_ADD(dst->inlining_counter, add->inlining_counter);
diff --git a/src/backend/jit/llvm/llvmjit.c b/src/backend/jit/llvm/llvmjit.c
index fb29449573..3e3340992d 100644
--- a/src/backend/jit/llvm/llvmjit.c
+++ b/src/backend/jit/llvm/llvmjit.c
@@ -236,6 +236,7 @@ llvm_mutable_module(LLVMJitContext *context)
 	 */
 	if (!context->module)
 	{
+		context->base.instr.created_modules++;
 		context->compiled = false;
 		context->module_generation = llvm_generation++;
 		context->module = LLVMModuleCreateWithName("pg");
@@ -578,12 +579,7 @@ llvm_optimize_module(LLVMJitContext *context, LLVMModuleRef module)
 	LLVMPassManagerBuilderSetOptLevel(llvm_pmb, compile_optlevel);
 	llvm_fpm = LLVMCreateFunctionPassManagerForModule(module);
 
-	if (context->base.flags & PGJIT_OPT3)
-	{
-		/* TODO: Unscientifically determined threshold */
-		LLVMPassManagerBuilderUseInlinerWithThreshold(llvm_pmb, 512);
-	}
-	else
+	if (!(context->base.flags & PGJIT_OPT3))
 	{
 		/* we rely on mem2reg heavily, so emit even in the O0 case */
 		LLVMAddPromoteMemoryToRegisterPass(llvm_fpm);
@@ -611,11 +607,9 @@ llvm_optimize_module(LLVMJitContext *context, LLVMModuleRef module)
 	LLVMPassManagerBuilderPopulateModulePassManager(llvm_pmb,
 													llvm_mpm);
 	/* always use always-inliner pass */
-	if (!(context->base.flags & PGJIT_OPT3))
-		LLVMAddAlwaysInlinerPass(llvm_mpm);
-	/* if doing inlining, but no expensive optimization, add inlining pass */
-	if (context->base.flags & PGJIT_INLINE
-		&& !(context->base.flags & PGJIT_OPT3))
+	LLVMAddAlwaysInlinerPass(llvm_mpm);
+	/* if doing inlining, add inlining pass */
+	if (context->base.flags & PGJIT_INLINE)
 		LLVMAddFunctionInliningPass(llvm_mpm);
 	LLVMRunPassManager(llvm_mpm, context->module);
 	LLVMDisposePassManager(llvm_mpm);
diff --git a/src/backend/jit/llvm/llvmjit_expr.c b/src/backend/jit/llvm/llvmjit_expr.c
index 8a4075bdaf..df27c0dd4b 100644
--- a/src/backend/jit/llvm/llvmjit_expr.c
+++ b/src/backend/jit/llvm/llvmjit_expr.c
@@ -52,6 +52,7 @@ typedef struct CompiledExprState
 } CompiledExprState;
 
 
+static Datum ExecCompileExpr(ExprState *state, ExprContext *econtext, bool *isNull);
 static Datum ExecRunCompiledExpr(ExprState *state, ExprContext *econtext, bool *isNull);
 
 static LLVMValueRef BuildV1Call(LLVMJitContext *context, LLVMBuilderRef b,
@@ -72,16 +73,63 @@ static LLVMValueRef create_LifetimeEnd(LLVMModuleRef mod);
 
 
 /*
- * JIT compile expression.
+ * Prepare the JIT compile expression.
  */
 bool
 llvm_compile_expr(ExprState *state)
 {
 	PlanState  *parent = state->parent;
-	char	   *funcname;
-
 	LLVMJitContext *context = NULL;
 
+
+	/*
+	 * Right now we don't support compiling expressions without a parent, as
+	 * we need access to the EState.
+	 */
+	Assert(parent);
+
+	llvm_enter_fatal_on_oom();
+
+	/* get or create JIT context */
+	if (parent->state->es_jit)
+		context = (LLVMJitContext *) parent->state->es_jit;
+	else
+	{
+		context = llvm_create_context(parent->state->es_jit_flags);
+		parent->state->es_jit = &context->base;
+	}
+
+	/*
+	 * Don't immediately emit nor actually generate the function.
+	 * Instead do so the first time the expression is actually evaluated.
+	 * This helps with not compiling functions that will never be evaluated,
+	 * as can be the case if e.g. a parallel append node is distributing
+	 * workers between its child nodes.
+	 */
+	{
+
+		CompiledExprState *cstate = palloc0(sizeof(CompiledExprState));
+
+		cstate->context = context;
+
+		state->evalfunc = ExecCompileExpr;
+		state->evalfunc_private = cstate;
+	}
+
+	llvm_leave_fatal_on_oom();
+
+	return true;
+}
+
+/*
+ * JIT compile expression.
+ */
+static Datum
+ExecCompileExpr(ExprState *state, ExprContext *econtext, bool *isNull)
+{
+	CompiledExprState *cstate = state->evalfunc_private;
+	LLVMJitContext *context = cstate->context;
+
 	LLVMBuilderRef b;
 	LLVMModuleRef mod;
 	LLVMValueRef eval_fn;
@@ -125,31 +173,16 @@ llvm_compile_expr(ExprState *state)
 
 	llvm_enter_fatal_on_oom();
 
-	/*
-	 * Right now we don't support compiling expressions without a parent, as
-	 * we need access to the EState.
-	 */
-	Assert(parent);
-
-	/* get or create JIT context */
-	if (parent->state->es_jit)
-		context = (LLVMJitContext *) parent->state->es_jit;
-	else
-	{
-		context = llvm_create_context(parent->state->es_jit_flags);
-		parent->state->es_jit = &context->base;
-	}
-
 	INSTR_TIME_SET_CURRENT(starttime);
 
 	mod = llvm_mutable_module(context);
 
 	b = LLVMCreateBuilder();
 
-	funcname = llvm_expand_funcname(context, "evalexpr");
+	cstate->funcname = llvm_expand_funcname(context, "evalexpr");
 
 	/* create function */
-	eval_fn = LLVMAddFunction(mod, funcname,
+	eval_fn = LLVMAddFunction(mod, cstate->funcname,
 							  llvm_pg_var_func_type("TypeExprStateEvalFunc"));
 	LLVMSetLinkage(eval_fn, LLVMExternalLinkage);
 	LLVMSetVisibility(eval_fn, LLVMDefaultVisibility);
@@ -2356,30 +2389,16 @@ llvm_compile_expr(ExprState *state)
 
 	LLVMDisposeBuilder(b);
 
-	/*
-	 * Don't immediately emit function, instead do so the first time the
-	 * expression is actually evaluated. That allows to emit a lot of
-	 * functions together, avoiding a lot of repeated llvm and memory
-	 * remapping overhead.
-	 */
-	{
-
-		CompiledExprState *cstate = palloc0(sizeof(CompiledExprState));
-
-		cstate->context = context;
-		cstate->funcname = funcname;
-
-		state->evalfunc = ExecRunCompiledExpr;
-		state->evalfunc_private = cstate;
-	}
-
 	llvm_leave_fatal_on_oom();
 
 	INSTR_TIME_SET_CURRENT(endtime);
 	INSTR_TIME_ACCUM_DIFF(context->base.instr.generation_counter,
 						  endtime, starttime);
 
-	return true;
+	/* remove indirection via this function for future calls */
+	state->evalfunc = ExecRunCompiledExpr;
+
+	return ExecRunCompiledExpr(state, econtext, isNull);
 }
 
 /*
diff --git a/src/backend/jit/llvm/llvmjit_inline.cpp b/src/backend/jit/llvm/llvmjit_inline.cpp
index 6f03595db5..43c3f7ddd3 100644
--- a/src/backend/jit/llvm/llvmjit_inline.cpp
+++ b/src/backend/jit/llvm/llvmjit_inline.cpp
@@ -62,6 +62,7 @@ extern "C"
 #include <llvm/IR/ModuleSummaryIndex.h>
 #include <llvm/Linker/IRMover.h>
 #include <llvm/Support/ManagedStatic.h>
+#include <llvm/Transforms/Utils/Cloning.h>
 
 
 /*
@@ -287,7 +288,6 @@ llvm_build_inline_plan(llvm::Module *mod)
 				elog(FATAL, "failed to materialize metadata");
 
 			Assert(!funcDef->isDeclaration());
-			Assert(funcDef->hasExternalLinkage());
 
 			llvm::StringSet<> importVars;
 			llvm::SmallPtrSet<const llvm::Function *, 8> visitedFunctions;
@@ -377,13 +377,13 @@ llvm_execute_inline_plan(llvm::Module *mod, ImportMapTy *globalsToInline)
 		const llvm::StringSet<>& modGlobalsToInline = toInline.second;
 		llvm::SetVector<llvm::GlobalValue *> GlobalsToImport;
 
-		Assert(module_cache->count(modPath));
-		std::unique_ptr<llvm::Module> importMod(std::move((*module_cache)[modPath]));
-		module_cache->erase(modPath);
-
 		if (modGlobalsToInline.empty())
 			continue;
 
+		auto iter = module_cache->find(modPath);
+		Assert(iter != module_cache->end());
+		auto &importMod = iter->second;
+
 		for (auto &glob: modGlobalsToInline)
 		{
 			llvm::StringRef SymbolName = glob.first();
@@ -451,7 +451,7 @@ llvm_execute_inline_plan(llvm::Module *mod, ImportMapTy *globalsToInline)
 #else
 #define IRMOVE_PARAMS
 #endif
-		if (Mover.move(std::move(importMod), GlobalsToImport.getArrayRef(),
+		if (Mover.move(llvm::CloneModule(*importMod), GlobalsToImport.getArrayRef(),
 					   [](llvm::GlobalValue &, llvm::IRMover::ValueAdder) {}
 					   IRMOVE_PARAMS))
 			elog(FATAL, "function import failed with linker error");
@@ -490,7 +490,7 @@ load_module(llvm::StringRef Identifier)
 	if (LLVMCreateMemoryBufferWithContentsOfFile(path, &buf, &msg))
 		elog(FATAL, "failed to open bitcode file \"%s\": %s",
 			 path, msg);
-	if (LLVMGetBitcodeModuleInContext2(LLVMGetGlobalContext(), buf, &mod))
+	if (LLVMParseBitcodeInContext2(LLVMGetGlobalContext(), buf, &mod))
 		elog(FATAL, "failed to parse bitcode in file \"%s\"", path);
 
 	/*
diff --git a/src/include/jit/jit.h b/src/include/jit/jit.h
index b634df30b9..7a080074c0 100644
--- a/src/include/jit/jit.h
+++ b/src/include/jit/jit.h
@@ -26,6 +26,9 @@
 
 typedef struct JitInstrumentation
 {
+	/* number of emitted modules */
+	size_t		created_modules;
+
 	/* number of emitted functions */
 	size_t		created_functions;
 
-- 
2.32.0

