#!/usr/bin/env python3
"""
SURPRISE DETECTOR - Titans Architecture Implementation
=======================================================
Implements Google's Titans surprise-based learning mechanism.

When the model encounters unexpected outcomes (high surprise), it triggers
larger learning gradients, causing memory updates.

Reference: arXiv:2501.00663 "Titans: Learning to Memorize at Test Time"
"""

import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any

logger = logging.getLogger(__name__)

GENESIS_ROOT = Path(__file__).parent.parent.parent


@dataclass
class Expectation:
    """Expected outcome for a task."""
    execution_time: float = 30.0  # Expected seconds
    test_pass_rate: float = 1.0   # Expected 100% pass
    error_count: int = 0          # Expected no errors
    output_tokens: int = 500      # Expected output size
    complexity_score: float = 0.5  # Expected complexity


@dataclass
class Outcome:
    """Actual outcome from execution."""
    execution_time: float
    test_pass_rate: float
    error_count: int
    output_tokens: int
    complexity_score: float
    raw_output: str = ""
    errors: List[str] = field(default_factory=list)


@dataclass
class SurpriseEvent:
    """A detected surprise event that triggers learning."""
    timestamp: str
    surprise_score: float
    expectation: Dict[str, Any]
    outcome: Dict[str, Any]
    trigger_reason: str
    learning_action: str
    cycle_id: int = 0


