"""
Genesis Enhanced Surprise Detection
====================================
Embedding-based novelty detection replacing simple keyword matching.

Uses cosine similarity against memory bank to detect truly novel information.
Higher surprise = lower similarity to existing memories.

Usage:
    from enhanced_surprise import EnhancedSurpriseDetector

    detector = EnhancedSurpriseDetector()

    # Score new content
    result = detector.evaluate("Some new information", "source", "domain")
    print(result["score"]["novelty"])  # 0.0-1.0

    # Add to memory bank
    detector.add_to_memory("Known information", "source", "domain")
"""

import json
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field, asdict
import hashlib
import threading

# Try to import embedding models
try:
    from fastembed import TextEmbedding
    FASTEMBED_AVAILABLE = True
except ImportError:
    FASTEMBED_AVAILABLE = False
    TextEmbedding = None

# Fallback: try sentence-transformers
if not FASTEMBED_AVAILABLE:
    try:
        from sentence_transformers import SentenceTransformer
        SENTENCE_TRANSFORMERS_AVAILABLE = True
    except ImportError:
        SENTENCE_TRANSFORMERS_AVAILABLE = False
        SentenceTransformer = None
else:
    SENTENCE_TRANSFORMERS_AVAILABLE = False


@dataclass
class SurpriseScore:
    """Detailed surprise score breakdown."""
    novelty: float = 0.0       # How different from known memories (embedding-based)
    violation: float = 0.0     # Expectation violation (prediction error)
    impact: float = 0.0        # Estimated importance/impact
    rarity: float = 0.0        # Domain rarity score
    total: float = 0.0         # Combined weighted score

    def to_dict(self) -> Dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: Dict) -> 'SurpriseScore':
        return cls(**data)


@dataclass
class MemoryVector:
    """A memory with its embedding vector."""
    id: str
    content: str
    source: str
    domain: str
    vector: List[float]
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    access_count: int = 0

    def to_dict(self) -> Dict:
        return {
            "id": self.id,
            "content": self.content,
            "source": self.source,
            "domain": self.domain,
            "vector": self.vector,
            "timestamp": self.timestamp,
            "access_count": self.access_count
        }

    @classmethod
    def from_dict(cls, data: Dict) -> 'MemoryVector':
        return cls(**data)


