"""
AIVA Surprise Memory System
============================
Production implementation of surprise-based memory routing.

Replaces all prior stubs with the full SurpriseDetector from:
    AIVA/queen_outputs/beta/beta_06_surprise_detector.py

The composite_score drives memory tier routing:
    > 0.7  → episodic (long-term storage, axiom candidate)
    0.5-0.7 → notable (working memory with elevated priority)
    < 0.5  → working (short-term, standard eviction policy)

Public API (backward-compatible with all existing callers):
    MemorySystem.evaluate(content, source, domain) -> Dict
    MemorySystem.observe(event_type, actual_outcome, context) -> SurpriseScore
    MemorySystem.reflect(actual) -> List[Dict]
    MemorySystem.get_stats() -> Dict

Queen AIVA Integration:
    This module is the primary surprise signal for:
    - Titan Memory axiom extraction
    - RLM Gateway interaction scoring
    - Memory promotion decisions
    - Learning rate modulation

Author: Genesis System (coronation integration 2026-02-20)
Version: 2.0.0  (real engine — no stubs)
"""

import hashlib
import json
import math
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple


# ---------------------------------------------------------------------------
# Enumerations
# ---------------------------------------------------------------------------

class SurpriseLevel(Enum):
    """
    Surprise intensity classification.

    Maps to the beta_06 taxonomy with backward-compatible aliases added
    so callers that check for LOW/MEDIUM/HIGH/CRITICAL still work.
    """
    # Primary taxonomy (from beta_06_surprise_detector)
    MUNDANE         = "mundane"         # < 0.30 — expected, routine
    NOTABLE         = "notable"         # 0.30–0.50 — worth noting
    SURPRISING      = "surprising"      # 0.50–0.70 — violated expectations
    SHOCKING        = "shocking"        # 0.70–0.90 — major deviation
    PARADIGM_SHIFT  = "paradigm"        # > 0.90 — fundamental change

    # Aliases for backward compatibility
    LOW             = "mundane"         # alias → MUNDANE
    MEDIUM          = "notable"         # alias → NOTABLE
    HIGH            = "surprising"      # alias → SURPRISING
    CRITICAL        = "shocking"        # alias → SHOCKING


# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------

@dataclass
class MemoryItem:
    """A single item stored in the memory system."""
    content: str
    source: str
    domain: str
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())


@dataclass
class Prediction:
    """A prediction registered for later resolution."""
    prediction_id: str
    domain: str
    expected_outcome: str
    confidence: float
    context: Dict[str, Any]
    created_at: str
    expires_at: Optional[str] = None
    resolved: bool = False
    actual_outcome: Optional[str] = None
    prediction_error: Optional[float] = None


@dataclass
class SurpriseScore:
    """
    Multi-dimensional surprise evaluation.

    All four dimension fields are present for callers that reference them
    by name (the beta_06 set: prediction_error, novelty, impact, rarity)
    as well as the legacy set (violation, novelty, impact, rarity).
    'total' and 'composite_score' are synonyms for the weighted sum.
    """
    # Primary dimensions (beta_06)
    prediction_error: float = 0.0
    novelty: float = 0.0
    impact: float = 0.0
    rarity: float = 0.0

    # Legacy alias for 'prediction_error' (old MemorySystem callers used this)
    violation: float = 0.0

    # Computed
    total: float = 0.0
    composite_score: float = 0.0       # synonym for total
    level: SurpriseLevel = SurpriseLevel.MUNDANE

    # Memory routing flags
    should_promote_memory: bool = False
    should_generate_axiom: bool = False

    def compute_total(self, weights: Dict[str, float] = None) -> float:
        """Compute weighted composite and populate all derived fields."""
        weights = weights or {
            "prediction_error": 0.35,
            "novelty":          0.25,
            "impact":           0.25,
            "rarity":           0.15,
        }
        # Ensure violation alias stays in sync
        self.violation = self.prediction_error

        self.total = min(1.0, (
            self.prediction_error * weights["prediction_error"]
            + self.novelty        * weights["novelty"]
            + self.impact         * weights["impact"]
            + self.rarity         * weights["rarity"]
        ))
        self.composite_score = self.total

        # Classify
        if self.total >= 0.90:
            self.level = SurpriseLevel.PARADIGM_SHIFT
        elif self.total >= 0.70:
            self.level = SurpriseLevel.SHOCKING
        elif self.total >= 0.50:
            self.level = SurpriseLevel.SURPRISING
        elif self.total >= 0.30:
            self.level = SurpriseLevel.NOTABLE
        else:
            self.level = SurpriseLevel.MUNDANE

        self.should_promote_memory = self.total >= 0.50
        self.should_generate_axiom = self.total >= 0.70

        return self.total

    def to_dict(self) -> Dict:
        d = asdict(self)
        d["level"] = self.level.value
        return d


