import math
from typing import Callable, Any, Tuple, List, Dict

class RecursiveRefinementEngine:
    """
    A core engine for Queen AIVA's recursive refinement loop, enabling infinite depth reasoning.
    This engine iteratively refines a given problem or query, aiming to improve answer quality
    until convergence or diminishing returns are detected.
    """

    def __init__(
        self,
        refinement_model: Callable[[Any, int, float], Tuple[Any, float]],
        max_iterations: int = 20,
        convergence_threshold: float = 0.005,
        diminishing_returns_threshold: float = 0.001,
        min_iterations_for_convergence: int = 3,
        plateau_window: int = 3
    ):
        """
        Initializes the RecursiveRefinementEngine.

        Args:
            refinement_model (Callable[[Any, int, float], Tuple[Any, float]]):
                A callable representing AIVA's refinement capability. It takes:
                - current_state (Any): The current refined answer or state.
                - iteration (int): The current iteration number.
                - current_quality (float): The quality score of the current_state.
                It must return a tuple:
                - new_state (Any): The refined answer/state for the next iteration.
                - new_quality_score (float): The quality score of the new_state.
                (Quality scores are expected to be non-negative and generally increasing).
            max_iterations (int): Maximum number of refinement iterations to prevent infinite loops.
            convergence_threshold (float): If the improvement in quality is below this threshold
                                          for 'plateau_window' iterations, consider it converged.
            diminishing_returns_threshold (float): If the improvement in quality is below this
                                                   for 'plateau_window' iterations, trigger diminishing returns.
                                                   This is typically stricter than convergence_threshold.
            min_iterations_for_convergence (int): Minimum number of iterations before convergence
                                                  or diminishing returns can be triggered.
            plateau_window (int): The number of consecutive iterations with low improvement to consider
                                  a plateau (convergence or diminishing returns).
        """
        if not callable(refinement_model):
            raise ValueError("refinement_model must be a callable function or object.")
        if max_iterations <= 0 or convergence_threshold <= 0 or diminishing_returns_threshold <= 0:
            raise ValueError("Thresholds and max_iterations must be positive.")
        if diminishing_returns_threshold >= convergence_threshold:
            raise ValueError("diminishing_returns_threshold should typically be stricter (smaller) than convergence_threshold.")

        self.refinement_model = refinement_model
        self.max_iterations = max_iterations
        self.convergence_threshold = convergence_threshold
        self.diminishing_returns_threshold = diminishing_returns_threshold
        self.min_iterations_for_convergence = min_iterations_for_convergence
        self.plateau_window = plateau_window

    def _calculate_quality_delta(self, current_quality: float, previous_quality: float) -> float:
        """
        Calculates the improvement in quality from the previous iteration.
        Returns 0 if previous_quality is 0 to avoid division by zero or negative quality.
        """
        if previous_quality <= 0:
            return current_quality # If previous quality was zero or negative, any positive current quality is a significant delta
        return (current_quality - previous_quality) / previous_quality

    def _check_plateau(self, history: List[Dict[str, Any]], threshold: float) -> bool:
        """
        Checks if the quality improvement has plateaued for the specified window.
        """
        if len(history) < self.plateau_window + 1: # Need at least window + 1 entries to compare window deltas
            return False

        # Check the last 'plateau_window' deltas
        for i in range(1, self.plateau_window + 1):
            current_entry = history[-i]
            previous_entry = history[-(i + 1)]
            delta = self._calculate_quality_delta(
                current_entry['quality'],
                previous_entry['quality']
            )
            if delta > threshold:
                return False # Improvement is still significant
        return True # All deltas in the window were below the threshold

    def run(self, initial_problem: Any, initial_quality: float = 0.0) -> Dict[str, Any]:
        """
        Executes the recursive refinement loop for a given initial problem.

        Args:
            initial_problem (Any): The initial state or query to be refined.
            initial_quality (float): The initial quality score of the problem (defaults to 0.0).

        Returns:
            Dict[str, Any]: A dictionary containing the final refined state, its quality,
                            the full refinement history, and the reason for stopping.
        """
        current_state = initial_problem
        current_quality = initial_quality
        history: List[Dict[str, Any]] = []
        stop_reason = "Max iterations reached"

        # Record initial state
        history.append({
            "iteration": 0,
            "state": current_state,
            "quality": current_quality,
            "delta_from_previous": 0.0 # No previous to compare to
        })

        for i in range(1, self.max_iterations + 1):
            try:
                # AIVA's core reasoning step: refine the current state
                new_state, new_quality = self.refinement_model(current_state, i, current_quality)
            except Exception as e:
                stop_reason = f"Refinement model error at iteration {i}: {e}"
                break

            # Ensure quality doesn't decrease unexpectedly, or at least doesn't go negative
            if new_quality < 0:
                new_quality = 0.0
            if new_quality < current_quality and i > self.min_iterations_for_convergence: # Allow initial exploration to potentially drop quality
                 # If quality drops significantly, it might indicate a divergent path
                if self._calculate_quality_delta(new_quality, current_quality) < -0.1: # -10% drop
                    stop_reason = f"Significant quality degradation detected at iteration {i}"
                    current_state = new_state
                    current_quality = new_quality
                    history.append({
                        "iteration": i,
                        "state": current_state,
                        "quality": current_quality,
                        "delta_from_previous": self._calculate_quality_delta(current_quality, history[-1]['quality'])
                    })
                    break

            delta_from_previous = self._calculate_quality_delta(new_quality, current_quality)

            history.append({
                "iteration": i,
                "state": new_state,
                "quality": new_quality,
                "delta_from_previous": delta_from_previous
            })

            # Update for next iteration
            current_state = new_state
            current_quality = new_quality

            # Check for stopping conditions only after min_iterations_for_convergence
            if i >= self.min_iterations_for_convergence:
                # Diminishing Returns check (stricter)
                if self._check_plateau(history, self.diminishing_returns_threshold):
                    stop_reason = "Diminishing returns detected"
                    break

                # Convergence check (less strict than diminishing returns)
                if self._check_plateau(history, self.convergence_threshold):
                    stop_reason = "Convergence detected"
                    break

        return {
            "final_state": current_state,
            "final_quality": current_quality,
            "iterations_run": len(history) - 1,
            "stop_reason": stop_reason,
            "history": history
        }

