#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>

#include <x86intrin.h>

/*#define DEBUG1*/

typedef uint8_t uint8;
typedef uint64_t uint64;
typedef uint64_t Datum;
typedef int64_t int64;
#define SIZEOF_DATUM 8
#define lengthof(array) (sizeof (array) / sizeof ((array)[0]))
#define Assert(x)
#define CppConcat(x, y)                 x##y
#define CHECK_FOR_INTERRUPTS()
#define Min(x, y)		((x) < (y) ? (x) : (y))

#if defined(__GNUC__)
#define pg_noinline __attribute__((noinline))
/* msvc via declspec */
#elif defined(_MSC_VER)
#define pg_noinline __declspec(noinline)
#else
#define pg_noinline
#endif

#if defined(__GNUC__) && defined(__OPTIMIZE__)
/* GCC supports always_inline via __attribute__ */
#define pg_attribute_always_inline __attribute__((always_inline)) inline
#elif defined(_MSC_VER)
/* MSVC has a special keyword for this */
#define pg_attribute_always_inline __forceinline
#else
/* Otherwise, the best we can do is to say "inline" */
#define pg_attribute_always_inline inline
#endif


typedef struct
{
	void	   *tuple;			/* the tuple itself */
	Datum		datum1;			/* value of first key column */
	union
	{
		struct
		{
			bool		isnull1;		/* is first key column NULL? */
			int			srctape;		/* source tape number */
		};
		Datum		cond_datum1;
	};
} SortTuple;


/* Used for conditioned datums, so we can ignore NULLs and sort direction. */
static pg_attribute_always_inline int
qsort_tuple_conditioned_compare(SortTuple *a, SortTuple *b)
{
	if (a->cond_datum1 < b->cond_datum1)
		return -1;
	if (a->cond_datum1 > b->cond_datum1)
		return 1;

	/*
	 * No need to waste effort calling the tiebreak function when there are no
	 * other keys to sort on.
	 */
	//if (state->base.onlyKey != NULL)
		return 0;

	//return state->base.comparetup_tiebreak(a, b, state);
}

#define ST_SORT qsort_tuple_conditioned
#define ST_ELEMENT_TYPE SortTuple
#define ST_COMPARE(a, b) qsort_tuple_conditioned_compare(a, b)
#define ST_CHECK_FOR_INTERRUPTS
#define ST_SCOPE static
#define ST_DEFINE
#include "lib/sort_template.h"


typedef struct PartitionInfo
{
    union
    {
        size_t count;
        size_t offset;
    };
    size_t next_offset;
} PartitionInfo;

static inline uint8_t
extract_key(Datum key, int level)
{
	return (key >> (((SIZEOF_DATUM - 1) - level) * 8)) & 0xFF;
}

static inline void
swap(SortTuple * a, SortTuple * b)
{
	SortTuple tmp = *a;

	*a = *b;
	*b = tmp;
}


