"""
Neural 02: Chain-of-Thought Reasoning System for AIVA Queen

A comprehensive reasoning engine implementing:
1. ChainOfThought - Generate step-by-step reasoning traces
2. ThoughtDecomposer - Break complex problems into atomic steps
3. IntermediateVerifier - Verify correctness at each reasoning step
4. ReasoningTree - Tree-structured reasoning with branching paths
5. SelfConsistency - Sample multiple chains and aggregate
6. FinalAggregator - Combine reasoning paths into final answer

This module provides advanced cognitive capabilities for the AIVA Queen
architecture, enabling transparent, verifiable, and self-correcting reasoning.

Author: Genesis System
Version: 1.0.0
"""

import logging
import json
import hashlib
import time
import random
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, asdict
from typing import (
    Any, Callable, Dict, List, Optional, Tuple,
    Union, Set, TypeVar, Generic, Iterator
)
from enum import Enum, auto
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
import uuid

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


# =============================================================================
# ENUMS AND TYPE DEFINITIONS
# =============================================================================

class ThoughtType(Enum):
    """Classification of thought types in the reasoning chain."""
    OBSERVATION = auto()      # Direct observation from input
    INFERENCE = auto()        # Logical inference from prior thoughts
    HYPOTHESIS = auto()       # Tentative conclusion requiring verification
    CALCULATION = auto()      # Mathematical or computational step
    RETRIEVAL = auto()        # Information retrieved from memory/knowledge
    DECOMPOSITION = auto()    # Breaking down into sub-problems
    SYNTHESIS = auto()        # Combining multiple conclusions
    VERIFICATION = auto()     # Checking validity of prior step
    CORRECTION = auto()       # Fixing an error in reasoning
    CONCLUSION = auto()       # Final answer or result


class VerificationStatus(Enum):
    """Status of intermediate verification."""
    UNVERIFIED = auto()
    VERIFIED_VALID = auto()
    VERIFIED_INVALID = auto()
    VERIFICATION_FAILED = auto()
    SKIPPED = auto()


class AggregationStrategy(Enum):
    """Strategy for aggregating multiple reasoning paths."""
    MAJORITY_VOTE = auto()
    WEIGHTED_CONFIDENCE = auto()
    BEST_PATH = auto()
    ENSEMBLE = auto()
    CONSENSUS = auto()


T = TypeVar('T')


# =============================================================================
# DATA CLASSES
# =============================================================================

