#include "postgres.h"


/*
 * LinearInterval's alternative defeinition for the environments without
 * int128 arithmetics. See interval_cmp_value for datails.
 */
typedef struct
{
	uint64		lo;				/* holds the lower 64 bits without sign */
	int64		hi;				/* holds significant 64 bits including a sign
								 * bit */
} LinearInterval;

typedef union LI
{
	int128		i128;
	LinearInterval li;
} LI;


/*
 * INT64_AU32 extracts the most significant 32 bits of int64 as int64, while
 * INT64_AL32 extracts the least significant 32 bits as uint64.
 */
#define INT64_AU32(i64) ((i64) >> 32)
#define INT64_AL32(i64) ((i64) & UINT64CONST(0xFFFFFFFF))

/*
 * Add an unsigned int64 value into a LinearInterval variable.
 * First add the value to the .lo part, then check to see if a carry
 * needs to be propagated into the .hi part.  A carry is needed if both
 * inputs have high bits set, or if just one input has high bit set
 * but the new .lo part doesn't.  Remember that .lo part is unsigned;
 * we cast to signed here just as a cheap way to check the high bit.
 */
#define LINEARINTERVAL_ADD_UINT64(li, v) \
do { \
	uint64		t = (uint64) (v); \
	uint64		oldlo = (li).lo; \
	(li).lo += t; \
	if (((int64) t < 0 && (int64) oldlo < 0) || \
		(((int64) t < 0 || (int64) oldlo < 0) && (int64) (li).lo >= 0)) \
		(li).hi++; \
} while (0)

static inline LinearInterval
interval_times(int64 x, int64 y)
{
	LinearInterval span = {0, 0};

	/*----------
	 * Form the 128-bit product x * y using 64-bit arithmetic.
	 * Considering each 64-bit input as having 32-bit high and low parts,
	 * we can compute
	 *
	 *	 x * y = ((x.hi << 32) + x.lo) * (((y.hi << 32) + y.lo)
	 *		   = (x.hi * y.hi) << 64 +
	 *			 (x.hi * y.lo) << 32 +
	 *			 (x.lo * y.hi) << 32 +
	 *			 x.lo * y.lo
	 *
	 * Each individual product is of 32-bit terms so it won't overflow when
	 * computed in 64-bit arithmetic.  Then we just have to shift it to the
	 * correct position while adding into the 128-bit result.  We must also
	 * keep in mind that the "lo" parts must be treated as unsigned.
	 *----------
	 */

	/* INT64_AU32 must use arithmetic right shift */
	StaticAssertStmt(((int64) -1 >> 1) == (int64) -1,
					 "arithmetic right shift is needed");

	/* No need to work hard if product must be zero */
	if (x != 0 && y != 0)
	{
		int64		x_u32 = INT64_AU32(x);
		uint64		x_l32 = INT64_AL32(x);
		int64		y_u32 = INT64_AU32(y);
		uint64		y_l32 = INT64_AL32(y);
		int64		tmp;

		/* the first term */
		span.hi = x_u32 * y_u32;
		printf("first term  = %016lX\n", span.hi);

		/* the second term: sign-extend it only if x is negative */
		tmp = x_u32 * y_l32;
		printf("second term =         %016lX\n", tmp);
		if (x < 0)
			span.hi += INT64_AU32(tmp);
		else
			span.hi += ((uint64) tmp) >> 32;
		LINEARINTERVAL_ADD_UINT64(span, ((uint64) INT64_AL32(tmp)) << 32);
		printf("partial sum = %016lX%016lX\n", span.hi, span.lo);

		/* the third term: sign-extend it only if y is negative */
		tmp = x_l32 * y_u32;
		printf("third term  =         %016lX\n", tmp);
		if (y < 0)
			span.hi += INT64_AU32(tmp);
		else
			span.hi += ((uint64) tmp) >> 32;
		LINEARINTERVAL_ADD_UINT64(span, ((uint64) INT64_AL32(tmp)) << 32);
		printf("partial sum = %016lX%016lX\n", span.hi, span.lo);

		/* the fourth term: always unsigned */
		printf("fourth term =                 %016lX\n", x_l32 * y_l32);
		LINEARINTERVAL_ADD_UINT64(span, x_l32 * y_l32);
	}

	return span;
}

int
main(int argc, char **argv)
{
	int64		x = strtol(argv[1], NULL, 0);
	int64		y = strtol(argv[2], NULL, 0);
	LI			li;
	LI			li2;

	printf("%lX * %lX\n", x, y);

	li.li = interval_times(x, y);

	printf("result      = %016lX%016lX\n", li.li.hi, li.li.lo);
	printf("result = %ld %lu\n", li.li.hi, li.li.lo);

	li2.i128 = (int128) x * (int128) y;

	if (li.li.hi != li2.li.hi || li.li.lo != li2.li.lo)
	{
		printf("MISMATCH!\n");
		printf("result      = %016lX%016lX\n", li2.li.hi, li2.li.lo);
		printf("result = %ld %lu\n", li2.li.hi, li2.li.lo);
	}

	return 0;
}
