"""
Genesis Memory System - Value of Information (VoI) Scoring
==========================================================
Implements utility-based memory management for optimal retention.

Based on 2025 research:
- UserCentrix (Saleh et al.)
- Mem0 Priority Scoring
- GAM Architecture patterns

VoI Score = w1*Recency + w2*Relevance + w3*Importance + w4*Outcome
"""

import math
from datetime import datetime, timezone, timedelta
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
import json


class OutcomeType(Enum):
    """Types of memory usage outcomes."""
    SUCCESS = "success"          # Memory led to successful action
    PARTIAL = "partial"          # Memory partially helpful
    IRRELEVANT = "irrelevant"    # Memory retrieved but not useful
    FAILURE = "failure"          # Memory led to wrong action
    NOT_USED = "not_used"        # Memory never retrieved


@dataclass
class MemoryOutcome:
    """Record of memory usage outcome."""
    memory_id: str
    outcome_type: OutcomeType
    context: str
    timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
    score_impact: float = 0.0


@dataclass
class VoIScore:
    """Complete VoI score breakdown."""
    memory_id: str
    total_score: float
    recency_score: float
    relevance_score: float
    importance_score: float
    outcome_score: float
    usage_count: int
    last_used: Optional[str]
    created_at: str
    decay_applied: bool = False


