"""
AIVA Memory Gate - Unified Memory Access Layer
================================================

Single entry point for all memory operations across the 3-tier architecture:
  1. Working Memory (Redis)   -- fast session state, <5ms
  2. Episodic Memory (PostgreSQL) -- audit trail, decisions, 20-50ms
  3. Semantic Memory (Qdrant)  -- vector similarity over past knowledge, 50-100ms

Provides:
  - query_memory(query_text, memory_tier) -- unified query interface
  - store_memory(content, tier) -- unified storage interface
  - promote(key) -- tier promotion: working -> episodic -> semantic
  - Graceful degradation when backends are unavailable
  - Health checking per backend

VERIFICATION_STAMP
Story: AIVA-MEMGATE-001
Verified By: Claude Opus 4.6
Verified At: 2026-02-10
Component: Memory Gate (unified access layer)

NO SQLITE. All storage uses Elestio PostgreSQL/Qdrant/Redis.
"""

import sys
import json
import time
import logging
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
from enum import Enum

# Elestio config path
sys.path.append('/mnt/e/genesis-system/data/genesis-memory')

logger = logging.getLogger("AIVA.MemoryGate")


class MemoryTier(Enum):
    """Memory tier identifiers."""
    WORKING = "working"       # Redis - fast, ephemeral
    EPISODIC = "episodic"     # PostgreSQL - durable, structured
    SEMANTIC = "semantic"     # Qdrant - vector similarity
    ALL = "all"               # Query all tiers


class BackendStatus(Enum):
    """Health status of a memory backend."""
    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNAVAILABLE = "unavailable"


@dataclass
class MemoryResult:
    """Structured result from a memory query."""
    tier: str
    items: List[Dict]
    latency_ms: float
    status: str  # "ok", "degraded", "error"
    error: Optional[str] = None


@dataclass
class GateStatus:
    """Overall gate health status."""
    working: BackendStatus = BackendStatus.UNAVAILABLE
    episodic: BackendStatus = BackendStatus.UNAVAILABLE
    semantic: BackendStatus = BackendStatus.UNAVAILABLE
    last_check: Optional[float] = None


