#!/usr/bin/env python3
"""
GENESIS CACHING LAYER
======================
Multi-tier caching with LRU, TTL, and persistence support.

Features:
    - In-memory LRU cache
    - TTL-based expiration
    - File-based persistence
    - Cache decorators
    - Statistics tracking
    - Namespacing

Usage:
    cache = Cache()
    cache.set("key", value, ttl=300)
    value = cache.get("key")

    @cached(ttl=60)
    def expensive_function(x):
        return compute(x)
"""

"""
RULE 7 COMPLIANT: Uses Elestio PostgreSQL via genesis_db module.
"""
import hashlib
import json
import pickle
import threading
import time
import logging
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from functools import wraps
from pathlib import Path
from typing import Dict, List, Any, Optional, Callable, TypeVar, Generic, Union
from enum import Enum

# RULE 7: Use PostgreSQL via genesis_db (no sqlite3)
from core.genesis_db import connection, ensure_table

logger = logging.getLogger(__name__)


T = TypeVar('T')


class CachePolicy(Enum):
    """Cache eviction policies."""
    LRU = "lru"       # Least Recently Used
    LFU = "lfu"       # Least Frequently Used
    FIFO = "fifo"     # First In First Out
    TTL = "ttl"       # Time To Live only


@dataclass
class CacheEntry:
    """A cache entry with metadata."""
    key: str
    value: Any
    created_at: float
    expires_at: Optional[float]
    access_count: int = 0
    last_accessed: float = field(default_factory=time.time)
    size_bytes: int = 0

    def is_expired(self) -> bool:
        if self.expires_at is None:
            return False
        return time.time() > self.expires_at

    def touch(self):
        """Update access metadata."""
        self.access_count += 1
        self.last_accessed = time.time()


@dataclass
class CacheStats:
    """Cache statistics."""
    hits: int = 0
    misses: int = 0
    evictions: int = 0
    size: int = 0
    max_size: int = 0

    @property
    def hit_rate(self) -> float:
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0.0

    def to_dict(self) -> Dict:
        return {
            "hits": self.hits,
            "misses": self.misses,
            "evictions": self.evictions,
            "hit_rate": round(self.hit_rate, 4),
            "size": self.size,
            "max_size": self.max_size
        }


class MemoryCache:
    """
    In-memory LRU cache with TTL support.
    """

    def __init__(
        self,
        max_size: int = 1000,
        default_ttl: int = 3600,
        policy: CachePolicy = CachePolicy.LRU
    ):
        self.max_size = max_size
        self.default_ttl = default_ttl
        self.policy = policy

        self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
        self._lock = threading.RLock()
        self._stats = CacheStats(max_size=max_size)

    def get(self, key: str, default: Any = None) -> Any:
        """Get value from cache."""
        with self._lock:
            entry = self._cache.get(key)

            if entry is None:
                self._stats.misses += 1
                return default

            if entry.is_expired():
                self._delete(key)
                self._stats.misses += 1
                return default

            # Update access metadata
            entry.touch()

            # Move to end for LRU
            if self.policy == CachePolicy.LRU:
                self._cache.move_to_end(key)

            self._stats.hits += 1
            return entry.value

    def set(self, key: str, value: Any, ttl: int = None) -> bool:
        """Set value in cache."""
        with self._lock:
            # Calculate size
            try:
                size_bytes = len(pickle.dumps(value))
            except Exception:
                size_bytes = 0

            # Create entry
            expires_at = time.time() + (ttl or self.default_ttl) if ttl != 0 else None

            entry = CacheEntry(
                key=key,
                value=value,
                created_at=time.time(),
                expires_at=expires_at,
                size_bytes=size_bytes
            )

            # Remove old entry if exists
            if key in self._cache:
                self._delete(key)

            # Evict if necessary
            while len(self._cache) >= self.max_size:
                self._evict()

            self._cache[key] = entry
            self._stats.size = len(self._cache)

            return True

    def delete(self, key: str) -> bool:
        """Delete entry from cache."""
        with self._lock:
            return self._delete(key)

    def _delete(self, key: str) -> bool:
        """Internal delete without lock."""
        if key in self._cache:
            del self._cache[key]
            self._stats.size = len(self._cache)
            return True
        return False

    def _evict(self):
        """Evict an entry based on policy."""
        if not self._cache:
            return

        if self.policy == CachePolicy.LRU:
            # Remove least recently used (first item)
            self._cache.popitem(last=False)

        elif self.policy == CachePolicy.LFU:
            # Remove least frequently used
            min_key = min(self._cache.keys(), key=lambda k: self._cache[k].access_count)
            del self._cache[min_key]

        elif self.policy == CachePolicy.FIFO:
            # Remove first added
            self._cache.popitem(last=False)

        elif self.policy == CachePolicy.TTL:
            # Remove expired or oldest
            now = time.time()
            expired = [k for k, v in self._cache.items() if v.expires_at and v.expires_at < now]
            if expired:
                del self._cache[expired[0]]
            else:
                self._cache.popitem(last=False)

        self._stats.evictions += 1
        self._stats.size = len(self._cache)

    def clear(self):
        """Clear all entries."""
        with self._lock:
            self._cache.clear()
            self._stats.size = 0

    def cleanup_expired(self) -> int:
        """Remove expired entries."""
        with self._lock:
            now = time.time()
            expired = [k for k, v in self._cache.items() if v.expires_at and v.expires_at < now]
            for key in expired:
                self._delete(key)
            return len(expired)

    def keys(self) -> List[str]:
        """Get all keys."""
        with self._lock:
            return list(self._cache.keys())

    def contains(self, key: str) -> bool:
        """Check if key exists and is not expired."""
        with self._lock:
            entry = self._cache.get(key)
            if entry is None:
                return False
            if entry.is_expired():
                self._delete(key)
                return False
            return True

    def get_stats(self) -> CacheStats:
        """Get cache statistics."""
        return self._stats