class VoIScorer:
    """
    Value of Information scoring engine.

    Calculates utility-based scores for memory retention decisions.
    Supports:
    - Temporal decay (exponential)
    - Outcome tracking (success/failure)
    - Dynamic weight adjustment
    - Auto-pruning recommendations
    """

    # Default weights (can be tuned via learning)
    DEFAULT_WEIGHTS = {
        "recency": 0.25,
        "relevance": 0.35,
        "importance": 0.20,
        "outcome": 0.20
    }

    # Decay parameters
    DECAY_LAMBDA = 0.01  # Decay rate per hour
    DECAY_HALF_LIFE_HOURS = 72  # ~3 days half-life

    # Outcome score mappings
    OUTCOME_SCORES = {
        OutcomeType.SUCCESS: 1.0,
        OutcomeType.PARTIAL: 0.6,
        OutcomeType.IRRELEVANT: 0.2,
        OutcomeType.FAILURE: 0.0,
        OutcomeType.NOT_USED: 0.5  # Neutral
    }

    # Pruning thresholds
    PRUNE_THRESHOLD = 0.15  # Below this = recommend deletion
    ARCHIVE_THRESHOLD = 0.30  # Below this = recommend archive

    def __init__(self, weights: Optional[Dict[str, float]] = None):
        self.weights = weights or self.DEFAULT_WEIGHTS.copy()
        self.outcome_history: Dict[str, List[MemoryOutcome]] = {}
        self.usage_counts: Dict[str, int] = {}
        self.last_used: Dict[str, str] = {}

    # =========================================================================
    # CORE VoI CALCULATION
    # =========================================================================

    def calculate_voi(
        self,
        memory_id: str,
        content: str,
        created_at: str,
        importance: float,
        relevance_score: float = 0.5,
        query_context: Optional[str] = None
    ) -> VoIScore:
        """
        Calculate complete Value of Information score.

        Args:
            memory_id: Unique memory identifier
            content: Memory content (for additional analysis)
            created_at: ISO timestamp of creation
            importance: Pre-assigned importance (0.0-1.0)
            relevance_score: Context relevance from search (0.0-1.0)
            query_context: Current query for dynamic relevance

        Returns:
            VoIScore with complete breakdown
        """
        # 1. Recency Score (exponential decay)
        recency = self._calculate_recency(created_at)

        # 2. Relevance Score (passed from retriever or calculated)
        relevance = relevance_score

        # 3. Importance Score (pre-assigned)
        importance_normalized = max(0.0, min(1.0, importance))

        # 4. Outcome Score (based on usage history)
        outcome = self._calculate_outcome_score(memory_id)

        # Combine with weights
        total = (
            self.weights["recency"] * recency +
            self.weights["relevance"] * relevance +
            self.weights["importance"] * importance_normalized +
            self.weights["outcome"] * outcome
        )

        return VoIScore(
            memory_id=memory_id,
            total_score=round(total, 4),
            recency_score=round(recency, 4),
            relevance_score=round(relevance, 4),
            importance_score=round(importance_normalized, 4),
            outcome_score=round(outcome, 4),
            usage_count=self.usage_counts.get(memory_id, 0),
            last_used=self.last_used.get(memory_id),
            created_at=created_at,
            decay_applied=True
        )

    def _calculate_recency(self, created_at: str) -> float:
        """
        Calculate recency score with exponential decay.

        Uses: score = exp(-lambda * age_in_hours)
        """
        try:
            if isinstance(created_at, str):
                created = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
            else:
                created = created_at

            now = datetime.now(timezone.utc)
            age_hours = (now - created).total_seconds() / 3600

            # Exponential decay
            # With default lambda=0.01, half-life is ~69 hours
            decay_score = math.exp(-self.DECAY_LAMBDA * age_hours)

            return max(0.0, min(1.0, decay_score))

        except Exception:
            return 0.5  # Neutral if parsing fails

    def _calculate_outcome_score(self, memory_id: str) -> float:
        """
        Calculate outcome score based on usage history.

        Weighted by recency of outcomes.
        """
        outcomes = self.outcome_history.get(memory_id, [])

        if not outcomes:
            return 0.5  # Neutral for unused memories

        # Weight recent outcomes more heavily
        weighted_sum = 0.0
        weight_total = 0.0

        for i, outcome in enumerate(reversed(outcomes[-10:])):  # Last 10 outcomes
            weight = 1.0 / (i + 1)  # More recent = higher weight
            score = self.OUTCOME_SCORES.get(outcome.outcome_type, 0.5)
            weighted_sum += weight * score
            weight_total += weight

        if weight_total > 0:
            return weighted_sum / weight_total
        return 0.5

    # =========================================================================
    # OUTCOME TRACKING
    # =========================================================================

    def record_outcome(
        self,
        memory_id: str,
        outcome_type: OutcomeType,
        context: str = ""
    ) -> MemoryOutcome:
        """
        Record an outcome when a memory is used.

        Args:
            memory_id: Memory that was used
            outcome_type: Result of using the memory
            context: Description of usage context

        Returns:
            MemoryOutcome record
        """
        outcome = MemoryOutcome(
            memory_id=memory_id,
            outcome_type=outcome_type,
            context=context
        )

        # Store in history
        if memory_id not in self.outcome_history:
            self.outcome_history[memory_id] = []
        self.outcome_history[memory_id].append(outcome)

        # Update usage tracking
        self.usage_counts[memory_id] = self.usage_counts.get(memory_id, 0) + 1
        self.last_used[memory_id] = outcome.timestamp

        return outcome

    def record_retrieval(self, memory_id: str):
        """Record that a memory was retrieved (not necessarily used)."""
        self.usage_counts[memory_id] = self.usage_counts.get(memory_id, 0) + 1
        self.last_used[memory_id] = datetime.now(timezone.utc).isoformat()

    # =========================================================================
    # BATCH OPERATIONS
    # =========================================================================

    def score_batch(
        self,
        memories: List[Dict[str, Any]],
        relevance_scores: Optional[Dict[str, float]] = None
    ) -> List[VoIScore]:
        """
        Calculate VoI scores for a batch of memories.

        Args:
            memories: List of memory dicts with id, content, created_at, importance
            relevance_scores: Optional dict of memory_id -> relevance score

        Returns:
            List of VoIScore sorted by total_score descending
        """
        scores = []
        relevance_scores = relevance_scores or {}

        for mem in memories:
            score = self.calculate_voi(
                memory_id=mem.get("id", mem.get("memory_id", "")),
                content=mem.get("content", ""),
                created_at=mem.get("created_at", datetime.now(timezone.utc).isoformat()),
                importance=mem.get("importance", 0.5),
                relevance_score=relevance_scores.get(mem.get("id", ""), 0.5)
            )
            scores.append(score)

        # Sort by total score descending
        scores.sort(key=lambda x: x.total_score, reverse=True)
        return scores

    def get_prune_candidates(
        self,
        memories: List[Dict[str, Any]],
        threshold: Optional[float] = None
    ) -> Tuple[List[str], List[str]]:
        """
        Identify memories that should be pruned or archived.

        Returns:
            Tuple of (delete_ids, archive_ids)
        """
        prune_threshold = threshold or self.PRUNE_THRESHOLD
        archive_threshold = self.ARCHIVE_THRESHOLD

        scores = self.score_batch(memories)

        delete_ids = []
        archive_ids = []

        for score in scores:
            if score.total_score < prune_threshold:
                delete_ids.append(score.memory_id)
            elif score.total_score < archive_threshold:
                archive_ids.append(score.memory_id)

        return delete_ids, archive_ids

    # =========================================================================
    # WEIGHT ADJUSTMENT
    # =========================================================================

    def adjust_weights(
        self,
        feedback: Dict[str, float],
        learning_rate: float = 0.1
    ):
        """
        Adjust weights based on feedback.

        Args:
            feedback: Dict of component -> adjustment (+/- value)
            learning_rate: How much to adjust
        """
        for component, adjustment in feedback.items():
            if component in self.weights:
                new_weight = self.weights[component] + (learning_rate * adjustment)
                self.weights[component] = max(0.0, min(1.0, new_weight))

        # Normalize to sum to 1.0
        total = sum(self.weights.values())
        if total > 0:
            self.weights = {k: v / total for k, v in self.weights.items()}

    def get_weight_config(self) -> Dict[str, float]:
        """Get current weight configuration."""
        return self.weights.copy()

    # =========================================================================
    # REPORTING
    # =========================================================================

    def get_statistics(self) -> Dict[str, Any]:
        """Get scoring statistics."""
        total_outcomes = sum(len(v) for v in self.outcome_history.values())
        outcome_counts = {}

        for outcomes in self.outcome_history.values():
            for outcome in outcomes:
                ot = outcome.outcome_type.value
                outcome_counts[ot] = outcome_counts.get(ot, 0) + 1

        return {
            "total_memories_tracked": len(self.usage_counts),
            "total_outcomes_recorded": total_outcomes,
            "outcome_distribution": outcome_counts,
            "average_usage_count": (
                sum(self.usage_counts.values()) / len(self.usage_counts)
                if self.usage_counts else 0
            ),
            "current_weights": self.weights,
            "prune_threshold": self.PRUNE_THRESHOLD,
            "archive_threshold": self.ARCHIVE_THRESHOLD
        }

    def export_state(self) -> Dict[str, Any]:
        """Export scorer state for persistence."""
        return {
            "weights": self.weights,
            "usage_counts": self.usage_counts,
            "last_used": self.last_used,
            "outcome_history": {
                k: [
                    {
                        "memory_id": o.memory_id,
                        "outcome_type": o.outcome_type.value,
                        "context": o.context,
                        "timestamp": o.timestamp
                    }
                    for o in v
                ]
                for k, v in self.outcome_history.items()
            }
        }

    def import_state(self, state: Dict[str, Any]):
        """Import scorer state from persistence."""
        self.weights = state.get("weights", self.DEFAULT_WEIGHTS.copy())
        self.usage_counts = state.get("usage_counts", {})
        self.last_used = state.get("last_used", {})

        # Reconstruct outcome history
        self.outcome_history = {}
        for mem_id, outcomes in state.get("outcome_history", {}).items():
            self.outcome_history[mem_id] = [
                MemoryOutcome(
                    memory_id=o["memory_id"],
                    outcome_type=OutcomeType(o["outcome_type"]),
                    context=o["context"],
                    timestamp=o["timestamp"]
                )
                for o in outcomes
            ]


