#include "postgres.h"


/*
 * LinearInterval's alternative defeinition for the environments without
 * int128 arithmetics. See interval_cmp_value for datails.
 */
typedef struct
{
	uint64	lo; /* holds the lower 64 bits without sign */
	int64	hi;	/* holds significant 64 bits including a sign bit */
} LinearInterval;

typedef union LI
{
	int128 i128;
	LinearInterval li;
} LI;


/*
 * arithmetic 32 bit extraction from int64
 *
 * INT64_AU32 extracts significant 32 bit of int64 as a int64, and INT64_AL32
 * extracts non-siginificant 32 bit as a int64. Both macros extends sign bits
 * according to the given value. The values of these macros and the parameter
 * value are in the following relationship.
 *
 * i64 = (int64)INT64_AU32(i64) * (2^32) + (int64)INT64_AL32(i64)
 */
#define INT64_AU32(i64) ((i64) / (1LL<<32))
#define INT64_AL32(i64) (((i64) & 0xffffffff) | ((i64) < 0 ? 0xffffffff00000000 : 0))

/*
 * Adds signed 65 bit integer into LinearInterval variable. If s is not zero,
 * its sign is used as v's sign.
 */
#define LINEARINTERVAL_ADD_INT65(li, v, s) \
{ \
	uint64 t = (uint64)(v); \
	uint64 p = (li).lo;	\
	(li).lo += t; \
	if (s < 0 || (s == 0 && v < 0))	\
		(li).hi --; \
	if ((li).lo < p) \
		(li).hi ++; \
}

static inline LinearInterval
interval_times(int64 x, int64 y)
{
	LinearInterval	span = {0, 0};
	int64	 tmp;

	/*
	 * perform 128 bit multiplication using 64 bit variables.
	 *
	 *   x * y = ((x.hi << 32) + x.lo) * (((y.hi << 32) + y.lo)
	 *         = (x.hi * y.hi) << 64 +
	 *           ((x.hi * y.lo) + (x.lo * y.hi)) << 32 +
	 *           x.lo * y.lo
	 */

	/* We don't bother calculation results in zero */
	if (x != 0 && y != 0)
	{
		int64 x_u32 = INT64_AU32(x);
		int64 x_l32 = INT64_AL32(x);

		/* the first term */
		span.hi = x_u32 * (y >> 32);

		/* the second term */
		tmp = x_l32 * (y >> 32)
			+ x_u32 * (y & 0xffffffff);
		span.hi += INT64_AU32(tmp);

		/* this shift may push out MSB. supply it explicitly */
		LINEARINTERVAL_ADD_INT65(span, INT64_AL32(tmp) << 32, tmp);

		/* the third term */
		tmp = x_l32 * (y & 0xffffffff);
		LINEARINTERVAL_ADD_INT65(span, tmp, 0);
	}

	return span;
}

int
main(int argc, char **argv)
{
	int64 x = strtol(argv[1], NULL, 0);
	int64 y = strtol(argv[2], NULL, 0);
	LI li;
	LI li2;

	printf("%lX * %lX\n", x, y);

	li.li = interval_times(x, y);

	printf("result = %ld %lu\n", li.li.hi, li.li.lo);
	printf("result = %lX %lX\n", li.li.hi, li.li.lo);

	li2.i128 = (int128) x * (int128) y;

	if (li.li.hi != li2.li.hi || li.li.lo != li2.li.lo)
	{
	    printf("MISMATCH!\n");
	    printf("result = %ld %lu\n", li2.li.hi, li2.li.lo);
	    printf("result = %lX %lX\n", li2.li.hi, li2.li.lo);
	}

	return 0;
}
