// CFLAGS="-O3" make test
//
// PostgreSQL's max value for shared_buffers is 1073741823,
// min is 16, default is 16384.
//
// ./test 16 31 64 100 128 512 1023 1024 16384 1073741823

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <time.h>
#include <pthread.h>
#include <assert.h>

// Thread-safe atomic operations using GCC built-ins
typedef struct {
    volatile uint64_t value;
} pg_atomic_uint64;

static inline void
pg_atomic_init_u64(pg_atomic_uint64 *ptr, uint64_t val) {
    ptr->value = val;
}

static inline uint64_t
pg_atomic_read_u64(pg_atomic_uint64 *ptr) {
    return __atomic_load_n(&ptr->value, __ATOMIC_SEQ_CST);
}

static inline uint64_t
pg_atomic_fetch_add_u64(pg_atomic_uint64 *ptr, uint64_t add_) {
    return __atomic_fetch_add(&ptr->value, add_, __ATOMIC_SEQ_CST);
}

// ClockSweep structure with magic constant support
typedef struct {
    pg_atomic_uint64 counter;
    uint32_t size;
    uint32_t mask;           // For power-of-2 sizes
    uint64_t magic;          // Magic constant for non-power-of-2
    int use_mask;            // 1 for power-of-2, 0 for magic modulo
} ClockSweep;

// Fast modulo using magic constant
static inline uint32_t
fast_mod(uint32_t n, uint32_t divisor, uint64_t magic) {
    // Compute quotient using magic multiplication
    uint32_t quotient = (uint32_t)(((uint64_t)n * magic) >> 32);

    // Compute remainder
    uint32_t remainder = n - quotient * divisor;

    // Adjust if remainder is too large (can only be off by divisor)
    return remainder < divisor ? remainder : remainder - divisor;
}

// Initialize ClockSweep with magic constant calculation
static void
ClockSweepInit(ClockSweep *sweep, uint32_t size) {
    pg_atomic_init_u64(&sweep->counter, 0);
    sweep->size = size;

    // Check if size is power of 2
    if ((size & (size - 1)) == 0) {
	// Power of 2: use simple mask
	sweep->mask = size - 1;
	sweep->magic = 0;  // Unused
	sweep->use_mask = 1;
    } else {
	// Non-power of 2: calculate magic constant
	sweep->mask = 0;   // Unused
	sweep->magic = ((1ULL << 32) + size - 1) / size;  // Ceiling division
	sweep->use_mask = 0;
    }
}

// Get current position without advancing
static inline uint32_t
ClockSweepPosition(ClockSweep *sweep) {
    uint64_t current = pg_atomic_read_u64(&sweep->counter);
    uint32_t counter32 = (uint32_t)current;

    if (sweep->use_mask) {
	// Power of 2: use mask
	return counter32 & sweep->mask;
    } else {
	// Non-power of 2: use magic modulo
	return fast_mod(counter32, sweep->size, sweep->magic);
    }
}

// Advance counter and return new position
static inline uint32_t
ClockSweepTick(ClockSweep *sweep) {
    uint64_t current = pg_atomic_fetch_add_u64(&sweep->counter, 1);
    uint32_t counter = (uint32_t)current;

    if (sweep->use_mask) {
	// Power of 2: use mask
	return counter & sweep->mask;
    } else {
	// Non-power of 2: use magic modulo
	return fast_mod(counter, sweep->size, sweep->magic);
    }
}

// Get number of complete cycles
static inline uint64_t
ClockSweepCycles(ClockSweep *sweep) {
    uint64_t current = pg_atomic_read_u64(&sweep->counter);
    return current / sweep->size;
}

// Baseline implementation using modulo operator for comparison
typedef struct {
    pg_atomic_uint64 counter;
    uint32_t size;
} ClockSweepBaseline;

static void
ClockSweepBaselineInit(ClockSweepBaseline *sweep, uint32_t size) {
    pg_atomic_init_u64(&sweep->counter, 0);
    sweep->size = size;
}

static inline uint32_t
ClockSweepBaselineTick(ClockSweepBaseline *sweep) {
    uint64_t current = pg_atomic_fetch_add_u64(&sweep->counter, 1);
    uint32_t counter = (uint32_t)current;

    return counter % sweep->size;  // Standard modulo
}

