#!/usr/bin/env python3
"""
GENESIS MEMORY INTEGRATION LAYER
=================================
Unified interface to all memory tiers with intelligent routing.

Memory Tiers:
    - Working Memory: Fast, transient (Redis/in-memory)
    - Episodic Memory: Task history, outcomes (SQLite/PostgreSQL)
    - Semantic Memory: Knowledge, patterns (Vector DB)
    - Procedural Memory: Skills, how-to (Skill files)

Usage:
    memory = MemoryIntegration()
    memory.store("key", data, tier="auto")
    result = memory.recall("query", context=context)
"""

"""
RULE 7 COMPLIANT: Uses Elestio PostgreSQL via genesis_db module.
"""
import json
import hashlib
import threading
import time
import logging
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple, Union
from enum import Enum

# RULE 7: Use PostgreSQL via genesis_db (no sqlite3)
from core.genesis_db import connection, ensure_table

logger = logging.getLogger(__name__)


class MemoryTier(Enum):
    """Memory tier types."""
    WORKING = "working"      # Fast, volatile
    EPISODIC = "episodic"    # Event history
    SEMANTIC = "semantic"    # Knowledge/facts
    PROCEDURAL = "procedural" # Skills/how-to


@dataclass
class MemoryEntry:
    """A memory entry with metadata."""
    key: str
    value: Any
    tier: MemoryTier
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    accessed_at: str = field(default_factory=lambda: datetime.now().isoformat())
    access_count: int = 0
    importance: float = 0.5
    decay_rate: float = 0.1
    associations: List[str] = field(default_factory=list)

    def to_dict(self) -> Dict:
        return {
            "key": self.key,
            "value": self.value,
            "tier": self.tier.value,
            "created_at": self.created_at,
            "accessed_at": self.accessed_at,
            "access_count": self.access_count,
            "importance": self.importance,
            "decay_rate": self.decay_rate,
            "associations": self.associations
        }


@dataclass
class RecallResult:
    """Result of a memory recall operation."""
    entries: List[MemoryEntry]
    relevance_scores: List[float]
    source_tiers: List[MemoryTier]
    total_time: float


class WorkingMemory:
    """
    Fast, volatile working memory.
    Uses LRU cache with TTL for transient data.
    """

    def __init__(self, max_size: int = 1000, default_ttl: int = 3600):
        self.max_size = max_size
        self.default_ttl = default_ttl
        self._cache: OrderedDict[str, Tuple[Any, float]] = OrderedDict()
        self._lock = threading.RLock()

    def store(self, key: str, value: Any, ttl: int = None) -> bool:
        """Store value in working memory."""
        with self._lock:
            expiry = time.time() + (ttl or self.default_ttl)

            if key in self._cache:
                del self._cache[key]

            self._cache[key] = (value, expiry)

            # Enforce size limit
            while len(self._cache) > self.max_size:
                self._cache.popitem(last=False)

            return True

    def recall(self, key: str) -> Optional[Any]:
        """Recall value from working memory."""
        with self._lock:
            if key not in self._cache:
                return None

            value, expiry = self._cache[key]

            if time.time() > expiry:
                del self._cache[key]
                return None

            # Move to end (LRU)
            self._cache.move_to_end(key)
            return value

    def exists(self, key: str) -> bool:
        """Check if key exists and is not expired."""
        return self.recall(key) is not None

    def delete(self, key: str) -> bool:
        """Delete from working memory."""
        with self._lock:
            if key in self._cache:
                del self._cache[key]
                return True
            return False

    def clear_expired(self):
        """Remove expired entries."""
        with self._lock:
            now = time.time()
            expired = [k for k, (_, exp) in self._cache.items() if now > exp]
            for k in expired:
                del self._cache[k]

    def get_stats(self) -> Dict:
        """Get working memory statistics."""
        with self._lock:
            now = time.time()
            valid = sum(1 for _, (_, exp) in self._cache.items() if now <= exp)
            return {
                "total_entries": len(self._cache),
                "valid_entries": valid,
                "max_size": self.max_size
            }


