/*
 * riscv-des.c
 *
 * Demonstrates Clang RISC-V vector extension bug affecting DES implementation
 * and tests compiler barrier workaround.
 *
 * Clang 20.1.2 and possibly earlier versions miscompile scatter/gather write
 * patterns when auto-vectorizing with -O2 causing incorrect DES encryption
 * results.
 *
 * Build and test:
 *   gcc -O2 riscv-des.c -o des-gcc
 *   gcc -O2 riscv-des.c -march=rv64gcv riscv-des.c -o des-gcc-vec
 *   clang-20 -O2 riscv-des.c -o des-clang
 *   clang-20 -O2 -march=rv64gcv riscv-des.c -o des-clang-vec
 *   clang-20 -O1 -march=rv64gcv riscv-des.c -o des-clang-o1-vec
 *
 * All GCC compiled versions should produce "PASS", Clang with optimization
 * "O" greater-than "1" fails because initialization will produce the wrong
 * permutation tables.
 */

#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <time.h>

/* Compiler barrier macro */
#ifdef __clang__
#define MEMORY_BARRIER() __asm__ volatile("" ::: "memory")
#else
#define MEMORY_BARRIER() ((void)0)
#endif

/* DES constants - P-box permutation (32 bits) */
static const uint8_t pbox[32] = {
	16, 7, 20, 21,
	29, 12, 28, 17,
	1, 15, 23, 26,
	5, 18, 31, 10,
	2, 8, 24, 14,
	32, 27, 3, 9,
	19, 13, 30, 6,
	22, 11, 4, 25
};

/* Initial Permutation (64 bits) */
static const uint8_t IP[64] = {
	58, 50, 42, 34, 26, 18, 10, 2,
	60, 52, 44, 36, 28, 20, 12, 4,
	62, 54, 46, 38, 30, 22, 14, 6,
	64, 56, 48, 40, 32, 24, 16, 8,
	57, 49, 41, 33, 25, 17, 9, 1,
	59, 51, 43, 35, 27, 19, 11, 3,
	61, 53, 45, 37, 29, 21, 13, 5,
	63, 55, 47, 39, 31, 23, 15, 7
};

static uint8_t un_pbox[32];
static uint8_t init_perm[64];
static uint8_t final_perm[64];

/*
 * Initialize DES permutation tables.
 * This function contains scatter/gather patterns that trigger the Clang bug.
 */
static void
des_init_buggy(void)
{
	int i;

	/* Invert the P-box permutation - BUGGY with Clang -march=rv64gcv */
	for (i = 0; i < 32; i++)
		un_pbox[pbox[i] - 1] = i;

	/* Set up initial & final permutations - BUGGY with Clang -march=rv64gcv */
	for (i = 0; i < 64; i++)
		init_perm[final_perm[i] = IP[i] - 1] = i;
}

/*
 * Initialize DES permutation tables with compiler barriers.
 * This version uses MEMORY_BARRIER() to prevent auto-vectorization.
 */
static void
des_init_fixed(void)
{
	int i;

	/* Invert the P-box permutation - with barrier */
	for (i = 0; i < 32; i++)
	{
		un_pbox[pbox[i] - 1] = i;
		MEMORY_BARRIER();
	}

	/* Set up initial & final permutations - with barriers */
	for (i = 0; i < 64; i++)
	{
		init_perm[final_perm[i] = IP[i] - 1] = i;
		MEMORY_BARRIER();
	}
}

/*
 * Verify that permutation tables are correct.
 */
