#!/bin/bash
# Parallel tuple cost calibration benchmark
#
# reverse-engineer the actual per-byte IPC
# cost by varying tuple width through a Gather node and measuring runtime.
#
# Methodology:
# - Create tables with an integer id and a fixed-width text column
# - Populate with lpad(val, W, 'x') for each width W
# - Run: SELECT txt, count(*) FROM ptc_bench_W GROUP BY txt
#   This produces partial aggregates in workers, then Gather passes
#   (txt, partial_count) tuples whose width ≈ W + 8 bytes.
# - Set work_mem high enough that partial aggregation succeeds at all widths
# - Measure runtime at each width, parallel ON and OFF, best of 5
# - Compare predicted vs actual overhead scaling

set -euo pipefail

PGBIN=/home/andrew/pgl/ptest/inst/bin
PORT=5679
DB=ptc_calibrate
OUTDIR=/home/andrew/pgl/ptest/bench/ptc_calibrate_results
NROWS=10000000    # 10M rows
NDISTINCT=100000  # 100K distinct text values

mkdir -p "$OUTDIR"
> "$OUTDIR/agg_detail.txt"

psql_cmd() {
    "$PGBIN/psql" -h /tmp -p $PORT -d "$DB" -qAt "$@"
}

psql_cmd_timing() {
    "$PGBIN/psql" -h /tmp -p $PORT -d "$DB" -o /dev/null -qAt \
        -c '\timing on' "$@" 2>&1 | grep -oP 'Time: \K[\d.]+' | tail -1
}

# Create the database if needed
"$PGBIN/psql" -h /tmp -p $PORT -d postgres -qAt \
    -c "SELECT 1 FROM pg_database WHERE datname = '$DB'" | grep -q 1 \
    || "$PGBIN/psql" -h /tmp -p $PORT -d postgres -qAt \
        -c "CREATE DATABASE $DB;"

# Widths: powers of 2 plus intermediate points to fill the 512-1024 gap
WIDTHS="8 16 32 64 128 256 384 512 768 1024"
NWORKERS=2
NRUNS=5
# work_mem large enough for partial agg at width=1024:
# 100K groups × 1024 bytes ≈ 100MB, plus hash overhead → use 256MB
WORKMEM="256MB"

echo "=== Creating and populating tables ==="
echo "Rows: $NROWS, Distinct values: $NDISTINCT, work_mem: $WORKMEM"

for W in $WIDTHS; do
    # Reuse existing table if row count matches
    existing=$(psql_cmd -c "SELECT count(*) FROM ptc_bench_${W};" 2>/dev/null || echo "0")
    if [ "$existing" = "$NROWS" ]; then
        echo "  width=$W: reusing existing table ($existing rows)"
        continue
    fi
    echo "  Loading width=$W..."
    psql_cmd <<EOF
DROP TABLE IF EXISTS ptc_bench_${W};
CREATE TABLE ptc_bench_${W} (
    id  integer NOT NULL,
    txt text    NOT NULL
);
INSERT INTO ptc_bench_${W} (id, txt)
SELECT g, lpad((g % $NDISTINCT)::text, $W, 'x')
FROM generate_series(1, $NROWS) g;
ANALYZE ptc_bench_${W};
EOF
done

echo ""
echo "=== Verifying table stats ==="
for W in $WIDTHS; do
    psql_cmd -c "SELECT '$W' AS width,
        count(*) AS rows,
        count(DISTINCT txt) AS distinct_vals,
        avg(length(txt))::int AS avg_len
        FROM ptc_bench_${W};"
done

echo ""
echo "=== Collecting EXPLAIN plans ==="
for W in $WIDTHS; do
    echo "--- width=$W ---"
    psql_cmd -c "SET max_parallel_workers_per_gather = $NWORKERS;
        SET work_mem = '$WORKMEM';
        EXPLAIN (VERBOSE, COSTS ON)
        SELECT txt, count(*) FROM ptc_bench_${W} GROUP BY txt;" \
        | head -15
    echo ""
done

echo ""
echo "=== Benchmark: GROUP BY txt, count(*) ==="
echo "  Gather passes ~240K-300K partial aggregate rows at width ≈ W + 8."
echo "  work_mem=$WORKMEM to ensure consistent partial aggregation."
echo "  Best of $NRUNS runs, $NWORKERS workers."
echo ""

AGG_QUERY_TEMPLATE="SELECT txt, count(*) FROM ptc_bench_WIDTH GROUP BY txt"

printf "%-8s %12s %12s %12s %12s\n" "Width" "Parallel(ms)" "Serial(ms)" "Speedup" "Gather rows"
printf "%-8s %12s %12s %12s %12s\n" "-----" "------------" "----------" "-------" "-----------"

for W in $WIDTHS; do
    Q="${AGG_QUERY_TEMPLATE//WIDTH/$W}"
    SET_CMDS="SET max_parallel_workers_per_gather = $NWORKERS; SET work_mem = '$WORKMEM';"
    SET_SER="SET max_parallel_workers_per_gather = 0; SET work_mem = '$WORKMEM';"

    # Get Gather row count from EXPLAIN
    gather_rows=$(psql_cmd -c "$SET_CMDS EXPLAIN (COSTS ON) $Q;" \
        | grep -oP 'Gather.*rows=\K\d+' | head -1)
    gather_rows=${gather_rows:-"?"}

    # Warm up (2 runs each)
    psql_cmd -c "$SET_CMDS $Q" > /dev/null 2>&1
    psql_cmd -c "$SET_CMDS $Q" > /dev/null 2>&1
    psql_cmd -c "$SET_SER $Q" > /dev/null 2>&1
    psql_cmd -c "$SET_SER $Q" > /dev/null 2>&1

    # Parallel runs
    par_times=()
    for r in $(seq 1 $NRUNS); do
        t=$(psql_cmd_timing \
            -c "SET max_parallel_workers_per_gather = $NWORKERS;" \
            -c "SET work_mem = '$WORKMEM';" \
            -c "$Q")
        par_times+=("$t")
    done

    # Serial runs
    ser_times=()
    for r in $(seq 1 $NRUNS); do
        t=$(psql_cmd_timing \
            -c "SET max_parallel_workers_per_gather = 0;" \
            -c "SET work_mem = '$WORKMEM';" \
            -c "$Q")
        ser_times+=("$t")
    done

    # Best of N
    par_best=$(printf '%s\n' "${par_times[@]}" | sort -n | head -1)
    ser_best=$(printf '%s\n' "${ser_times[@]}" | sort -n | head -1)

    speedup=$(python3 -c "print(f'{float(\"$ser_best\")/float(\"$par_best\"):.2f}x')" 2>/dev/null || echo "?")

    printf "%-8s %12s %12s %12s %12s\n" "$W" "$par_best" "$ser_best" "$speedup" "$gather_rows"

    echo "width=$W parallel: ${par_times[*]}" >> "$OUTDIR/agg_detail.txt"
    echo "width=$W serial:   ${ser_times[*]}" >> "$OUTDIR/agg_detail.txt"
done

echo ""
echo "=== Done: $(date -Iseconds) ==="
