"""
AIVA Decision Context Builder
===============================

Before any AIVA decision, this module queries all 3 memory tiers for
relevant context, ranks results by relevance and recency, and returns
a structured DecisionContext object the decision engine can consume.

Implements the "3-layer validation" pattern from MEMORY_ARCHITECTURE.md:
  Layer 1: Working Memory (Redis)   -- what is happening NOW
  Layer 2: Semantic Memory (Qdrant) -- what SIMILAR things happened before
  Layer 3: Episodic Memory (PG)     -- what was the OUTCOME last time

VERIFICATION_STAMP
Story: AIVA-MEMGATE-002
Verified By: Claude Opus 4.6
Verified At: 2026-02-10
Component: Decision Context Builder

NO SQLITE. All storage uses Elestio PostgreSQL/Qdrant/Redis.
"""

import time
import logging
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
from datetime import datetime, timedelta

from .memory_gate import MemoryGate, MemoryTier, MemoryResult, get_memory_gate

logger = logging.getLogger("AIVA.DecisionContext")


@dataclass
class ContextItem:
    """A single piece of context retrieved from memory."""
    source_tier: str            # "working", "episodic", "semantic"
    content: Any                # The actual content
    relevance_score: float      # 0.0-1.0 composite ranking score
    recency_score: float        # 0.0-1.0 how recent this item is
    importance: float           # 0.0-1.0 stored importance
    created_at: Optional[str] = None
    metadata: Dict = field(default_factory=dict)


@dataclass
class DecisionContext:
    """
    Structured context package for AIVA's decision engine.

    Contains ranked context from all available memory tiers,
    plus metadata about retrieval quality and timing.
    """
    # Context items ranked by composite score
    items: List[ContextItem] = field(default_factory=list)

    # Retrieval metadata
    query_text: str = ""
    total_items: int = 0
    tiers_queried: List[str] = field(default_factory=list)
    tiers_available: List[str] = field(default_factory=list)
    tiers_failed: List[str] = field(default_factory=list)
    total_latency_ms: float = 0.0
    is_degraded: bool = False

    # Validation layers that passed
    layer1_working: bool = False    # Did we get working memory context?
    layer2_semantic: bool = False   # Did we get semantic similarity context?
    layer3_episodic: bool = False   # Did we get historical outcome context?

    def has_context(self) -> bool:
        """Check if any context was retrieved."""
        return len(self.items) > 0

    def top_items(self, n: int = 5) -> List[ContextItem]:
        """Get the top N most relevant context items."""
        return self.items[:n]

    def format_for_prompt(self, max_tokens: int = 4000) -> str:
        """
        Format context for injection into an LLM prompt.

        Args:
            max_tokens: Approximate token budget (4 chars per token)

        Returns:
            Formatted context string
        """
        if not self.items:
            return ""

        parts = [
            "## AIVA Memory Context",
            f"Query: {self.query_text}",
            f"Tiers: {', '.join(self.tiers_queried)} | "
            f"Items: {self.total_items} | "
            f"Latency: {self.total_latency_ms:.0f}ms",
            ""
        ]

        # Validation status
        validation = []
        if self.layer1_working:
            validation.append("L1-Working: ACTIVE")
        if self.layer2_semantic:
            validation.append("L2-Semantic: ACTIVE")
        if self.layer3_episodic:
            validation.append("L3-Episodic: ACTIVE")
        if validation:
            parts.append("Validation: " + " | ".join(validation))
            parts.append("")

        token_estimate = self._estimate_tokens("\n".join(parts))

        for i, item in enumerate(self.items):
            item_text = self._format_item(i + 1, item)
            item_tokens = self._estimate_tokens(item_text)

            if token_estimate + item_tokens > max_tokens:
                parts.append(f"\n[Truncated: {len(self.items) - i} more items]")
                break

            parts.append(item_text)
            token_estimate += item_tokens

        return "\n".join(parts)

    def _format_item(self, index: int, item: ContextItem) -> str:
        """Format a single context item."""
        content_str = str(item.content)
        if len(content_str) > 300:
            content_str = content_str[:300] + "..."

        return (
            f"### [{index}] {item.source_tier.upper()} "
            f"(relevance={item.relevance_score:.2f})\n"
            f"{content_str}\n"
        )

    @staticmethod
    def _estimate_tokens(text: str) -> int:
        """Estimate token count (~4 chars per token)."""
        return len(text) // 4

    def to_dict(self) -> Dict[str, Any]:
        """Serialize to dictionary."""
        return {
            "query_text": self.query_text,
            "total_items": self.total_items,
            "tiers_queried": self.tiers_queried,
            "tiers_available": self.tiers_available,
            "tiers_failed": self.tiers_failed,
            "total_latency_ms": self.total_latency_ms,
            "is_degraded": self.is_degraded,
            "layer1_working": self.layer1_working,
            "layer2_semantic": self.layer2_semantic,
            "layer3_episodic": self.layer3_episodic,
            "items": [
                {
                    "source_tier": item.source_tier,
                    "content": item.content,
                    "relevance_score": item.relevance_score,
                    "recency_score": item.recency_score,
                    "importance": item.importance,
                    "created_at": item.created_at,
                    "metadata": item.metadata,
                }
                for item in self.items
            ]
        }


