#!/usr/bin/env python3
"""
Genesis Predictive Pre-Loading System
======================================
Fast Model Anticipates Context Needs

Before the main LLM processes a query, this system:
1. Uses Claude Haiku to predict needed context
2. Pre-loads relevant memories to Redis L1 cache
3. Main retrieval benefits from cache hits

Target: 2.3x faster retrieval (800ms → 350ms)
Cost: ~$0.0002/query × 500 queries/day = $0.10/day
"""

import asyncio
import hashlib
import json
import os
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
import psycopg2
from psycopg2.extras import RealDictCursor

try:
    import redis
    REDIS_AVAILABLE = True
except ImportError:
    REDIS_AVAILABLE = False

try:
    import anthropic
    ANTHROPIC_AVAILABLE = True
except ImportError:
    ANTHROPIC_AVAILABLE = False


@dataclass
class PredictedContext:
    """Predicted context needs for a query."""
    topics: List[str]
    entities: List[str]
    event_types: List[str]
    time_range: Optional[str]  # "recent", "last_week", "last_month", None
    predicted_at: datetime = None

    def __post_init__(self):
        if self.predicted_at is None:
            self.predicted_at = datetime.now()


class PredictivePreloader:
    """
    Predicts and pre-loads context before main retrieval.

    Uses Claude Haiku (fast, cheap) to predict what memories
    will be needed, then pre-fetches them into Redis cache.
    """

    # Configuration
    HAIKU_MODEL = "claude-3-5-haiku-20241022"
    CACHE_TTL = 300  # 5 minutes
    CACHE_PREFIX = "cache:predicted:"
    MAX_PRELOAD_EVENTS = 50

    PREDICTION_PROMPT = """You are a context prediction engine for Genesis, an autonomous AI system.

Given a user query, predict what context from memory will be needed to answer it well.

OUTPUT FORMAT (JSON only):
{
  "topics": ["topic1", "topic2"],  // 1-5 relevant topics
  "entities": ["entity1", "entity2"],  // specific people, projects, files mentioned
  "event_types": ["tool_call", "decision", "error"],  // relevant event types
  "time_range": "recent" | "last_week" | "last_month" | null  // how far back to look
}

EVENT TYPES available:
- tool_call: Tool usage events
- decision: Decisions made
- error: Errors encountered
- plan: Plans created
- learning: Lessons learned
- conversation: User interactions

Be specific but not overly narrow. Return ONLY the JSON, no other text."""

    def __init__(
        self,
        postgres_config: Dict = None,
        redis_config: Dict = None,
        anthropic_api_key: str = None
    ):
        self.postgres_config = postgres_config or {
            "host": os.getenv("GENESIS_PG_HOST", "postgresql-genesis-u50607.vm.elestio.app"),
            "port": int(os.getenv("GENESIS_PG_PORT", "25432")),
            "database": os.getenv("GENESIS_PG_DATABASE", "postgres"),
            "user": os.getenv("GENESIS_PG_USER", "postgres"),
            "password": os.getenv("GENESIS_PG_PASSWORD", "etY0eog17tD-dDuj--IRH")
        }

        self.redis_config = redis_config or {
            "host": os.getenv("GENESIS_REDIS_HOST", "Redis-genesis-u50607.vm.elestio.app"),
            "port": int(os.getenv("GENESIS_REDIS_PORT", "6379")),
            "password": os.getenv("GENESIS_REDIS_PASSWORD", ""),
            "db": 0
        }

        self.api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
        self._client = None
        self._conn = None
        self._redis = None

    def _get_conn(self):
        """Get PostgreSQL connection."""
        if self._conn is None or self._conn.closed:
            self._conn = psycopg2.connect(**self.postgres_config)
            self._conn.autocommit = True
        return self._conn

    def _get_redis(self):
        """Get Redis connection."""
        if self._redis is None and REDIS_AVAILABLE:
            try:
                self._redis = redis.Redis(**self.redis_config, decode_responses=True)
                self._redis.ping()
            except Exception as e:
                print(f"Redis connection failed: {e}")
                self._redis = None
        return self._redis

    def _get_client(self):
        """Get Anthropic client."""
        if self._client is None and ANTHROPIC_AVAILABLE:
            self._client = anthropic.Anthropic(api_key=self.api_key)
        return self._client

    def _query_hash(self, query: str) -> str:
        """Generate cache key for query."""
        return hashlib.md5(query.lower().strip().encode()).hexdigest()[:16]

    async def predict_context(self, query: str) -> PredictedContext:
        """
        Use Haiku to predict what context will be needed.

        This runs in parallel with initial query processing,
        so the prediction is ready before main retrieval starts.
        """
        if not ANTHROPIC_AVAILABLE:
            return self._fallback_predict(query)

        client = self._get_client()
        if not client:
            return self._fallback_predict(query)

        try:
            response = client.messages.create(
                model=self.HAIKU_MODEL,
                max_tokens=256,
                system=self.PREDICTION_PROMPT,
                messages=[{"role": "user", "content": f"Query: {query}"}]
            )

            response_text = response.content[0].text.strip()

            # Handle potential markdown
            if response_text.startswith("```"):
                response_text = response_text.split("```")[1]
                if response_text.startswith("json"):
                    response_text = response_text[4:]

            data = json.loads(response_text)

            return PredictedContext(
                topics=data.get("topics", []),
                entities=data.get("entities", []),
                event_types=data.get("event_types", []),
                time_range=data.get("time_range")
            )

        except Exception as e:
            print(f"Prediction error: {e}")
            return self._fallback_predict(query)

    def _fallback_predict(self, query: str) -> PredictedContext:
        """Simple keyword-based prediction when Haiku unavailable."""
        query_lower = query.lower()

        topics = []
        entities = []
        event_types = []
        time_range = None

        # Extract potential topics/entities
        words = query_lower.split()
        topics = [w for w in words if len(w) > 4][:5]

        # Guess event types
        if any(w in query_lower for w in ["error", "bug", "fail", "crash"]):
            event_types.append("error")
        if any(w in query_lower for w in ["decide", "decision", "chose", "choice"]):
            event_types.append("decision")
        if any(w in query_lower for w in ["plan", "strategy", "approach"]):
            event_types.append("plan")
        if any(w in query_lower for w in ["learn", "discover", "realize"]):
            event_types.append("learning")

        # Guess time range
        if any(w in query_lower for w in ["today", "now", "recent", "latest"]):
            time_range = "recent"
        elif any(w in query_lower for w in ["week", "last week"]):
            time_range = "last_week"
        elif any(w in query_lower for w in ["month"]):
            time_range = "last_month"

        return PredictedContext(
            topics=topics,
            entities=entities,
            event_types=event_types or ["tool_call", "decision"],
            time_range=time_range or "recent"
        )

    async def preload_context(
        self,
        query: str,
        prediction: PredictedContext = None
    ) -> Dict[str, Any]:
        """
        Pre-load predicted context to Redis cache.

        Returns info about what was preloaded.
        """
        if prediction is None:
            prediction = await self.predict_context(query)

        # Build query for relevant events
        events = await self._fetch_predicted_events(prediction)

        # Cache them
        cache_key = f"{self.CACHE_PREFIX}{self._query_hash(query)}"
        cached = await self._cache_events(cache_key, events)

        return {
            "query_hash": self._query_hash(query),
            "prediction": {
                "topics": prediction.topics,
                "entities": prediction.entities,
                "event_types": prediction.event_types,
                "time_range": prediction.time_range
            },
            "events_preloaded": len(cached),
            "cache_key": cache_key,
            "ttl_seconds": self.CACHE_TTL
        }

    async def _fetch_predicted_events(
        self,
        prediction: PredictedContext
    ) -> List[Dict]:
        """Fetch events matching prediction."""
        try:
            conn = self._get_conn()
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                # Build time filter
                time_filter = ""
                if prediction.time_range == "recent":
                    time_filter = "AND timestamp > NOW() - INTERVAL '24 hours'"
                elif prediction.time_range == "last_week":
                    time_filter = "AND timestamp > NOW() - INTERVAL '7 days'"
                elif prediction.time_range == "last_month":
                    time_filter = "AND timestamp > NOW() - INTERVAL '30 days'"

                # Build event type filter
                type_filter = ""
                if prediction.event_types:
                    types = "','".join(prediction.event_types)
                    type_filter = f"AND event_type IN ('{types}')"

                # Search for matching events
                search_terms = prediction.topics + prediction.entities
                if search_terms:
                    search_pattern = "|".join(search_terms[:10])
                    query = f"""
                        SELECT event_id, event_type, content, surprise_score,
                               timestamp, tier, metadata
                        FROM ambient_events
                        WHERE tier IN ('episodic', 'semantic')
                          {time_filter}
                          {type_filter}
                          AND content::text ~* %s
                        ORDER BY surprise_score DESC, timestamp DESC
                        LIMIT %s
                    """
                    cur.execute(query, (search_pattern, self.MAX_PRELOAD_EVENTS))
                else:
                    # No search terms, just get recent high-value events
                    query = f"""
                        SELECT event_id, event_type, content, surprise_score,
                               timestamp, tier, metadata
                        FROM ambient_events
                        WHERE tier IN ('episodic', 'semantic')
                          {time_filter}
                          {type_filter}
                        ORDER BY surprise_score DESC, timestamp DESC
                        LIMIT %s
                    """
                    cur.execute(query, (self.MAX_PRELOAD_EVENTS,))

                events = [dict(row) for row in cur.fetchall()]

                # Track access for VoI
                if events:
                    event_ids = [e["event_id"] for e in events]
                    cur.execute("""
                        UPDATE ambient_events
                        SET access_count = COALESCE(access_count, 0) + 1,
                            last_accessed = NOW()
                        WHERE event_id = ANY(%s)
                    """, (event_ids,))

                return events

        except Exception as e:
            print(f"Error fetching predicted events: {e}")
            return []

    async def _cache_events(
        self,
        cache_key: str,
        events: List[Dict]
    ) -> List[Dict]:
        """Cache events in Redis."""
        r = self._get_redis()
        if not r or not events:
            return events

        try:
            # Serialize events
            for event in events:
                if "timestamp" in event and event["timestamp"]:
                    event["timestamp"] = event["timestamp"].isoformat()

            r.setex(
                cache_key,
                self.CACHE_TTL,
                json.dumps(events, default=str)
            )
            return events

        except Exception as e:
            print(f"Cache error: {e}")
            return events

    async def get_cached_context(self, query: str) -> Optional[List[Dict]]:
        """
        Check if context is already cached for this query.

        Called by main retrieval to check for cache hits.
        """
        r = self._get_redis()
        if not r:
            return None

        try:
            cache_key = f"{self.CACHE_PREFIX}{self._query_hash(query)}"
            cached = r.get(cache_key)
            if cached:
                return json.loads(cached)
            return None
        except Exception as e:
            print(f"Cache read error: {e}")
            return None

    async def parallel_preload(self, query: str) -> Dict[str, Any]:
        """
        Entry point for parallel pre-loading.

        This is called in parallel with the main query processing:

            async with asyncio.TaskGroup() as tg:
                preload_task = tg.create_task(preloader.parallel_preload(query))
                main_task = tg.create_task(process_query(query))

            preload_result = await preload_task
            main_result = await main_task
        """
        start = datetime.now()

        # Predict context
        prediction = await self.predict_context(query)

        # Pre-load to cache
        result = await self.preload_context(query, prediction)

        result["latency_ms"] = (datetime.now() - start).total_seconds() * 1000
        return result

    def close(self):
        """Close connections."""
        if self._conn:
            self._conn.close()
        if self._redis:
            self._redis.close()


