#define _POSIX_C_SOURCE 199309L
#define _XOPEN_SOURCE 500

#include <float.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

/*
 * The two expressions "high - low" can result in +inf if "high" and
 * "low" are have a distance exceeding DBL_MAX. Similarly, the
 * expression "oper - low" can be +inf if "oper" and "low" differs
 * by more than DBL_MAX.
 */
int width_bucket_orig(double oper, double low, double high, int count) {
  return ((double)count * (oper - low) / (high - low)) + 1;
}

int width_bucket_null(double oper, double low, double high, int count) {
  return 1;
}

int width_bucket_new1(double oper, double low, double high, int count) {
  if (isinf(high - low)) {
    const double bucket_size = (high / 2 - low / 2);
    const double bucket_pos = (oper / 2 - low / 2);
    return count * (bucket_pos / bucket_size) + 1;
  } else {
    return count * ((oper - low) / (high - low)) + 1;
  }
}

int width_bucket_new2(double oper, double low, double high, int count) {
  if (isinf(high - low)) {
    high /= 2;
    low /= 2;
    oper /= 2;
  }
  return count * ((oper - low) / (high - low)) + 1;
}

int width_bucket_new3(double oper, double low, double high, int count) {
  if (isinf(high - low)) {
    high /= 2;
    low /= 2;
    oper /= 2;
  }
  int result = count * ((oper - low) / (high - low));
  if (result == count && oper > high)
    result += 1;
  return result;
}

static struct value {
  double operand;
  double low;
  double high;
  int count;
} values[] = {
    {10.4, -DBL_MAX, DBL_MAX, 10},
    {-DBL_MAX / 2, -DBL_MAX, DBL_MAX, 10},
    {DBL_MAX / 2, -DBL_MAX, DBL_MAX, 10},
    {10.4, -DBL_MAX, DBL_MAX, 12},
    {-DBL_MAX / 2, -DBL_MAX, DBL_MAX, 12},
    {DBL_MAX / 2, -DBL_MAX, DBL_MAX, 12},
    {10.4, -DBL_MAX, DBL_MAX, 1},
    {-DBL_MAX / 2, -DBL_MAX, DBL_MAX, 1},
    {DBL_MAX / 2, -DBL_MAX, DBL_MAX, 1},
    {10.4, -DBL_MAX, DBL_MAX, 2},
    {-DBL_MAX / 2, -DBL_MAX, DBL_MAX, 2},
    {DBL_MAX / 2, -DBL_MAX, DBL_MAX, 2},
    {5.35, 0.024, 10.06, 5},
    {DBL_MIN, -2 * DBL_MIN, 2 * DBL_MIN, 4},
    {-DBL_MIN, -2 * DBL_MIN, 2 * DBL_MIN, 4},
    {DBL_MIN, -3 * DBL_MIN, 3 * DBL_MIN, 4},
    {-DBL_MIN, -3 * DBL_MIN, 3 * DBL_MIN, 4},
    {DBL_MIN, -3 * DBL_MIN, 3 * DBL_MIN, 6},
    {-DBL_MIN, -3 * DBL_MIN, 3 * DBL_MIN, 6},
    {0, -1e100, 1, 10},
};

void print_result(double oper, double low, double high, int count) {
  printf(
      "width_bucket(% 2.5e, % 2.5e, % 2.5e, % 2d): orig: % 2d, new: % 2d, "
      "new2: % 2d\n",
      oper,
      low,
      high,
      count,
      width_bucket_orig(oper, low, high, count),
      width_bucket_new1(oper, low, high, count),
      width_bucket_new3(oper, low, high, count));
}

int measure(const char *name,
            int (*width_bucket)(double, double, double, int)) {
  struct timespec before, after;
  clock_gettime(CLOCK_REALTIME, &before);
  unsigned long sum = 0; /* To avoid optimizations removing code */
  for (int i = 0; i < 10000000; ++i)
    sum += width_bucket(drand48(), drand48(), drand48(), random() % 10);
  clock_gettime(CLOCK_REALTIME, &after);
  printf("%20s: %f\n",
         name,
         1000 * (after.tv_sec - before.tv_sec) +
             (after.tv_nsec - before.tv_nsec) / 1e6);
  return sum;
}

int main(int argc, char *argv[]) {
  if (argc == 1) {
    int i;
    printf("DBL_MAX: %e, DBL_MIN: %e\n", DBL_MAX, DBL_MIN);
    printf("2 * DBL_MAX: %e\n", 2 * DBL_MAX);
    for (i = 0; i < sizeof(values) / sizeof(*values); ++i)
      print_result(
          values[i].operand, values[i].low, values[i].high, values[i].count);
  } else if (argc == 5) {
    printf("DBL_MAX: %e, DBL_MIN: %e\n", DBL_MAX, DBL_MIN);
    printf("2 * DBL_MAX: %e\n", 2 * DBL_MAX);
    print_result(strtod(argv[1], NULL),
                 strtod(argv[2], NULL),
                 strtod(argv[3], NULL),
                 strtol(argv[4], NULL, 0));
  } else if (strcmp(argv[1], "--benchmark") == 0) {
    long count = argc > 2 ? strtol(argv[2], NULL, 0) : 20;
    for (int i = 0; i < count; ++i) {
      measure("orig", width_bucket_orig);
      measure("null", width_bucket_null);
      measure("new1", width_bucket_new1);
      measure("new2", width_bucket_new2);
      measure("new3", width_bucket_new3);
    }
  } else {
    fprintf(stderr, "Usage: <oper> <low> <high> <count>\n");
    exit(2);
  }
}
