#include <iostream>
#include <thread>
#include <chrono>
#include <zlib.h>
#include <string.h>

#define likely(x)       __builtin_expect((x),1)
#define unlikely(x)     __builtin_expect((x),0)

const size_t data_block_size = 64*1024;
char data_block[64*1024];

size_t k_iter=100;
unsigned long global_crcval = 0;

/* ---------------- spin-lock ---------------- */
uint32_t lock_unit;

void lock()
{
  while (true) {
    uint32_t expected = 0;
#ifdef CAS
    if (!__atomic_compare_exchange_n(&lock_unit, &expected, 1, false,
                                     __ATOMIC_ACQUIRE, __ATOMIC_ACQUIRE)) {
#elif TAS
    if (__sync_lock_test_and_set(&lock_unit, 1)) {
#endif
      /* Spin-and-Retry as the lock is acquired by some other thread */
      __asm__ __volatile__("" ::: "memory");
      continue;
    }
    break;
  }
}

void unlock()
{
#ifdef CAS
  __atomic_store_n(&lock_unit, 0, __ATOMIC_RELEASE);
#else
  __sync_lock_release(&lock_unit);
#endif
}

/* ---------------- workload ---------------- */
void workload_execute()
{
  for (size_t i = 0; i < k_iter; ++i) {
    lock();
    /* Each thread try to take lock -> execute critical section -> unlock */
    memset(data_block, rand() % 255, data_block_size);
    unsigned long crcval = 0;
    crc32(crcval, (const unsigned char *)data_block, data_block_size);
    global_crcval += crcval;
    unlock();
  }
}

int main(int argc, char *argv[]) {
  if (argc != 2) {
    std::cerr << "usage: <program> <number-of-threads/parallelism>" << std::endl;
    return 1;
  }
  size_t num_of_threads = atol(argv[1]);

  std::thread* handles[num_of_threads];

  auto start = std::chrono::high_resolution_clock::now();

  for (size_t i = 0; i < num_of_threads; ++i) {
    handles[i] = new std::thread(workload_execute);
  }
  for (size_t i = 0; i < num_of_threads; ++i) {
    handles[i]->join();
  }

  auto finish = std::chrono::high_resolution_clock::now();

  for (size_t i = 0; i < num_of_threads; ++i) {
    delete handles[i];
  }

  std::chrono::duration<double> elapsed = finish - start;
  std::cout << "Elapsed time: " << elapsed.count() << " s\n";
}