# Singleton instance
_preloader_instance: Optional[PredictivePreloader] = None


def get_preloader() -> PredictivePreloader:
    """Get singleton preloader instance."""
    global _preloader_instance
    if _preloader_instance is None:
        _preloader_instance = PredictivePreloader()
    return _preloader_instance


async def preload_for_query(query: str) -> Dict[str, Any]:
    """Convenience function for pre-loading context."""
    preloader = get_preloader()
    return await preloader.parallel_preload(query)


async def check_cache(query: str) -> Optional[List[Dict]]:
    """Check if context is cached for query."""
    preloader = get_preloader()
    return await preloader.get_cached_context(query)


# CLI Interface
async def main():
    import sys

    if len(sys.argv) < 2:
        print("Genesis Predictive Pre-Loading System")
        print("\nUsage:")
        print("  python predictive_preload.py predict 'query'   # Predict context needs")
        print("  python predictive_preload.py preload 'query'   # Pre-load context")
        print("  python predictive_preload.py cache 'query'     # Check cache")
        print("  python predictive_preload.py bench 'query'     # Benchmark preload")
        return

    preloader = PredictivePreloader()
    command = sys.argv[1]
    query = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "What errors occurred recently?"

    try:
        if command == "predict":
            prediction = await preloader.predict_context(query)
            print(json.dumps({
                "topics": prediction.topics,
                "entities": prediction.entities,
                "event_types": prediction.event_types,
                "time_range": prediction.time_range
            }, indent=2))

        elif command == "preload":
            result = await preloader.preload_context(query)
            print(json.dumps(result, indent=2, default=str))

        elif command == "cache":
            cached = await preloader.get_cached_context(query)
            if cached:
                print(f"Cache HIT: {len(cached)} events")
                print(json.dumps(cached[:3], indent=2, default=str))  # First 3
            else:
                print("Cache MISS")

        elif command == "bench":
            import time

            # Benchmark pre-load
            start = time.time()
            result = await preloader.parallel_preload(query)
            preload_time = (time.time() - start) * 1000

            # Benchmark cache hit
            start = time.time()
            cached = await preloader.get_cached_context(query)
            cache_time = (time.time() - start) * 1000

            print(f"Pre-load time: {preload_time:.1f}ms")
            print(f"Cache hit time: {cache_time:.1f}ms")
            print(f"Events preloaded: {result['events_preloaded']}")
            print(f"Speedup factor: {preload_time/max(cache_time, 0.1):.1f}x")

        else:
            print(f"Unknown command: {command}")

    finally:
        preloader.close()


if __name__ == "__main__":
    asyncio.run(main())