// Test data structure for multi-threading
typedef struct {
    ClockSweep *sweep;
    int thread_id;
    int iterations;
    uint32_t *results;
    int result_count;
} ThreadData;

// Thread function for multi-threaded testing
static void *
thread_test_function(void *arg) {
    ThreadData *data = (ThreadData *)arg;

    for (int i = 0; i < data->iterations; i++) {
	uint32_t pos = ClockSweepTick(data->sweep);
	if (i < data->result_count) {
	    data->results[i] = pos;
	}
    }

    return NULL;
}

// Verify magic constant works correctly
static void
verify_magic_mod(uint32_t divisor, uint64_t magic) {
    printf("Testing divisor %u with magic %llu:\n", divisor, magic);

    // Test first few multiples of divisor
    for (uint32_t i = 0; i < divisor * 3; i++) {
	uint32_t expected = i % divisor;
	uint32_t actual = fast_mod(i, divisor, magic);

	if (actual != expected) {
	    printf("FAIL: %u %% %u = %u, got %u\n", i, divisor, expected, actual);
	    return;
	}
    }

    // Test some larger values
    uint32_t test_values[] = {1000, 10000, 100000, 1000000, 0xFFFFFFFF};
    for (int i = 0; i < 5; i++) {
	uint32_t n = test_values[i];
	uint32_t expected = n % divisor;
	uint32_t actual = fast_mod(n, divisor, magic);

	if (actual != expected) {
	    printf("FAIL: %u %% %u = %u, got %u\n", n, divisor, expected, actual);
	    return;
	}
    }

    printf("PASS: All tests passed for divisor %u\n", divisor);
}

// Test basic functionality
static void
test_basic_functionality(uint32_t *test_sizes, int num_sizes) {
    printf("=== Basic Functionality Test ===\n");

    for (int i = 0; i < num_sizes; i++) {
	uint32_t size = test_sizes[i];
	ClockSweep sweep;
	ClockSweepInit(&sweep, size);

	printf("Testing size %u (use_mask=%d, magic=%llu):\n",
	       size, sweep.use_mask, sweep.magic);

	// Verify magic constant if not using mask
	if (!sweep.use_mask) {
	    verify_magic_mod(size, sweep.magic);
	}

	// Test sequence of positions
	for (uint32_t j = 0; j < size * 2; j++) {
	    uint32_t pos = ClockSweepTick(&sweep);
	    uint32_t expected = (j + 1) % size;

	    if (pos != expected) {
		printf("FAIL: At iteration %u, expected %u, got %u\n", j, expected, pos);
		return;
	    }
	}

	// Test cycles
	uint64_t cycles = ClockSweepCycles(&sweep);
	if (cycles != 2) {
	    printf("FAIL: Expected 2 cycles, got %llu\n", cycles);
	    return;
	}

	printf("PASS: Size %u\n", size);
    }

    printf("=== Basic Functionality Test PASSED ===\n\n");
}

// Test edge cases and large counter values
static void
test_edge_cases(void) {
    printf("=== Edge Cases Test ===\n");

    ClockSweep sweep;
    ClockSweepInit(&sweep, 1023);

    // Test with large counter values
    uint64_t large_values[] = {
	0xFFFFFFFF,
	0x100000000ULL,
	0x123456789ABCDEFULL
    };

    for (int i = 0; i < 3; i++) {
	pg_atomic_init_u64(&sweep.counter, large_values[i]);

	uint32_t pos = ClockSweepPosition(&sweep);
	uint32_t expected = (uint32_t)large_values[i] % 1023;

	if (pos != expected) {
	    printf("FAIL: For counter %llu, expected %u, got %u\n",
		   large_values[i], expected, pos);
	    return;
	}

	printf("PASS: Large counter %llu -> position %u\n", large_values[i], pos);
    }

    printf("=== Edge Cases Test PASSED ===\n\n");
}