class SurpriseDetector:
    """
    Detects surprise in execution outcomes and triggers learning.

    Based on Titans paper:
    - High surprise = unexpected outcome
    - Surprise triggers memory updates
    - Learning gradient proportional to surprise
    """

    # Surprise thresholds
    LOW_SURPRISE = 0.2
    MEDIUM_SURPRISE = 0.5
    HIGH_SURPRISE = 0.8

    # Metric weights (sum to 1.0)
    WEIGHTS = {
        'execution_time': 0.2,
        'test_pass_rate': 0.3,
        'error_count': 0.2,
        'output_tokens': 0.15,
        'complexity_score': 0.15,
    }

    def __init__(self, events_file: Optional[Path] = None):
        """Initialize surprise detector."""
        self.events_file = events_file or (GENESIS_ROOT / "data" / "surprise_events.jsonl")
        self.events_file.parent.mkdir(parents=True, exist_ok=True)
        self._event_count = 0

    def calculate_surprise(
        self,
        expectation: Expectation,
        outcome: Outcome
    ) -> float:
        """
        Calculate surprise score between expected and actual outcomes.

        Returns:
            Surprise score from 0.0 (expected) to 1.0 (max surprise)
        """
        surprise = 0.0

        # Execution time surprise
        if expectation.execution_time > 0:
            time_deviation = abs(outcome.execution_time - expectation.execution_time)
            time_surprise = min(time_deviation / expectation.execution_time, 1.0)
            surprise += self.WEIGHTS['execution_time'] * time_surprise

        # Test pass rate surprise (inverted - lower pass rate = more surprise)
        pass_deviation = abs(outcome.test_pass_rate - expectation.test_pass_rate)
        surprise += self.WEIGHTS['test_pass_rate'] * pass_deviation

        # Error count surprise
        if expectation.error_count == 0 and outcome.error_count > 0:
            error_surprise = min(outcome.error_count / 5, 1.0)  # Normalize to 5 errors
        elif expectation.error_count > 0:
            error_deviation = abs(outcome.error_count - expectation.error_count)
            error_surprise = min(error_deviation / expectation.error_count, 1.0)
        else:
            error_surprise = 0.0
        surprise += self.WEIGHTS['error_count'] * error_surprise

        # Output tokens surprise
        if expectation.output_tokens > 0:
            token_deviation = abs(outcome.output_tokens - expectation.output_tokens)
            token_surprise = min(token_deviation / expectation.output_tokens, 1.0)
            surprise += self.WEIGHTS['output_tokens'] * token_surprise

        # Complexity surprise
        complexity_deviation = abs(outcome.complexity_score - expectation.complexity_score)
        surprise += self.WEIGHTS['complexity_score'] * complexity_deviation

        return min(surprise, 1.0)

    def detect_surprise(
        self,
        expectation: Expectation,
        outcome: Outcome,
        cycle_id: int = 0,
        threshold: float = None
    ) -> Optional[SurpriseEvent]:
        """
        Detect if outcome is surprising enough to trigger learning.

        Args:
            expectation: Expected outcome
            outcome: Actual outcome
            cycle_id: Evolution cycle ID
            threshold: Override default threshold

        Returns:
            SurpriseEvent if surprise exceeds threshold, else None
        """
        threshold = threshold or self.MEDIUM_SURPRISE
        score = self.calculate_surprise(expectation, outcome)

        if score < threshold:
            logger.debug(f"No surprise: score={score:.3f} < threshold={threshold}")
            return None

        # Determine trigger reason
        trigger_reason = self._determine_trigger_reason(expectation, outcome, score)

        # Determine learning action
        learning_action = self._determine_learning_action(score)

        event = SurpriseEvent(
            timestamp=datetime.now().isoformat(),
            surprise_score=score,
            expectation={
                'execution_time': expectation.execution_time,
                'test_pass_rate': expectation.test_pass_rate,
                'error_count': expectation.error_count,
                'output_tokens': expectation.output_tokens,
                'complexity_score': expectation.complexity_score,
            },
            outcome={
                'execution_time': outcome.execution_time,
                'test_pass_rate': outcome.test_pass_rate,
                'error_count': outcome.error_count,
                'output_tokens': outcome.output_tokens,
                'complexity_score': outcome.complexity_score,
            },
            trigger_reason=trigger_reason,
            learning_action=learning_action,
            cycle_id=cycle_id,
        )

        self._log_event(event)
        self._event_count += 1

        logger.warning(f"SURPRISE_EVENT detected: score={score:.3f}, reason={trigger_reason}")

        return event

    def _determine_trigger_reason(
        self,
        expectation: Expectation,
        outcome: Outcome,
        score: float
    ) -> str:
        """Determine the primary reason for surprise."""
        reasons = []

        if outcome.test_pass_rate < expectation.test_pass_rate * 0.8:
            reasons.append("test_failures")

        if outcome.error_count > expectation.error_count + 2:
            reasons.append("unexpected_errors")

        if outcome.execution_time > expectation.execution_time * 2:
            reasons.append("slow_execution")
        elif outcome.execution_time < expectation.execution_time * 0.3:
            reasons.append("fast_execution")

        if outcome.output_tokens > expectation.output_tokens * 3:
            reasons.append("verbose_output")
        elif outcome.output_tokens < expectation.output_tokens * 0.2:
            reasons.append("sparse_output")

        if not reasons:
            reasons.append("general_deviation")

        return ", ".join(reasons)

    def _determine_learning_action(self, score: float) -> str:
        """Determine what learning action to take based on surprise score."""
        if score >= self.HIGH_SURPRISE:
            return "major_axiom_update"
        elif score >= self.MEDIUM_SURPRISE:
            return "minor_axiom_update"
        else:
            return "observation_logged"

    def _log_event(self, event: SurpriseEvent):
        """Log surprise event to JSONL file."""
        with open(self.events_file, 'a') as f:
            f.write(json.dumps({
                'timestamp': event.timestamp,
                'surprise_score': event.surprise_score,
                'expectation': event.expectation,
                'outcome': event.outcome,
                'trigger_reason': event.trigger_reason,
                'learning_action': event.learning_action,
                'cycle_id': event.cycle_id,
            }) + '\n')

    def get_recent_events(self, limit: int = 10) -> List[SurpriseEvent]:
        """Get recent surprise events."""
        events = []
        if not self.events_file.exists():
            return events

        lines = self.events_file.read_text().strip().split('\n')
        for line in lines[-limit:]:
            if line:
                data = json.loads(line)
                events.append(SurpriseEvent(
                    timestamp=data['timestamp'],
                    surprise_score=data['surprise_score'],
                    expectation=data['expectation'],
                    outcome=data['outcome'],
                    trigger_reason=data['trigger_reason'],
                    learning_action=data['learning_action'],
                    cycle_id=data.get('cycle_id', 0),
                ))

        return events

    def get_statistics(self) -> Dict[str, Any]:
        """Get surprise detection statistics."""
        events = self.get_recent_events(limit=1000)

        if not events:
            return {'total_events': 0}

        scores = [e.surprise_score for e in events]
        actions = [e.learning_action for e in events]

        return {
            'total_events': len(events),
            'avg_surprise': sum(scores) / len(scores),
            'max_surprise': max(scores),
            'min_surprise': min(scores),
            'major_updates': actions.count('major_axiom_update'),
            'minor_updates': actions.count('minor_axiom_update'),
            'observations': actions.count('observation_logged'),
        }


if __name__ == '__main__':
    # Quick test
    logging.basicConfig(level=logging.INFO)

    detector = SurpriseDetector()

    # Test with expected outcome
    exp = Expectation()
    out = Outcome(
        execution_time=35.0,
        test_pass_rate=0.95,
        error_count=1,
        output_tokens=550,
        complexity_score=0.55
    )

    score = detector.calculate_surprise(exp, out)
    print(f"Normal case surprise: {score:.3f}")

    # Test with surprising outcome
    out_surprise = Outcome(
        execution_time=120.0,  # 4x expected
        test_pass_rate=0.5,    # Half expected
        error_count=10,        # Many errors
        output_tokens=50,      # Very sparse
        complexity_score=0.9   # High complexity
    )

    event = detector.detect_surprise(exp, out_surprise, cycle_id=1, threshold=0.3)
    if event:
        print(f"Surprise detected! Score: {event.surprise_score:.3f}")
        print(f"Reason: {event.trigger_reason}")
        print(f"Action: {event.learning_action}")

    print(f"\nStatistics: {detector.get_statistics()}")
