#!/usr/bin/env python3
from itertools import product
import math
import psycopg2
import random
import time

def biased_random_int(min_val, max_val):
    """
    Generate a random integer between min_val and max_val (inclusive),
    with reduced probability near the boundaries.

    Args:
        min_val: Minimum possible value
        max_val: Maximum possible value

    Returns:
        int: A random integer with reduced probability near the bounds
    """
    # Use beta distribution to create a bell-shaped probability curve
    # Alpha = Beta = 2 gives a parabolic shape with lower probability at extremes
    alpha = 2
    beta = 2

    # Generate a random value between 0 and 1 with beta distribution
    random_val = random.betavariate(alpha, beta)

    # Scale to our range and round to integer
    scaled_val = min_val + random_val * (max_val - min_val)
    return round(scaled_val)

class PostgreSQLSkipScanTester:
    def __init__(self, conn_params, table_name, num_rows, num_samples, num_tests):
        self.conn_params = conn_params
        self.table_name = table_name
        self.num_rows = num_rows
        self.num_samples = num_samples
        self.num_tests = num_tests
        self.columns = ['a', 'b', 'c', 'd']
        self.equality_operators = ['=', 'IN', 'IS NULL']
        self.inequality_operators = ['<', '<=', '>=', '>', 'IS NOT NULL']
        self.conn = None

    def connect(self):
        """Establish connection to PostgreSQL database"""
        try:
            self.conn = psycopg2.connect(**self.conn_params)
            print("Successfully connected")
        except Exception as e:
            print(f"Connection error: {e}")
            raise

    def setup_test_environment(self):
        """Create test table and populate with random data"""
        try:
            cursor = self.conn.cursor()

            # Create test table
            cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
            cursor.execute(f"""
                create table {self.table_name} (
                    id serial primary key,
                    a integer,
                    b integer,
                    c integer,
                    d integer)
            """)

            # Create the composite index first (makes suffix truncation
            # effective)
            cursor.execute(f"""
                CREATE INDEX idx_{self.table_name}_abcd
                ON {self.table_name} (a, b, c, d)
            """)
            # Insert random data (with some NULL values)
            for _ in range(self.num_rows):
                # Randomly decide whether to include NULL values for each column
                a_val = None if random.random() < 0.05 else random.randint(1, 20)
                b_val = None if random.random() < 0.05 else random.randint(1, 20)
                c_val = None if random.random() < 0.05 else random.randint(1, 100)
                d_val = None if random.random() < 0.05 else random.randint(1, 10_000)

                cursor.execute(f"""
                    INSERT INTO {self.table_name} (a, b, c, d)
                    VALUES (%s, %s, %s, %s)
                """, (a_val, b_val, c_val, d_val))

            self.conn.commit()

            # VACUUM
            self.conn.set_session(autocommit=True) # So that we can run VACUUM, etc
            cursor.execute(f"""
                vacuum analyze {self.table_name}
            """)

            print(f"Created test table {self.table_name} with {self.num_rows} rows")
            cursor.execute(f"""
                set max_parallel_workers_per_gather to 0;
            """)

        except Exception as e:
            self.conn.rollback()
            print(f"Setup error: {e}")
            raise

    def generate_cond(self, column, operator_type=None,
                      equality_weights = [0.78, 0.20, 0.02], # Use IS NULL much less often
                      inequality_weights = [0.20, 0.20, 0.20, 0.17, 0.03]):   # Use IS NOT NULL much less often
        """
        Generate a single condition for a column using the specified operator type.

        Args:
            column: The column name to generate condition for
            operator_type: 'equality', 'inequality', or None (random)

        Returns:
            A condition string like "a > 5" or "b IS NULL"
        """
        if operator_type == 'equality':
            op = random.choices(self.equality_operators, weights=equality_weights)[0]
        elif operator_type == 'inequality':
            op = random.choices(self.inequality_operators, weights=inequality_weights)[0]
        else:
            # Random operator from all operators
            all_operators = self.equality_operators + self.inequality_operators
            op = random.choice(all_operators)

        if op == 'IS NULL' or op == 'IS NOT NULL':
            return f"{column} {op}"
        elif op == 'IN':
            nelements = random.randint(2, 20)
            elements = set()
            for i in range(nelements):
                # Generate a value appropriate for the column
                if column == 'a':
                    elements.add(biased_random_int(-1, 21))
                elif column == 'b':
                    elements.add(biased_random_int(-1, 21))
                elif column == 'c':
                    elements.add(biased_random_int(-1, 101))
                else:  # column == 'd'
                    elements.add(biased_random_int(-1, 10_001))

            values_str = ", ".join(str(x) for x in sorted(elements))
            return f"{column} IN ({values_str})"
        else:
            # Generate a value appropriate for the column
            if column == 'a':
                value = biased_random_int(-1, 21)
            elif column == 'b':
                value = biased_random_int(-1, 21)
            elif column == 'c':
                value = biased_random_int(-1, 101)
            else:  # column == 'd'
                value = biased_random_int(-1, 10_001)

            return f"{column} {op} {value}"

    def find_matching_operator_indices(self, all_conditions):

        matching_indices = set()

        for index, operator in enumerate(self.inequality_operators):
            for dynamic_string in all_conditions:
                if operator in dynamic_string:
                    matching_indices.add(index)
                    if index == 0: # if < add <= to avoid redundancies
                        matching_indices.add(1)
                    if index == 1: # if <= add < to avoid redundancies
                        matching_indices.add(0)
                    if index == 2: # if >= add > to avoid redundancies
                        matching_indices.add(3)
                    if index == 3: # if > add >= to avoid redundancies
                        matching_indices.add(2)
                    break  # Found one match for this operator, no need to check other strings

        return matching_indices

    def sort_constraints(self, constraints):
        """
        Sort constraints with column name treated most significant, followed by operator order.
        Lower bound operators (>, >=) sort before upper bound operators (<, <=).  This presents
        things in a consistent order, that seems more readable.

        Args:
            constraints: List of constraint strings (e.g., ["a <= 5", "a > 8", "a >= 4", "b < 5"])

        Returns:
            list: Sorted list of constraints
        """
        # Define operator precedence (lower values = higher precedence)
        operator_priority = {
            ">": 0,   # Highest priority for lower bounds
            ">=": 1,
            "<": 2,   # Lower priority for upper bounds
            "<=": 3
        }

        def get_sort_key(constraint):
            # Parse the constraint into column and operator parts
            parts = constraint.split()
            if len(parts) >= 2:
                column = parts[0]
                operator = parts[1]

                # Return a tuple for sorting (column name, operator priority)
                return (column, operator_priority.get(operator, 999))
            return (constraint, 999)  # Fallback for unparseable constraints

        # Sort the constraints using the custom sort key
        return sorted(constraints, key=get_sort_key)

    def resolve_contradictions(self, conditions):
        """
        Allowing contradictory quals seems to be a poor use of available test cycles.
        Resolves contradictions in a list of conditions by swapping integer constants
        when needed, specifically for conditions on the same column that are impossible
        to satisfy simultaneously.

        Args:
            conditions: List of condition strings (e.g., ["b > 16", "b < 9", "d IS NOT NULL"])

        Returns:
            list: Modified list of conditions with contradictions resolved via
            constant swapping
        """
        # Parse conditions into a more usable format
        parsed_conditions = []
        for condition in conditions:
            parts = condition.split()
            # Only process numeric comparisons
            if len(parts) == 3 and parts[1] in ['>', '<', '>=', '<=']:
                try:
                    column = parts[0]
                    operator = parts[1]
                    value = int(parts[2])
                    parsed_conditions.append((column, operator, value, condition))
                except ValueError:
                    # If value isn't an integer, just keep the original condition
                    parsed_conditions.append((None, None, None, condition))
            else:
                # For non-comparison conditions like "IS NOT NULL"
                parsed_conditions.append((None, None, None, condition))

        # Group conditions by column
        column_conditions = {}
        for col, op, val, cond in parsed_conditions:
            if col is not None:
                if col not in column_conditions:
                    column_conditions[col] = []
                column_conditions[col].append((op, val, cond))

        # Check and resolve contradictions
        modified_conditions = conditions.copy()

        for column, col_conditions in column_conditions.items():
            lower_bounds = []  # > and >=
            upper_bounds = []  # < and <=

            # Separate into lower and upper bounds
            for op, val, cond in col_conditions:
                if op in ['>', '>=']:
                    lower_bounds.append((op, val, cond))
                elif op in ['<', '<=']:
                    upper_bounds.append((op, val, cond))

            # Check for contradictions between lower and upper bounds
            for lower_op, lower_val, lower_cond in lower_bounds:
                for upper_op, upper_val, upper_cond in upper_bounds:
                    # Contradiction if lower bound >= upper bound
                    is_contradiction = False

                    if lower_op == '>' and upper_op == '<' and lower_val >= upper_val:
                        is_contradiction = True
                    elif lower_op == '>' and upper_op == '<=' and lower_val >= upper_val:
                        is_contradiction = True
                    elif lower_op == '>=' and upper_op == '<' and lower_val >= upper_val:
                        is_contradiction = True
                    elif lower_op == '>=' and upper_op == '<=' and lower_val > upper_val:
                        is_contradiction = True

                    # If contradiction, swap the values
                    if is_contradiction:
                        # Create new conditions with swapped values
                        new_lower_cond = f"{column} {lower_op} {upper_val}"
                        new_upper_cond = f"{column} {upper_op} {lower_val}"

                        # Replace the old conditions with new ones
                        modified_conditions[modified_conditions.index(lower_cond)] = new_lower_cond
                        modified_conditions[modified_conditions.index(upper_cond)] = new_upper_cond

        return modified_conditions

    def generate_random_query(self):
        """
        Generate a random query with conditions on columns.
        For columns without equality conditions, sometimes generate multiple inequality conditions.
        """
        # Decide which columns will have conditions (at least one, at most four)
        assert(len(self.columns) == 4)
        weights = [0.05, 0.05, 0.05, 0.85]  # Probabilities for 1, 2, 3, or 4 columns
        num_columns_with_conditions = random.choices([1, 2, 3, 4], weights=weights)[0]
        columns_with_conditions = random.sample(self.columns, num_columns_with_conditions)

        # Avoid just having one column with conditions when that column is "a";
        # make it on "b", instead
        if num_columns_with_conditions == 1 and columns_with_conditions[0] == 'a':
            columns_with_conditions[0] = 'b'

        # Track which columns have equality/IS NULL conditions
        columns_with_equality = set()
        all_conditions = []

        # First pass: decide on equality vs inequality for each column
        for col in columns_with_conditions:
            # 50% chance of having an equality condition
            if random.random() < 0.5:
                condition = self.generate_cond(col, 'equality')
                all_conditions.append(condition)
                columns_with_equality.add(col)
            else:
                # Will handle inequalities in the next pass
                pass

        # Second pass: handle inequality conditions, possibly multiple per column
        for col in columns_with_conditions:
            if col not in columns_with_equality:
                # This column doesn't have an equality condition, so it can have multiple inequalities

                # Determine how many inequality conditions to add (1-3)
                weights = [0.25, 0.7, 0.05]  # Probabilities for 1, 2, or 3 inequality conditions
                num_conditions = random.choices([1, 2, 3], weights=weights)[0]
                # inequality_weights is self.inequality_operators-offset-wise list:
                inequality_weights = [0.20, 0.20, 0.20, 0.17, 0.03]

                for _ in range(num_conditions):
                    zero_weights = self.find_matching_operator_indices(all_conditions)
                    if zero_weights == set([0, 1, 2, 3, 4]):
                        break
                    for index, value in enumerate(zero_weights):
                        inequality_weights[value] = 0
                    condition = self.generate_cond(col, 'inequality',
                                                               inequality_weights=inequality_weights)
                    all_conditions.append(condition)

        # If we somehow ended up with no conditions (unlikely), add one
        if not all_conditions:
            col = random.choice(self.columns)
            all_conditions.append(self.generate_cond(col))

        all_conditions = self.sort_constraints(all_conditions)
        all_conditions = self.resolve_contradictions(all_conditions)
        return " AND ".join(all_conditions)

    def execute_test_query(self, where_clause):
        """Execute a test query with both sequential scan and index scan"""
        cursor = self.conn.cursor()

        # Force sequential scan
        cursor.execute("SET enable_indexscan = off; SET enable_bitmapscan = off;")
        seq_query = f"EXPLAIN ANALYZE SELECT * FROM {self.table_name} WHERE {where_clause}"
        cursor.execute(seq_query)
        seq_plan = cursor.fetchall()

        # Get sequential scan results
        cursor.execute(f"SELECT * FROM {self.table_name} WHERE {where_clause}")
        seq_results = cursor.fetchall()

        # Force index scan
        cursor.execute("SET enable_indexscan = on; SET enable_seqscan = off; SET enable_bitmapscan = off;")
        idx_query = f"EXPLAIN ANALYZE SELECT * FROM {self.table_name} WHERE {where_clause}"
        cursor.execute(idx_query)
        idx_plan = cursor.fetchall()

        # Get index scan results
        cursor.execute(f"SELECT * FROM {self.table_name} WHERE {where_clause}")
        idx_results = cursor.fetchall()

        # Reset scan settings
        cursor.execute("RESET enable_indexscan; RESET enable_seqscan; RESET enable_bitmapscan;")

        return {
            'where_clause': where_clause,
            'seq_plan': seq_plan,
            'idx_plan': idx_plan,
            'seq_results': seq_results,
            'idx_results': idx_results,
            'results_match': sorted(seq_results) == sorted(idx_results),
            'seq_count': len(seq_results),
            'idx_count': len(idx_results)
        }

    def verify_scan_results(self, test_result):
        """Verify that sequential scan and index scan results match"""
        if not test_result['results_match']:
            print("\n❌ TEST FAILED: Results do not match!")
            print(f"Query: SELECT * FROM {self.table_name} WHERE {test_result['where_clause']}")
            print(f"Sequential scan found {test_result['seq_count']} rows")
            print(f"Index scan found {test_result['idx_count']} rows")
            return False
        return True

    def run_fuzzing_queries(self):
        """Run a batch of random test queries and verify results"""
        print(f"\nRunning {self.num_tests} random test queries...")

        start_time = time.time()
        failures = 0
        multiple_inequality_count = 0

        for i in range(1, self.num_tests + 1):
            where_clause = self.generate_random_query()

            # Count queries with multiple inequalities on the same column (fixed)
            multiple_inequalities = False
            for column in self.columns:
                # Count occurrences of inequality operators for this column
                column_conditions = sum(1 for op in ['<', '<=', '>=', '>', 'IS NOT NULL']
                                      if f"{column} {op}" in where_clause)
                if column_conditions > 1:
                    multiple_inequalities = True
                    break

            if multiple_inequalities:
                multiple_inequality_count += 1

            test_result = self.execute_test_query(where_clause)

            if not self.verify_scan_results(test_result):
                failures += 1

            if i % 10 == 0:
                print(f"Completed {i} tests. Failures: {failures}")

        end_time = time.time()
        duration = end_time - start_time

        print(f"\nCompleted {self.num_tests} tests in {duration:.2f} seconds")
        print(f"Queries with multiple inequalities on the same column: {multiple_inequality_count}")
        print(f"Total failures: {failures}")

        if failures == 0:
            print("✅ All tests passed!")
        else:
            print(f"❌ {failures} tests failed!")

        return failures == 0

    def dump_plan_samples(self):
        """Analyze and print execution plans for a few sample queries"""
        print(f"\nAnalyzing execution plans for {self.num_samples} sample queries...")

        for i in range(self.num_samples):
            where_clause = self.generate_random_query()
            test_result = self.execute_test_query(where_clause)

            print(f"\nQuery {i+1}: SELECT * FROM {self.table_name} WHERE {where_clause}")
            print("\nSequential scan plan:")
            for line in test_result['seq_plan']:
                print(line[0])

            print("\nIndex scan plan:")
            for line in test_result['idx_plan']:
                print(line[0])

            print(f"\nResults match: {test_result['results_match']}")
            print(f"Row count: {test_result['seq_count']}")
            print("-" * 80)

    def cleanup(self):
        """Close connection"""
        if self.conn:
            try:
                cursor = self.conn.cursor()
                self.conn.commit()
                self.conn.close()
                print(f"Closed connection")
            except Exception as e:
                print(f"Cleanup error: {e}")

    def run_all_tests(self):
        """Run all test types"""
        try:
            self.connect()
            self.setup_test_environment()

            # Analyze some sample execution plans first, to preview the work
            # that run_fuzzing_queries() will do (runs quickly)
            self.dump_plan_samples()

            # The real work happens in run_fuzzing_queries() (takes a while)
            return self.run_fuzzing_queries()

        except Exception as e:
            print(f"Test error: {e}")
            return False
        finally:
            self.cleanup()


if __name__ == "__main__":
    # Connection parameters - adjust as needed
    conn_params = {
        "host": "localhost",
        "database": "regression",
        "user": "pg",
    }

    # Create and run the tester
    tester = PostgreSQLSkipScanTester(
        conn_params=conn_params,
        table_name="skip_scan_test",
        num_rows=100_000, # rows in `table_name` table
        num_samples=10, # Number of plan samples to dump (previews test query structure)
        num_tests=5_000 # Number of test queries
    )

    success = tester.run_all_tests()

    if success:
        print("\n✅ All tests completed successfully")
    else:
        print("\n❌ Test failures detected")
