From 98f505ba81e8ef317d2a9b764348b523346d7f24 Mon Sep 17 00:00:00 2001
From: steve-chavez <stevechavezast@gmail.com>
Date: Mon, 8 Jan 2024 18:57:26 -0500
Subject: [PATCH] psql: add \create_function command

Currently a function definition must include its body inline.
Because of this, when storing function definitions in files,
linters and syntax highlighters for non-SQL languages
(python, perl, tcl, etc) won't work.

This patch adds a psql command to create a function and obtain its body
from another file. It is used as:

\create_function from ./data/max.py max(int,int) returns int LANGUAGE plpython3u

Its design is similar to the `\copy` command, which is a frontend
version of the COPY statement.

Includes tests with plpython3u, pltcl, plperl and tab completion.
---
 src/bin/psql/Makefile                         |   1 +
 src/bin/psql/command.c                        |  26 ++++
 src/bin/psql/create_function.c                | 128 ++++++++++++++++++
 src/bin/psql/create_function.h                |  15 ++
 src/bin/psql/meson.build                      |   1 +
 src/bin/psql/tab-complete.c                   |  17 ++-
 src/pl/plperl/data/max.pl                     |   2 +
 src/pl/plperl/expected/plperl.out             |   7 +
 src/pl/plperl/sql/plperl.sql                  |   4 +
 src/pl/plpython/data/max.py                   |   3 +
 src/pl/plpython/expected/plpython_test.out    |   7 +
 src/pl/plpython/sql/plpython_test.sql         |   4 +
 src/pl/tcl/data/max.tcl                       |   2 +
 src/pl/tcl/expected/pltcl_setup.out           |   7 +
 src/pl/tcl/sql/pltcl_setup.sql                |   4 +
 src/test/regress/data/max.sql                 |   1 +
 .../regress/expected/create_function_sql.out  |  10 +-
 src/test/regress/sql/create_function_sql.sql  |   4 +
 18 files changed, 241 insertions(+), 2 deletions(-)
 create mode 100644 src/bin/psql/create_function.c
 create mode 100644 src/bin/psql/create_function.h
 create mode 100644 src/pl/plperl/data/max.pl
 create mode 100644 src/pl/plpython/data/max.py
 create mode 100644 src/pl/tcl/data/max.tcl
 create mode 100644 src/test/regress/data/max.sql

diff --git a/src/bin/psql/Makefile b/src/bin/psql/Makefile
index 374c4c3ab8..285291b8ab 100644
--- a/src/bin/psql/Makefile
+++ b/src/bin/psql/Makefile
@@ -29,6 +29,7 @@ OBJS = \
 	command.o \
 	common.o \
 	copy.o \
+	create_function.o \
 	crosstabview.o \
 	describe.o \
 	help.o \
diff --git a/src/bin/psql/command.c b/src/bin/psql/command.c
index 5c906e4806..d2c5799ed0 100644
--- a/src/bin/psql/command.c
+++ b/src/bin/psql/command.c
@@ -30,6 +30,7 @@
 #include "common/logging.h"
 #include "common/string.h"
 #include "copy.h"
+#include "create_function.h"
 #include "crosstabview.h"
 #include "describe.h"
 #include "fe_utils/cancel.h"