static int
verify_permutations(void)
{
	int i;
	int errors = 0;

	/* Expected un_pbox values (computed manually) */
	const uint8_t expected_un_pbox[32] = {
		8, 16, 22, 30, 12, 27, 1, 17,
		23, 15, 29, 5, 25, 19, 9, 0,
		7, 13, 24, 2, 3, 28, 10, 18,
		31, 11, 21, 6, 4, 26, 14, 20
	};

	/* Check un_pbox */
	for (i = 0; i < 32; i++)
	{
		if (un_pbox[i] != expected_un_pbox[i])
		{
			if (errors == 0)
				printf("ERROR: un_pbox mismatch:\n");
			if (errors < 5)
				printf("\tun_pbox[%d] = %d, expected %d\n",
					i, un_pbox[i], expected_un_pbox[i]);
			errors++;
		}
	}

	/* Check that init_perm and final_perm are inverses */
	for (i = 0; i < 64; i++)
	{
		if (init_perm[final_perm[i]] != i)
		{
			if (errors == 0 || errors == 32)
				printf("ERROR: init_perm/final_perm not inverses\n");
			if (errors < 5)
				printf("\tinit_perm[final_perm[%d]] = %d, expected %d\n",
					i, init_perm[final_perm[i]], i);
			errors++;
		}
	}

	if (errors > 5)
		printf("  ... and %d more errors\n", errors - 5);

	return errors;
}

/*
 * Benchmark initialization performance.
 */
static double
benchmark_init(void (*init_func)(void), int iterations)
{
	struct timespec start, end;
	int i;

	clock_gettime(CLOCK_MONOTONIC, &start);

	for (i = 0; i < iterations; i++)
	{
		memset(un_pbox, 0, sizeof(un_pbox));
		memset(init_perm, 0, sizeof(init_perm));
		memset(final_perm, 0, sizeof(final_perm));
		init_func();
	}

	clock_gettime(CLOCK_MONOTONIC, &end);

	double elapsed = (end.tv_sec - start.tv_sec) +
		(end.tv_nsec - start.tv_nsec) / 1e9;

	return elapsed;
}

/*
 * Get compiler and compilation flags.
 */
static void
print_compiler_info(void)
{
#ifdef __clang__
	printf("Compiler: Clang %d.%d.%d\n",
		   __clang_major__, __clang_minor__, __clang_patchlevel__);
#elif defined(__GNUC__)
	printf("Compiler: GCC %d.%d.%d\n",
		   __GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__);
#else
	printf("Compiler: Unknown\n");
#endif

#ifdef __riscv
	printf("Target: RISC-V %d-bit\n", __riscv_xlen);
#ifdef __riscv_vector
	printf("Vector extension: Enabled (RVV)\n");
#else
	printf("Vector extension: Not enabled\n");
#endif
#else
	printf("Target: Not RISC-V\n");
#endif

	printf("\n");
}

int
main(void)
{
	double buggy_time, fixed_time;
	int iterations = 1000000;

	print_compiler_info();

	/* Test buggy version (without barriers) */
	printf("Testing WITHOUT compiler barriers:\n");
	memset(un_pbox, 0, sizeof(un_pbox));
	memset(init_perm, 0, sizeof(init_perm));
	memset(final_perm, 0, sizeof(final_perm));
	des_init_buggy();

	if (verify_permutations() == 0)
		printf("PASS: Permutation tables are correct\n");
	else
		printf("FAIL: Permutation tables are incorrect\n");
	printf("\n");

	/* Test fixed version (with barriers) */
	printf("Testing WITH compiler barriers:\n");
	memset(un_pbox, 0, sizeof(un_pbox));
	memset(init_perm, 0, sizeof(init_perm));
	memset(final_perm, 0, sizeof(final_perm));
	des_init_fixed();

	if (verify_permutations() == 0)
		printf("PASS: Permutation tables are correct\n");
	else
		printf("FAIL: Permutation tables are incorrect (fix didn't work)\n");
	printf("\n");

	/* Performance comparison */
	printf("Performance Comparison (%d iterations):\n", iterations);

	buggy_time = benchmark_init(des_init_buggy, iterations);
	printf("Without barriers: %.3f seconds (%.0f ns/iter)\n",
		buggy_time, buggy_time * 1e9 / iterations);

	fixed_time = benchmark_init(des_init_fixed, iterations);
	printf("With barriers:    %.3f seconds (%.0f ns/iter)\n",
		fixed_time, fixed_time * 1e9 / iterations);

	if (fixed_time > buggy_time * 1.01)
		printf("Overhead: %.1f%%\n",
			(fixed_time / buggy_time - 1.0) * 100.0);
	else
		printf("Overhead: Negligible\n");

	return 0;
}
