#!/usr/bin/env python3
"""
Genesis Surprise-Based Memory System
=====================================
Implements intelligent memory retention based on surprise metrics.

Usage:
    from surprise_memory import MemorySystem

    memory = MemorySystem()
    score = memory.evaluate("Claude Code completed Task 007")
    memory.store_if_worthy(info, score)
"""

import json
import re
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict
from pathlib import Path
import hashlib


@dataclass
class MemoryItem:
    """Represents a piece of information being evaluated."""
    content: str
    source: str
    domain: str
    timestamp: str = ""
    metadata: Dict[str, Any] = None

    def __post_init__(self):
        if not self.timestamp:
            self.timestamp = datetime.now().isoformat()
        if self.metadata is None:
            self.metadata = {}


@dataclass
class SurpriseScore:
    """Detailed surprise score breakdown."""
    violation: float
    novelty: float
    impact: float
    rarity: float
    total: float

    def to_dict(self) -> dict:
        return asdict(self)


class SurpriseCalculator:
    """
    Calculates surprise scores for new information.

    Surprise = weighted combination of:
    - Violation: Does this contradict expectations?
    - Novelty: Is this information new?
    - Impact: Will this change future decisions?
    - Rarity: How uncommon is this pattern?
    """

    WEIGHTS = {
        "violation": 0.30,
        "novelty": 0.30,
        "impact": 0.25,
        "rarity": 0.15
    }

    # Keywords that indicate high-impact information
    IMPACT_KEYWORDS = [
        "capability", "limitation", "strategy", "decision", "architecture",
        "critical", "breakthrough", "failure", "success", "learned",
        "discovered", "realized", "changed", "important", "must",
        "always", "never", "bug", "fix", "error", "solution"
    ]

    # Domains with their expected patterns
    EXPECTATIONS = {
        "technical": ["works", "implemented", "completed", "tested"],
        "learning": ["understood", "realized", "learned"],
        "error": ["failed", "error", "bug", "issue"],
        "decision": ["decided", "chose", "selected", "adopted"]
    }

    def __init__(self, memory_store: Optional['MemoryStore'] = None):
        self.memory = memory_store or MemoryStore()

    def calculate(self, item: MemoryItem) -> SurpriseScore:
        """Calculate complete surprise score for an item."""
        violation = self._calc_violation(item)
        novelty = self._calc_novelty(item)
        impact = self._calc_impact(item)
        rarity = self._calc_rarity(item)

        total = (
            violation * self.WEIGHTS["violation"] +
            novelty * self.WEIGHTS["novelty"] +
            impact * self.WEIGHTS["impact"] +
            rarity * self.WEIGHTS["rarity"]
        )

        return SurpriseScore(
            violation=round(violation, 3),
            novelty=round(novelty, 3),
            impact=round(impact, 3),
            rarity=round(rarity, 3),
            total=round(total, 3)
        )

    def _calc_violation(self, item: MemoryItem) -> float:
        """Check if information violates domain expectations."""
        content_lower = item.content.lower()
        expected = self.EXPECTATIONS.get(item.domain, [])

        if not expected:
            return 0.5  # Neutral for unknown domains

        # Check for negation + expected pattern = violation
        negation_patterns = ["not ", "didn't ", "failed to ", "couldn't ", "won't "]

        for exp in expected:
            if exp in content_lower:
                # Check if negated
                for neg in negation_patterns:
                    if neg + exp in content_lower or exp + " not" in content_lower:
                        return 0.9  # Strong violation
                return 0.1  # Expected pattern, low surprise

        # No expected patterns found - mildly surprising
        return 0.5

    def _calc_novelty(self, item: MemoryItem) -> float:
        """Calculate novelty via similarity to existing memories."""
        existing = self.memory.search_similar(item.content, limit=5)

        if not existing:
            return 1.0  # Completely novel

        # Calculate word overlap similarity
        item_words = set(self._tokenize(item.content))

        max_similarity = 0.0
        for mem in existing:
            mem_words = set(self._tokenize(mem.get("content", "")))
            if item_words and mem_words:
                overlap = len(item_words & mem_words)
                union = len(item_words | mem_words)
                similarity = overlap / union if union > 0 else 0
                max_similarity = max(max_similarity, similarity)

        return 1.0 - max_similarity

    def _calc_impact(self, item: MemoryItem) -> float:
        """Estimate impact on future decisions."""
        content_lower = item.content.lower()

        matches = sum(1 for k in self.IMPACT_KEYWORDS if k in content_lower)

        # Normalize: 3+ keywords = max impact
        return min(matches / 3, 1.0)

    def _calc_rarity(self, item: MemoryItem) -> float:
        """Calculate rarity based on pattern frequency."""
        similar_count = self.memory.count_similar(item.content)

        # Inverse frequency: more similar items = less rare
        return 1.0 / (1 + similar_count)

    def _tokenize(self, text: str) -> List[str]:
        """Simple word tokenization."""
        words = re.findall(r'\b[a-z]+\b', text.lower())
        # Filter stopwords
        stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been',
                     'to', 'of', 'and', 'in', 'that', 'it', 'for', 'on', 'with'}
        return [w for w in words if w not in stopwords and len(w) > 2]