# ---------------------------------------------------------------------------
# Surprise History (persistence + baselines)
# ---------------------------------------------------------------------------

class SurpriseHistory:
    """
    Tracks historical surprise events.

    Persists to a JSON sidecar so novelty calculations survive session
    restarts.  File location defaults to the genesis data directory but
    gracefully degrades to in-memory-only if the path is not writable.
    """

    _DEFAULT_PATH = "/mnt/e/genesis-system/data/surprise_history.json"

    def __init__(self, history_path: str = None):
        self.history_path = Path(history_path or self._DEFAULT_PATH)
        self.events: List[Dict] = []
        self.domain_baselines: Dict[str, float] = {}
        self.content_hashes: Dict[str, int] = {}
        self._persistent = True
        self._load()

    # -- persistence --------------------------------------------------------

    def _load(self):
        if self.history_path.exists():
            try:
                with open(self.history_path) as f:
                    data = json.load(f)
                self.events             = data.get("events", [])[-1000:]
                self.domain_baselines   = data.get("domain_baselines", {})
                self.content_hashes     = data.get("content_hashes", {})
            except Exception:
                pass  # Corrupt file → start fresh

    def _save(self):
        if not self._persistent:
            return
        try:
            self.history_path.parent.mkdir(parents=True, exist_ok=True)
            with open(self.history_path, "w") as f:
                json.dump({
                    "events":           self.events[-1000:],
                    "domain_baselines": self.domain_baselines,
                    "content_hashes":   dict(list(self.content_hashes.items())[-5000:]),
                }, f, indent=2)
        except Exception:
            self._persistent = False  # Fall back to in-memory

    # -- public API ---------------------------------------------------------

    def add_event(self, event: Dict):
        self.events.append(event)
        domain = event.get("domain", "general")
        score  = event.get("surprise_score", 0.5)
        alpha  = 0.1  # EMA factor
        if domain not in self.domain_baselines:
            self.domain_baselines[domain] = score
        else:
            self.domain_baselines[domain] = (
                alpha * score + (1 - alpha) * self.domain_baselines[domain]
            )
        self._save()

    def get_novelty(self, content: str) -> float:
        """Novelty score in [0.1, 0.9] based on whether content was seen before."""
        ch = hashlib.md5(content.encode()).hexdigest()[:16]
        if ch in self.content_hashes:
            count = self.content_hashes[ch]
            self.content_hashes[ch] = count + 1
            return max(0.1, 1.0 / (1 + math.log(count + 1)))
        self.content_hashes[ch] = 1
        return 0.9

    def get_domain_baseline(self, domain: str) -> float:
        return self.domain_baselines.get(domain, 0.5)


# ---------------------------------------------------------------------------
# Core Surprise Detector (beta_06 logic)
# ---------------------------------------------------------------------------