class EnhancedSurpriseDetector:
    """
    Embedding-based surprise detection for Genesis memory system.

    Features:
    - Cosine similarity against memory bank
    - Domain-specific thresholds
    - Prediction-error based violation detection
    - Adaptive novelty scaling
    """

    def __init__(
        self,
        memory_path: Optional[str] = None,
        model_name: str = "BAAI/bge-small-en-v1.5",
        similarity_threshold: float = 0.7,
        max_memories: int = 10000
    ):
        self.memory_path = Path(memory_path) if memory_path else Path("E:/genesis-system/data/surprise_vectors.json")
        self.similarity_threshold = similarity_threshold
        self.max_memories = max_memories
        self._lock = threading.RLock()

        # Memory storage
        self.memories: Dict[str, MemoryVector] = {}
        self.domain_stats: Dict[str, Dict] = {}

        # Initialize embedder
        self.embedder = None
        self.vector_size = 384
        self._init_embedder(model_name)

        # Load existing memories
        self._load()

        # Weights for combined score
        self.weights = {
            "novelty": 0.4,
            "violation": 0.2,
            "impact": 0.2,
            "rarity": 0.2
        }

    def _init_embedder(self, model_name: str) -> None:
        """Initialize the embedding model."""
        if FASTEMBED_AVAILABLE:
            try:
                self.embedder = TextEmbedding(model_name=model_name)
                self.vector_size = 384
                print(f"[OK] EnhancedSurprise: Using FastEmbed ({model_name})")
                return
            except Exception as e:
                print(f"[!] FastEmbed init failed: {e}")

        if SENTENCE_TRANSFORMERS_AVAILABLE:
            try:
                self.embedder = SentenceTransformer(model_name)
                self.vector_size = self.embedder.get_sentence_embedding_dimension()
                print(f"[OK] EnhancedSurprise: Using SentenceTransformers ({model_name})")
                return
            except Exception as e:
                print(f"[!] SentenceTransformers init failed: {e}")

        print("[!] EnhancedSurprise: No embedding model available, using fallback")
        self.embedder = None

    def _embed(self, text: str) -> List[float]:
        """Generate embedding for text."""
        if self.embedder is None:
            # Fallback: simple hash-based pseudo-embedding
            return self._fallback_embed(text)

        try:
            if FASTEMBED_AVAILABLE and hasattr(self.embedder, 'embed'):
                # FastEmbed returns generator
                embeddings = list(self.embedder.embed([text]))
                return embeddings[0].tolist()
            elif SENTENCE_TRANSFORMERS_AVAILABLE:
                embedding = self.embedder.encode(text)
                return embedding.tolist()
        except Exception as e:
            print(f"[!] Embedding failed: {e}")
            return self._fallback_embed(text)

    def _fallback_embed(self, text: str) -> List[float]:
        """Fallback embedding using hash-based features."""
        # Create deterministic pseudo-embedding from text
        words = text.lower().split()
        vector = [0.0] * self.vector_size

        for i, word in enumerate(words):
            hash_val = int(hashlib.md5(word.encode()).hexdigest(), 16)
            idx = hash_val % self.vector_size
            vector[idx] += 1.0 / (i + 1)

        # Normalize
        norm = sum(v**2 for v in vector) ** 0.5
        if norm > 0:
            vector = [v / norm for v in vector]

        return vector

    def _cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
        """Calculate cosine similarity between two vectors."""
        a = np.array(v1)
        b = np.array(v2)
        dot = np.dot(a, b)
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        if norm_a == 0 or norm_b == 0:
            return 0.0
        return float(dot / (norm_a * norm_b))

    def _find_most_similar(
        self,
        vector: List[float],
        domain: Optional[str] = None,
        top_k: int = 5
    ) -> List[Tuple[str, float]]:
        """Find most similar memories."""
        similarities = []

        for mem_id, memory in self.memories.items():
            if domain and memory.domain != domain:
                continue
            sim = self._cosine_similarity(vector, memory.vector)
            similarities.append((mem_id, sim))

        # Sort by similarity descending
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]

    def evaluate(
        self,
        content: str,
        source: str,
        domain: str,
        expected: Optional[str] = None
    ) -> Dict:
        """
        Evaluate surprise/novelty of content.

        Args:
            content: The content to evaluate
            source: Source of the content
            domain: Domain/category
            expected: Optional expected content for violation detection

        Returns:
            Dict with score, tier recommendation, and details
        """
        with self._lock:
            # Generate embedding
            vector = self._embed(content)

            # Calculate novelty (inverse of max similarity)
            novelty_score = 1.0
            similar = self._find_most_similar(vector, domain=None, top_k=3)
            if similar:
                max_sim = max(s[1] for s in similar)
                novelty_score = 1.0 - max_sim
                # Scale to emphasize truly novel content
                novelty_score = min(1.0, novelty_score * 1.5)

            # Calculate violation (prediction error)
            violation_score = 0.0
            if expected:
                expected_vector = self._embed(expected)
                violation_score = 1.0 - self._cosine_similarity(vector, expected_vector)

            # Calculate domain rarity
            rarity_score = self._calculate_rarity(domain)

            # Calculate impact (heuristic based on content)
            impact_score = self._estimate_impact(content, domain)

            # Combined score
            total = (
                self.weights["novelty"] * novelty_score +
                self.weights["violation"] * violation_score +
                self.weights["impact"] * impact_score +
                self.weights["rarity"] * rarity_score
            )

            score = SurpriseScore(
                novelty=round(novelty_score, 3),
                violation=round(violation_score, 3),
                impact=round(impact_score, 3),
                rarity=round(rarity_score, 3),
                total=round(total, 3)
            )

            # Determine memory tier
            if total > 0.7:
                tier = "semantic"  # Long-term conceptual
            elif total > 0.4:
                tier = "episodic"  # Experiential
            else:
                tier = "working"   # Short-term

            return {
                "score": score.to_dict(),
                "tier": tier,
                "similar_count": len(similar),
                "max_similarity": round(similar[0][1], 3) if similar else 0.0,
                "embedding_size": len(vector)
            }

    def _calculate_rarity(self, domain: str) -> float:
        """Calculate domain rarity based on frequency."""
        stats = self.domain_stats.get(domain, {"count": 0})
        total_memories = len(self.memories)
        if total_memories == 0:
            return 1.0  # First memory is maximally rare

        domain_count = stats.get("count", 0)
        frequency = domain_count / total_memories
        # Invert frequency to get rarity
        return 1.0 - min(frequency * 5, 1.0)  # Scale factor of 5

    def _estimate_impact(self, content: str, domain: str) -> float:
        """Estimate impact based on content analysis."""
        content_lower = content.lower()
        impact = 0.3  # Base impact

        # High-impact keywords
        high_impact = ["critical", "error", "urgent", "breaking", "important",
                       "security", "vulnerability", "failure", "success"]
        medium_impact = ["update", "change", "new", "modified", "discovered",
                        "found", "learned", "insight"]

        for word in high_impact:
            if word in content_lower:
                impact = max(impact, 0.8)
                break

        for word in medium_impact:
            if word in content_lower:
                impact = max(impact, 0.5)

        # Length bonus (longer content often more impactful)
        if len(content) > 500:
            impact = min(impact + 0.1, 1.0)

        return impact

    def add_to_memory(
        self,
        content: str,
        source: str,
        domain: str,
        memory_id: Optional[str] = None
    ) -> str:
        """Add content to memory bank."""
        with self._lock:
            if memory_id is None:
                hash_input = f"{datetime.now().isoformat()}:{content[:100]}"
                memory_id = hashlib.sha256(hash_input.encode()).hexdigest()[:12]

            vector = self._embed(content)

            memory = MemoryVector(
                id=memory_id,
                content=content,
                source=source,
                domain=domain,
                vector=vector
            )

            self.memories[memory_id] = memory

            # Update domain stats
            if domain not in self.domain_stats:
                self.domain_stats[domain] = {"count": 0, "first_seen": datetime.now().isoformat()}
            self.domain_stats[domain]["count"] += 1

            # Prune if too many memories
            if len(self.memories) > self.max_memories:
                self._prune_old_memories()

            return memory_id

    def _prune_old_memories(self) -> None:
        """Remove oldest/least accessed memories."""
        if len(self.memories) <= self.max_memories:
            return

        # Sort by access count and timestamp
        sorted_mems = sorted(
            self.memories.items(),
            key=lambda x: (x[1].access_count, x[1].timestamp)
        )

        # Remove bottom 10%
        remove_count = len(sorted_mems) // 10
        for mem_id, _ in sorted_mems[:remove_count]:
            del self.memories[mem_id]

        print(f"[!] EnhancedSurprise: Pruned {remove_count} old memories")

    def _load(self) -> None:
        """Load memories from disk."""
        if not self.memory_path.exists():
            return

        try:
            with open(self.memory_path, 'r') as f:
                data = json.load(f)

            for mem_data in data.get("memories", []):
                memory = MemoryVector.from_dict(mem_data)
                self.memories[memory.id] = memory

            self.domain_stats = data.get("domain_stats", {})
            print(f"[OK] EnhancedSurprise: Loaded {len(self.memories)} memories")

        except Exception as e:
            print(f"[!] EnhancedSurprise load error: {e}")

    def persist(self) -> None:
        """Save memories to disk."""
        with self._lock:
            try:
                self.memory_path.parent.mkdir(parents=True, exist_ok=True)

                data = {
                    "memories": [m.to_dict() for m in self.memories.values()],
                    "domain_stats": self.domain_stats,
                    "saved_at": datetime.now().isoformat()
                }

                with open(self.memory_path, 'w') as f:
                    json.dump(data, f)

                print(f"[OK] EnhancedSurprise: Saved {len(self.memories)} memories")

            except Exception as e:
                print(f"[!] EnhancedSurprise persist error: {e}")

    def get_stats(self) -> Dict:
        """Get detector statistics."""
        return {
            "total_memories": len(self.memories),
            "domains": len(self.domain_stats),
            "domain_distribution": {
                k: v["count"] for k, v in self.domain_stats.items()
            },
            "embedder_available": self.embedder is not None,
            "vector_size": self.vector_size
        }


