#include "postgres.h"

#include "fmgr.h"
#include "miscadmin.h"
#include "pgstat.h"
#include "portability/instr_time.h"
#include "port/atomics.h"
#include "storage/ipc.h"
#include "storage/latch.h"
#include "storage/procsignal.h"
#include "storage/s_lock.h"
#include "storage/shmem.h"
#include "storage/spin.h"
#include "utils/elog.h"

#ifdef PG_MODULE_MAGIC
PG_MODULE_MAGIC;
#endif

typedef struct {
	slock_t	 mutex;
	Oid		 userid;
	bool	 isSet;
	Latch	*callerLatch;
} UserIdSlot;
const Size UserIdSlotSize = BUFFERALIGN(sizeof(UserIdSlot));

static ProcSignalReason extractEffectiveUserReason;
static UserIdSlot 		*userIdSlot = NULL;
static shmem_startup_hook_type prev_shmem_startup_hook = NULL;

static void foo_shmem_startup();
static void sendEffectiveUserId(ProcSignalReason reason);
static Oid extractEffectiveUserId(pid_t remoteSessionId);

void
_PG_init()
{
    elog(LOG, "Load. My PID = %d", MyProcPid);

    extractEffectiveUserReason =
        RegisterCustomProcSignalHandler(sendEffectiveUserId);
	if (extractEffectiveUserReason == INVALID_PROCSIGNAL)
	{
		elog(WARNING, "Insufficient custom ProcSignal slots");
		return;
	}

	RequestAddinShmemSpace(UserIdSlotSize);

	prev_shmem_startup_hook = shmem_startup_hook;
	shmem_startup_hook = foo_shmem_startup;
}

void
foo_shmem_startup()
{
	bool	found;

    elog(LOG, "Stand out shmem. My PID = %d", MyProcPid);

	userIdSlot = ShmemInitStruct("foo userid slot", UserIdSlotSize, &found);

	if (prev_shmem_startup_hook)
		prev_shmem_startup_hook();
}

void
_PG_fini()
{
    elog(LOG, "Unload. My PID = %d", MyProcPid);

    UnregisterCustomProcSignal(extractEffectiveUserReason);
	shmem_startup_hook = prev_shmem_startup_hook;
}

PG_FUNCTION_INFO_V1(remote_effective_user);

Datum
remote_effective_user(PG_FUNCTION_ARGS)
{
    pid_t	pid = PG_GETARG_INT32(0);

    PG_RETURN_INT32(extractEffectiveUserId(pid));
}

Oid
extractEffectiveUserId(pid_t remoteSessionId)
{
	Oid		result;
	int		sendSignalStatus;
	long	timeout = 5000;
	int		rc = 0;

	userIdSlot->isSet = false;
	userIdSlot->callerLatch = MyLatch;
	pg_write_barrier();

    sendSignalStatus = SendProcSignal(
			remoteSessionId, extractEffectiveUserReason, InvalidBackendId);
	if (sendSignalStatus == -1)
	{
		switch (errno)
		{
			case ESRCH:
				elog(WARNING, "Process not found");
				break;
			default:
				elog(WARNING, "Error with sending signal");
		}
		return InvalidOid;
	}

	for (;;)
	{
		bool	isSet = false;
		instr_time	start_time;
		instr_time 	end_time;

		SpinLockAcquire(&userIdSlot->mutex);
		result = userIdSlot->userid;
		isSet = userIdSlot->isSet;
		SpinLockRelease(&userIdSlot->mutex);

		if (isSet)
			break;
		if (rc & WL_TIMEOUT || timeout <= 0)
		{
			elog(WARNING, "Remote session is not retry");
			return InvalidOid;
		}

		INSTR_TIME_SET_CURRENT(start_time);
		rc = WaitLatch(MyLatch, WL_LATCH_SET | WL_TIMEOUT, timeout,
				PG_WAIT_EXTENSION);
		INSTR_TIME_SET_CURRENT(end_time);
		INSTR_TIME_SUBTRACT(end_time, start_time);

		timeout -= (long) INSTR_TIME_GET_MILLISEC(end_time);

		CHECK_FOR_INTERRUPTS();
		ResetLatch(MyLatch);
	}

	return result;
}

void
sendEffectiveUserId(ProcSignalReason reason)
{
	bool	fakeFlag;

	AssertArg(reason == extractEffectiveUserReason);

    elog(LOG, "Extract effective user. My PID = %d", MyProcPid);

	SpinLockAcquire(&userIdSlot->mutex);
	GetUserIdAndContext(&userIdSlot->userid, &fakeFlag);
	userIdSlot->isSet = true;
	SpinLockRelease(&userIdSlot->mutex);

	SetLatch(userIdSlot->callerLatch);
}