class _SurpriseDetector:
    """
    Internal full-fidelity surprise detector.

    Implements the beta_06 multi-dimensional algorithm with prediction
    resolution, domain baselines, and rarity tracking.
    """

    # Impact signal tables
    _IMPACT_INDICATORS = {
        "error":        0.80, "critical":    0.90, "success":  0.30,
        "failure":      0.70, "unexpected":  0.80, "warning":  0.50,
        "breakthrough": 0.85, "discovered":  0.70, "patent":   0.60,
        "revenue":      0.70, "validated":   0.40, "invalid":  0.60,
        # Genesis-specific
        "kinan":        0.65, "aiva":        0.70, "queen":    0.75,
        "deploy":       0.65, "shutdown":    0.85, "restart":  0.75,
        "payment":      0.70, "stripe":      0.70, "launch":   0.65,
    }

    # Additional rarity signals
    _RARITY_SIGNALS = [
        "unexpected", "unknown", "unhandled", "assertion",
        "timeout", "retry", "fallback", "override", "bypass",
        "workaround", "hack", "todo", "fixme",
    ]

    def __init__(self, history_path: str = None):
        self.history     = SurpriseHistory(history_path)
        self.predictions: Dict[str, Prediction] = {}

    # -- predictions --------------------------------------------------------

    def make_prediction(
        self,
        domain: str,
        expected_outcome: str,
        confidence: float = 0.7,
        context: Dict = None,
        ttl_minutes: int = 60,
    ) -> str:
        pred_id = hashlib.md5(
            f"{domain}:{expected_outcome}:{datetime.now().isoformat()}".encode()
        ).hexdigest()[:12]

        now = datetime.now()
        self.predictions[pred_id] = Prediction(
            prediction_id=pred_id,
            domain=domain,
            expected_outcome=expected_outcome,
            confidence=confidence,
            context=context or {},
            created_at=now.isoformat(),
            expires_at=(now + timedelta(minutes=ttl_minutes)).isoformat(),
        )
        return pred_id

    def resolve_prediction(
        self,
        prediction_id: str,
        actual_outcome: str,
    ) -> Tuple[float, SurpriseScore]:
        if prediction_id not in self.predictions:
            score = self.evaluate(actual_outcome, "unknown", {})
            return 0.5, score

        pred = self.predictions[prediction_id]
        pred_error = self._calc_prediction_error(
            pred.expected_outcome, actual_outcome, pred.confidence
        )
        pred.resolved       = True
        pred.actual_outcome = actual_outcome
        pred.prediction_error = pred_error

        score = self.evaluate(
            actual_outcome, pred.domain, pred.context,
            prediction_error=pred_error
        )
        return pred_error, score

    def _calc_prediction_error(
        self,
        expected: str,
        actual: str,
        confidence: float,
    ) -> float:
        exp_lower = expected.lower()
        act_lower = actual.lower()
        if exp_lower == act_lower:
            return 0.0
        exp_words = set(exp_lower.split())
        act_words = set(act_lower.split())
        if not exp_words or not act_words:
            return 0.5
        overlap    = len(exp_words & act_words)
        total      = len(exp_words | act_words)
        similarity = overlap / total
        base_error = 1.0 - similarity
        return min(1.0, base_error * (0.5 + 0.5 * confidence))

    # -- scoring ------------------------------------------------------------

    def evaluate(
        self,
        content: str,
        domain: str,
        context: Dict = None,
        prediction_error: float = None,
    ) -> SurpriseScore:
        context = context or {}

        if prediction_error is None:
            baseline        = self.history.get_domain_baseline(domain)
            prediction_error = abs(0.5 - baseline) + 0.30

        novelty = self.history.get_novelty(content)
        impact  = self._calc_impact(content)
        rarity  = self._calc_rarity(domain)

        score = SurpriseScore(
            prediction_error=prediction_error,
            violation=prediction_error,  # alias kept in sync
            novelty=novelty,
            impact=impact,
            rarity=rarity,
        )
        score.compute_total()

        self.history.add_event({
            "timestamp":       datetime.now().isoformat(),
            "domain":          domain,
            "content_preview": content[:100],
            "surprise_score":  score.total,
            "level":           score.level.value,
        })
        return score

    def _calc_impact(self, content: str) -> float:
        cl = content.lower()
        max_impact = 0.30
        for kw, val in self._IMPACT_INDICATORS.items():
            if kw in cl:
                max_impact = max(max_impact, val)
        return max_impact

    def _calc_rarity(self, domain: str) -> float:
        domain_events = [
            e for e in self.history.events if e.get("domain") == domain
        ]
        if len(domain_events) < 10:
            return 0.50
        return max(0.20, 1.0 - len(domain_events) / 100)

    def get_stats(self) -> Dict:
        events = self.history.events
        if not events:
            return {
                "total_events": 0,
                "avg_surprise": 0.5,
                "domains": [],
                "level_distribution": {},
            }
        avg = sum(e.get("surprise_score", 0.5) for e in events) / len(events)
        level_counts: Dict[str, int] = defaultdict(int)
        for e in events:
            level_counts[e.get("level", "mundane")] += 1
        return {
            "total_events":      len(events),
            "avg_surprise":      avg,
            "domains":           list(self.history.domain_baselines.keys()),
            "domain_baselines":  self.history.domain_baselines,
            "level_distribution": dict(level_counts),
            "active_predictions": len(
                [p for p in self.predictions.values() if not p.resolved]
            ),
        }


# ---------------------------------------------------------------------------
# Public MemorySystem facade (backward-compatible with all callers)
# ---------------------------------------------------------------------------