class PersistentCache:
    """
    Persistent cache using PostgreSQL (RULE 7).
    """

    def __init__(self, db_path: Path = None, default_ttl: int = 86400):
        # RULE 7: db_path is ignored - uses PostgreSQL via genesis_db
        self.default_ttl = default_ttl
        self._init_db()

    def _init_db(self):
        """Initialize database (RULE 7: PostgreSQL)."""
        ensure_table('cache_store', '''
            key TEXT PRIMARY KEY,
            value BYTEA NOT NULL,
            created_at REAL NOT NULL,
            expires_at REAL,
            access_count INTEGER DEFAULT 0
        ''')
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("CREATE INDEX IF NOT EXISTS idx_cache_expires ON cache_store(expires_at)")
        except Exception as e:
            logger.warning(f"Index creation warning: {e}")

    def get(self, key: str, default: Any = None) -> Any:
        """Get value from cache (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute(
                    "SELECT value, expires_at FROM cache_store WHERE key = %s",
                    (key,)
                )
                row = cursor.fetchone()

                if row is None:
                    return default

                value_blob, expires_at = row

                if expires_at and time.time() > expires_at:
                    self.delete(key)
                    return default

                # Update access count
                cursor.execute(
                    "UPDATE cache_store SET access_count = access_count + 1 WHERE key = %s",
                    (key,)
                )

                return pickle.loads(bytes(value_blob))

        except Exception as e:
            logger.warning(f"Cache get failed: {e}")
            return default

    def set(self, key: str, value: Any, ttl: int = None) -> bool:
        """Set value in cache (RULE 7: PostgreSQL)."""
        try:
            value_blob = pickle.dumps(value)
            expires_at = time.time() + (ttl or self.default_ttl) if ttl != 0 else None

            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    INSERT INTO cache_store (key, value, created_at, expires_at, access_count)
                    VALUES (%s, %s, %s, %s, 0)
                    ON CONFLICT (key) DO UPDATE SET
                        value = EXCLUDED.value,
                        expires_at = EXCLUDED.expires_at,
                        access_count = 0
                """, (key, value_blob, time.time(), expires_at))

            return True
        except Exception as e:
            logger.warning(f"Cache set failed: {e}")
            return False

    def delete(self, key: str) -> bool:
        """Delete entry from cache (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("DELETE FROM cache_store WHERE key = %s", (key,))
            return True
        except Exception as e:
            logger.warning(f"Cache delete failed: {e}")
            return False

    def clear(self):
        """Clear all entries (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("DELETE FROM cache_store")
        except Exception as e:
            logger.warning(f"Cache clear failed: {e}")

    def cleanup_expired(self) -> int:
        """Remove expired entries (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute(
                    "DELETE FROM cache_store WHERE expires_at IS NOT NULL AND expires_at < %s",
                    (time.time(),)
                )
                return cursor.rowcount
        except Exception as e:
            logger.warning(f"Cache cleanup failed: {e}")
            return 0


class TieredCache:
    """
    Multi-tier cache with memory and persistent layers.
    """

    def __init__(
        self,
        memory_size: int = 1000,
        memory_ttl: int = 300,
        persist: bool = True,
        persist_ttl: int = 86400
    ):
        self.memory = MemoryCache(max_size=memory_size, default_ttl=memory_ttl)
        self.persistent = PersistentCache(default_ttl=persist_ttl) if persist else None

    def get(self, key: str, default: Any = None) -> Any:
        """Get value, checking memory first then persistent."""
        # Check memory
        value = self.memory.get(key)
        if value is not None:
            return value

        # Check persistent
        if self.persistent:
            value = self.persistent.get(key)
            if value is not None:
                # Promote to memory
                self.memory.set(key, value)
                return value

        return default

    def set(self, key: str, value: Any, ttl: int = None, persist: bool = True) -> bool:
        """Set value in both tiers."""
        success = self.memory.set(key, value, ttl)

        if persist and self.persistent:
            self.persistent.set(key, value, ttl)

        return success

    def delete(self, key: str) -> bool:
        """Delete from both tiers."""
        success = self.memory.delete(key)
        if self.persistent:
            self.persistent.delete(key)
        return success

    def clear(self):
        """Clear both tiers."""
        self.memory.clear()
        if self.persistent:
            self.persistent.clear()


class NamespacedCache:
    """
    Cache with namespace support for isolation.
    """

    def __init__(self, cache: Union[MemoryCache, TieredCache], namespace: str):
        self._cache = cache
        self._namespace = namespace

    def _make_key(self, key: str) -> str:
        return f"{self._namespace}:{key}"

    def get(self, key: str, default: Any = None) -> Any:
        return self._cache.get(self._make_key(key), default)

    def set(self, key: str, value: Any, ttl: int = None, **kwargs) -> bool:
        return self._cache.set(self._make_key(key), value, ttl, **kwargs)

    def delete(self, key: str) -> bool:
        return self._cache.delete(self._make_key(key))

    def clear_namespace(self):
        """Clear all keys in this namespace."""
        prefix = f"{self._namespace}:"
        if isinstance(self._cache, MemoryCache):
            for key in self._cache.keys():
                if key.startswith(prefix):
                    self._cache.delete(key)


def cached(
    ttl: int = 300,
    key_fn: Callable[..., str] = None,
    cache: MemoryCache = None
):
    """
    Decorator for caching function results.

    Args:
        ttl: Cache time-to-live in seconds
        key_fn: Custom function to generate cache key
        cache: Cache instance to use (default: global cache)
    """
    _cache = cache or _default_cache

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # Generate key
            if key_fn:
                key = key_fn(*args, **kwargs)
            else:
                key = _generate_cache_key(func.__name__, args, kwargs)

            # Try cache
            result = _cache.get(key)
            if result is not None:
                return result

            # Compute and cache
            result = func(*args, **kwargs)
            _cache.set(key, result, ttl)
            return result

        wrapper.cache_clear = lambda: None  # Compatibility
        return wrapper

    return decorator


def _generate_cache_key(func_name: str, args: tuple, kwargs: dict) -> str:
    """Generate cache key from function call."""
    key_parts = [func_name]
    key_parts.extend(str(arg) for arg in args)
    key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
    key_str = ":".join(key_parts)
    return hashlib.md5(key_str.encode()).hexdigest()


# Global default cache
_default_cache = MemoryCache()


def get_cache() -> MemoryCache:
    """Get the default global cache."""
    return _default_cache


def main():
    """CLI and demo for cache layer."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Cache Layer")
    parser.add_argument("command", choices=["demo", "stats", "clear"])
    args = parser.parse_args()

    if args.command == "demo":
        print("Cache Layer Demo")
        print("=" * 40)

        # Memory cache
        cache = MemoryCache(max_size=5)

        print("\n1. Basic operations:")
        cache.set("user:1", {"name": "Alice", "score": 100})
        cache.set("user:2", {"name": "Bob", "score": 85})
        print(f"  Set user:1 and user:2")

        print(f"  Get user:1: {cache.get('user:1')}")
        print(f"  Get user:3: {cache.get('user:3', 'NOT FOUND')}")

        # TTL
        print("\n2. TTL expiration:")
        cache.set("temp", "expires soon", ttl=1)
        print(f"  Set temp with 1s TTL")
        print(f"  Get temp now: {cache.get('temp')}")
        time.sleep(1.5)
        print(f"  Get temp after 1.5s: {cache.get('temp', 'EXPIRED')}")

        # LRU eviction
        print("\n3. LRU eviction (max_size=5):")
        for i in range(10):
            cache.set(f"key{i}", f"value{i}")
            print(f"  Set key{i}, cache size: {len(cache.keys())}")

        # Decorator
        print("\n4. @cached decorator:")

        @cached(ttl=60)
        def expensive_computation(n):
            print(f"  Computing for n={n}...")
            time.sleep(0.1)
            return n * n

        print(f"  First call: {expensive_computation(5)}")
        print(f"  Second call (cached): {expensive_computation(5)}")
        print(f"  Different arg: {expensive_computation(10)}")

        # Stats
        print(f"\n5. Statistics: {json.dumps(cache.get_stats().to_dict(), indent=2)}")

        # Tiered cache
        print("\n6. Tiered cache:")
        tiered = TieredCache(memory_size=100, persist=True)
        tiered.set("persistent_key", {"data": "will survive restart"})
        print(f"  Stored in tiered cache")
        print(f"  Get: {tiered.get('persistent_key')}")

        # Namespaced
        print("\n7. Namespaced cache:")
        ns_cache = NamespacedCache(cache, "user_data")
        ns_cache.set("preferences", {"theme": "dark"})
        print(f"  Set with namespace 'user_data'")
        print(f"  Get preferences: {ns_cache.get('preferences')}")

        print("\nDemo complete!")

    elif args.command == "stats":
        stats = _default_cache.get_stats()
        print(json.dumps(stats.to_dict(), indent=2))

    elif args.command == "clear":
        _default_cache.clear()
        print("Cache cleared")


if __name__ == "__main__":
    main()