// Test multi-threaded access
static void
test_multithreaded(void) {
    printf("=== Multi-threaded Test ===\n");

    const int num_threads = 4;
    const int iterations_per_thread = 10000;
    const uint32_t size = 16384;  // Default size

    ClockSweep sweep;
    ClockSweepInit(&sweep, size);

    pthread_t threads[num_threads];
    ThreadData thread_data[num_threads];

    // Create threads
    for (int i = 0; i < num_threads; i++) {
	thread_data[i].sweep = &sweep;
	thread_data[i].thread_id = i;
	thread_data[i].iterations = iterations_per_thread;
	thread_data[i].results = malloc(1000 * sizeof(uint32_t));
	thread_data[i].result_count = 1000;

	pthread_create(&threads[i], NULL, thread_test_function, &thread_data[i]);
    }

    // Wait for threads to complete
    for (int i = 0; i < num_threads; i++) {
	pthread_join(threads[i], NULL);
    }

    // Verify final counter value
    uint64_t final_counter = pg_atomic_read_u64(&sweep.counter);
    uint64_t expected_counter = num_threads * iterations_per_thread;

    if (final_counter != expected_counter) {
	printf("FAIL: Expected final counter %llu, got %llu\n",
	       expected_counter, final_counter);

	// Clean up and return
	for (int i = 0; i < num_threads; i++) {
	    free(thread_data[i].results);
	}
	return;
    }

    // Verify all positions are valid
    for (int i = 0; i < num_threads; i++) {
	for (int j = 0; j < thread_data[i].result_count; j++) {
	    uint32_t pos = thread_data[i].results[j];
	    if (pos >= size) {
		printf("FAIL: Thread %d got invalid position %u (size=%u)\n",
		       i, pos, size);

		// Clean up and return
		for (int k = 0; k < num_threads; k++) {
		    free(thread_data[k].results);
		}
		return;
	    }
	}
	free(thread_data[i].results);
    }

    printf("PASS: Multi-threaded test with %d threads, %d iterations each\n",
	   num_threads, iterations_per_thread);
    printf("Final counter: %llu\n", final_counter);
    printf("=== Multi-threaded Test PASSED ===\n\n");
}

// Performance benchmark with modulo comparison
static void
benchmark_performance(uint32_t *test_sizes, int num_sizes) {
    const int iterations = 10000000;

    printf("=== Performance Benchmark ===\n");
    printf("Comparing optimized ClockSweep vs standard modulo operator\n\n");

    printf("Size     Method      Ops/sec (M)   Speedup\n");
    printf("----     ------      -----------   -------\n");

    for (int i = 0; i < num_sizes; i++) {
	uint32_t size = test_sizes[i];

	// Test optimized version
	ClockSweep sweep;
	ClockSweepInit(&sweep, size);

	clock_t start = clock();
	for (int j = 0; j < iterations; j++) {
	    ClockSweepTick(&sweep);
	}
	clock_t end = clock();

	double elapsed_opt = ((double)(end - start)) / CLOCKS_PER_SEC;
	double ops_per_sec_opt = iterations / elapsed_opt;

	// Test baseline modulo version
	ClockSweepBaseline baseline;
	ClockSweepBaselineInit(&baseline, size);

	start = clock();
	for (int j = 0; j < iterations; j++) {
	    ClockSweepBaselineTick(&baseline);
	}
	end = clock();

	double elapsed_mod = ((double)(end - start)) / CLOCKS_PER_SEC;
	double ops_per_sec_mod = iterations / elapsed_mod;

	double speedup = ops_per_sec_opt / ops_per_sec_mod;

	printf("%4u     %-6s      %8.2f      %5.2fx\n",
	       size,
	       sweep.use_mask ? "mask" : "magic",
	       ops_per_sec_opt / 1000000.0,
	       speedup);

	printf("%4u     modulo      %8.2f      %5.2fx\n",
	       size,
	       ops_per_sec_mod / 1000000.0,
	       1.0);

	printf("\n");
    }

    printf("=== Performance Benchmark COMPLETED ===\n\n");
}

// Main test function
int main(int argc, char *argv[]) {
    clock_t start_time = clock();

    uint32_t *sizes = malloc(sizeof(uint32_t) * argc);
    for (int i = 1; i < argc; i++)
      sizes[i - 1] = atoi(argv[i]);

    test_basic_functionality(sizes, argc - 1);
    test_edge_cases();
    test_multithreaded();
    benchmark_performance(sizes, argc - 1);

    clock_t end_time = clock();
    double total_time = ((double)(end_time - start_time)) / CLOCKS_PER_SEC;

    printf("All tests completed successfully!\n");
    printf("Total execution time: %.2f seconds\n", total_time);

    return 0;
}