class EpisodicMemory:
    """
    Episodic memory for task history and events.
    Uses PostgreSQL for persistence (RULE 7).
    """

    def __init__(self, db_path: Path = None):
        # RULE 7: db_path is ignored - uses PostgreSQL via genesis_db
        self._init_db()

    def _init_db(self):
        """Initialize database schema (RULE 7: PostgreSQL)."""
        ensure_table('episodes', '''
            id SERIAL PRIMARY KEY,
            key TEXT UNIQUE NOT NULL,
            value JSONB NOT NULL,
            created_at TIMESTAMPTZ NOT NULL,
            accessed_at TIMESTAMPTZ NOT NULL,
            access_count INTEGER DEFAULT 0,
            importance REAL DEFAULT 0.5,
            episode_type TEXT,
            context JSONB
        ''')
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("CREATE INDEX IF NOT EXISTS idx_episodes_key ON episodes(key)")
                cursor.execute("CREATE INDEX IF NOT EXISTS idx_episodes_type ON episodes(episode_type)")
                cursor.execute("CREATE INDEX IF NOT EXISTS idx_episodes_importance ON episodes(importance)")
        except Exception as e:
            logger.warning(f"Index creation warning: {e}")

    def store(
        self,
        key: str,
        value: Any,
        episode_type: str = "general",
        importance: float = 0.5,
        context: Dict = None
    ) -> bool:
        """Store episode (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                now = datetime.now().isoformat()
                cursor.execute("""
                    INSERT INTO episodes
                    (key, value, created_at, accessed_at, access_count, importance, episode_type, context)
                    VALUES (%s, %s, %s, %s, 0, %s, %s, %s)
                    ON CONFLICT (key) DO UPDATE SET
                        value = EXCLUDED.value,
                        accessed_at = EXCLUDED.accessed_at,
                        importance = EXCLUDED.importance,
                        episode_type = EXCLUDED.episode_type,
                        context = EXCLUDED.context
                """, (
                    key,
                    json.dumps(value),
                    now,
                    now,
                    importance,
                    episode_type,
                    json.dumps(context) if context else None
                ))
                return True
        except Exception as e:
            logger.warning(f"Failed to store episode: {e}")
            return False

    def recall(self, key: str) -> Optional[Any]:
        """Recall episode by key (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute(
                    "SELECT value FROM episodes WHERE key = %s", (key,)
                )
                row = cursor.fetchone()
                if row:
                    # Update access stats
                    cursor.execute("""
                        UPDATE episodes
                        SET accessed_at = %s, access_count = access_count + 1
                        WHERE key = %s
                    """, (datetime.now().isoformat(), key))
                    # Value is already JSONB, return directly if dict, else parse
                    value = row[0]
                    if isinstance(value, str):
                        return json.loads(value)
                    return value
                return None
        except Exception as e:
            logger.warning(f"Failed to recall episode: {e}")
            return None

    def search(
        self,
        episode_type: str = None,
        min_importance: float = None,
        since: str = None,
        limit: int = 100
    ) -> List[Dict]:
        """Search episodes with filters (RULE 7: PostgreSQL)."""
        conditions = []
        params = []

        if episode_type:
            conditions.append("episode_type = %s")
            params.append(episode_type)

        if min_importance is not None:
            conditions.append("importance >= %s")
            params.append(min_importance)

        if since:
            conditions.append("created_at >= %s")
            params.append(since)

        query = "SELECT key, value, created_at, importance, episode_type FROM episodes"
        if conditions:
            query += " WHERE " + " AND ".join(conditions)
        query += " ORDER BY importance DESC, created_at DESC LIMIT %s"
        params.append(limit)

        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute(query, params)
                results = []
                for row in cursor.fetchall():
                    value = row[1]
                    if isinstance(value, str):
                        value = json.loads(value)
                    results.append({
                        "key": row[0],
                        "value": value,
                        "created_at": row[2],
                        "importance": row[3],
                        "episode_type": row[4]
                    })
                return results
        except Exception as e:
            logger.warning(f"Failed to search episodes: {e}")
            return []

    def get_recent(self, limit: int = 50) -> List[Dict]:
        """Get most recent episodes."""
        return self.search(limit=limit)

    def get_stats(self) -> Dict:
        """Get episodic memory statistics (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("SELECT COUNT(*) FROM episodes")
                total = cursor.fetchone()[0]
                cursor.execute("""
                    SELECT episode_type, COUNT(*)
                    FROM episodes
                    GROUP BY episode_type
                """)
                by_type = cursor.fetchall()
                cursor.execute("SELECT AVG(importance) FROM episodes")
                avg_importance = cursor.fetchone()[0]

                return {
                    "total_episodes": total,
                    "by_type": dict(by_type),
                    "avg_importance": avg_importance or 0
                }
        except Exception as e:
            logger.warning(f"Failed to get episode stats: {e}")
            return {"total_episodes": 0, "by_type": {}, "avg_importance": 0}


class SemanticMemory:
    """
    Semantic memory for knowledge and patterns.
    Uses simple keyword-based retrieval (can be upgraded to vectors).
    """

    def __init__(self, storage_path: Path = None):
        self.storage_path = storage_path or Path(__file__).parent.parent / "data" / "semantic_memory.json"
        self.storage_path.parent.mkdir(parents=True, exist_ok=True)
        self._knowledge: Dict[str, Dict] = {}
        self._load()

    def _load(self):
        """Load semantic memory from disk."""
        if self.storage_path.exists():
            try:
                with open(self.storage_path, 'r') as f:
                    self._knowledge = json.load(f)
            except Exception:
                self._knowledge = {}

    def _save(self):
        """Save semantic memory to disk."""
        try:
            with open(self.storage_path, 'w') as f:
                json.dump(self._knowledge, f, indent=2)
        except Exception:
            pass

    def store(
        self,
        concept: str,
        knowledge: Any,
        category: str = "general",
        keywords: List[str] = None
    ) -> bool:
        """Store semantic knowledge."""
        key = hashlib.md5(f"{category}:{concept}".encode()).hexdigest()[:12]

        self._knowledge[key] = {
            "concept": concept,
            "knowledge": knowledge,
            "category": category,
            "keywords": keywords or self._extract_keywords(concept),
            "created_at": datetime.now().isoformat(),
            "access_count": 0
        }

        self._save()
        return True

    def _extract_keywords(self, text: str) -> List[str]:
        """Extract keywords from text."""
        # Simple keyword extraction
        words = text.lower().split()
        stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'to', 'of', 'and', 'or', 'in', 'on', 'at', 'for'}
        return [w for w in words if len(w) > 2 and w not in stop_words]

    def recall(self, query: str, limit: int = 10) -> List[Dict]:
        """Recall semantic knowledge by query."""
        query_keywords = set(self._extract_keywords(query))

        scored = []
        for key, entry in self._knowledge.items():
            entry_keywords = set(entry.get("keywords", []))
            overlap = len(query_keywords & entry_keywords)
            if overlap > 0:
                score = overlap / max(len(query_keywords), 1)
                scored.append((score, entry))

        # Sort by score
        scored.sort(key=lambda x: -x[0])

        return [
            {**entry, "relevance": score}
            for score, entry in scored[:limit]
        ]

    def get_by_category(self, category: str) -> List[Dict]:
        """Get all knowledge in a category."""
        return [
            entry for entry in self._knowledge.values()
            if entry.get("category") == category
        ]

    def get_stats(self) -> Dict:
        """Get semantic memory statistics."""
        categories = {}
        for entry in self._knowledge.values():
            cat = entry.get("category", "unknown")
            categories[cat] = categories.get(cat, 0) + 1

        return {
            "total_concepts": len(self._knowledge),
            "by_category": categories
        }


class ProceduralMemory:
    """
    Procedural memory for skills and how-to knowledge.
    Links to skill files and execution patterns.
    """

    def __init__(self, skills_path: Path = None):
        self.skills_path = skills_path or Path(__file__).parent.parent / "skills"
        self.skills_path.mkdir(parents=True, exist_ok=True)
        self._procedures: Dict[str, Dict] = {}
        self._load_skills()

    def _load_skills(self):
        """Load skill metadata."""
        for skill_file in self.skills_path.glob("*.py"):
            try:
                skill_name = skill_file.stem
                self._procedures[skill_name] = {
                    "name": skill_name,
                    "path": str(skill_file),
                    "type": "python_skill",
                    "loaded_at": datetime.now().isoformat()
                }
            except Exception:
                pass

    def store_procedure(
        self,
        name: str,
        steps: List[str],
        category: str = "general",
        examples: List[Dict] = None
    ) -> bool:
        """Store a procedure (how-to)."""
        self._procedures[name] = {
            "name": name,
            "steps": steps,
            "category": category,
            "examples": examples or [],
            "type": "procedure",
            "created_at": datetime.now().isoformat(),
            "success_count": 0,
            "failure_count": 0
        }
        return True

    def recall_procedure(self, name: str) -> Optional[Dict]:
        """Recall a procedure by name."""
        return self._procedures.get(name)

    def search_procedures(self, query: str) -> List[Dict]:
        """Search procedures by keyword."""
        query_lower = query.lower()
        results = []

        for name, proc in self._procedures.items():
            if query_lower in name.lower():
                results.append(proc)
            elif proc.get("steps"):
                if any(query_lower in step.lower() for step in proc["steps"]):
                    results.append(proc)

        return results

    def record_outcome(self, name: str, success: bool):
        """Record procedure execution outcome."""
        if name in self._procedures:
            if success:
                self._procedures[name]["success_count"] = self._procedures[name].get("success_count", 0) + 1
            else:
                self._procedures[name]["failure_count"] = self._procedures[name].get("failure_count", 0) + 1

    def get_stats(self) -> Dict:
        """Get procedural memory statistics."""
        by_type = {}
        for proc in self._procedures.values():
            ptype = proc.get("type", "unknown")
            by_type[ptype] = by_type.get(ptype, 0) + 1

        return {
            "total_procedures": len(self._procedures),
            "by_type": by_type
        }


class MemoryIntegration:
    """
    Unified memory interface integrating all tiers.
    Provides intelligent routing and cross-tier queries.
    """

    def __init__(self):
        self.working = WorkingMemory()
        self.episodic = EpisodicMemory()
        self.semantic = SemanticMemory()
        self.procedural = ProceduralMemory()

        # Memory consolidation settings
        self.consolidation_threshold = 3  # Access count to promote
        self.importance_threshold = 0.7   # Importance to promote

    def store(
        self,
        key: str,
        value: Any,
        tier: Union[MemoryTier, str] = "auto",
        **kwargs
    ) -> bool:
        """
        Store value in appropriate memory tier.

        Args:
            key: Memory key
            value: Value to store
            tier: Target tier or "auto" for intelligent routing
            **kwargs: Additional tier-specific parameters
        """
        if isinstance(tier, str):
            if tier == "auto":
                tier = self._determine_tier(key, value)
            else:
                tier = MemoryTier(tier)

        if tier == MemoryTier.WORKING:
            return self.working.store(key, value, kwargs.get("ttl"))
        elif tier == MemoryTier.EPISODIC:
            return self.episodic.store(
                key, value,
                episode_type=kwargs.get("episode_type", "general"),
                importance=kwargs.get("importance", 0.5),
                context=kwargs.get("context")
            )
        elif tier == MemoryTier.SEMANTIC:
            return self.semantic.store(
                key, value,
                category=kwargs.get("category", "general"),
                keywords=kwargs.get("keywords")
            )
        elif tier == MemoryTier.PROCEDURAL:
            return self.procedural.store_procedure(
                key,
                steps=value if isinstance(value, list) else [str(value)],
                category=kwargs.get("category", "general"),
                examples=kwargs.get("examples")
            )

        return False

    def _determine_tier(self, key: str, value: Any) -> MemoryTier:
        """Determine best tier for a value."""
        # Procedures
        if isinstance(value, list) and all(isinstance(s, str) for s in value):
            if any(step.startswith(("1.", "Step", "-")) for step in value[:3]):
                return MemoryTier.PROCEDURAL

        # Check for task/event indicators (episodic)
        if isinstance(value, dict):
            if any(k in value for k in ["task_id", "event", "outcome", "result", "error"]):
                return MemoryTier.EPISODIC
            if any(k in value for k in ["concept", "definition", "knowledge", "fact"]):
                return MemoryTier.SEMANTIC

        # Check key patterns
        if any(p in key.lower() for p in ["task:", "event:", "log:", "result:"]):
            return MemoryTier.EPISODIC
        if any(p in key.lower() for p in ["concept:", "knowledge:", "fact:"]):
            return MemoryTier.SEMANTIC
        if any(p in key.lower() for p in ["skill:", "procedure:", "howto:"]):
            return MemoryTier.PROCEDURAL

        # Default to working memory
        return MemoryTier.WORKING

    def recall(
        self,
        query: str,
        tiers: List[MemoryTier] = None,
        limit: int = 10
    ) -> RecallResult:
        """
        Recall from memory across tiers.

        Args:
            query: Search query or key
            tiers: Specific tiers to search (default: all)
            limit: Max results
        """
        start_time = time.time()
        entries = []
        scores = []
        source_tiers = []

        tiers = tiers or [MemoryTier.WORKING, MemoryTier.EPISODIC,
                          MemoryTier.SEMANTIC, MemoryTier.PROCEDURAL]

        # Working memory (exact key match)
        if MemoryTier.WORKING in tiers:
            value = self.working.recall(query)
            if value:
                entries.append(MemoryEntry(
                    key=query,
                    value=value,
                    tier=MemoryTier.WORKING
                ))
                scores.append(1.0)
                source_tiers.append(MemoryTier.WORKING)

        # Episodic memory
        if MemoryTier.EPISODIC in tiers:
            # Try exact key first
            value = self.episodic.recall(query)
            if value:
                entries.append(MemoryEntry(
                    key=query,
                    value=value,
                    tier=MemoryTier.EPISODIC
                ))
                scores.append(1.0)
                source_tiers.append(MemoryTier.EPISODIC)
            else:
                # Search by type
                results = self.episodic.search(episode_type=query, limit=limit)
                for r in results:
                    entries.append(MemoryEntry(
                        key=r["key"],
                        value=r["value"],
                        tier=MemoryTier.EPISODIC,
                        importance=r["importance"]
                    ))
                    scores.append(r["importance"])
                    source_tiers.append(MemoryTier.EPISODIC)

        # Semantic memory
        if MemoryTier.SEMANTIC in tiers:
            results = self.semantic.recall(query, limit=limit)
            for r in results:
                entries.append(MemoryEntry(
                    key=r["concept"],
                    value=r["knowledge"],
                    tier=MemoryTier.SEMANTIC
                ))
                scores.append(r.get("relevance", 0.5))
                source_tiers.append(MemoryTier.SEMANTIC)

        # Procedural memory
        if MemoryTier.PROCEDURAL in tiers:
            proc = self.procedural.recall_procedure(query)
            if proc:
                entries.append(MemoryEntry(
                    key=query,
                    value=proc,
                    tier=MemoryTier.PROCEDURAL
                ))
                scores.append(1.0)
                source_tiers.append(MemoryTier.PROCEDURAL)
            else:
                results = self.procedural.search_procedures(query)
                for r in results[:limit]:
                    entries.append(MemoryEntry(
                        key=r["name"],
                        value=r,
                        tier=MemoryTier.PROCEDURAL
                    ))
                    scores.append(0.7)
                    source_tiers.append(MemoryTier.PROCEDURAL)

        # Sort by score and limit
        if entries:
            combined = list(zip(entries, scores, source_tiers))
            combined.sort(key=lambda x: -x[1])
            combined = combined[:limit]
            entries, scores, source_tiers = zip(*combined)
            entries, scores, source_tiers = list(entries), list(scores), list(source_tiers)

        return RecallResult(
            entries=entries,
            relevance_scores=scores,
            source_tiers=source_tiers,
            total_time=time.time() - start_time
        )

    def consolidate(self):
        """
        Consolidate memories - promote frequently accessed working memory
        to episodic, and important episodic to semantic.
        """
        # Working → Episodic
        # (Would need access tracking in working memory)

        # Episodic → Semantic (high importance patterns)
        high_importance = self.episodic.search(min_importance=self.importance_threshold)
        for episode in high_importance:
            if episode.get("access_count", 0) >= self.consolidation_threshold:
                # Extract and store as semantic knowledge
                self.semantic.store(
                    concept=episode["key"],
                    knowledge=episode["value"],
                    category=episode.get("episode_type", "learned")
                )

    def get_stats(self) -> Dict:
        """Get memory statistics across all tiers."""
        return {
            "working": self.working.get_stats(),
            "episodic": self.episodic.get_stats(),
            "semantic": self.semantic.get_stats(),
            "procedural": self.procedural.get_stats()
        }

    def health_check(self) -> Dict:
        """Check health of all memory tiers."""
        health = {}

        # Working memory
        try:
            self.working.store("_health_check", True, ttl=1)
            health["working"] = "healthy"
        except Exception as e:
            health["working"] = f"error: {e}"

        # Episodic memory
        try:
            self.episodic.get_stats()
            health["episodic"] = "healthy"
        except Exception as e:
            health["episodic"] = f"error: {e}"

        # Semantic memory
        try:
            self.semantic.get_stats()
            health["semantic"] = "healthy"
        except Exception as e:
            health["semantic"] = f"error: {e}"

        # Procedural memory
        try:
            self.procedural.get_stats()
            health["procedural"] = "healthy"
        except Exception as e:
            health["procedural"] = f"error: {e}"

        return health


# Global instance
_memory: Optional[MemoryIntegration] = None


def get_memory() -> MemoryIntegration:
    """Get global memory integration instance."""
    global _memory
    if _memory is None:
        _memory = MemoryIntegration()
    return _memory


def main():
    """Demo and CLI for memory integration."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Memory Integration")
    parser.add_argument("command", choices=["demo", "stats", "health", "store", "recall"])
    parser.add_argument("--key", help="Memory key")
    parser.add_argument("--value", help="Value to store")
    parser.add_argument("--tier", choices=["working", "episodic", "semantic", "procedural", "auto"], default="auto")
    args = parser.parse_args()

    memory = MemoryIntegration()

    if args.command == "demo":
        print("Memory Integration Demo")
        print("=" * 40)

        # Store in different tiers
        memory.store("task:001", {"id": "001", "status": "complete"}, tier="episodic")
        memory.store("concept:genesis", "Self-evolving AI system", tier="semantic")
        memory.store("cache:temp", {"temp": 42}, tier="working")
        memory.store("howto:test", ["Step 1: Run tests", "Step 2: Check output"], tier="procedural")

        # Recall
        result = memory.recall("genesis")
        print(f"Recall 'genesis': {len(result.entries)} results")
        for entry in result.entries:
            print(f"  [{entry.tier.value}] {entry.key}: {entry.value}")

        print(f"\nStats: {json.dumps(memory.get_stats(), indent=2)}")

    elif args.command == "stats":
        print(json.dumps(memory.get_stats(), indent=2))

    elif args.command == "health":
        print(json.dumps(memory.health_check(), indent=2))

    elif args.command == "store":
        if not args.key or not args.value:
            print("--key and --value required")
            return
        success = memory.store(args.key, args.value, tier=args.tier)
        print(f"Stored: {success}")

    elif args.command == "recall":
        if not args.key:
            print("--key required")
            return
        result = memory.recall(args.key)
        print(f"Found {len(result.entries)} results in {result.total_time:.3f}s")
        for entry, score in zip(result.entries, result.relevance_scores):
            print(f"  [{entry.tier.value}] {entry.key} (score: {score:.2f})")
            print(f"    {entry.value}")


if __name__ == "__main__":
    main()