@dataclass
class Thought:
    """
    A single step in a chain-of-thought reasoning process.

    Attributes:
        id: Unique identifier for this thought
        content: The actual reasoning content
        thought_type: Classification of this thought
        confidence: Confidence score [0, 1]
        parent_ids: IDs of thoughts this one depends on
        metadata: Additional context and information
        verification_status: Whether this thought has been verified
        verification_details: Details from verification process
        timestamp: When this thought was generated
    """
    id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
    content: str = ""
    thought_type: ThoughtType = ThoughtType.INFERENCE
    confidence: float = 1.0
    parent_ids: List[str] = field(default_factory=list)
    metadata: Dict[str, Any] = field(default_factory=dict)
    verification_status: VerificationStatus = VerificationStatus.UNVERIFIED
    verification_details: Optional[str] = None
    timestamp: float = field(default_factory=time.time)

    def __hash__(self):
        return hash(self.id)

    def __eq__(self, other):
        if isinstance(other, Thought):
            return self.id == other.id
        return False

    def to_dict(self) -> Dict[str, Any]:
        """Convert thought to dictionary for serialization."""
        return {
            "id": self.id,
            "content": self.content,
            "thought_type": self.thought_type.name,
            "confidence": self.confidence,
            "parent_ids": self.parent_ids,
            "metadata": self.metadata,
            "verification_status": self.verification_status.name,
            "verification_details": self.verification_details,
            "timestamp": self.timestamp
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'Thought':
        """Create thought from dictionary."""
        return cls(
            id=data.get("id", str(uuid.uuid4())[:8]),
            content=data.get("content", ""),
            thought_type=ThoughtType[data.get("thought_type", "INFERENCE")],
            confidence=data.get("confidence", 1.0),
            parent_ids=data.get("parent_ids", []),
            metadata=data.get("metadata", {}),
            verification_status=VerificationStatus[data.get("verification_status", "UNVERIFIED")],
            verification_details=data.get("verification_details"),
            timestamp=data.get("timestamp", time.time())
        )


@dataclass
class ReasoningChain:
    """
    A complete chain of reasoning steps.

    Attributes:
        id: Unique identifier for this chain
        problem: The original problem/query
        thoughts: Ordered list of thoughts in the chain
        final_answer: The concluded answer
        total_confidence: Aggregate confidence score
        metadata: Additional chain-level information
        created_at: Timestamp of chain creation
    """
    id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
    problem: str = ""
    thoughts: List[Thought] = field(default_factory=list)
    final_answer: Optional[str] = None
    total_confidence: float = 0.0
    metadata: Dict[str, Any] = field(default_factory=dict)
    created_at: float = field(default_factory=time.time)

    def add_thought(self, thought: Thought) -> None:
        """Add a thought to the chain."""
        self.thoughts.append(thought)
        self._recalculate_confidence()

    def _recalculate_confidence(self) -> None:
        """Recalculate total chain confidence."""
        if not self.thoughts:
            self.total_confidence = 0.0
            return

        # Geometric mean of individual confidences
        product = 1.0
        for thought in self.thoughts:
            product *= thought.confidence
        self.total_confidence = product ** (1 / len(self.thoughts))

    def get_thought_by_id(self, thought_id: str) -> Optional[Thought]:
        """Retrieve a thought by its ID."""
        for thought in self.thoughts:
            if thought.id == thought_id:
                return thought
        return None

    def get_verification_summary(self) -> Dict[str, int]:
        """Get summary of verification statuses."""
        summary = defaultdict(int)
        for thought in self.thoughts:
            summary[thought.verification_status.name] += 1
        return dict(summary)

    def to_dict(self) -> Dict[str, Any]:
        """Convert chain to dictionary."""
        return {
            "id": self.id,
            "problem": self.problem,
            "thoughts": [t.to_dict() for t in self.thoughts],
            "final_answer": self.final_answer,
            "total_confidence": self.total_confidence,
            "metadata": self.metadata,
            "created_at": self.created_at
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningChain':
        """Create chain from dictionary."""
        chain = cls(
            id=data.get("id", str(uuid.uuid4())[:8]),
            problem=data.get("problem", ""),
            final_answer=data.get("final_answer"),
            total_confidence=data.get("total_confidence", 0.0),
            metadata=data.get("metadata", {}),
            created_at=data.get("created_at", time.time())
        )
        for thought_data in data.get("thoughts", []):
            chain.thoughts.append(Thought.from_dict(thought_data))
        return chain

    def __str__(self) -> str:
        """Human-readable representation of the chain."""
        lines = [f"=== Reasoning Chain [{self.id}] ==="]
        lines.append(f"Problem: {self.problem}")
        lines.append(f"Total Confidence: {self.total_confidence:.2%}")
        lines.append("-" * 40)

        for i, thought in enumerate(self.thoughts, 1):
            status_icon = {
                VerificationStatus.UNVERIFIED: "?",
                VerificationStatus.VERIFIED_VALID: "V",
                VerificationStatus.VERIFIED_INVALID: "X",
                VerificationStatus.VERIFICATION_FAILED: "!",
                VerificationStatus.SKIPPED: "-"
            }.get(thought.verification_status, "?")

            lines.append(f"[{i}] [{status_icon}] {thought.thought_type.name}")
            lines.append(f"    {thought.content}")
            lines.append(f"    Confidence: {thought.confidence:.2%}")

        lines.append("-" * 40)
        lines.append(f"Final Answer: {self.final_answer}")
        return "\n".join(lines)


@dataclass
class TreeNode:
    """
    A node in the reasoning tree structure.

    Attributes:
        thought: The thought at this node
        children: Child nodes (alternative/subsequent reasoning paths)
        parent: Reference to parent node
        depth: Depth in the tree
        path_confidence: Cumulative confidence along path to this node
    """
    thought: Thought
    children: List['TreeNode'] = field(default_factory=list)
    parent: Optional['TreeNode'] = None
    depth: int = 0
    path_confidence: float = 1.0

    def add_child(self, thought: Thought) -> 'TreeNode':
        """Add a child node with the given thought."""
        child = TreeNode(
            thought=thought,
            parent=self,
            depth=self.depth + 1,
            path_confidence=self.path_confidence * thought.confidence
        )
        self.children.append(child)
        return child

    def is_leaf(self) -> bool:
        """Check if this is a leaf node."""
        return len(self.children) == 0

    def get_path_to_root(self) -> List['TreeNode']:
        """Get the path from this node to the root."""
        path = [self]
        current = self
        while current.parent is not None:
            path.append(current.parent)
            current = current.parent
        return list(reversed(path))

    def get_path_thoughts(self) -> List[Thought]:
        """Get all thoughts along the path to root."""
        return [node.thought for node in self.get_path_to_root()]


# =============================================================================
# ABSTRACT BASE CLASSES
# =============================================================================

class ThoughtGenerator(ABC):
    """Abstract base class for thought generation strategies."""

    @abstractmethod
    def generate(self, context: str, prior_thoughts: List[Thought]) -> Thought:
        """Generate the next thought given context and prior thoughts."""
        pass


class Verifier(ABC):
    """Abstract base class for thought verification."""

    @abstractmethod
    def verify(self, thought: Thought, context: str) -> Tuple[VerificationStatus, str]:
        """Verify a thought and return status with explanation."""
        pass


class Decomposer(ABC):
    """Abstract base class for problem decomposition."""

    @abstractmethod
    def decompose(self, problem: str) -> List[str]:
        """Decompose a problem into sub-problems."""
        pass


# =============================================================================
# CHAIN OF THOUGHT GENERATOR
# =============================================================================

class ChainOfThought:
    """
    Core chain-of-thought reasoning engine.

    Generates step-by-step reasoning traces for complex problems,
    with support for various generation strategies and verification.

    Attributes:
        thought_generator: Strategy for generating individual thoughts
        verifier: Optional verifier for intermediate steps
        max_steps: Maximum number of reasoning steps
        confidence_threshold: Minimum confidence to continue chain
        enable_backtracking: Whether to backtrack on verification failures
    """

    def __init__(
        self,
        thought_generator: Optional[ThoughtGenerator] = None,
        verifier: Optional[Verifier] = None,
        max_steps: int = 20,
        confidence_threshold: float = 0.3,
        enable_backtracking: bool = True,
        llm_callable: Optional[Callable[[str], str]] = None
    ):
        """
        Initialize the ChainOfThought engine.

        Args:
            thought_generator: Custom thought generator (uses default if None)
            verifier: Custom verifier for intermediate steps
            max_steps: Maximum reasoning steps allowed
            confidence_threshold: Stop if confidence drops below this
            enable_backtracking: Allow backtracking on failures
            llm_callable: Function to call LLM for reasoning (prompt -> response)
        """
        self.thought_generator = thought_generator or DefaultThoughtGenerator(llm_callable)
        self.verifier = verifier
        self.max_steps = max_steps
        self.confidence_threshold = confidence_threshold
        self.enable_backtracking = enable_backtracking
        self.llm_callable = llm_callable

        self._chain_history: List[ReasoningChain] = []

        logger.info(f"ChainOfThought initialized with max_steps={max_steps}")

    def reason(
        self,
        problem: str,
        context: Optional[str] = None,
        initial_thoughts: Optional[List[Thought]] = None
    ) -> ReasoningChain:
        """
        Generate a complete chain of reasoning for a problem.

        Args:
            problem: The problem/query to reason about
            context: Optional additional context
            initial_thoughts: Optional seed thoughts to start from

        Returns:
            Complete ReasoningChain with all steps and final answer
        """
        logger.info(f"Starting reasoning for: {problem[:100]}...")

        chain = ReasoningChain(problem=problem)
        chain.metadata["context"] = context

        # Add initial thoughts if provided
        if initial_thoughts:
            for thought in initial_thoughts:
                chain.add_thought(thought)

        # Add initial observation thought
        observation = Thought(
            content=f"Problem to solve: {problem}",
            thought_type=ThoughtType.OBSERVATION,
            confidence=1.0,
            metadata={"source": "input"}
        )
        chain.add_thought(observation)

        # Generate reasoning steps
        full_context = f"{context or ''}\n\nProblem: {problem}"
        step_count = 0
        backtrack_count = 0
        max_backtracks = 3

        while step_count < self.max_steps:
            step_count += 1

            # Generate next thought
            thought = self.thought_generator.generate(full_context, chain.thoughts)

            # Check for termination
            if thought.thought_type == ThoughtType.CONCLUSION:
                chain.add_thought(thought)
                chain.final_answer = thought.content
                break

            # Verify if verifier is available
            if self.verifier:
                status, details = self.verifier.verify(thought, full_context)
                thought.verification_status = status
                thought.verification_details = details

                if status == VerificationStatus.VERIFIED_INVALID:
                    logger.warning(f"Step {step_count} verification failed: {details}")

                    if self.enable_backtracking and backtrack_count < max_backtracks:
                        backtrack_count += 1
                        # Generate correction thought
                        correction = Thought(
                            content=f"Correcting previous reasoning: {details}",
                            thought_type=ThoughtType.CORRECTION,
                            confidence=0.7,
                            parent_ids=[thought.id],
                            metadata={"correction_for": thought.id}
                        )
                        chain.add_thought(correction)
                        continue

            chain.add_thought(thought)

            # Check confidence threshold
            if chain.total_confidence < self.confidence_threshold:
                logger.warning(f"Chain confidence {chain.total_confidence:.2%} below threshold")
                break

        # Ensure we have a final answer
        if chain.final_answer is None:
            chain.final_answer = self._synthesize_answer(chain)

        self._chain_history.append(chain)
        logger.info(f"Completed reasoning with {len(chain.thoughts)} steps")

        return chain

    def _synthesize_answer(self, chain: ReasoningChain) -> str:
        """Synthesize final answer from chain if not explicitly concluded."""
        synthesis_thoughts = [
            t for t in chain.thoughts
            if t.thought_type in (ThoughtType.SYNTHESIS, ThoughtType.INFERENCE)
        ]

        if synthesis_thoughts:
            # Take the highest confidence synthesis
            best = max(synthesis_thoughts, key=lambda t: t.confidence)
            return best.content

        # Fallback: concatenate last few thoughts
        recent = chain.thoughts[-3:] if len(chain.thoughts) >= 3 else chain.thoughts
        return " -> ".join([t.content for t in recent])

    def get_history(self) -> List[ReasoningChain]:
        """Get history of all reasoning chains."""
        return self._chain_history.copy()

    def clear_history(self) -> None:
        """Clear reasoning chain history."""
        self._chain_history.clear()


class DefaultThoughtGenerator(ThoughtGenerator):
    """
    Default thought generator using pattern-based or LLM-based generation.

    Supports both rule-based generation for simple cases and
    LLM-based generation for complex reasoning.
    """

    REASONING_PATTERNS = [
        ("if", "then", ThoughtType.INFERENCE),
        ("because", "therefore", ThoughtType.INFERENCE),
        ("calculate", "result", ThoughtType.CALCULATION),
        ("observe", "notice", ThoughtType.OBSERVATION),
        ("assume", "suppose", ThoughtType.HYPOTHESIS),
        ("combine", "together", ThoughtType.SYNTHESIS),
        ("verify", "check", ThoughtType.VERIFICATION),
        ("conclude", "finally", ThoughtType.CONCLUSION),
    ]

    def __init__(self, llm_callable: Optional[Callable[[str], str]] = None):
        """
        Initialize the generator.

        Args:
            llm_callable: Optional LLM function for advanced generation
        """
        self.llm_callable = llm_callable
        self._step_counter = 0

    def generate(self, context: str, prior_thoughts: List[Thought]) -> Thought:
        """Generate the next reasoning step."""
        self._step_counter += 1

        if self.llm_callable:
            return self._generate_with_llm(context, prior_thoughts)

        return self._generate_rule_based(context, prior_thoughts)

    def _generate_with_llm(self, context: str, prior_thoughts: List[Thought]) -> Thought:
        """Generate thought using LLM."""
        # Build prompt from prior thoughts
        thought_trace = "\n".join([
            f"Step {i+1}: [{t.thought_type.name}] {t.content}"
            for i, t in enumerate(prior_thoughts)
        ])

        prompt = f"""Given the following context and reasoning so far, generate the next logical reasoning step.

Context:
{context}

Reasoning so far:
{thought_trace}

Generate the next step in the reasoning process. If you have reached a conclusion, start with "CONCLUSION:".
Otherwise, explain your reasoning for this step.

Next step:"""

        try:
            response = self.llm_callable(prompt)

            # Determine thought type from response
            thought_type = ThoughtType.INFERENCE
            confidence = 0.85

            if response.upper().startswith("CONCLUSION:"):
                thought_type = ThoughtType.CONCLUSION
                response = response[11:].strip()
                confidence = 0.9
            elif "calculate" in response.lower() or "=" in response:
                thought_type = ThoughtType.CALCULATION
            elif "assume" in response.lower() or "suppose" in response.lower():
                thought_type = ThoughtType.HYPOTHESIS
                confidence = 0.7
            elif "verify" in response.lower() or "check" in response.lower():
                thought_type = ThoughtType.VERIFICATION

            parent_ids = [prior_thoughts[-1].id] if prior_thoughts else []

            return Thought(
                content=response,
                thought_type=thought_type,
                confidence=confidence,
                parent_ids=parent_ids,
                metadata={"generation_method": "llm", "step": self._step_counter}
            )

        except Exception as e:
            logger.error(f"LLM generation failed: {e}")
            return self._generate_rule_based(context, prior_thoughts)

    def _generate_rule_based(self, context: str, prior_thoughts: List[Thought]) -> Thought:
        """Generate thought using rule-based patterns."""
        # Determine what type of thought to generate based on history
        num_thoughts = len(prior_thoughts)

        if num_thoughts == 0:
            thought_type = ThoughtType.OBSERVATION
            content = f"Initial analysis of the problem context"
        elif num_thoughts < 3:
            thought_type = ThoughtType.DECOMPOSITION
            content = f"Breaking down the problem into components (step {self._step_counter})"
        elif num_thoughts < 6:
            thought_type = ThoughtType.INFERENCE
            content = f"Drawing inference from previous observations (step {self._step_counter})"
        elif num_thoughts < 8:
            thought_type = ThoughtType.SYNTHESIS
            content = f"Synthesizing findings from prior steps (step {self._step_counter})"
        else:
            thought_type = ThoughtType.CONCLUSION
            content = f"Based on the analysis, the conclusion is derived from steps 1-{num_thoughts}"

        # Calculate confidence based on step number
        confidence = max(0.5, 1.0 - (self._step_counter * 0.05))
        parent_ids = [prior_thoughts[-1].id] if prior_thoughts else []

        return Thought(
            content=content,
            thought_type=thought_type,
            confidence=confidence,
            parent_ids=parent_ids,
            metadata={"generation_method": "rule_based", "step": self._step_counter}
        )


# =============================================================================
# THOUGHT DECOMPOSER
# =============================================================================

class ThoughtDecomposer(Decomposer):
    """
    Decomposes complex problems into simpler sub-problems.

    Uses various strategies to break down problems:
    - Syntactic decomposition (sentence structure)
    - Semantic decomposition (meaning units)
    - Logical decomposition (premises and conclusions)
    - Hierarchical decomposition (top-down refinement)

    Attributes:
        strategy: Decomposition strategy to use
        max_depth: Maximum decomposition depth
        min_granularity: Minimum size of sub-problems
        llm_callable: Optional LLM for semantic decomposition
    """

    class Strategy(Enum):
        SYNTACTIC = auto()
        SEMANTIC = auto()
        LOGICAL = auto()
        HIERARCHICAL = auto()
        HYBRID = auto()

    def __init__(
        self,
        strategy: Strategy = Strategy.HYBRID,
        max_depth: int = 5,
        min_granularity: int = 10,
        llm_callable: Optional[Callable[[str], str]] = None
    ):
        """
        Initialize the decomposer.

        Args:
            strategy: Which decomposition strategy to use
            max_depth: Maximum recursion depth
            min_granularity: Minimum character length for sub-problems
            llm_callable: Optional LLM for advanced decomposition
        """
        self.strategy = strategy
        self.max_depth = max_depth
        self.min_granularity = min_granularity
        self.llm_callable = llm_callable

        logger.info(f"ThoughtDecomposer initialized with {strategy.name} strategy")

    def decompose(self, problem: str, depth: int = 0) -> List[str]:
        """
        Decompose a problem into sub-problems.

        Args:
            problem: The problem to decompose
            depth: Current recursion depth

        Returns:
            List of sub-problems
        """
        if depth >= self.max_depth or len(problem) <= self.min_granularity:
            return [problem]

        if self.strategy == self.Strategy.SYNTACTIC:
            return self._syntactic_decompose(problem)
        elif self.strategy == self.Strategy.SEMANTIC:
            return self._semantic_decompose(problem)
        elif self.strategy == self.Strategy.LOGICAL:
            return self._logical_decompose(problem)
        elif self.strategy == self.Strategy.HIERARCHICAL:
            return self._hierarchical_decompose(problem, depth)
        else:  # HYBRID
            return self._hybrid_decompose(problem, depth)

    def _syntactic_decompose(self, problem: str) -> List[str]:
        """Decompose based on sentence structure."""
        # Split on conjunctions and sentence boundaries
        delimiters = ['. ', '? ', '! ', ' and ', ' but ', ' or ', '; ', ', which ', ', that ']

        parts = [problem]
        for delimiter in delimiters:
            new_parts = []
            for part in parts:
                splits = part.split(delimiter)
                new_parts.extend([s.strip() for s in splits if s.strip()])
            parts = new_parts

        return parts if len(parts) > 1 else [problem]

    def _semantic_decompose(self, problem: str) -> List[str]:
        """Decompose based on semantic meaning."""
        if self.llm_callable:
            prompt = f"""Break down the following problem into distinct semantic components.
Each component should represent a single concept or question to address.

Problem: {problem}

List each component on a new line, numbered 1, 2, 3, etc.
Components:"""

            try:
                response = self.llm_callable(prompt)
                # Parse numbered list
                components = []
                for line in response.split('\n'):
                    line = line.strip()
                    if line and line[0].isdigit():
                        # Remove numbering
                        content = re.sub(r'^\d+[\.\)]\s*', '', line)
                        if content:
                            components.append(content)
                return components if components else [problem]
            except Exception as e:
                logger.error(f"Semantic decomposition failed: {e}")

        return self._syntactic_decompose(problem)

    def _logical_decompose(self, problem: str) -> List[str]:
        """Decompose into logical premises and conclusions."""
        components = []

        # Identify premises (if, given, assuming, since)
        premise_patterns = [
            r'[Ii]f\s+([^,]+)',
            r'[Gg]iven\s+([^,]+)',
            r'[Aa]ssuming\s+([^,]+)',
            r'[Ss]ince\s+([^,]+)',
            r'[Bb]ecause\s+([^,]+)'
        ]

        for pattern in premise_patterns:
            matches = re.findall(pattern, problem)
            for match in matches:
                components.append(f"Premise: {match.strip()}")

        # Identify conclusion markers
        conclusion_patterns = [
            r'[Tt]hen\s+([^,\.]+)',
            r'[Tt]herefore\s+([^,\.]+)',
            r'[Hh]ence\s+([^,\.]+)',
            r'[Ss]o\s+([^,\.]+)',
            r'[Ww]hat\s+is\s+([^,\?]+)'
        ]

        for pattern in conclusion_patterns:
            matches = re.findall(pattern, problem)
            for match in matches:
                components.append(f"Conclusion to derive: {match.strip()}")

        if not components:
            components = [f"Analyze: {problem}"]

        return components

    def _hierarchical_decompose(self, problem: str, depth: int) -> List[str]:
        """Recursively decompose into a hierarchy."""
        # First level decomposition
        initial_parts = self._syntactic_decompose(problem)

        if depth < self.max_depth - 1:
            # Recursively decompose each part
            all_parts = []
            for part in initial_parts:
                if len(part) > self.min_granularity:
                    sub_parts = self._hierarchical_decompose(part, depth + 1)
                    all_parts.extend(sub_parts)
                else:
                    all_parts.append(part)
            return all_parts

        return initial_parts

    def _hybrid_decompose(self, problem: str, depth: int) -> List[str]:
        """Combine multiple decomposition strategies."""
        all_components = set()

        # Try syntactic
        syntactic = self._syntactic_decompose(problem)
        all_components.update(syntactic)

        # Try logical
        logical = self._logical_decompose(problem)
        all_components.update(logical)

        # If we have LLM, try semantic as well
        if self.llm_callable and depth == 0:
            semantic = self._semantic_decompose(problem)
            all_components.update(semantic)

        return list(all_components) if all_components else [problem]

    def get_decomposition_tree(self, problem: str) -> Dict[str, Any]:
        """
        Generate a full decomposition tree.

        Returns a nested dictionary showing the hierarchical structure.
        """
        def build_tree(text: str, depth: int) -> Dict[str, Any]:
            if depth >= self.max_depth or len(text) <= self.min_granularity:
                return {"text": text, "children": []}

            children = self.decompose(text, depth)
            return {
                "text": text,
                "depth": depth,
                "children": [build_tree(child, depth + 1) for child in children if child != text]
            }

        return build_tree(problem, 0)


# =============================================================================
# INTERMEDIATE VERIFIER
# =============================================================================

class IntermediateVerifier(Verifier):
    """
    Verifies the correctness of intermediate reasoning steps.

    Implements multiple verification strategies:
    - Logical consistency checking
    - Factual grounding verification
    - Mathematical validation
    - Self-consistency checks

    Attributes:
        verification_functions: Custom verification functions
        strict_mode: Whether to fail on any issue
        cache_verifications: Whether to cache results
        llm_callable: Optional LLM for complex verification
    """

    def __init__(
        self,
        verification_functions: Optional[List[Callable[[Thought, str], bool]]] = None,
        strict_mode: bool = False,
        cache_verifications: bool = True,
        llm_callable: Optional[Callable[[str], str]] = None
    ):
        """
        Initialize the verifier.

        Args:
            verification_functions: Custom verification functions
            strict_mode: Fail on any verification issue
            cache_verifications: Cache verification results
            llm_callable: Optional LLM for verification
        """
        self.verification_functions = verification_functions or []
        self.strict_mode = strict_mode
        self.cache_verifications = cache_verifications
        self.llm_callable = llm_callable

        self._cache: Dict[str, Tuple[VerificationStatus, str]] = {}
        self._verification_stats = defaultdict(int)

        logger.info("IntermediateVerifier initialized")

    def verify(self, thought: Thought, context: str) -> Tuple[VerificationStatus, str]:
        """
        Verify a thought against context.

        Args:
            thought: The thought to verify
            context: The reasoning context

        Returns:
            Tuple of (verification status, explanation)
        """
        # Check cache
        cache_key = f"{thought.id}:{hashlib.md5(context.encode()).hexdigest()[:8]}"
        if self.cache_verifications and cache_key in self._cache:
            self._verification_stats["cache_hits"] += 1
            return self._cache[cache_key]

        self._verification_stats["total_verifications"] += 1

        issues = []

        # Run built-in verifications
        issues.extend(self._check_logical_consistency(thought, context))
        issues.extend(self._check_confidence_validity(thought))
        issues.extend(self._check_dependency_validity(thought))

        # Run custom verification functions
        for verify_fn in self.verification_functions:
            try:
                if not verify_fn(thought, context):
                    issues.append(f"Custom verification failed: {verify_fn.__name__}")
            except Exception as e:
                issues.append(f"Verification error in {verify_fn.__name__}: {e}")

        # LLM-based verification if available
        if self.llm_callable:
            llm_issues = self._llm_verify(thought, context)
            issues.extend(llm_issues)

        # Determine final status
        if not issues:
            status = VerificationStatus.VERIFIED_VALID
            details = "All verification checks passed"
            self._verification_stats["valid"] += 1
        elif self.strict_mode:
            status = VerificationStatus.VERIFIED_INVALID
            details = "; ".join(issues)
            self._verification_stats["invalid"] += 1
        else:
            # In non-strict mode, only fail on critical issues
            critical = [i for i in issues if "critical" in i.lower() or "invalid" in i.lower()]
            if critical:
                status = VerificationStatus.VERIFIED_INVALID
                details = "; ".join(critical)
                self._verification_stats["invalid"] += 1
            else:
                status = VerificationStatus.VERIFIED_VALID
                details = f"Passed with warnings: {'; '.join(issues)}"
                self._verification_stats["valid_with_warnings"] += 1

        # Cache result
        if self.cache_verifications:
            self._cache[cache_key] = (status, details)

        return status, details

    def _check_logical_consistency(self, thought: Thought, context: str) -> List[str]:
        """Check for logical consistency issues."""
        issues = []
        content_lower = thought.content.lower()

        # Check for contradictions with context
        if "not" in content_lower and "is" in content_lower:
            # Simple negation check - look for direct contradictions
            pass  # Would need more sophisticated NLP

        # Check for unsupported claims
        absolute_terms = ["always", "never", "all", "none", "impossible", "certain"]
        for term in absolute_terms:
            if term in content_lower:
                issues.append(f"Contains absolute term '{term}' - may need qualification")

        # Check for circular reasoning
        if thought.parent_ids:
            if thought.content in context:
                issues.append("Potential circular reasoning - conclusion restates premise")

        return issues

    def _check_confidence_validity(self, thought: Thought) -> List[str]:
        """Check that confidence score is valid."""
        issues = []

        if not 0 <= thought.confidence <= 1:
            issues.append(f"Invalid confidence {thought.confidence} - must be in [0,1]")

        # Hypotheses should have lower confidence
        if thought.thought_type == ThoughtType.HYPOTHESIS and thought.confidence > 0.8:
            issues.append("Hypothesis has high confidence - consider reducing")

        # Conclusions need supporting thoughts
        if thought.thought_type == ThoughtType.CONCLUSION and thought.confidence > 0.95:
            if not thought.parent_ids:
                issues.append("High-confidence conclusion without supporting steps")

        return issues

    def _check_dependency_validity(self, thought: Thought) -> List[str]:
        """Check that thought dependencies are valid."""
        issues = []

        # Derived thoughts should have parents
        derived_types = [ThoughtType.INFERENCE, ThoughtType.SYNTHESIS, ThoughtType.CONCLUSION]
        if thought.thought_type in derived_types and not thought.parent_ids:
            issues.append(f"{thought.thought_type.name} thought has no parent references")

        return issues

    def _llm_verify(self, thought: Thought, context: str) -> List[str]:
        """Use LLM to verify the thought."""
        issues = []

        prompt = f"""Verify the following reasoning step for logical validity and factual accuracy.

Context: {context}

Reasoning step: {thought.content}
Type: {thought.thought_type.name}
Confidence: {thought.confidence}

Identify any issues with:
1. Logical validity - does the reasoning follow?
2. Factual accuracy - are the claims supportable?
3. Completeness - are key considerations missing?

If valid, respond with "VALID".
If there are issues, list each issue on a new line starting with "ISSUE:".

Verification:"""

        try:
            response = self.llm_callable(prompt)

            if "VALID" in response.upper() and "ISSUE:" not in response.upper():
                return []

            for line in response.split('\n'):
                if line.strip().upper().startswith("ISSUE:"):
                    issue = line.split(":", 1)[1].strip()
                    if issue:
                        issues.append(issue)

        except Exception as e:
            logger.error(f"LLM verification failed: {e}")
            issues.append(f"LLM verification unavailable: {e}")

        return issues

    def get_stats(self) -> Dict[str, int]:
        """Get verification statistics."""
        return dict(self._verification_stats)

    def clear_cache(self) -> None:
        """Clear verification cache."""
        self._cache.clear()


# =============================================================================
# REASONING TREE
# =============================================================================

class ReasoningTree:
    """
    Tree-structured reasoning with branching paths.

    Enables exploration of multiple reasoning paths simultaneously,
    with pruning and selection of the most promising branches.

    Attributes:
        root: Root node of the tree
        thought_generator: Generator for creating thoughts
        verifier: Optional verifier for thoughts
        max_depth: Maximum tree depth
        max_branches: Maximum branches per node
        pruning_threshold: Confidence below which to prune
    """

    def __init__(
        self,
        thought_generator: Optional[ThoughtGenerator] = None,
        verifier: Optional[Verifier] = None,
        max_depth: int = 10,
        max_branches: int = 3,
        pruning_threshold: float = 0.3,
        llm_callable: Optional[Callable[[str], str]] = None
    ):
        """
        Initialize the reasoning tree.

        Args:
            thought_generator: Strategy for generating thoughts
            verifier: Verifier for intermediate steps
            max_depth: Maximum tree depth
            max_branches: Maximum branches at each node
            pruning_threshold: Prune branches below this confidence
            llm_callable: Optional LLM function
        """
        self.thought_generator = thought_generator or DefaultThoughtGenerator(llm_callable)
        self.verifier = verifier
        self.max_depth = max_depth
        self.max_branches = max_branches
        self.pruning_threshold = pruning_threshold
        self.llm_callable = llm_callable

        self.root: Optional[TreeNode] = None
        self._all_nodes: List[TreeNode] = []
        self._leaf_nodes: List[TreeNode] = []

        logger.info(f"ReasoningTree initialized with depth={max_depth}, branches={max_branches}")

    def build_tree(
        self,
        problem: str,
        context: Optional[str] = None,
        num_samples: int = 3
    ) -> TreeNode:
        """
        Build a reasoning tree for a problem.

        Args:
            problem: The problem to reason about
            context: Optional additional context
            num_samples: Number of alternative thoughts per node

        Returns:
            Root node of the constructed tree
        """
        logger.info(f"Building reasoning tree for: {problem[:100]}...")

        # Create root node
        root_thought = Thought(
            content=f"Root: {problem}",
            thought_type=ThoughtType.OBSERVATION,
            confidence=1.0,
            metadata={"is_root": True}
        )
        self.root = TreeNode(thought=root_thought)
        self._all_nodes = [self.root]
        self._leaf_nodes = [self.root]

        full_context = f"{context or ''}\n\nProblem: {problem}"

        # Build tree level by level
        for depth in range(self.max_depth):
            if not self._leaf_nodes:
                break

            new_leaves = []

            for leaf in self._leaf_nodes:
                if leaf.path_confidence < self.pruning_threshold:
                    continue  # Prune low-confidence branches

                # Generate multiple alternative thoughts
                prior_thoughts = leaf.get_path_thoughts()
                branches_created = 0

                for _ in range(min(num_samples, self.max_branches)):
                    thought = self.thought_generator.generate(full_context, prior_thoughts)

                    # Verify if available
                    if self.verifier:
                        status, details = self.verifier.verify(thought, full_context)
                        thought.verification_status = status
                        thought.verification_details = details

                        if status == VerificationStatus.VERIFIED_INVALID:
                            thought.confidence *= 0.5  # Penalize invalid thoughts

                    # Add as child
                    child = leaf.add_child(thought)
                    self._all_nodes.append(child)
                    new_leaves.append(child)
                    branches_created += 1

                    # Stop branching on conclusions
                    if thought.thought_type == ThoughtType.CONCLUSION:
                        break

            self._leaf_nodes = new_leaves
            logger.debug(f"Depth {depth+1}: {len(new_leaves)} leaf nodes")

        logger.info(f"Tree built with {len(self._all_nodes)} total nodes")
        return self.root

    def get_best_path(self) -> List[Thought]:
        """Get the highest-confidence path through the tree."""
        if not self._all_nodes:
            return []

        # Find leaf with highest path confidence
        leaves_with_conclusions = [
            n for n in self._all_nodes
            if n.is_leaf() or n.thought.thought_type == ThoughtType.CONCLUSION
        ]

        if not leaves_with_conclusions:
            leaves_with_conclusions = [n for n in self._all_nodes if n.is_leaf()]

        if not leaves_with_conclusions:
            return [self.root.thought] if self.root else []

        best_leaf = max(leaves_with_conclusions, key=lambda n: n.path_confidence)
        return best_leaf.get_path_thoughts()

    def get_all_paths(self) -> List[List[Thought]]:
        """Get all paths from root to leaves."""
        paths = []

        def collect_paths(node: TreeNode, current_path: List[Thought]):
            current_path = current_path + [node.thought]

            if node.is_leaf():
                paths.append(current_path)
            else:
                for child in node.children:
                    collect_paths(child, current_path)

        if self.root:
            collect_paths(self.root, [])

        return paths

    def prune_low_confidence(self, threshold: Optional[float] = None) -> int:
        """
        Prune branches below confidence threshold.

        Returns number of nodes pruned.
        """
        threshold = threshold or self.pruning_threshold
        pruned_count = 0

        def prune_recursive(node: TreeNode) -> bool:
            """Returns True if node should be pruned."""
            if node.path_confidence < threshold:
                return True

            # Recursively check children
            children_to_keep = []
            for child in node.children:
                if not prune_recursive(child):
                    children_to_keep.append(child)
                else:
                    nonlocal pruned_count
                    pruned_count += 1

            node.children = children_to_keep
            return False

        if self.root:
            prune_recursive(self.root)

        # Update node lists
        self._all_nodes = []
        self._leaf_nodes = []

        def collect_nodes(node: TreeNode):
            self._all_nodes.append(node)
            if node.is_leaf():
                self._leaf_nodes.append(node)
            for child in node.children:
                collect_nodes(child)

        if self.root:
            collect_nodes(self.root)

        logger.info(f"Pruned {pruned_count} nodes below confidence {threshold}")
        return pruned_count

    def to_dict(self) -> Dict[str, Any]:
        """Convert tree to dictionary representation."""
        def node_to_dict(node: TreeNode) -> Dict[str, Any]:
            return {
                "thought": node.thought.to_dict(),
                "depth": node.depth,
                "path_confidence": node.path_confidence,
                "children": [node_to_dict(c) for c in node.children]
            }

        return node_to_dict(self.root) if self.root else {}

    def get_statistics(self) -> Dict[str, Any]:
        """Get tree statistics."""
        depths = [n.depth for n in self._all_nodes]
        confidences = [n.path_confidence for n in self._all_nodes]

        return {
            "total_nodes": len(self._all_nodes),
            "leaf_nodes": len(self._leaf_nodes),
            "max_depth": max(depths) if depths else 0,
            "avg_confidence": sum(confidences) / len(confidences) if confidences else 0,
            "thought_type_distribution": self._get_type_distribution()
        }

    def _get_type_distribution(self) -> Dict[str, int]:
        """Get distribution of thought types in tree."""
        distribution = defaultdict(int)
        for node in self._all_nodes:
            distribution[node.thought.thought_type.name] += 1
        return dict(distribution)


# =============================================================================
# SELF-CONSISTENCY SAMPLER
# =============================================================================

class SelfConsistency:
    """
    Implements self-consistency through multiple chain sampling.

    Generates multiple independent reasoning chains and aggregates
    their conclusions to improve reliability.

    Attributes:
        chain_generator: ChainOfThought instance for generation
        num_samples: Number of chains to sample
        temperature_range: Range for sampling temperature variation
        aggregation_strategy: How to combine results
    """

    def __init__(
        self,
        chain_generator: Optional[ChainOfThought] = None,
        num_samples: int = 5,
        temperature_range: Tuple[float, float] = (0.7, 1.0),
        aggregation_strategy: AggregationStrategy = AggregationStrategy.WEIGHTED_CONFIDENCE,
        parallel: bool = True,
        llm_callable: Optional[Callable[[str], str]] = None
    ):
        """
        Initialize self-consistency sampler.

        Args:
            chain_generator: ChainOfThought instance to use
            num_samples: Number of chains to sample
            temperature_range: Temperature variation for diversity
            aggregation_strategy: Strategy for combining results
            parallel: Whether to sample in parallel
            llm_callable: Optional LLM function
        """
        self.chain_generator = chain_generator or ChainOfThought(llm_callable=llm_callable)
        self.num_samples = num_samples
        self.temperature_range = temperature_range
        self.aggregation_strategy = aggregation_strategy
        self.parallel = parallel
        self.llm_callable = llm_callable

        self._sampled_chains: List[ReasoningChain] = []

        logger.info(f"SelfConsistency initialized with {num_samples} samples")

    def sample(
        self,
        problem: str,
        context: Optional[str] = None
    ) -> List[ReasoningChain]:
        """
        Sample multiple reasoning chains.

        Args:
            problem: The problem to reason about
            context: Optional additional context

        Returns:
            List of sampled reasoning chains
        """
        logger.info(f"Sampling {self.num_samples} chains for: {problem[:100]}...")

        self._sampled_chains = []

        if self.parallel:
            with ThreadPoolExecutor(max_workers=min(self.num_samples, 4)) as executor:
                futures = [
                    executor.submit(self._generate_chain, problem, context, i)
                    for i in range(self.num_samples)
                ]

                for future in as_completed(futures):
                    try:
                        chain = future.result()
                        self._sampled_chains.append(chain)
                    except Exception as e:
                        logger.error(f"Chain generation failed: {e}")
        else:
            for i in range(self.num_samples):
                try:
                    chain = self._generate_chain(problem, context, i)
                    self._sampled_chains.append(chain)
                except Exception as e:
                    logger.error(f"Chain {i} generation failed: {e}")

        logger.info(f"Sampled {len(self._sampled_chains)} chains")
        return self._sampled_chains

    def _generate_chain(
        self,
        problem: str,
        context: Optional[str],
        sample_idx: int
    ) -> ReasoningChain:
        """Generate a single chain with temperature variation."""
        # Add some variation through metadata
        varied_context = context or ""
        if sample_idx > 0:
            varied_context += f"\n[Approach {sample_idx + 1}]"

        chain = self.chain_generator.reason(problem, varied_context)
        chain.metadata["sample_index"] = sample_idx
        chain.metadata["temperature"] = random.uniform(*self.temperature_range)

        return chain

    def get_answer_distribution(self) -> Dict[str, float]:
        """
        Get distribution of final answers across chains.

        Returns:
            Dictionary mapping answers to their frequency
        """
        if not self._sampled_chains:
            return {}

        answer_counts = defaultdict(int)
        for chain in self._sampled_chains:
            if chain.final_answer:
                # Normalize answer for comparison
                normalized = chain.final_answer.strip().lower()
                answer_counts[normalized] += 1

        total = sum(answer_counts.values())
        return {k: v / total for k, v in answer_counts.items()}

    def get_confidence_weighted_answers(self) -> Dict[str, float]:
        """
        Get confidence-weighted distribution of answers.

        Returns:
            Dictionary mapping answers to weighted scores
        """
        if not self._sampled_chains:
            return {}

        answer_scores = defaultdict(float)
        for chain in self._sampled_chains:
            if chain.final_answer:
                normalized = chain.final_answer.strip().lower()
                answer_scores[normalized] += chain.total_confidence

        # Normalize
        total = sum(answer_scores.values())
        if total > 0:
            return {k: v / total for k, v in answer_scores.items()}
        return answer_scores

    def get_consensus(self, threshold: float = 0.6) -> Optional[str]:
        """
        Get consensus answer if one exists above threshold.

        Args:
            threshold: Minimum agreement for consensus

        Returns:
            Consensus answer or None
        """
        distribution = self.get_answer_distribution()

        for answer, frequency in distribution.items():
            if frequency >= threshold:
                return answer

        return None

    def get_sampled_chains(self) -> List[ReasoningChain]:
        """Get all sampled chains."""
        return self._sampled_chains.copy()


# =============================================================================
# FINAL AGGREGATOR
# =============================================================================

class FinalAggregator:
    """
    Aggregates multiple reasoning paths into a final answer.

    Implements various aggregation strategies to combine
    insights from multiple reasoning approaches.

    Attributes:
        strategy: Aggregation strategy to use
        min_confidence: Minimum confidence to include in aggregation
        llm_callable: Optional LLM for sophisticated aggregation
    """

    def __init__(
        self,
        strategy: AggregationStrategy = AggregationStrategy.WEIGHTED_CONFIDENCE,
        min_confidence: float = 0.3,
        llm_callable: Optional[Callable[[str], str]] = None
    ):
        """
        Initialize the aggregator.

        Args:
            strategy: Aggregation strategy to use
            min_confidence: Minimum confidence threshold
            llm_callable: Optional LLM for sophisticated aggregation
        """
        self.strategy = strategy
        self.min_confidence = min_confidence
        self.llm_callable = llm_callable

        self._aggregation_history: List[Dict[str, Any]] = []

        logger.info(f"FinalAggregator initialized with {strategy.name} strategy")

    def aggregate(
        self,
        chains: List[ReasoningChain],
        problem: Optional[str] = None
    ) -> Tuple[str, float, Dict[str, Any]]:
        """
        Aggregate multiple chains into a final answer.

        Args:
            chains: List of reasoning chains to aggregate
            problem: Original problem (for context)

        Returns:
            Tuple of (final_answer, confidence, metadata)
        """
        if not chains:
            return "No reasoning chains to aggregate", 0.0, {}

        # Filter chains by confidence
        valid_chains = [c for c in chains if c.total_confidence >= self.min_confidence]

        if not valid_chains:
            logger.warning("No chains above confidence threshold, using all chains")
            valid_chains = chains

        logger.info(f"Aggregating {len(valid_chains)} chains using {self.strategy.name}")

        if self.strategy == AggregationStrategy.MAJORITY_VOTE:
            result = self._majority_vote(valid_chains)
        elif self.strategy == AggregationStrategy.WEIGHTED_CONFIDENCE:
            result = self._weighted_confidence(valid_chains)
        elif self.strategy == AggregationStrategy.BEST_PATH:
            result = self._best_path(valid_chains)
        elif self.strategy == AggregationStrategy.ENSEMBLE:
            result = self._ensemble(valid_chains, problem)
        elif self.strategy == AggregationStrategy.CONSENSUS:
            result = self._consensus(valid_chains)
        else:
            result = self._weighted_confidence(valid_chains)

        # Record aggregation
        self._aggregation_history.append({
            "num_chains": len(chains),
            "valid_chains": len(valid_chains),
            "strategy": self.strategy.name,
            "result": result,
            "timestamp": time.time()
        })

        return result

    def _majority_vote(self, chains: List[ReasoningChain]) -> Tuple[str, float, Dict[str, Any]]:
        """Simple majority voting."""
        answer_counts = defaultdict(int)

        for chain in chains:
            if chain.final_answer:
                normalized = chain.final_answer.strip().lower()
                answer_counts[normalized] += 1

        if not answer_counts:
            return "No answers available", 0.0, {}

        best_answer = max(answer_counts.items(), key=lambda x: x[1])
        confidence = best_answer[1] / len(chains)

        metadata = {
            "method": "majority_vote",
            "vote_distribution": dict(answer_counts),
            "winning_votes": best_answer[1]
        }

        return best_answer[0], confidence, metadata

    def _weighted_confidence(self, chains: List[ReasoningChain]) -> Tuple[str, float, Dict[str, Any]]:
        """Confidence-weighted aggregation."""
        answer_scores = defaultdict(float)
        answer_chains = defaultdict(list)

        for chain in chains:
            if chain.final_answer:
                normalized = chain.final_answer.strip().lower()
                answer_scores[normalized] += chain.total_confidence
                answer_chains[normalized].append(chain.id)

        if not answer_scores:
            return "No answers available", 0.0, {}

        total_weight = sum(answer_scores.values())
        best_answer = max(answer_scores.items(), key=lambda x: x[1])
        confidence = best_answer[1] / total_weight if total_weight > 0 else 0

        metadata = {
            "method": "weighted_confidence",
            "score_distribution": {k: v / total_weight for k, v in answer_scores.items()},
            "supporting_chains": answer_chains[best_answer[0]]
        }

        return best_answer[0], confidence, metadata

    def _best_path(self, chains: List[ReasoningChain]) -> Tuple[str, float, Dict[str, Any]]:
        """Select the single best reasoning path."""
        best_chain = max(chains, key=lambda c: c.total_confidence)

        metadata = {
            "method": "best_path",
            "selected_chain": best_chain.id,
            "num_steps": len(best_chain.thoughts),
            "verification_summary": best_chain.get_verification_summary()
        }

        return best_chain.final_answer or "No conclusion", best_chain.total_confidence, metadata

    def _ensemble(
        self,
        chains: List[ReasoningChain],
        problem: Optional[str]
    ) -> Tuple[str, float, Dict[str, Any]]:
        """Ensemble aggregation using LLM if available."""
        if self.llm_callable and problem:
            return self._llm_ensemble(chains, problem)

        # Fallback to weighted confidence
        return self._weighted_confidence(chains)

    def _llm_ensemble(
        self,
        chains: List[ReasoningChain],
        problem: str
    ) -> Tuple[str, float, Dict[str, Any]]:
        """Use LLM to synthesize answers."""
        chain_summaries = []
        for i, chain in enumerate(chains):
            summary = f"Path {i+1} (confidence {chain.total_confidence:.2%}): {chain.final_answer}"
            chain_summaries.append(summary)

        prompt = f"""Given the following problem and multiple reasoning paths with their conclusions,
synthesize the best final answer.

Problem: {problem}

Reasoning paths:
{chr(10).join(chain_summaries)}

Consider the confidence levels and any patterns across the paths.
Provide a synthesized final answer that best addresses the problem.

Synthesized answer:"""

        try:
            response = self.llm_callable(prompt)

            # Estimate confidence from agreement
            avg_confidence = sum(c.total_confidence for c in chains) / len(chains)

            metadata = {
                "method": "llm_ensemble",
                "num_paths_considered": len(chains),
                "individual_conclusions": [c.final_answer for c in chains]
            }

            return response.strip(), avg_confidence, metadata

        except Exception as e:
            logger.error(f"LLM ensemble failed: {e}")
            return self._weighted_confidence(chains)

    def _consensus(self, chains: List[ReasoningChain]) -> Tuple[str, float, Dict[str, Any]]:
        """Find consensus among chains."""
        answer_counts = defaultdict(list)

        for chain in chains:
            if chain.final_answer:
                normalized = chain.final_answer.strip().lower()
                answer_counts[normalized].append(chain)

        if not answer_counts:
            return "No consensus possible", 0.0, {}

        # Find answer with most agreement
        best_answer = max(answer_counts.items(), key=lambda x: len(x[1]))
        agreement_ratio = len(best_answer[1]) / len(chains)

        # Average confidence of agreeing chains
        avg_confidence = sum(c.total_confidence for c in best_answer[1]) / len(best_answer[1])

        # Combine agreement and confidence
        final_confidence = agreement_ratio * avg_confidence

        metadata = {
            "method": "consensus",
            "agreement_ratio": agreement_ratio,
            "agreeing_chains": [c.id for c in best_answer[1]],
            "average_chain_confidence": avg_confidence
        }

        return best_answer[0], final_confidence, metadata

    def get_aggregation_history(self) -> List[Dict[str, Any]]:
        """Get history of all aggregations."""
        return self._aggregation_history.copy()


# =============================================================================
# COMPLETE REASONING SYSTEM
# =============================================================================

class AIVAQueenReasoning:
    """
    Complete chain-of-thought reasoning system for AIVA Queen.

    Integrates all components:
    - ChainOfThought for basic reasoning
    - ThoughtDecomposer for problem breakdown
    - IntermediateVerifier for step validation
    - ReasoningTree for branching exploration
    - SelfConsistency for robust sampling
    - FinalAggregator for answer synthesis

    This provides a unified interface for sophisticated reasoning.
    """

    def __init__(
        self,
        llm_callable: Optional[Callable[[str], str]] = None,
        enable_verification: bool = True,
        enable_decomposition: bool = True,
        enable_tree_reasoning: bool = True,
        enable_self_consistency: bool = True,
        num_consistency_samples: int = 3,
        aggregation_strategy: AggregationStrategy = AggregationStrategy.WEIGHTED_CONFIDENCE
    ):
        """
        Initialize the complete reasoning system.

        Args:
            llm_callable: Function to call LLM (prompt -> response)
            enable_verification: Enable intermediate verification
            enable_decomposition: Enable problem decomposition
            enable_tree_reasoning: Enable tree-structured reasoning
            enable_self_consistency: Enable self-consistency sampling
            num_consistency_samples: Number of samples for self-consistency
            aggregation_strategy: Strategy for aggregating results
        """
        self.llm_callable = llm_callable

        # Initialize verifier if enabled
        self.verifier = IntermediateVerifier(
            llm_callable=llm_callable
        ) if enable_verification else None

        # Initialize decomposer if enabled
        self.decomposer = ThoughtDecomposer(
            llm_callable=llm_callable
        ) if enable_decomposition else None

        # Initialize chain generator
        self.chain_generator = ChainOfThought(
            verifier=self.verifier,
            llm_callable=llm_callable
        )

        # Initialize tree reasoner if enabled
        self.tree_reasoner = ReasoningTree(
            verifier=self.verifier,
            llm_callable=llm_callable
        ) if enable_tree_reasoning else None

        # Initialize self-consistency if enabled
        self.self_consistency = SelfConsistency(
            chain_generator=self.chain_generator,
            num_samples=num_consistency_samples,
            llm_callable=llm_callable
        ) if enable_self_consistency else None

        # Initialize aggregator
        self.aggregator = FinalAggregator(
            strategy=aggregation_strategy,
            llm_callable=llm_callable
        )

        self._reasoning_history: List[Dict[str, Any]] = []

        logger.info("AIVAQueenReasoning system initialized")

    def reason(
        self,
        problem: str,
        context: Optional[str] = None,
        mode: str = "comprehensive"
    ) -> Dict[str, Any]:
        """
        Perform reasoning on a problem.

        Args:
            problem: The problem/query to reason about
            context: Optional additional context
            mode: Reasoning mode - "simple", "tree", "consistent", "comprehensive"

        Returns:
            Dictionary containing reasoning results and metadata
        """
        logger.info(f"Starting reasoning in {mode} mode for: {problem[:100]}...")
        start_time = time.time()

        result = {
            "problem": problem,
            "context": context,
            "mode": mode,
            "timestamp": start_time
        }

        # Decompose if enabled and problem is complex
        sub_problems = [problem]
        if self.decomposer and len(problem) > 100:
            sub_problems = self.decomposer.decompose(problem)
            result["decomposition"] = sub_problems

        if mode == "simple":
            # Single chain reasoning
            chain = self.chain_generator.reason(problem, context)
            result["chains"] = [chain.to_dict()]
            result["final_answer"] = chain.final_answer
            result["confidence"] = chain.total_confidence

        elif mode == "tree":
            # Tree-structured reasoning
            if self.tree_reasoner:
                self.tree_reasoner.build_tree(problem, context)
                best_path = self.tree_reasoner.get_best_path()

                chain = ReasoningChain(problem=problem)
                for thought in best_path:
                    chain.add_thought(thought)
                chain.final_answer = best_path[-1].content if best_path else None

                result["tree_stats"] = self.tree_reasoner.get_statistics()
                result["chains"] = [chain.to_dict()]
                result["final_answer"] = chain.final_answer
                result["confidence"] = chain.total_confidence
            else:
                # Fallback to simple
                return self.reason(problem, context, mode="simple")

        elif mode == "consistent":
            # Self-consistency sampling
            if self.self_consistency:
                chains = self.self_consistency.sample(problem, context)
                answer, confidence, meta = self.aggregator.aggregate(chains, problem)

                result["chains"] = [c.to_dict() for c in chains]
                result["answer_distribution"] = self.self_consistency.get_answer_distribution()
                result["final_answer"] = answer
                result["confidence"] = confidence
                result["aggregation_metadata"] = meta
            else:
                return self.reason(problem, context, mode="simple")

        elif mode == "comprehensive":
            # Full comprehensive reasoning
            all_chains = []

            # Generate multiple simple chains
            for i, sub_problem in enumerate(sub_problems[:3]):
                chain = self.chain_generator.reason(sub_problem, context)
                chain.metadata["sub_problem_index"] = i
                all_chains.append(chain)

            # Generate tree paths if enabled
            if self.tree_reasoner:
                self.tree_reasoner.build_tree(problem, context, num_samples=2)
                for path in self.tree_reasoner.get_all_paths()[:3]:
                    chain = ReasoningChain(problem=problem)
                    for thought in path:
                        chain.add_thought(thought)
                    if path:
                        chain.final_answer = path[-1].content
                    chain.metadata["source"] = "tree"
                    all_chains.append(chain)

            # Aggregate all results
            answer, confidence, meta = self.aggregator.aggregate(all_chains, problem)

            result["chains"] = [c.to_dict() for c in all_chains]
            result["final_answer"] = answer
            result["confidence"] = confidence
            result["aggregation_metadata"] = meta

        else:
            raise ValueError(f"Unknown reasoning mode: {mode}")

        result["duration_seconds"] = time.time() - start_time

        self._reasoning_history.append(result)

        logger.info(f"Reasoning complete in {result['duration_seconds']:.2f}s")
        return result

    def get_reasoning_history(self) -> List[Dict[str, Any]]:
        """Get history of all reasoning sessions."""
        return self._reasoning_history.copy()

    def export_session(self, filepath: str) -> None:
        """Export reasoning history to file."""
        with open(filepath, 'w') as f:
            json.dump(self._reasoning_history, f, indent=2, default=str)
        logger.info(f"Exported reasoning history to {filepath}")

    def get_statistics(self) -> Dict[str, Any]:
        """Get overall system statistics."""
        return {
            "total_sessions": len(self._reasoning_history),
            "verifier_stats": self.verifier.get_stats() if self.verifier else {},
            "aggregator_history_count": len(self.aggregator.get_aggregation_history()),
            "tree_stats": self.tree_reasoner.get_statistics() if self.tree_reasoner else {}
        }


# =============================================================================
# MAIN EXECUTION
# =============================================================================

if __name__ == "__main__":
    print("=" * 60)
    print("AIVA Queen Chain-of-Thought Reasoning System")
    print("=" * 60)

    # Example LLM callable (placeholder - replace with actual implementation)
    def mock_llm(prompt: str) -> str:
        """Mock LLM for testing."""
        if "next step" in prompt.lower():
            return "Based on the prior analysis, I can infer that the solution involves systematic decomposition."
        elif "verify" in prompt.lower():
            return "VALID"
        elif "break down" in prompt.lower() or "components" in prompt.lower():
            return "1. Identify the core variables\n2. Establish relationships\n3. Apply constraints"
        elif "synthesize" in prompt.lower():
            return "The synthesized answer combines insights from multiple reasoning paths."
        else:
            return f"Response to: {prompt[:50]}..."

    # Test 1: Basic Chain of Thought
    print("\n--- Test 1: Basic Chain of Thought ---")
    cot = ChainOfThought(llm_callable=mock_llm, max_steps=5)
    chain = cot.reason("What is the capital of France?")
    print(chain)

    # Test 2: Thought Decomposer
    print("\n--- Test 2: Thought Decomposer ---")
    decomposer = ThoughtDecomposer(llm_callable=mock_llm)
    problem = "If John has 5 apples and gives 2 to Mary, then Mary gives 1 to Bob, how many apples does each person have?"
    sub_problems = decomposer.decompose(problem)
    print(f"Original: {problem}")
    print(f"Sub-problems: {sub_problems}")

    # Test 3: Intermediate Verifier
    print("\n--- Test 3: Intermediate Verifier ---")
    verifier = IntermediateVerifier(llm_callable=mock_llm)
    test_thought = Thought(
        content="The answer is always 42",
        thought_type=ThoughtType.HYPOTHESIS,
        confidence=0.9
    )
    status, details = verifier.verify(test_thought, "Mathematical problem")
    print(f"Verification: {status.name} - {details}")

    # Test 4: Reasoning Tree
    print("\n--- Test 4: Reasoning Tree ---")
    tree = ReasoningTree(llm_callable=mock_llm, max_depth=3, max_branches=2)
    tree.build_tree("Solve: 2x + 5 = 15")
    stats = tree.get_statistics()
    print(f"Tree stats: {stats}")
    best_path = tree.get_best_path()
    print(f"Best path length: {len(best_path)}")

    # Test 5: Self-Consistency
    print("\n--- Test 5: Self-Consistency ---")
    consistency = SelfConsistency(num_samples=3, llm_callable=mock_llm, parallel=False)
    chains = consistency.sample("What is 2 + 2?")
    print(f"Sampled {len(chains)} chains")
    distribution = consistency.get_answer_distribution()
    print(f"Answer distribution: {distribution}")

    # Test 6: Final Aggregator
    print("\n--- Test 6: Final Aggregator ---")
    aggregator = FinalAggregator(
        strategy=AggregationStrategy.WEIGHTED_CONFIDENCE,
        llm_callable=mock_llm
    )
    answer, confidence, meta = aggregator.aggregate(chains, "What is 2 + 2?")
    print(f"Aggregated answer: {answer} (confidence: {confidence:.2%})")
    print(f"Aggregation metadata: {meta}")

    # Test 7: Complete System
    print("\n--- Test 7: Complete AIVAQueen Reasoning System ---")
    queen_reasoning = AIVAQueenReasoning(
        llm_callable=mock_llm,
        enable_verification=True,
        enable_decomposition=True,
        enable_tree_reasoning=True,
        enable_self_consistency=True,
        num_consistency_samples=3
    )

    result = queen_reasoning.reason(
        problem="If a train travels at 60 mph for 2 hours, how far does it travel?",
        context="This is a physics problem involving distance, speed, and time.",
        mode="comprehensive"
    )

    print(f"Final answer: {result['final_answer']}")
    print(f"Confidence: {result['confidence']:.2%}")
    print(f"Duration: {result['duration_seconds']:.2f}s")
    print(f"Number of chains: {len(result.get('chains', []))}")

    # Get overall statistics
    print("\n--- System Statistics ---")
    stats = queen_reasoning.get_statistics()
    print(json.dumps(stats, indent=2, default=str))

    print("\n" + "=" * 60)
    print("All tests completed successfully!")
    print("=" * 60)