# Singleton instance
_scorer: Optional[VoIScorer] = None


def get_voi_scorer() -> VoIScorer:
    """Get or create the VoI scorer singleton."""
    global _scorer
    if _scorer is None:
        _scorer = VoIScorer()
    return _scorer


# ============================================================================
# CLI Interface
# ============================================================================

if __name__ == "__main__":
    import sys

    scorer = get_voi_scorer()

    print("=" * 60)
    print("GENESIS VoI SCORING ENGINE")
    print("=" * 60)

    # Demo: Calculate VoI for sample memories
    sample_memories = [
        {
            "id": "mem_001",
            "content": "Important discovery: GHL has official MCP server",
            "created_at": datetime.now(timezone.utc).isoformat(),
            "importance": 0.9
        },
        {
            "id": "mem_002",
            "content": "Routine task completed",
            "created_at": (datetime.now(timezone.utc) - timedelta(days=7)).isoformat(),
            "importance": 0.3
        },
        {
            "id": "mem_003",
            "content": "Evolution protocol tested successfully",
            "created_at": (datetime.now(timezone.utc) - timedelta(hours=2)).isoformat(),
            "importance": 0.7
        }
    ]

    print("\n## Sample Memory VoI Scores")
    print("-" * 60)

    # Record some outcomes first
    scorer.record_outcome("mem_001", OutcomeType.SUCCESS, "Led to MCP integration")
    scorer.record_outcome("mem_003", OutcomeType.SUCCESS, "Protocol working")
    scorer.record_outcome("mem_002", OutcomeType.IRRELEVANT, "Not useful")

    scores = scorer.score_batch(sample_memories)

    for score in scores:
        print(f"\nMemory: {score.memory_id}")
        print(f"  Total VoI: {score.total_score:.4f}")
        print(f"  Recency:   {score.recency_score:.4f}")
        print(f"  Relevance: {score.relevance_score:.4f}")
        print(f"  Importance:{score.importance_score:.4f}")
        print(f"  Outcome:   {score.outcome_score:.4f}")
        print(f"  Usage:     {score.usage_count} times")

    # Demo: Prune recommendations
    print("\n## Prune Recommendations")
    print("-" * 60)
    delete_ids, archive_ids = scorer.get_prune_candidates(sample_memories)
    print(f"Delete: {delete_ids}")
    print(f"Archive: {archive_ids}")

    # Demo: Statistics
    print("\n## Statistics")
    print("-" * 60)
    stats = scorer.get_statistics()
    print(json.dumps(stats, indent=2))

    print("\n" + "=" * 60)
    print("VoI SCORING ENGINE READY")
    print("=" * 60)