@@ -71,6 +72,7 @@ static backslashResult exec_command_cd(PsqlScanState scan_state, bool active_bra
 static backslashResult exec_command_conninfo(PsqlScanState scan_state, bool active_branch);
 static backslashResult exec_command_copy(PsqlScanState scan_state, bool active_branch);
 static backslashResult exec_command_copyright(PsqlScanState scan_state, bool active_branch);
+static backslashResult exec_command_create_function(PsqlScanState scan_state, bool active_branch);
 static backslashResult exec_command_crosstabview(PsqlScanState scan_state, bool active_branch);
 static backslashResult exec_command_d(PsqlScanState scan_state, bool active_branch,
 									  const char *cmd);
@@ -323,6 +325,8 @@ exec_command(const char *cmd,
 		status = exec_command_copy(scan_state, active_branch);
 	else if (strcmp(cmd, "copyright") == 0)
 		status = exec_command_copyright(scan_state, active_branch);
+	else if (strcmp(cmd, "create_function") == 0)
+		status = exec_command_create_function(scan_state, active_branch);
 	else if (strcmp(cmd, "crosstabview") == 0)
 		status = exec_command_crosstabview(scan_state, active_branch);
 	else if (cmd[0] == 'd')
@@ -720,6 +724,28 @@ exec_command_copyright(PsqlScanState scan_state, bool active_branch)
 	return PSQL_CMD_SKIP_LINE;
 }
 
+/*
+ * \create_function -- create a function obtaining its body from a file
+ */
+static backslashResult
+exec_command_create_function(PsqlScanState scan_state, bool active_branch)
+{
+	bool		success = true;
+
+	if (active_branch)
+	{
+		char	   *opt = psql_scan_slash_option(scan_state,
+												 OT_WHOLE_LINE, NULL, false);
+
+		success = do_create_function(opt);
+		free(opt);
+	}
+	else
+		ignore_slash_whole_line(scan_state);
+
+	return success ? PSQL_CMD_SKIP_LINE : PSQL_CMD_ERROR;
+}
+
 /*
  * \crosstabview -- execute a query and display result in crosstab
  */
diff --git a/src/bin/psql/create_function.c b/src/bin/psql/create_function.c
new file mode 100644
index 0000000000..4f4e8c42ed
--- /dev/null
+++ b/src/bin/psql/create_function.c
@@ -0,0 +1,128 @@
+/*
+ * psql - the PostgreSQL interactive terminal
+ *
+ * Copyright (c) 2000-2024, PostgreSQL Global Development Group
+ *
+ * src/bin/psql/copy.c
+ */
+#include "postgres_fe.h"
+
+#include "common.h"
+#include "common/logging.h"
+#include "create_function.h"
+#include "libpq-fe.h"
+#include "pqexpbuffer.h"
+#include "settings.h"
+#include "stringutils.h"
+
+struct create_function_options
+{
+	char	   *from_file;
+	char	   *after_from;
+};
+
+static void
+free_create_function_options(struct create_function_options *ptr)
+{
+	if (!ptr)
+		return;
+	free(ptr->from_file);
+	free(ptr->after_from);
+	free(ptr);
+}
+
+static struct create_function_options *
+parse_slash_create_function(const char *args)
+{
+	struct create_function_options *result;
+	char	   *token;
+	const char *whitespace = " \t\n\r";
+
+	if (!args)
+	{
+		pg_log_error("\\create_function: arguments required");
+		return NULL;
+	}
+
+	result = pg_malloc0(sizeof(struct create_function_options));
+
+	token = strtokx(args, whitespace, NULL, NULL,
+					0, false, false, pset.encoding);
+
+	if (pg_strcasecmp(token, "from") != 0)
+		goto error;
+
+	token = strtokx(NULL, whitespace, NULL, NULL,
+					0, false, false, pset.encoding);
+
+	if(!token)
+		goto error;
+
+	result->from_file = pg_strdup(token);
+
+	token = strtokx(NULL, "", ";", NULL,
+						0, false, false, pset.encoding);
+
+	if (!token)
+		goto error;
+
+	result->after_from = pg_strdup(token);
+
+	return result;
+
+error:
+	if (token)
+		pg_log_error("\\create_function: parse error at \"%s\"", token);
+	else
+		pg_log_error("\\create_function: parse error at end of line");
+
+	free_create_function_options(result);
+
+	return NULL;
+}
+
+bool
+do_create_function(const char *args)
+{
+	PQExpBufferData query;
+	struct create_function_options *options = parse_slash_create_function(args);
+	FILE	   *func_file;
+	bool		success;
+
+	if (!options)
+		return false;
+	else{
+		char buf[255];
+
+		initPQExpBuffer(&query);
+		printfPQExpBuffer(&query, "CREATE OR REPLACE FUNCTION ");
+		appendPQExpBufferStr(&query, options->after_from);
+		appendPQExpBufferStr(&query, " AS $___$");
+
+		expand_tilde(&(options->from_file));
+		canonicalize_path(options->from_file);
+		func_file = fopen(options->from_file, PG_BINARY_R);
+
+		if (!func_file)
+		{
+			pg_log_error("%s: %m", options->from_file);
+			free_create_function_options(options);
+			return false;
+		}
+
+		while(fgets(buf, sizeof(buf), func_file) != NULL)
+			appendPQExpBufferStr(&query, buf);
+
+		fclose(func_file);
+
+		appendPQExpBufferStr(&query, " $___$;");
+
+		success = SendQuery(query.data);
+
+		termPQExpBuffer(&query);
+
+		free_create_function_options(options);
+
+		return success;
+	}
+}
diff --git a/src/bin/psql/create_function.h b/src/bin/psql/create_function.h
new file mode 100644
index 0000000000..c7aaec0bd7
--- /dev/null
+++ b/src/bin/psql/create_function.h
@@ -0,0 +1,15 @@
+/*
+ * psql - the PostgreSQL interactive terminal
+ *
+ * Copyright (c) 2000-2024, PostgreSQL Global Development Group
+ *
+ * src/bin/psql/create_function.h
+ */
+#ifndef CREATE_FUNCTION_H
+#define CREATE_FUNCTION_H
+
+#include "libpq-fe.h"
+
+extern bool do_create_function(const char *args);
+
+#endif
diff --git a/src/bin/psql/meson.build b/src/bin/psql/meson.build
index f3a6392138..c3ef115ed1 100644
--- a/src/bin/psql/meson.build
+++ b/src/bin/psql/meson.build
@@ -4,6 +4,7 @@ psql_sources = files(
   'command.c',
   'common.c',
   'copy.c',
+  'create_function.c',
   'crosstabview.c',
   'describe.c',
   'help.c',
diff --git a/src/bin/psql/tab-complete.c b/src/bin/psql/tab-complete.c
index ada711d02f..d3b79684db 100644
--- a/src/bin/psql/tab-complete.c
+++ b/src/bin/psql/tab-complete.c
@@ -1714,7 +1714,7 @@ psql_completion(const char *text, int start, int end)
 		"\\a",
 		"\\bind",
 		"\\connect", "\\conninfo", "\\C", "\\cd", "\\copy",
-		"\\copyright", "\\crosstabview",
+		"\\copyright", "\\create_function", "\\crosstabview",
 		"\\d", "\\da", "\\dA", "\\dAc", "\\dAf", "\\dAo", "\\dAp",
 		"\\db", "\\dc", "\\dconfig", "\\dC", "\\dd", "\\ddp", "\\dD",
 		"\\des", "\\det", "\\deu", "\\dew", "\\dE", "\\df",
@@ -2913,6 +2913,21 @@ psql_completion(const char *text, int start, int end)
 	else if (Matches("COPY|\\copy", MatchAny, "FROM", MatchAny, "WITH", MatchAny))
 		COMPLETE_WITH("WHERE");
 
+/* \create_function */
+
+	else if (Matches("\\create_function"))
+		COMPLETE_WITH("FROM");
+	else if (Matches("\\create_function", "FROM"))
+	{
+		completion_charp = "";
+		completion_force_quote = false;
+		matches = rl_completion_matches(text, complete_from_files);
+	}
+	else if (Matches("\\create_function", "FROM", MatchAny))
+	{
+		COMPLETE_WITH_VERSIONED_SCHEMA_QUERY(Query_for_list_of_functions);
+	}
+
 	/* CREATE ACCESS METHOD */
 	/* Complete "CREATE ACCESS METHOD <name>" */
 	else if (Matches("CREATE", "ACCESS", "METHOD", MatchAny))
diff --git a/src/pl/plperl/data/max.pl b/src/pl/plperl/data/max.pl
new file mode 100644
index 0000000000..351b05e8bd
--- /dev/null
+++ b/src/pl/plperl/data/max.pl
@@ -0,0 +1,2 @@
+if ($_[0] > $_[1]) { return $_[0]; }
+return $_[1];
diff --git a/src/pl/plperl/expected/plperl.out b/src/pl/plperl/expected/plperl.out
index e3d7c8896a..9ec9a798ca 100644
--- a/src/pl/plperl/expected/plperl.out
+++ b/src/pl/plperl/expected/plperl.out
@@ -792,3 +792,10 @@ SELECT self_modify(42);
          126
 (1 row)
 
+\create_function from ./data/max.pl max(int,int) returns int LANGUAGE plperl
+select max(11, 22);
+ max 
+-----
+  22
+(1 row)
+
diff --git a/src/pl/plperl/sql/plperl.sql b/src/pl/plperl/sql/plperl.sql
index bb0b8ce4cb..9833934bcb 100644
--- a/src/pl/plperl/sql/plperl.sql
+++ b/src/pl/plperl/sql/plperl.sql
@@ -521,3 +521,7 @@ $$ LANGUAGE plperl;
 
 SELECT self_modify(42);
 SELECT self_modify(42);
+
+\create_function from ./data/max.pl max(int,int) returns int LANGUAGE plperl
+
+select max(11, 22);
diff --git a/src/pl/plpython/data/max.py b/src/pl/plpython/data/max.py
new file mode 100644
index 0000000000..108bfd10a6
--- /dev/null
+++ b/src/pl/plpython/data/max.py
@@ -0,0 +1,3 @@
+if args[0] > args[1]:
+    return args[0]
+return args[1]
diff --git a/src/pl/plpython/expected/plpython_test.out b/src/pl/plpython/expected/plpython_test.out
index 13c14119c0..28b698b09f 100644
--- a/src/pl/plpython/expected/plpython_test.out
+++ b/src/pl/plpython/expected/plpython_test.out
@@ -91,3 +91,10 @@ CONTEXT:  Traceback (most recent call last):
   PL/Python function "elog_test_basic", line 10, in <module>
     plpy.error('error')
 PL/Python function "elog_test_basic"
+\create_function from ./data/max.py max(int,int) returns int LANGUAGE plpython3u
+select max(11, 22);
+ max 
+-----
+  22
+(1 row)
+
diff --git a/src/pl/plpython/sql/plpython_test.sql b/src/pl/plpython/sql/plpython_test.sql
index aa22a27415..d860abee0a 100644
--- a/src/pl/plpython/sql/plpython_test.sql
+++ b/src/pl/plpython/sql/plpython_test.sql
@@ -50,3 +50,7 @@ plpy.error('error')
 $$ LANGUAGE plpython3u;
 
 SELECT elog_test_basic();
+
+\create_function from ./data/max.py max(int,int) returns int LANGUAGE plpython3u
+
+select max(11, 22);
diff --git a/src/pl/tcl/data/max.tcl b/src/pl/tcl/data/max.tcl
new file mode 100644
index 0000000000..292f77de56
--- /dev/null
+++ b/src/pl/tcl/data/max.tcl
@@ -0,0 +1,2 @@
+if {$1 > $2} {return $1}
+return $2
diff --git a/src/pl/tcl/expected/pltcl_setup.out b/src/pl/tcl/expected/pltcl_setup.out
index a8fdcf3125..cb45f1c78a 100644
--- a/src/pl/tcl/expected/pltcl_setup.out
+++ b/src/pl/tcl/expected/pltcl_setup.out
@@ -261,3 +261,10 @@ if {$1 == "t"} {
 }
 elog NOTICE "end of function"
 $function$;
+\create_function from ./data/max.tcl max(int,int) returns int LANGUAGE pltcl
+select max(11, 22);
+ max 
+-----
+  22
+(1 row)
+
diff --git a/src/pl/tcl/sql/pltcl_setup.sql b/src/pl/tcl/sql/pltcl_setup.sql
index b9892ea4f7..f80e2fee30 100644
--- a/src/pl/tcl/sql/pltcl_setup.sql
+++ b/src/pl/tcl/sql/pltcl_setup.sql
@@ -276,3 +276,7 @@ if {$1 == "t"} {
 }
 elog NOTICE "end of function"
 $function$;
+
+\create_function from ./data/max.tcl max(int,int) returns int LANGUAGE pltcl
+
+select max(11, 22);
diff --git a/src/test/regress/data/max.sql b/src/test/regress/data/max.sql
new file mode 100644
index 0000000000..cdd823900f
--- /dev/null
+++ b/src/test/regress/data/max.sql
@@ -0,0 +1 @@
+select max(x) from unnest(ARRAY[$1, $2]) x;
diff --git a/src/test/regress/expected/create_function_sql.out b/src/test/regress/expected/create_function_sql.out
index 50aca5940f..3a9ef050b1 100644
--- a/src/test/regress/expected/create_function_sql.out
+++ b/src/test/regress/expected/create_function_sql.out
@@ -706,9 +706,16 @@ LINE 2:     AS 'SELECT $2;';
 CREATE FUNCTION test1 (int) RETURNS int LANGUAGE SQL
     AS 'a', 'b';
 ERROR:  only one AS item needed for language "sql"
+\create_function from ./data/max.sql max(int,int) returns int LANGUAGE SQL
+select max(11, 22);
+ max 
+-----
+  22
+(1 row)
+
 -- Cleanup
 DROP SCHEMA temp_func_test CASCADE;
-NOTICE:  drop cascades to 30 other objects
+NOTICE:  drop cascades to 31 other objects
 DETAIL:  drop cascades to function functest_a_1(text,date)
 drop cascades to function functest_a_2(text[])
 drop cascades to function functest_a_3()
@@ -739,5 +746,6 @@ drop cascades to function voidtest3(integer)
 drop cascades to function voidtest4(integer)
 drop cascades to function voidtest5(integer)
 drop cascades to function double_append(anyarray,anyelement)
+drop cascades to function max(integer,integer)
 DROP USER regress_unpriv_user;
 RESET search_path;
diff --git a/src/test/regress/sql/create_function_sql.sql b/src/test/regress/sql/create_function_sql.sql
index 89e9af3a49..f69143aa4c 100644
--- a/src/test/regress/sql/create_function_sql.sql
+++ b/src/test/regress/sql/create_function_sql.sql
@@ -415,6 +415,10 @@ CREATE FUNCTION test1 (int) RETURNS int LANGUAGE SQL
 CREATE FUNCTION test1 (int) RETURNS int LANGUAGE SQL
     AS 'a', 'b';
 
+\create_function from ./data/max.sql max(int,int) returns int LANGUAGE SQL
+
+select max(11, 22);
+
 -- Cleanup
 DROP SCHEMA temp_func_test CASCADE;
 DROP USER regress_unpriv_user;
-- 
2.40.1