class MemoryStore:
    """
    Local memory store for surprise-based memory system.
    Stores memories as JSON for persistence.
    """

    def __init__(self, storage_path: str = None):
        if storage_path is None:
            storage_path = Path("/mnt/e/genesis-system/memory_store.json")
        self.storage_path = Path(storage_path)
        self.memories: List[Dict] = []
        self._load()

    def _load(self):
        """Load memories from disk."""
        if self.storage_path.exists():
            try:
                with open(self.storage_path, 'r') as f:
                    data = json.load(f)
                    self.memories = data.get("memories", [])
            except (json.JSONDecodeError, IOError):
                self.memories = []

    def _save(self):
        """Save memories to disk."""
        self.storage_path.parent.mkdir(parents=True, exist_ok=True)
        with open(self.storage_path, 'w') as f:
            json.dump({
                "memories": self.memories,
                "last_updated": datetime.now().isoformat()
            }, f, indent=2)

    def store(self, item: MemoryItem, score: SurpriseScore, tier: str):
        """Store a memory item with its score."""
        memory = {
            "id": self._generate_id(item.content),
            "content": item.content,
            "source": item.source,
            "domain": item.domain,
            "timestamp": item.timestamp,
            "metadata": item.metadata,
            "score": score.to_dict(),
            "tier": tier,
            "access_count": 0,
            "created": datetime.now().isoformat()
        }
        self.memories.append(memory)
        self._save()
        return memory["id"]

    def search_similar(self, content: str, limit: int = 5) -> List[Dict]:
        """Find memories similar to the given content."""
        query_words = set(content.lower().split())

        scored = []
        for mem in self.memories:
            mem_words = set(mem.get("content", "").lower().split())
            if query_words and mem_words:
                overlap = len(query_words & mem_words)
                union = len(query_words | mem_words)
                similarity = overlap / union if union > 0 else 0
                if similarity > 0.1:  # Minimum threshold
                    scored.append((similarity, mem))

        scored.sort(key=lambda x: x[0], reverse=True)
        return [{"score": s, "content": m["content"]} for s, m in scored[:limit]]

    def count_similar(self, content: str, threshold: float = 0.3) -> int:
        """Count memories similar to the given content."""
        similar = self.search_similar(content, limit=100)
        return len([s for s in similar if s["score"] > threshold])

    def get_by_tier(self, tier: str) -> List[Dict]:
        """Get all memories in a specific tier."""
        return [m for m in self.memories if m.get("tier") == tier]

    def get_recent(self, limit: int = 10) -> List[Dict]:
        """Get most recent memories."""
        sorted_mems = sorted(
            self.memories,
            key=lambda x: x.get("created", ""),
            reverse=True
        )
        return sorted_mems[:limit]

    def increment_access(self, memory_id: str):
        """Increment access count for a memory."""
        for mem in self.memories:
            if mem.get("id") == memory_id:
                mem["access_count"] = mem.get("access_count", 0) + 1
                mem["last_accessed"] = datetime.now().isoformat()
                self._save()
                break

    def _generate_id(self, content: str) -> str:
        """Generate unique ID for memory."""
        timestamp = datetime.now().isoformat()
        hash_input = f"{content}{timestamp}".encode()
        return hashlib.sha256(hash_input).hexdigest()[:16]


class MemoryRouter:
    """
    Routes information to appropriate memory tier based on surprise score.
    """

    THRESHOLDS = {
        "discard": 0.3,
        "working": 0.5,
        "episodic": 0.6,
        "semantic": 0.8
    }

    def __init__(self, calculator: SurpriseCalculator = None, store: MemoryStore = None):
        self.store = store or MemoryStore()
        self.calculator = calculator or SurpriseCalculator(self.store)

    def route(self, item: MemoryItem) -> Dict[str, Any]:
        """Evaluate and route information to appropriate tier."""
        score = self.calculator.calculate(item)

        if score.total < self.THRESHOLDS["discard"]:
            tier = "discarded"
            memory_id = None
        elif score.total < self.THRESHOLDS["working"]:
            tier = "working"
            memory_id = None  # Not persisted
        elif score.total < self.THRESHOLDS["episodic"]:
            tier = "episodic"
            memory_id = self.store.store(item, score, tier)
        else:
            tier = "semantic"
            memory_id = self.store.store(item, score, tier)

        return {
            "tier": tier,
            "score": score.to_dict(),
            "memory_id": memory_id,
            "stored": memory_id is not None
        }

    def evaluate_only(self, content: str, source: str = "unknown",
                      domain: str = "general") -> Dict[str, Any]:
        """Evaluate surprise without storing."""
        item = MemoryItem(content=content, source=source, domain=domain)
        score = self.calculator.calculate(item)

        # Determine tier without storing
        if score.total < self.THRESHOLDS["discard"]:
            tier = "would_discard"
        elif score.total < self.THRESHOLDS["working"]:
            tier = "would_keep_working"
        elif score.total < self.THRESHOLDS["episodic"]:
            tier = "would_store_episodic"
        else:
            tier = "would_store_semantic"

        return {
            "tier": tier,
            "score": score.to_dict(),
            "recommendation": f"Score {score.total:.2f} → {tier}"
        }