static void
pg_noinline
ska_byte_sort(SortTuple *begin,
			  SortTuple *end, int level)
{
	/* size_t		counts0[256] = {0}; */
	size_t		counts1[256] = {0};
	size_t		counts2[256] = {0};
	size_t		counts3[256] = {0};
	PartitionInfo partitions[256] = {0};
	uint8_t		remaining_partitions[256] = {0};
	size_t		total = 0;
	int			num_partitions = 0;
	int			num_remaining;
	SortTuple  *ctup;

	/* count key chunks, unrolled for speed */

	for (ctup = begin; ctup + 4 < end; ctup += 4)
	{
		uint8		key_chunk0 = extract_key((ctup + 0)->cond_datum1, level);
		uint8		key_chunk1 = extract_key((ctup + 1)->cond_datum1, level);
		uint8		key_chunk2 = extract_key((ctup + 2)->cond_datum1, level);
		uint8		key_chunk3 = extract_key((ctup + 3)->cond_datum1, level);

		partitions[key_chunk0].count++;
		counts1[key_chunk1]++;
		counts2[key_chunk2]++;
		counts3[key_chunk3]++;

	}

	for (size_t i = 0; i < 256; i++)
		partitions[i].count += counts1[i] + counts2[i] + counts3[i];

	for (; ctup < end; ctup++)
	{
		uint8		key_chunk;

		key_chunk = extract_key(ctup->cond_datum1, level);
		partitions[key_chunk].count++;
	}

	/* compute partition offsets */
	for (int i = 0; i < 256; ++i)
	{
		size_t		count = partitions[i].count;

		if (count)
		{
			partitions[i].offset = total;
			total += count;
			remaining_partitions[num_partitions] = i;
			++num_partitions;
		}
		partitions[i].next_offset = total;
	}

	num_remaining = num_partitions;

	/*
	 * Permute tuples to correct partition. If we started with one partition,
	 * there is nothing to do. If a permutation from a previous iteration
	 * results in a single partition that hasn't been marked as sorted, we
	 * know it's actually sorted.
	 */
	while (num_remaining > 1)
	{
		/*
		 * We can only exit the loop when all partitions are sorted, so must
		 * reset every iteration
		 */
		num_remaining = num_partitions;

		for (int i = 0; i < num_partitions; i++)
		{
			uint8		idx = remaining_partitions[i];

			PartitionInfo part = partitions[idx];

			for (SortTuple *st = begin + part.offset;
				 st < begin + part.next_offset;
				 st++)
			{
				uint8		this_partition = extract_key(st->cond_datum1, level);
				size_t		offset = partitions[this_partition].offset++;

				Assert(begin + offset < end);
				swap(st, begin + offset);
			};

			if (part.offset == part.next_offset)
			{
				/* partition is sorted; skip */
				num_remaining--;
			}
		}
	}
	/* no recursion */
}



int
main ()
{
	uint64_t start;
	uint64_t finish;
	double qticks;
	double rticks;
#define COUNT 8000
	int lengths[] = { 100, 200, 400, 800, 1600, 3200, 6400 };
	SortTuple *st = malloc(COUNT * sizeof(SortTuple));
	SortTuple *test_radix = malloc(COUNT * sizeof(SortTuple));
	SortTuple *test_qsort = malloc(COUNT * sizeof(SortTuple));

// 256 or less so that all entropy is in a single byte
#define CARDINALITY 256
	printf("cardinality: %d\n", CARDINALITY);
	for (int i=0; i<COUNT; i++)
	{
		// only lowest byte is populated
		int64 val = random() % CARDINALITY;
		SortTuple x = { .cond_datum1 = val };
		st[i] = x;
	}

	for (int j=0; j< lengthof(lengths); j++)
	{
		int len = lengths[j];
		qticks = rticks = 0;

		printf("number of elements: %4d   ", len);

#define NUM_MEASUREMENTS 1000000
		for (int k=0; k<NUM_MEASUREMENTS; k++)
		{
			// repopulate test
			memcpy(test_qsort, st, len * sizeof(SortTuple));

			start = __rdtsc();
			// only sort lowest byte
			qsort_tuple_conditioned(test_qsort, len);
			finish = __rdtsc();
			qticks += finish - start;

			memcpy(test_radix, st, len * sizeof(SortTuple));

#ifdef DEBUG1
			printf("before:\n");
			for (int i=0; i<len; i++)
				printf("%ld\n", test_radix[i].cond_datum1);
#endif

			start = __rdtsc();
			// only sort lowest byte
			ska_byte_sort(test_radix, test_radix + len, 7);
			finish = __rdtsc();
			rticks += finish - start;
		}

		printf("qsort: %04.1f radix: %04.1f\n", qticks / NUM_MEASUREMENTS / len,
											rticks / NUM_MEASUREMENTS / len);

#ifdef DEBUG1
			printf("after:\n");
			for (int i=0; i<len; i++)
				printf("%ld\n", test_radix[i].cond_datum1);
#endif
	}

}