# --- Example Usage (for testing and demonstration) ---
if __name__ == "__main__":

    print("\n--- Simulating Queen AIVA's Recursive Refinement Loop ---")

    # A placeholder/mock refinement model for demonstration.
    # In AIVA's true form, this would be a sophisticated reasoning module.
    def mock_aiva_refinement_model(current_state: str, iteration: int, current_quality: float) -> Tuple[str, float]:
        # Simulate improving the answer and its quality
        new_state = f"Refinement {iteration}: {current_state.split(': ', 1)[-1]} -> Deeper insight into {iteration}"
        
        # Simulate quality improvement, with a tendency to plateau
        # Using a sigmoid-like curve for quality growth: fast initially, then slower
        quality_gain_factor = math.exp(-0.1 * (iteration - 1)) # Decreases as iteration increases
        # Ensure we don't just add, but use current quality as a base
        new_quality = current_quality + (1.0 / (1 + math.exp(-0.5 * iteration))) * quality_gain_factor * 10
        
        # Introduce some noise/variation to make detection more realistic
        new_quality += (math.random() - 0.5) * 0.1 # Small random fluctuation

        return new_state, max(0.0, new_quality) # Quality cannot be negative

    # Initialize the engine
    engine = RecursiveRefinementEngine(
        refinement_model=mock_aiva_refinement_model,
        max_iterations=30,
        convergence_threshold=0.01,         # 1% improvement considered significant
        diminishing_returns_threshold=0.002, # 0.2% improvement triggers diminishing returns
        min_iterations_for_convergence=5,   # Don't stop too early
        plateau_window=4                    # Need 4 consecutive low-improvement iterations
    )

    initial_problem_statement = "Initial query: Understand the nature of consciousness"
    print(f"Initial Problem: {initial_problem_statement}")

    # Run the refinement process
    results = engine.run(initial_problem_statement, initial_quality=1.0)

    print("\n--- Refinement History ---")
    for entry in results['history']:
        print(f"Iter {entry['iteration']:2d}: Quality={entry['quality']:.4f}, Delta={entry['delta_from_previous']:.4f}, State='{entry['state'][:80]}...'" )

    print("\n--- Refinement Summary ---")
    print(f"Final State: {results['final_state'][:120]}...")
    print(f"Final Quality: {results['final_quality']:.4f}")
    print(f"Iterations Run: {results['iterations_run']}")
    print(f"Stop Reason: {results['stop_reason']}")

    print("\n--- Verification of Acceptance Criteria (Simulated) ---")
    # 1. Answer quality improves with iterations
    initial_q = results['history'][0]['quality']
    final_q = results['final_quality']
    if final_q > initial_q:
        print(f"[MET] Answer quality improved from {initial_q:.4f} to {final_q:.4f}.")
    else:
        print(f"[FAILED] Answer quality did not improve or decreased from {initial_q:.4f} to {final_q:.4f}.")

    # 2. Convergence detection
    if "Convergence detected" in results['stop_reason'] or "Diminishing returns detected" in results['stop_reason']:
        print("[MET] Convergence/Diminishing returns detection was triggered.")
    else:
        print("[FAILED] Convergence/Diminishing returns detection was NOT triggered (stopped by max iterations or error).")

    # 3. Diminishing returns trigger
    if "Diminishing returns detected" in results['stop_reason']:
        print("[MET] Diminishing returns trigger was activated.")
    else:
        print("[NOT MET/N/A] Diminishing returns trigger was not the primary stop reason (might have converged first or hit max iterations).")

    print("\n--- Second Example: Faster Convergence ---")
    def fast_converge_model(current_state: str, iteration: int, current_quality: float) -> Tuple[str, float]:
        new_state = f"Fast Refinement {iteration}: {current_state.split(': ', 1)[-1]} -> Resolved aspect {iteration}"
        new_quality = current_quality + (0.5 / (1 + math.exp(-0.2 * iteration))) * 5 # Converges quicker
        new_quality += (math.random() - 0.5) * 0.05 # Smaller noise
        return new_state, max(0.0, new_quality)

    engine_fast = RecursiveRefinementEngine(
        refinement_model=fast_converge_model,
        max_iterations=15,
        convergence_threshold=0.02,
        diminishing_returns_threshold=0.01,
        min_iterations_for_convergence=2,
        plateau_window=2
    )
    results_fast = engine_fast.run("Initial problem for fast convergence", initial_quality=10.0)
    print(f"Fast Convergence Stop Reason: {results_fast['stop_reason']}")
    print(f"Fast Convergence Final Quality: {results_fast['final_quality']:.4f}")
    print(f"Fast Convergence Iterations: {results_fast['iterations_run']}")
    if "Convergence detected" in results_fast['stop_reason'] or "Diminishing returns detected" in results_fast['stop_reason']:
        print("[MET] Fast convergence scenario triggered detection.")
    else:
        print("[FAILED] Fast convergence scenario did not trigger detection.")

    print("\n--- Third Example: Max Iterations Reached ---")
    def slow_progress_model(current_state: str, iteration: int, current_quality: float) -> Tuple[str, float]:
        new_state = f"Slow Refinement {iteration}: {current_state.split(': ', 1)[-1]} -> Minor insight {iteration}"
        new_quality = current_quality + 0.01 + (math.random() - 0.5) * 0.005 # Very slow, consistent gain
        return new_state, max(0.0, new_quality)

    engine_slow = RecursiveRefinementEngine(
        refinement_model=slow_progress_model,
        max_iterations=10,
        convergence_threshold=0.05, # High threshold, unlikely to converge
        diminishing_returns_threshold=0.01,
        min_iterations_for_convergence=2,
        plateau_window=2
    )
    results_slow = engine_slow.run("Initial problem for max iterations", initial_quality=1.0)
    print(f"Max Iterations Stop Reason: {results_slow['stop_reason']}")
    print(f"Max Iterations Final Quality: {results_slow['final_quality']:.4f}")
    print(f"Max Iterations Iterations: {results_slow['iterations_run']}")
    if "Max iterations reached" in results_slow['stop_reason']:
        print("[MET] Max iterations correctly stopped the process.")
    else:
        print("[FAILED] Max iterations did not stop the process when expected.")