class ReflectiveLoop:
    """
    Implements reflective learning from action outcomes.

    1. Observe: Record action + expected outcome
    2. Reflect: Compare expected vs actual
    3. Learn: Store significant deviations
    """

    def __init__(self, router: MemoryRouter = None):
        self.router = router or MemoryRouter()
        self.observations: List[Dict] = []

    def observe(self, action: str, expected_outcome: str, context: str = ""):
        """Record an action and expected outcome."""
        self.observations.append({
            "action": action,
            "expected": expected_outcome,
            "context": context,
            "timestamp": datetime.now().isoformat()
        })

    def reflect(self, actual_outcome: str) -> List[Dict]:
        """Compare expected vs actual, extract learnings."""
        learnings = []

        for obs in self.observations:
            deviation = self._measure_deviation(obs["expected"], actual_outcome)

            if deviation > 0.4:  # Significant deviation
                learning = MemoryItem(
                    content=f"Learning: Expected '{obs['expected']}' but got '{actual_outcome}'",
                    source="reflective_loop",
                    domain="learning",
                    metadata={
                        "action": obs["action"],
                        "deviation": deviation,
                        "context": obs["context"],
                        "original_expectation": obs["expected"],
                        "actual_outcome": actual_outcome
                    }
                )

                result = self.router.route(learning)
                learnings.append({
                    "observation": obs,
                    "actual": actual_outcome,
                    "deviation": deviation,
                    "routing": result
                })

        self.observations.clear()
        return learnings

    def _measure_deviation(self, expected: str, actual: str) -> float:
        """Measure semantic deviation between expected and actual."""
        expected_words = set(expected.lower().split())
        actual_words = set(actual.lower().split())

        if not expected_words or not actual_words:
            return 0.5

        overlap = len(expected_words & actual_words)
        total = len(expected_words | actual_words)
        similarity = overlap / total if total > 0 else 0

        return 1.0 - similarity


class MemorySystem:
    """
    Main interface for Genesis surprise-based memory.

    Usage:
        memory = MemorySystem()

        # Evaluate and store
        result = memory.process("Important discovery about memory architecture")

        # Just evaluate
        score = memory.evaluate("Some information")

        # Record action for reflection
        memory.observe("Deployed new feature", "Users will love it")
        memory.reflect("Users reported bugs")
    """

    def __init__(self, storage_path: str = None):
        self.store = MemoryStore(storage_path)
        self.calculator = SurpriseCalculator(self.store)
        self.router = MemoryRouter(self.calculator, self.store)
        self.reflector = ReflectiveLoop(self.router)

    def process(self, content: str, source: str = "claude_code",
                domain: str = "general", metadata: Dict = None) -> Dict:
        """Evaluate information and store if worthy."""
        item = MemoryItem(
            content=content,
            source=source,
            domain=domain,
            metadata=metadata or {}
        )
        return self.router.route(item)

    def evaluate(self, content: str, source: str = "unknown",
                 domain: str = "general") -> Dict:
        """Evaluate surprise score without storing."""
        return self.router.evaluate_only(content, source, domain)

    def observe(self, action: str, expected_outcome: str, context: str = ""):
        """Record an observation for reflective learning."""
        self.reflector.observe(action, expected_outcome, context)

    def reflect(self, actual_outcome: str) -> List[Dict]:
        """Process observations against actual outcome."""
        return self.reflector.reflect(actual_outcome)

    def get_stats(self) -> Dict:
        """Get memory system statistics."""
        memories = self.store.memories

        tier_counts = {}
        for mem in memories:
            tier = mem.get("tier", "unknown")
            tier_counts[tier] = tier_counts.get(tier, 0) + 1

        avg_score = 0
        if memories:
            scores = [m.get("score", {}).get("total", 0) for m in memories]
            avg_score = sum(scores) / len(scores)

        return {
            "total_memories": len(memories),
            "by_tier": tier_counts,
            "average_score": round(avg_score, 3),
            "recent": self.store.get_recent(5)
        }


# CLI Interface
if __name__ == "__main__":
    import sys

    memory = MemorySystem()

    if len(sys.argv) < 2:
        print("Usage:")
        print("  python surprise_memory.py evaluate 'Some information'")
        print("  python surprise_memory.py process 'Important discovery'")
        print("  python surprise_memory.py stats")
        sys.exit(0)

    command = sys.argv[1]

    if command == "stats":
        stats = memory.get_stats()
        print(json.dumps(stats, indent=2))

    elif command == "evaluate" and len(sys.argv) > 2:
        content = " ".join(sys.argv[2:])
        result = memory.evaluate(content)
        print(json.dumps(result, indent=2))

    elif command == "process" and len(sys.argv) > 2:
        content = " ".join(sys.argv[2:])
        result = memory.process(content)
        print(json.dumps(result, indent=2))

    else:
        print(f"Unknown command: {command}")
        sys.exit(1)