# Backwards compatibility: wrap as MemorySystem interface
class MemorySystem:
    """Backwards-compatible wrapper for enhanced surprise detection."""

    def __init__(self, persistence_path: Optional[str] = None):
        self.detector = EnhancedSurpriseDetector(
            memory_path=persistence_path if persistence_path else "E:/genesis-system/data/surprise_vectors.json"
        )

    def evaluate(self, content: str, source: str, domain: str) -> Dict:
        return self.detector.evaluate(content, source, domain)

    def observe(self, action: str, expected: str) -> None:
        """Store expectation for future violation detection."""
        self._last_expectation = expected

    def reflect(self, actual: str) -> List[Dict]:
        """Compare actual with expected."""
        if hasattr(self, '_last_expectation'):
            result = self.detector.evaluate(
                actual, "reflection", "internal",
                expected=self._last_expectation
            )
            return [{"violation_score": result["score"]["violation"]}]
        return []

    def get_stats(self) -> Dict:
        return self.detector.get_stats()


# CLI interface
if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        cmd = sys.argv[1]

        detector = EnhancedSurpriseDetector()

        if cmd == "stats":
            stats = detector.get_stats()
            print(json.dumps(stats, indent=2))

        elif cmd == "evaluate":
            if len(sys.argv) > 2:
                text = " ".join(sys.argv[2:])
                result = detector.evaluate(text, "cli", "test")
                print(json.dumps(result, indent=2))
            else:
                print("Usage: python enhanced_surprise.py evaluate <text>")

        elif cmd == "demo":
            print("=== Enhanced Surprise Detection Demo ===\n")

            # Add some baseline memories
            detector.add_to_memory("The weather is sunny today", "demo", "weather")
            detector.add_to_memory("Python is a programming language", "demo", "tech")
            detector.add_to_memory("Machine learning uses neural networks", "demo", "tech")

            # Test novelty detection
            tests = [
                ("The weather is rainy", "weather"),  # Similar
                ("Python has great libraries", "tech"),  # Similar
                ("Quantum computers use qubits", "tech"),  # Novel
                ("Critical security vulnerability found", "security"),  # Novel + high impact
            ]

            for text, domain in tests:
                result = detector.evaluate(text, "demo", domain)
                print(f"Text: '{text[:50]}...'")
                print(f"  Novelty: {result['score']['novelty']:.2f}")
                print(f"  Impact: {result['score']['impact']:.2f}")
                print(f"  Total: {result['score']['total']:.2f}")
                print(f"  Tier: {result['tier']}\n")

        else:
            print(f"Unknown command: {cmd}")
            print("Usage: python enhanced_surprise.py [stats|evaluate|demo]")
    else:
        print("Genesis Enhanced Surprise Detection")
        print("Usage: python enhanced_surprise.py [stats|evaluate|demo]")
