import time
import logging
import psycopg2
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
import subprocess

PGDATA = "pgdata-multixid-repro"

log = logging.getLogger(__name__)

def connect():
    return psycopg2.connect(
        database="postgres",
        user="postgres",
        host="/tmp/",
    )

# Constants and macros copied from PostgreSQL multixact.c and headers. These are needed to
# calculate the SLRU segments that a particular multixid or multixid-offsets falls into.
BLCKSZ = 8192
MULTIXACT_OFFSETS_PER_PAGE = int(BLCKSZ / 4)
SLRU_PAGES_PER_SEGMENT = int(32)
MXACT_MEMBER_BITS_PER_XACT = 8
MXACT_MEMBER_FLAGS_PER_BYTE = 1
MULTIXACT_FLAGBYTES_PER_GROUP = 4
MULTIXACT_MEMBERS_PER_MEMBERGROUP = MULTIXACT_FLAGBYTES_PER_GROUP * MXACT_MEMBER_FLAGS_PER_BYTE
MULTIXACT_MEMBERGROUP_SIZE = 4 * MULTIXACT_MEMBERS_PER_MEMBERGROUP + MULTIXACT_FLAGBYTES_PER_GROUP
MULTIXACT_MEMBERGROUPS_PER_PAGE = int(BLCKSZ / MULTIXACT_MEMBERGROUP_SIZE)
MULTIXACT_MEMBERS_PER_PAGE = MULTIXACT_MEMBERGROUPS_PER_PAGE * MULTIXACT_MEMBERS_PER_MEMBERGROUP

def MultiXactIdToOffsetSegment(xid: int):
    return int(xid / (SLRU_PAGES_PER_SEGMENT * MULTIXACT_OFFSETS_PER_PAGE))


def MXOffsetToMemberSegment(off: int):
    return int(off / (SLRU_PAGES_PER_SEGMENT * MULTIXACT_MEMBERS_PER_PAGE))

def advance_multixid_to(next_multi_xid: int, next_multi_offset: int
):
    """
    Use pg_resetwal to advance the nextMulti and nextMultiOffset values in a stand-alone
    Postgres cluster. This is useful to get close to wraparound or some other interesting
    value, without having to burn a lot of time consuming the (multi-)XIDs one by one.

    The new values should be higher than the old ones, in a wraparound-aware sense.

    On entry, the server should be running. It will be shut down and restarted.
    """

    # Read old values from the last checkpoint. We will pass the old oldestMultiXid value
    # back to pg_resetwal, there's no option to leave it alone.
    with connect() as conn:
        with conn.cursor() as cur:
            # Make sure the oldest-multi-xid value in the control file is up-to-date
            cur.execute("checkpoint")
            cur.execute("select oldest_multi_xid, next_multixact_id from pg_control_checkpoint()")
            (ckpt_oldest_multi_xid, ckpt_next_multi_xid) = cur.fetchone()
    log.info(f"oldestMultiXid was {ckpt_oldest_multi_xid}, nextMultiXid was {ckpt_next_multi_xid}")
    log.info(f"Resetting to {next_multi_xid}")

    # Use pg_resetwal to reset the next multiXid and multiOffset to given values.
    subprocess.check_call(["pg_ctl", "-D", PGDATA, "stop"])
    cmd = [
        "pg_resetwal",
        f"--multixact-ids={next_multi_xid},{ckpt_oldest_multi_xid}",
        f"--multixact-offset={next_multi_offset}",
        "-D",
        PGDATA,
    ]
    subprocess.check_call(cmd)

    # Because we skip over a lot of values, Postgres hasn't created the SLRU segments for
    # the new values yet. Create them manually, to allow Postgres to start up.
    #
    # This leaves "gaps" in the SLRU where segments between old value and new value are
    # missing. That's OK for our purposes. Autovacuum will print some warnings about the
    # missing segments, but will clean it up by truncating the SLRUs up to the new value,
    # closing the gap.
    segname = "%04X" % MultiXactIdToOffsetSegment(next_multi_xid)
    log.info(f"Creating dummy segment pg_multixact/offsets/{segname}")
    with open(f"{PGDATA}/pg_multixact/offsets/{segname}", "w") as of:
        of.write("\0" * SLRU_PAGES_PER_SEGMENT * BLCKSZ)
        of.flush()

    segname = "%04X" % MXOffsetToMemberSegment(next_multi_offset)
    log.info(f"Creating dummy segment pg_multixact/members/{segname}")
    with open(f"{PGDATA}/pg_multixact/members/{segname}", "w") as of:
        of.write("\0" * SLRU_PAGES_PER_SEGMENT * BLCKSZ)
        of.flush()

    # Start Postgres again and wait until autovacuum has processed all the databases
    #
    # This allows truncating the SLRUs, fixing the gaps with missing segments.
    subprocess.check_call(["pg_ctl", "-D", PGDATA, "start"])
    with connect().cursor() as cur:
        for _ in range(1000):
            cur.execute("select min(datminmxid::text::int8) from pg_database")
            datminmxid = int(cur.fetchall()[0][0])
            log.info(f"datminmxid {datminmxid}")
            if next_multi_xid - datminmxid < 1_000_000:  # not wraparound-aware!
                break
            time.sleep(0.5)


def main():
    # In order to to test multixid wraparound, we need to first advance the counter to
    # within spitting distance of the wraparound, that is 2^32 multi-XIDs. We could simply
    # run a workload that consumes a lot of multi-XIDs until we approach that, but that
    # takes a very long time. So we cheat.
    #
    # Our strategy is to create a Postgres cluster, and use pg_resetwal to
    # set the multi-xid counter a higher value. However, we cannot directly set
    # it to just before 2^32 (~ 4 billion), because that would make the exisitng
    # 'relminmxid' values to look like they're in the future. It's not clear how the
    # system would behave in that situation. So instead, we bump it up ~ 1 billion
    # multi-XIDs at a time, and let autovacuum to process all the relations and update
    # 'relminmxid' between each run.
    subprocess.check_call(["initdb", "-D", PGDATA, "-U", "postgres"])
    with open(f"{PGDATA}/postgresql.conf", "a") as file1:
        file1.writelines([
            "log_autovacuum_min_duration = 0\n",
            # Perform anti-wraparound vacuuming aggressively
            "autovacuum_naptime='1 s'\n",
            "autovacuum_freeze_max_age = 1000000\n",
            "autovacuum_multixact_freeze_max_age = 1000000\n",
            "shared_buffers='1 MB'",
        ])

    subprocess.check_call(["pg_ctl", "-D", PGDATA, "start"])
    advance_multixid_to(0x40000000, 0x10000000)
    advance_multixid_to(0x80000000, 0x20000000)
    advance_multixid_to(0xC0000000, 0x30000000)
    advance_multixid_to(0xFFFFFF00, 0xFFFFFF00)

if __name__ == "__main__":
    main()