class DecisionContextBuilder:
    """
    Builds decision context by querying all 3 memory tiers.

    Ranking formula (aligned with context_injector.py):
      - Recency:    40% (newer is better)
      - Relevance:  40% (keyword overlap with query)
      - Importance:  20% (stored importance / surprise score)

    Usage:
        builder = DecisionContextBuilder()
        ctx = builder.build_context("Should we use Telnyx or Twilio?")
        prompt = f"{ctx.format_for_prompt()}\n\nDecide: ..."
    """

    # Ranking weights
    WEIGHT_RECENCY = 0.4
    WEIGHT_RELEVANCE = 0.4
    WEIGHT_IMPORTANCE = 0.2

    # Max age for recency scoring (days)
    MAX_AGE_DAYS = 90

    def __init__(self, gate: Optional[MemoryGate] = None):
        """
        Initialize the decision context builder.

        Args:
            gate: Optional MemoryGate instance (uses singleton if not provided)
        """
        self.gate = gate or get_memory_gate()

    def build_context(
        self,
        query_text: str,
        limit_per_tier: int = 10,
        embedding: Optional[List[float]] = None,
        max_total_items: int = 20
    ) -> DecisionContext:
        """
        Build a full decision context by querying all available tiers.

        Implements the 3-layer validation:
          L1: Check working memory for active/recent context
          L2: Check semantic memory for similar past situations
          L3: Check episodic memory for historical outcomes

        Args:
            query_text: The decision question or task description
            limit_per_tier: Max results to fetch per tier
            embedding: Pre-computed embedding for semantic search
            max_total_items: Max items in the final ranked result

        Returns:
            DecisionContext with ranked, merged context
        """
        start_time = time.time()

        ctx = DecisionContext(query_text=query_text)
        ctx.tiers_available = self.gate.available_tiers()
        ctx.is_degraded = self.gate.is_degraded()

        # Query all tiers through the gate
        raw_results = self.gate.query_memory(
            query_text=query_text,
            memory_tier=MemoryTier.ALL,
            limit=limit_per_tier,
            embedding=embedding
        )

        total_latency = 0.0
        all_items: List[ContextItem] = []

        # Layer 1: Working Memory (Redis)
        if "working" in raw_results:
            working_result = raw_results["working"]
            total_latency += working_result.latency_ms
            ctx.tiers_queried.append("working")

            if working_result.status == "ok" and working_result.items:
                ctx.layer1_working = True
                for item in working_result.items:
                    all_items.append(self._to_context_item(
                        item, "working", query_text
                    ))
            elif working_result.status == "error":
                ctx.tiers_failed.append("working")

        # Layer 2: Semantic Memory (Qdrant)
        if "semantic" in raw_results:
            semantic_result = raw_results["semantic"]
            total_latency += semantic_result.latency_ms
            ctx.tiers_queried.append("semantic")

            if semantic_result.status == "ok" and semantic_result.items:
                ctx.layer2_semantic = True
                for item in semantic_result.items:
                    all_items.append(self._to_context_item(
                        item, "semantic", query_text
                    ))
            elif semantic_result.status == "error":
                ctx.tiers_failed.append("semantic")

        # Layer 3: Episodic Memory (PostgreSQL)
        if "episodic" in raw_results:
            episodic_result = raw_results["episodic"]
            total_latency += episodic_result.latency_ms
            ctx.tiers_queried.append("episodic")

            if episodic_result.status == "ok" and episodic_result.items:
                ctx.layer3_episodic = True
                for item in episodic_result.items:
                    all_items.append(self._to_context_item(
                        item, "episodic", query_text
                    ))
            elif episodic_result.status == "error":
                ctx.tiers_failed.append("episodic")

        # Rank all items by composite score
        all_items.sort(key=lambda x: x.relevance_score, reverse=True)

        ctx.items = all_items[:max_total_items]
        ctx.total_items = len(ctx.items)
        ctx.total_latency_ms = round(total_latency, 2)

        elapsed = (time.time() - start_time) * 1000
        logger.info(
            f"Decision context built: {ctx.total_items} items from "
            f"{len(ctx.tiers_queried)} tiers in {elapsed:.0f}ms"
        )

        return ctx

    def _to_context_item(
        self,
        raw_item: Dict,
        source_tier: str,
        query_text: str
    ) -> ContextItem:
        """
        Convert a raw memory result into a scored ContextItem.

        Args:
            raw_item: Raw dict from a memory tier query
            source_tier: "working", "episodic", or "semantic"
            query_text: The original query for relevance scoring

        Returns:
            ContextItem with computed scores
        """
        # Extract content based on tier format
        content = raw_item.get("value") or raw_item.get("content") or raw_item
        metadata = raw_item.get("metadata", {})
        created_at = raw_item.get("created_at")

        # Compute recency score
        recency_score = self._compute_recency(created_at)

        # Compute relevance score
        if source_tier == "semantic" and "score" in raw_item:
            # Qdrant already gives us a similarity score
            relevance_score = float(raw_item["score"])
        else:
            relevance_score = self._compute_keyword_relevance(
                query_text, self._content_to_string(content)
            )

        # Get stored importance
        importance = float(
            raw_item.get("importance")
            or metadata.get("surprise_score")
            or 0.5
        )

        # Composite score
        composite = (
            recency_score * self.WEIGHT_RECENCY +
            relevance_score * self.WEIGHT_RELEVANCE +
            importance * self.WEIGHT_IMPORTANCE
        )

        return ContextItem(
            source_tier=source_tier,
            content=content,
            relevance_score=round(composite, 4),
            recency_score=round(recency_score, 4),
            importance=round(importance, 4),
            created_at=created_at,
            metadata=metadata
        )

    def _compute_recency(self, created_at: Optional[str]) -> float:
        """
        Compute recency score. Newer = higher score.

        Args:
            created_at: ISO timestamp string or epoch float

        Returns:
            Score 0.0-1.0
        """
        if not created_at:
            return 0.5  # Unknown age gets baseline

        try:
            if isinstance(created_at, (int, float)):
                created = datetime.fromtimestamp(created_at)
            else:
                created = datetime.fromisoformat(str(created_at))

            age_days = (datetime.now() - created).total_seconds() / 86400
            score = max(0.0, 1.0 - (age_days / self.MAX_AGE_DAYS))
            return score
        except (ValueError, TypeError, OSError):
            return 0.5

    def _compute_keyword_relevance(self, query: str, text: str) -> float:
        """
        Compute keyword-based relevance (Jaccard similarity).

        Args:
            query: Query string
            text: Content string to compare

        Returns:
            Score 0.0-1.0
        """
        if not query or not text:
            return 0.0

        stop_words = {
            "the", "a", "an", "and", "or", "but", "in", "on", "at",
            "to", "for", "is", "it", "of", "with", "as", "by", "this"
        }

        query_words = set(query.lower().split()) - stop_words
        text_words = set(text.lower().split()) - stop_words

        if not query_words:
            return 0.0

        intersection = len(query_words & text_words)
        union = len(query_words | text_words)

        return intersection / union if union > 0 else 0.0

    @staticmethod
    def _content_to_string(content: Any) -> str:
        """Convert arbitrary content to a string for keyword matching."""
        if isinstance(content, str):
            return content
        if isinstance(content, dict):
            # Flatten dict values to text
            parts = []
            for v in content.values():
                parts.append(str(v))
            return " ".join(parts)
        return str(content)


# Module-level convenience functions

def build_decision_context(
    query_text: str,
    embedding: Optional[List[float]] = None,
    limit: int = 10
) -> DecisionContext:
    """
    Convenience function to build decision context.

    Args:
        query_text: Decision question or task description
        embedding: Optional embedding for semantic search
        limit: Max results per tier

    Returns:
        DecisionContext ready for use
    """
    builder = DecisionContextBuilder()
    return builder.build_context(
        query_text=query_text,
        embedding=embedding,
        limit_per_tier=limit
    )


def get_context_for_prompt(
    query_text: str,
    embedding: Optional[List[float]] = None,
    max_tokens: int = 4000
) -> str:
    """
    Convenience function to get formatted context string for prompt injection.

    Args:
        query_text: Decision question
        embedding: Optional embedding
        max_tokens: Token budget for context

    Returns:
        Formatted context string (empty string if no context available)
    """
    ctx = build_decision_context(query_text, embedding=embedding)
    return ctx.format_for_prompt(max_tokens=max_tokens)
