#!/usr/bin/env python3

import sys
import re
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass

@dataclass
class PallocLine:
    line_num: int
    content: str
    file_path: str
    palloc_type: str  # 'palloc', 'palloc0', 'palloc_object', 'palloc_array', 'palloc0_object', 'palloc0_array'
    is_deletion: bool
    
    def __str__(self):
        prefix = "-" if self.is_deletion else "+"
        return f"{prefix} Line {self.line_num}: {self.content.strip()}"

def extract_palloc_type(line: str) -> Optional[str]:
    # Match all palloc variants
    patterns = [
        r'\bpalloc0_object\s*\(',
        r'\bpalloc0_array\s*\(',
        r'\bpalloc_object\s*\(',
        r'\bpalloc_array\s*\(',
        r'\bpalloc0\s*\(',
        r'\bpalloc\s*\('
    ]
    
    for pattern in patterns:
        if re.search(pattern, line):
            match = re.search(r'\b(palloc0?(?:_(?:object|array))?)\s*\(', line)
            if match:
                return match.group(1)
    
    return None

def is_zero_initialized(palloc_type: str) -> bool:
    return palloc_type in ['palloc0', 'palloc0_object', 'palloc0_array']

def parse_patch_file(patch_file: str) -> List[Tuple[PallocLine, Optional[PallocLine]]]:
    try:
        with open(patch_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
    except FileNotFoundError:
        print(f"Error: Patch file '{patch_file}' not found.")
        sys.exit(1)
    except Exception as e:
        print(f"Error reading patch file: {e}")
        sys.exit(1)
    
    pairs = []
    current_file = ""
    deletion_lines = []  # Track pending deletions
    
    for i, line in enumerate(lines):
        line_content = line.rstrip()
        
        # Track current file being patched
        if line_content.startswith('+++'):
            match = re.search(r'\+\+\+ b/(.+)', line_content)
            if match:
                current_file = match.group(1)
        
        # Handle deletion lines
        elif line_content.startswith('-'):
            palloc_type = extract_palloc_type(line_content)
            if palloc_type:
                deletion_line = PallocLine(
                    line_num=i + 1,
                    content=line_content,
                    file_path=current_file,
                    palloc_type=palloc_type,
                    is_deletion=True
                )
                deletion_lines.append(deletion_line)
        
        # Handle insertion lines
        elif line_content.startswith('+'):
            palloc_type = extract_palloc_type(line_content)
            if palloc_type:
                insertion_line = PallocLine(
                    line_num=i + 1,
                    content=line_content,
                    file_path=current_file,
                    palloc_type=palloc_type,
                    is_deletion=False
                )
                
                # Try to match with a pending deletion
                matched_deletion = None
                for j, deletion in enumerate(deletion_lines):
                    # Simple heuristic: match if they're close and in the same file
                    if (deletion.file_path == current_file and 
                        abs(deletion.line_num - insertion_line.line_num) < 10):
                        matched_deletion = deletion_lines.pop(j)
                        break
                
                pairs.append((matched_deletion, insertion_line))
        
        # Handle context lines (reset deletion buffer if we've moved past the change)
        elif line_content.startswith(' ') or line_content.startswith('@@'):
            # Keep recent deletions, clear old ones
            deletion_lines = [d for d in deletion_lines if abs(d.line_num - (i + 1)) < 5]
    
    # Add any unmatched deletions
    for deletion in deletion_lines:
        pairs.append((deletion, None))
    
    return pairs

def is_array_allocation(line: str) -> bool:
    # Look for patterns like: count * sizeof, sizeof(...) * count, etc.
    patterns = [
        r'\w+\s*\*\s*sizeof\s*\(',  # count * sizeof(
        r'sizeof\s*\([^)]+\)\s*\*\s*\w+',  # sizeof(...) * count
        r'\(\s*\w+[^)]*\)\s*\*\s*sizeof',  # (count) * sizeof
        r'sizeof\s*\([^)]+\)\s*\*\s*\(',  # sizeof(...) * (count)
        r'sizeof\s*\([^)]+\)\s*\*\s*\d+',  # sizeof(...) * number
        r'\d+\s*\*\s*sizeof\s*\(',  # number * sizeof(
    ]
    
    for pattern in patterns:
        if re.search(pattern, line):
            return True
    
    # Check for more complex multiplication patterns
    # This catches cases like: lenq * sizeof(bool), nparts * sizeof(Type), etc.
    if re.search(r'\b\w+\s*\*\s*sizeof\s*\(\s*\w+\s*\)', line):
        return True
    
    # Check for parenthesized expressions with sizeof
    if re.search(r'\([^)]*\w+[^)]*\)\s*\*\s*sizeof', line):
        return True
        
    return False

def is_single_object_allocation(line: str) -> bool:
    # Look for patterns like: sizeof(Type) without multiplication
    if re.search(r'sizeof\s*\([^)]+\)(?!\s*\*)', line):
        return True
    if re.search(r'sizeof\s*\(\s*\*\s*\w+\s*\)', line):  # sizeof(*ptr)
        return True
    return False

def verify_pair(deletion: Optional[PallocLine], insertion: Optional[PallocLine]) -> Tuple[bool, str]:
    if deletion is None and insertion is None:
        return False, " Empty pair"
    
    if deletion is None:
        return True, f"  New palloc addition: {insertion.palloc_type}"
    
    if insertion is None:
        return False, f" Unmatched deletion: {deletion.palloc_type} (no corresponding insertion found)"
    
    # Check if zero-initialization behavior is preserved
    deletion_is_zero = is_zero_initialized(deletion.palloc_type)
    insertion_is_zero = is_zero_initialized(insertion.palloc_type)
    
    if deletion_is_zero != insertion_is_zero:
        if deletion_is_zero and not insertion_is_zero:
            return False, f" ZERO-INIT LOST: {deletion.palloc_type} -> {insertion.palloc_type}"
        else:
            return False, f" UNEXPECTED ZERO-INIT: {deletion.palloc_type} -> {insertion.palloc_type}"
    
    # Check if array vs object conversion is semantically correct
    deletion_is_array = is_array_allocation(deletion.content)
    deletion_is_object = is_single_object_allocation(deletion.content)
    
    # Array detection takes precedence over object detection
    # If it's detected as both, it's an array (since arrays contain sizeof too)
    if deletion_is_array and deletion_is_object:
        deletion_is_object = False
    
    insertion_uses_array = insertion.palloc_type.endswith('_array')
    insertion_uses_object = insertion.palloc_type.endswith('_object')
    
    # Only check semantic correctness if we can clearly identify the pattern
    if deletion_is_array or deletion_is_object:
        # Verify semantic correctness of array vs object choice
        if deletion_is_array and insertion_uses_object:
            return False, f" ARRAY->OBJECT MISMATCH: Array allocation converted to _object: {deletion.palloc_type} -> {insertion.palloc_type}"
        
        if deletion_is_object and insertion_uses_array:
            return False, f" OBJECT->ARRAY MISMATCH: Single object allocation converted to _array: {deletion.palloc_type} -> {insertion.palloc_type}"
    
    # All checks passed
    conversion_type = ""
    if deletion_is_array and insertion_uses_array:
        conversion_type = " (ARRAY)"
    elif deletion_is_object and insertion_uses_object:
        conversion_type = " (OBJECT)"
    elif not deletion_is_array and not deletion_is_object:
        conversion_type = " (COMPLEX)"
    
    return True, f"CORRECT: {deletion.palloc_type} -> {insertion.palloc_type}{conversion_type}"

def main():
    if len(sys.argv) != 2:
        print("Usage: python verify_palloc_pairs.py <patch_file>")
        sys.exit(1)
    
    patch_file = sys.argv[1]
    print(f"Analyzing palloc pairs in patch file: {patch_file}")
    print("=" * 80)
    
    pairs = parse_patch_file(patch_file)
    
    if not pairs:
        print(" No palloc conversions found in the patch file.")
        return
    
    print(f"Found {len(pairs)} palloc conversion pairs to verify:")
    print()
    
    valid_count = 0
    invalid_count = 0
    issues = []
    
    for i, (deletion, insertion) in enumerate(pairs, 1):
        is_valid, message = verify_pair(deletion, insertion)
        
        print(f"{i:3d}. {message}")
        
        if deletion:
            print(f"     File: {deletion.file_path}")
            print(f"     Del:  {deletion.content.strip()}")
        if insertion:
            if not deletion or deletion.file_path != insertion.file_path:
                print(f"     File: {insertion.file_path}")
            print(f"     Add:  {insertion.content.strip()}")
        
        if is_valid:
            valid_count += 1
        else:
            invalid_count += 1
            issues.append((deletion, insertion, message))
        
        print()
    
    print("=" * 80)
    print(f" SUMMARY:")
    print(f"    Valid pairs: {valid_count}")
    print(f"    Invalid pairs: {invalid_count}")
    print(f"    Total pairs: {len(pairs)}")
    
    # Count conversion types
    array_conversions = 0
    object_conversions = 0
    complex_conversions = 0
    zero_init_conversions = 0
    non_zero_conversions = 0
    
    for i, (deletion, insertion) in enumerate(pairs, 1):
        if deletion and insertion:
            if insertion.palloc_type.endswith('_array'):
                array_conversions += 1
            elif insertion.palloc_type.endswith('_object'):
                object_conversions += 1
            else:
                complex_conversions += 1
                
            if is_zero_initialized(insertion.palloc_type):
                zero_init_conversions += 1
            else:
                non_zero_conversions += 1
    
    print()
    print(f"CONVERSION BREAKDOWN:")
    print(f"   Array conversions (_array): {array_conversions}")
    print(f"   Object conversions (_object): {object_conversions}")
    print(f"   Complex conversions: {complex_conversions}")
    print(f"   Zero-initialized: {zero_init_conversions}")
    print(f"   Non-zero: {non_zero_conversions}")
    
    if invalid_count == 0:
        print()
        print("SUCCESS: All palloc conversion pairs are semantically correct!")
        print("   Zero-initialization behavior preserved")
        print("   Array vs Object semantics correct")
        sys.exit(0)
    else:
        print()
        print("ISSUES FOUND:")
        for deletion, insertion, message in issues:
            print(f"   • {message}")
            if deletion:
                print(f"     File: {deletion.file_path} (line {deletion.line_num})")
            elif insertion:
                print(f"     File: {insertion.file_path} (line {insertion.line_num})")
        print()
        print(" FAILED: Some conversion pairs have semantic issues.")
        sys.exit(1)

if __name__ == "__main__":
    main()
