/*
 * riscv-popcnt.c
 *
 * RISC-V Zbb popcount optimization
 *
 *   gcc -O2 -o popcnt-wo-zbb riscv-popcnt.c
 *   gcc -O2 -march=rv64gc_zbb -o popcnt-zbb riscv-popcnt.c
 */

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <time.h>

#define TEST_SIZE (1024 * 1024)  /* 1 MB */
#define ITERATIONS 100

/* software popcount taken from pg_bitutils.h */
static int
popcount_sw(uint64_t x)
{
	x = (x & 0x5555555555555555ULL) + ((x >> 1) & 0x5555555555555555ULL);
	x = (x & 0x3333333333333333ULL) + ((x >> 2) & 0x3333333333333333ULL);
	x = (x & 0x0F0F0F0F0F0F0F0FULL) + ((x >> 4) & 0x0F0F0F0F0F0F0F0FULL);
	return (x * 0x0101010101010101ULL) >> 56;
}

/* hardware popcount, expect that the compiler will use cpop on Zbb */
static int
popcount_hw(uint64_t x)
{
	return __builtin_popcountll(x);
}

static double
now(void)
{
	struct timespec ts;
	clock_gettime(CLOCK_MONOTONIC, &ts);
	return ts.tv_sec + ts.tv_nsec / 1e9;
}

int
main(void)
{
	uint64_t *data;
	uint64_t count_sw = 0, count_hw = 0;
	double start, elapsed_sw, elapsed_hw;
	double mb_per_sec;
	size_t i;

	data = malloc(TEST_SIZE);
	srand(42);

	for (i = 0; i < TEST_SIZE / sizeof(uint64_t); i++)
		data[i] = ((uint64_t)rand() << 32) | rand();

	start = now();
	for (int iter = 0; iter < ITERATIONS; iter++)
	{
		for (i = 0; i < TEST_SIZE / sizeof(uint64_t); i++)
			count_sw += popcount_sw(data[i]);
	}
	elapsed_sw = now() - start;
	mb_per_sec = (TEST_SIZE * ITERATIONS / (1024.0 * 1024.0)) / elapsed_sw;
	printf("sw popcount: %8.3f sec  (%10.2f MB/s)\n",
	       elapsed_sw, mb_per_sec);

	start = now();
	for (int iter = 0; iter < ITERATIONS; iter++)
	{
		for (i = 0; i < TEST_SIZE / sizeof(uint64_t); i++)
			count_hw += popcount_hw(data[i]);
	}
	elapsed_hw = now() - start;
	mb_per_sec = (TEST_SIZE * ITERATIONS / (1024.0 * 1024.0)) / elapsed_hw;
	printf("hw popcount: %8.3f sec  (%10.2f MB/s)\n",
	       elapsed_hw, mb_per_sec);

	printf("\ndiff: %.2fx\n", elapsed_sw / elapsed_hw);

	if (count_sw != count_hw)
	{
		printf("\n[ERROR] Results don't match!\n");
		printf("\tsw: %llu\n", (unsigned long long)count_sw);
		printf("\thw: %llu\n", (unsigned long long)count_hw);
	}
	else
	{
		printf("match: %llu bits counted\n", (unsigned long long)count_sw);
	}

	free(data);
	return 0;
}