class MemoryGate:
    """
    Unified memory access gateway for AIVA.

    Routes queries to the appropriate tier(s), handles fallbacks,
    and provides tier promotion (working -> episodic -> semantic).

    Usage:
        gate = MemoryGate()
        results = gate.query_memory("past decisions about pricing", MemoryTier.ALL)
        gate.store_memory({"decision": "...", "outcome": "success"}, MemoryTier.WORKING)
        gate.promote("decision_key_123", from_tier=MemoryTier.WORKING, to_tier=MemoryTier.EPISODIC)
    """

    def __init__(self):
        """Initialize the memory gate with lazy backend connections."""
        self._working = None
        self._episodic = None
        self._semantic = None
        self._status = GateStatus()
        self._init_backends()

    def _init_backends(self):
        """
        Initialize backend connections with graceful fallback.
        Each backend is optional -- if one fails, the others still work.
        """
        # Working Memory (Redis)
        try:
            from .working_memory import WorkingMemory
            self._working = WorkingMemory()
            self._status.working = BackendStatus.HEALTHY
            logger.info("Memory Gate: Redis (working memory) connected")
        except Exception as e:
            self._status.working = BackendStatus.UNAVAILABLE
            logger.warning(f"Memory Gate: Redis unavailable - {e}")

        # Episodic Memory (PostgreSQL)
        try:
            from .episodic_memory import EpisodicMemory
            self._episodic = EpisodicMemory()
            self._status.episodic = BackendStatus.HEALTHY
            logger.info("Memory Gate: PostgreSQL (episodic memory) connected")
        except Exception as e:
            self._status.episodic = BackendStatus.UNAVAILABLE
            logger.warning(f"Memory Gate: PostgreSQL unavailable - {e}")

        # Semantic Memory (Qdrant)
        try:
            from .semantic_memory import SemanticMemory
            self._semantic = SemanticMemory()
            self._status.semantic = BackendStatus.HEALTHY
            logger.info("Memory Gate: Qdrant (semantic memory) connected")
        except Exception as e:
            self._status.semantic = BackendStatus.UNAVAILABLE
            logger.warning(f"Memory Gate: Qdrant unavailable - {e}")

        self._status.last_check = time.time()

    # =========================================================================
    # QUERY INTERFACE
    # =========================================================================

    def query_memory(
        self,
        query_text: str,
        memory_tier: MemoryTier = MemoryTier.ALL,
        limit: int = 10,
        embedding: Optional[List[float]] = None
    ) -> Dict[str, MemoryResult]:
        """
        Query memory across specified tier(s).

        Follows the Recall Pulse hierarchy from MEMORY_ARCHITECTURE.md:
          1. Redis first (<5ms)
          2. Qdrant second (50-100ms)
          3. PostgreSQL third (20-50ms)

        Args:
            query_text: Natural language query
            memory_tier: Which tier(s) to search (default ALL)
            limit: Max results per tier
            embedding: Optional pre-computed embedding for semantic search

        Returns:
            Dict mapping tier names to MemoryResult objects
        """
        results = {}

        if memory_tier in (MemoryTier.WORKING, MemoryTier.ALL):
            results["working"] = self._query_working(query_text, limit)

        if memory_tier in (MemoryTier.SEMANTIC, MemoryTier.ALL):
            results["semantic"] = self._query_semantic(query_text, limit, embedding)

        if memory_tier in (MemoryTier.EPISODIC, MemoryTier.ALL):
            results["episodic"] = self._query_episodic(query_text, limit)

        return results

    def _query_working(self, query_text: str, limit: int) -> MemoryResult:
        """Query Redis working memory."""
        start = time.time()
        if not self._working or self._status.working == BackendStatus.UNAVAILABLE:
            return MemoryResult(
                tier="working", items=[], latency_ms=0,
                status="error", error="Redis unavailable"
            )
        try:
            items = self._working.search(query_text, limit=limit)
            latency = (time.time() - start) * 1000
            return MemoryResult(
                tier="working", items=items, latency_ms=round(latency, 2),
                status="ok"
            )
        except Exception as e:
            latency = (time.time() - start) * 1000
            self._status.working = BackendStatus.DEGRADED
            logger.error(f"Working memory query failed: {e}")
            return MemoryResult(
                tier="working", items=[], latency_ms=round(latency, 2),
                status="error", error=str(e)
            )

    def _query_episodic(self, query_text: str, limit: int) -> MemoryResult:
        """Query PostgreSQL episodic memory."""
        start = time.time()
        if not self._episodic or self._status.episodic == BackendStatus.UNAVAILABLE:
            return MemoryResult(
                tier="episodic", items=[], latency_ms=0,
                status="error", error="PostgreSQL unavailable"
            )
        try:
            items = self._episodic.search_by_content(query=query_text, limit=limit)
            latency = (time.time() - start) * 1000
            return MemoryResult(
                tier="episodic", items=items, latency_ms=round(latency, 2),
                status="ok"
            )
        except Exception as e:
            latency = (time.time() - start) * 1000
            self._status.episodic = BackendStatus.DEGRADED
            logger.error(f"Episodic memory query failed: {e}")
            return MemoryResult(
                tier="episodic", items=[], latency_ms=round(latency, 2),
                status="error", error=str(e)
            )

    def _query_semantic(
        self, query_text: str, limit: int,
        embedding: Optional[List[float]] = None
    ) -> MemoryResult:
        """Query Qdrant semantic memory."""
        start = time.time()
        if not self._semantic or self._status.semantic == BackendStatus.UNAVAILABLE:
            return MemoryResult(
                tier="semantic", items=[], latency_ms=0,
                status="error", error="Qdrant unavailable"
            )
        if not embedding:
            # Cannot do vector search without an embedding
            return MemoryResult(
                tier="semantic", items=[], latency_ms=0,
                status="degraded", error="No embedding provided for semantic search"
            )
        try:
            items = self._semantic.retrieve_similar(
                query_embedding=embedding,
                limit=limit,
                score_threshold=0.5
            )
            latency = (time.time() - start) * 1000
            return MemoryResult(
                tier="semantic", items=items, latency_ms=round(latency, 2),
                status="ok"
            )
        except Exception as e:
            latency = (time.time() - start) * 1000
            self._status.semantic = BackendStatus.DEGRADED
            logger.error(f"Semantic memory query failed: {e}")
            return MemoryResult(
                tier="semantic", items=[], latency_ms=round(latency, 2),
                status="error", error=str(e)
            )

    # =========================================================================
    # STORE INTERFACE
    # =========================================================================

    def store_memory(
        self,
        content: Any,
        tier: MemoryTier = MemoryTier.WORKING,
        key: Optional[str] = None,
        event_type: str = "general",
        session_id: Optional[str] = None,
        importance: float = 0.5,
        embedding: Optional[List[float]] = None,
        metadata: Optional[Dict] = None
    ) -> Dict[str, Any]:
        """
        Store content in the specified memory tier.

        Args:
            content: Data to store (dict or string)
            tier: Target tier
            key: Key for working memory (auto-generated if not provided)
            event_type: Event type for episodic memory
            session_id: Session identifier
            importance: Importance score 0.0-1.0
            embedding: Pre-computed embedding for semantic tier
            metadata: Optional metadata

        Returns:
            Dict with storage result info
        """
        result = {"tier": tier.value, "stored": False, "id": None, "error": None}

        if tier == MemoryTier.WORKING:
            if not self._working:
                result["error"] = "Redis unavailable"
                return result
            try:
                auto_key = key or f"{event_type}_{int(time.time() * 1000)}"
                self._working.add(auto_key, content, metadata=metadata)
                result["stored"] = True
                result["id"] = auto_key
            except Exception as e:
                result["error"] = str(e)

        elif tier == MemoryTier.EPISODIC:
            if not self._episodic:
                result["error"] = "PostgreSQL unavailable"
                return result
            try:
                content_dict = content if isinstance(content, dict) else {"data": content}
                episode_id = self._episodic.store_episode(
                    event_type=event_type,
                    content=content_dict,
                    session_id=session_id,
                    importance=importance
                )
                result["stored"] = True
                result["id"] = episode_id
            except Exception as e:
                result["error"] = str(e)

        elif tier == MemoryTier.SEMANTIC:
            if not self._semantic:
                result["error"] = "Qdrant unavailable"
                return result
            if not embedding:
                result["error"] = "Embedding required for semantic storage"
                return result
            try:
                content_str = content if isinstance(content, str) else json.dumps(content)
                point_id = self._semantic.store(
                    content=content_str,
                    embedding=embedding,
                    metadata=metadata,
                    knowledge_type=event_type
                )
                result["stored"] = True
                result["id"] = point_id
            except Exception as e:
                result["error"] = str(e)

        return result

    # =========================================================================
    # TIER PROMOTION
    # =========================================================================

    def promote(
        self,
        key: str,
        from_tier: MemoryTier = MemoryTier.WORKING,
        to_tier: MemoryTier = MemoryTier.EPISODIC,
        importance: float = 0.7,
        event_type: str = "promoted",
        session_id: Optional[str] = None,
        embedding: Optional[List[float]] = None
    ) -> Dict[str, Any]:
        """
        Promote a memory item from one tier to a higher tier.

        Promotion chain: WORKING -> EPISODIC -> SEMANTIC

        Args:
            key: Key/ID of the item in the source tier
            from_tier: Source tier
            to_tier: Destination tier
            importance: Importance score for the promoted item
            event_type: Event type label
            session_id: Session ID
            embedding: Required when promoting to SEMANTIC

        Returns:
            Dict with promotion result
        """
        result = {
            "from_tier": from_tier.value,
            "to_tier": to_tier.value,
            "promoted": False,
            "source_id": key,
            "dest_id": None,
            "error": None
        }

        # Retrieve from source tier
        content = None

        if from_tier == MemoryTier.WORKING:
            if not self._working:
                result["error"] = "Redis unavailable"
                return result
            content = self._working.get(key)
            if content is None:
                result["error"] = f"Key '{key}' not found in working memory"
                return result

        elif from_tier == MemoryTier.EPISODIC:
            if not self._episodic:
                result["error"] = "PostgreSQL unavailable"
                return result
            episode = self._episodic.recall(key)
            if episode is None:
                result["error"] = f"Episode '{key}' not found in episodic memory"
                return result
            content = episode.get("content", {})

        # Store in destination tier
        store_result = self.store_memory(
            content=content,
            tier=to_tier,
            event_type=event_type,
            session_id=session_id,
            importance=importance,
            embedding=embedding,
            metadata={"promoted_from": from_tier.value, "source_id": key}
        )

        if store_result.get("stored"):
            result["promoted"] = True
            result["dest_id"] = store_result["id"]

            # Optionally clean up source (only from working -> episodic)
            if from_tier == MemoryTier.WORKING and self._working:
                self._working.remove(key)
        else:
            result["error"] = store_result.get("error")

        return result

    # =========================================================================
    # HEALTH & STATUS
    # =========================================================================

    def health_check(self) -> GateStatus:
        """
        Check health of all backends and update status.

        Returns:
            GateStatus with per-backend health
        """
        # Check Redis
        if self._working:
            try:
                self._working.redis_client.ping()
                self._status.working = BackendStatus.HEALTHY
            except Exception:
                self._status.working = BackendStatus.UNAVAILABLE
        else:
            self._status.working = BackendStatus.UNAVAILABLE

        # Check PostgreSQL
        if self._episodic and self._episodic.conn and not self._episodic.conn.closed:
            try:
                with self._episodic.conn.cursor() as cur:
                    cur.execute("SELECT 1")
                self._status.episodic = BackendStatus.HEALTHY
            except Exception:
                self._status.episodic = BackendStatus.UNAVAILABLE
        else:
            self._status.episodic = BackendStatus.UNAVAILABLE

        # Check Qdrant
        if self._semantic:
            try:
                self._semantic.client.get_collections()
                self._status.semantic = BackendStatus.HEALTHY
            except Exception:
                self._status.semantic = BackendStatus.UNAVAILABLE
        else:
            self._status.semantic = BackendStatus.UNAVAILABLE

        self._status.last_check = time.time()
        return self._status

    def get_status_dict(self) -> Dict[str, Any]:
        """Get gate status as a plain dict."""
        status = self.health_check()
        return {
            "working": status.working.value,
            "episodic": status.episodic.value,
            "semantic": status.semantic.value,
            "last_check": status.last_check,
            "available_tiers": self.available_tiers()
        }

    def available_tiers(self) -> List[str]:
        """Return list of currently available tier names."""
        tiers = []
        if self._status.working == BackendStatus.HEALTHY:
            tiers.append("working")
        if self._status.episodic == BackendStatus.HEALTHY:
            tiers.append("episodic")
        if self._status.semantic == BackendStatus.HEALTHY:
            tiers.append("semantic")
        return tiers

    def is_degraded(self) -> bool:
        """Check if any backend is unavailable (degraded operation)."""
        return any([
            self._status.working != BackendStatus.HEALTHY,
            self._status.episodic != BackendStatus.HEALTHY,
            self._status.semantic != BackendStatus.HEALTHY,
        ])

    # =========================================================================
    # CLEANUP
    # =========================================================================

    def close(self):
        """Close all backend connections."""
        if self._episodic:
            try:
                self._episodic.close()
            except Exception:
                pass
        logger.info("Memory Gate closed")


# Module-level singleton accessor
_gate_instance: Optional[MemoryGate] = None


def get_memory_gate() -> MemoryGate:
    """
    Get or create the singleton MemoryGate instance.

    Returns:
        MemoryGate instance
    """
    global _gate_instance
    if _gate_instance is None:
        _gate_instance = MemoryGate()
    return _gate_instance