class MemorySystem:
    """
    Surprise-based memory routing system.

    Drop-in replacement for the previous stub.  All prior callers
    (GenesisKernel, RLM Gateway, Titan Memory, orchestrator) continue
    to work via the same three-method API:

        evaluate(content, source, domain) → Dict
        observe(event_type, actual_outcome, context) → SurpriseScore
        reflect(actual) → List[Dict]
        get_stats() → Dict

    The underlying engine is the full beta_06 SurpriseDetector.
    """

    _MIN_CONTENT_LENGTH = 10  # Ignore trivially short snippets

    def __init__(self, persistence_path: str = None):
        self._engine   = _SurpriseDetector(persistence_path)
        # Legacy attribute used by some callers
        self.memories: List[MemoryItem] = []
        self._seen_content_hashes: set  = set()

    # -- primary API --------------------------------------------------------

    def evaluate(self, content: str, source: str, domain: str) -> Dict:
        """
        Evaluate content surprise and route to memory tier.

        Returns:
            {
                "score": SurpriseScore.to_dict(),
                "tier":  "episodic" | "working"
            }
        """
        if not content or len(content) < self._MIN_CONTENT_LENGTH:
            score = SurpriseScore(composite_score=0.10, level=SurpriseLevel.MUNDANE)
            score.total = 0.10
            return {"score": score.to_dict(), "tier": "working"}

        score = self._engine.evaluate(content, domain)

        # Track for legacy novelty de-dup
        ch = hashlib.md5(content.encode()).hexdigest()
        self._seen_content_hashes.add(ch)
        self.memories.append(MemoryItem(content=content, source=source, domain=domain))

        tier = "episodic" if score.total > 0.70 else "working"
        return {"score": score.to_dict(), "tier": tier}

    def observe(
        self,
        event_type: str,
        actual_outcome: str,
        context: Dict = None,
    ) -> SurpriseScore:
        """
        Observe an event and return surprise score.

        Compatible with GenesisKernel.observe() signature.
        """
        content = actual_outcome or event_type
        domain  = (context or {}).get("domain", event_type)
        return self._engine.evaluate(content, domain, context)

    def reflect(self, actual: str) -> List[Dict]:
        """Legacy stub kept for interface compatibility."""
        return []

    def get_stats(self) -> Dict:
        return self._engine.get_stats()

    # -- prediction API (forwarded to engine) --------------------------------

    def make_prediction(
        self,
        domain: str,
        expected_outcome: str,
        confidence: float = 0.7,
        context: Dict = None,
        ttl_minutes: int = 60,
    ) -> str:
        """Register a prediction; returns prediction_id."""
        return self._engine.make_prediction(
            domain, expected_outcome, confidence, context, ttl_minutes
        )

    def resolve_prediction(
        self,
        prediction_id: str,
        actual_outcome: str,
    ) -> Tuple[float, SurpriseScore]:
        """Resolve a prediction; returns (error_score, SurpriseScore)."""
        return self._engine.resolve_prediction(prediction_id, actual_outcome)


# ---------------------------------------------------------------------------
# Module-level singleton (convenience)
# ---------------------------------------------------------------------------

_default_memory_system: Optional[MemorySystem] = None


def get_memory_system(persistence_path: str = None) -> MemorySystem:
    """Return the module-level singleton MemorySystem."""
    global _default_memory_system
    if _default_memory_system is None:
        _default_memory_system = MemorySystem(persistence_path)
    return _default_memory_system


# ---------------------------------------------------------------------------
# Self-test
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    print("=== AIVA Surprise Memory System — Self-Test ===\n")

    ms = MemorySystem()

    # Test 1: Routine content
    result = ms.evaluate(
        "The sky is blue today.", "test", "general"
    )
    print(f"Routine:  score={result['score']['composite_score']:.3f}  tier={result['tier']}")

    # Test 2: High-impact content
    result = ms.evaluate(
        "CRITICAL error: AIVA server deployment failed unexpectedly during revenue call.",
        "system", "operations"
    )
    print(f"Critical: score={result['score']['composite_score']:.3f}  tier={result['tier']}")

    # Test 3: Prediction workflow
    pred_id = ms.make_prediction(
        domain="deployment",
        expected_outcome="Deployment completed successfully",
        confidence=0.85,
    )
    error, score = ms.resolve_prediction(
        pred_id,
        "Critical failure: deployment rolled back due to unexpected database migration error"
    )
    print(f"Predict:  error={error:.3f}  score={score.total:.3f}  level={score.level.value}")
    print(f"          promote_memory={score.should_promote_memory}  gen_axiom={score.should_generate_axiom}")

    # Test 4: observe() compatibility
    obs = ms.observe("call_ended", "Caller booked a demo appointment", {"domain": "sales"})
    print(f"Observe:  score={obs.composite_score:.3f}  level={obs.level.value}")

    # Stats
    stats = ms.get_stats()
    print(f"\nStats:    total_events={stats['total_events']}  avg_surprise={stats['avg_surprise']:.3f}")
    print("\nSelf-test PASSED")
