#!/usr/bin/env python3
"""
GeminiAgent — Persistent Gemini Agent with Session Memory and CTM Hooks
========================================================================
A stateful Gemini agent that:
  - Maintains conversation history across turns
  - Persists session state to JSONL files (E: drive)
  - Extracts and commits [CTM] blocks to the Knowledge Graph
  - Tracks token usage and compacts history when approaching context limits

Uses google.genai (new SDK, NOT deprecated google.generativeai).

Author: Genesis Parallel Builder
Created: 2026-02-26
"""

import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
from uuid import uuid4

from google import genai
from google.genai import types as genai_types

from core.gemini_command_centres.ctm import ctm_to_kg, process_ctm_blocks

logger = logging.getLogger(__name__)

# ─── Paths (E: drive only) ────────────────────────────────────────────────────
GENESIS_ROOT = Path("/mnt/e/genesis-system")
GCC_SESSIONS_DIR = GENESIS_ROOT / "data" / "gcc_sessions"

# ─── Context window management ────────────────────────────────────────────────
# Gemini 3.x models support 1,048,576 tokens input.
# We compact when estimated token usage exceeds this threshold.
COMPACTION_TOKEN_THRESHOLD = 800_000
# After compaction: keep only the most recent N turns (each turn = 2 messages)
COMPACTION_KEEP_TURNS = 25


class GeminiAgent:
    """
    A persistent Gemini agent with session memory and CTM hooks.

    Conversation history is stored as a list of dicts matching the
    google.genai Content format:
        {"role": "user" | "model", "parts": [{"text": "..."}]}

    Session state is persisted to a JSONL file in data/gcc_sessions/.
    """

    def __init__(
        self,
        name: str,
        model: str,
        system_prompt: str,
        memory_file: Optional[str] = None,
        max_history: int = 50,
        api_key: Optional[str] = None,
    ) -> None:
        """
        Initialise a GeminiAgent.

        Args:
            name: Short identifier (e.g. "orchestrator", "builder").
            model: Gemini model ID (e.g. "gemini-3.1-pro-preview").
            system_prompt: System-level instruction injected each chat.
            memory_file: Path to JSONL session file. Auto-generated if None.
            max_history: Maximum conversation turns to retain before
                         compaction consideration (soft limit).
            api_key: Gemini API key. Falls back to GEMINI_API_KEY env var,
                     then the Genesis default key.
        """
        self.name = name
        self.model = model
        self.system_prompt = system_prompt
        self.max_history = max_history
        self.session_id = str(uuid4())
        self.created_at = datetime.now(timezone.utc)
        self.turn_count = 0
        self.ctm_buffer: list[dict] = []

        # Conversation history (google.genai Content format)
        self.history: list[dict] = []

        # Token tracking
        self.total_prompt_tokens = 0
        self.total_completion_tokens = 0
        self.estimated_context_tokens = 0

        # Session file
        GCC_SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
        if memory_file:
            self.memory_file = Path(memory_file)
        else:
            self.memory_file = GCC_SESSIONS_DIR / f"{name}_session.jsonl"

        # Build Gemini client
        resolved_key = (
            api_key
            or os.environ.get("GEMINI_API_KEY")
            or "AIzaSyCT_rx0NusUJWoqtT7uxHAKEfHo129SJb8"
        )
        self.client = genai.Client(api_key=resolved_key)

        logger.info(
            "GeminiAgent[%s] initialised — model=%s session=%s",
            self.name,
            self.model,
            self.session_id,
        )

    # ─── Chat ─────────────────────────────────────────────────────────────────

    async def chat(self, message: str) -> str:
        """
        Send a user message and receive a model response.

        Automatically:
        - Appends message to history
        - Calls Gemini API with full history + system prompt
        - Extracts [CTM] blocks from response and writes to KG
        - Updates token counters
        - Triggers history compaction if context is near limit

        Args:
            message: User message text.

        Returns:
            Model response text.
        """
        # Append user turn to history
        self.history.append({"role": "user", "parts": [{"text": message}]})

        # Build Contents list for API call (history already includes new message)
        contents = self._build_contents()

        try:
            response = await asyncio.to_thread(
                self._call_api,
                contents=contents,
            )
        except Exception as exc:
            logger.error("GeminiAgent[%s] API error: %s", self.name, exc)
            # Remove the user turn we just added so history stays consistent
            self.history.pop()
            raise

        response_text = self._extract_text(response)

        # Append model turn
        self.history.append({"role": "model", "parts": [{"text": response_text}]})

        self.turn_count += 1
        self._update_token_counts(response)

        # Auto-CTM: extract and persist [CTM] blocks
        ctm_results = process_ctm_blocks(response_text, agent_name=self.name)
        if ctm_results:
            self.ctm_buffer.extend(ctm_results)
            logger.info(
                "GeminiAgent[%s] CTM: %d blocks extracted and saved",
                self.name,
                len(ctm_results),
            )

        # Compact if approaching context limit
        if self.estimated_context_tokens >= COMPACTION_TOKEN_THRESHOLD:
            self._compact_history()

        return response_text

    def _call_api(self, contents: list[dict]) -> Any:
        """
        Synchronous Gemini API call (run inside asyncio.to_thread).

        Args:
            contents: List of Content dicts for the conversation.

        Returns:
            GenerateContentResponse object.
        """
        config = genai_types.GenerateContentConfig(
            system_instruction=self.system_prompt,
        )
        response = self.client.models.generate_content(
            model=self.model,
            contents=contents,
            config=config,
        )
        return response

    def _build_contents(self) -> list[dict]:
        """
        Build the contents list from current history.
        The history already includes the latest user message.
        """
        return list(self.history)

    def _extract_text(self, response: Any) -> str:
        """Extract text from a GenerateContentResponse."""
        try:
            return response.text
        except AttributeError:
            pass
        # Fallback: drill into candidates
        try:
            return response.candidates[0].content.parts[0].text
        except (AttributeError, IndexError, TypeError):
            return str(response)

    def _update_token_counts(self, response: Any) -> None:
        """Update token counters from response usage_metadata."""
        try:
            usage = response.usage_metadata
            self.total_prompt_tokens += getattr(usage, "prompt_token_count", 0) or 0
            self.total_completion_tokens += (
                getattr(usage, "candidates_token_count", 0) or 0
            )
            # Use the actual prompt token count as our context estimate
            self.estimated_context_tokens = (
                getattr(usage, "prompt_token_count", 0) or 0
            )
        except AttributeError:
            # No usage metadata — estimate from character count
            total_chars = sum(
                len(m["parts"][0]["text"]) for m in self.history
            )
            self.estimated_context_tokens = total_chars // 4  # rough 4 chars/token

    # ─── History Compaction ───────────────────────────────────────────────────

    def _compact_history(self) -> None:
        """
        Compact history when approaching context limit.

        Keeps the most recent COMPACTION_KEEP_TURNS conversation turns
        (each turn = 1 user msg + 1 model msg = 2 entries).
        Prepends a synthetic summary entry noting the compaction.
        """
        keep_messages = COMPACTION_KEEP_TURNS * 2  # user + model per turn
        if len(self.history) <= keep_messages:
            return

        dropped = len(self.history) - keep_messages
        self.history = self.history[-keep_messages:]

        summary_msg = (
            f"[CONTEXT COMPACTED: {dropped} earlier messages removed. "
            f"Continuing from turn {self.turn_count - COMPACTION_KEEP_TURNS}+]"
        )
        self.history.insert(0, {"role": "user", "parts": [{"text": summary_msg}]})
        self.history.insert(
            1,
            {
                "role": "model",
                "parts": [{"text": "Understood. Resuming from compacted context."}],
            },
        )

        self.estimated_context_tokens = 0  # reset after compaction
        logger.info(
            "GeminiAgent[%s] history compacted — dropped %d messages, kept %d",
            self.name,
            dropped,
            len(self.history),
        )

    # ─── CTM ──────────────────────────────────────────────────────────────────

    def ctm(self, content: str, category: str = "entity") -> dict:
        """
        Explicitly commit content to the Knowledge Graph.

        Args:
            content: Insight or fact to persist.
            category: "entity" or "axiom".

        Returns:
            Write result dict from ctm_to_kg().
        """
        result = ctm_to_kg(
            agent_name=self.name,
            content=content,
            category=category,
        )
        self.ctm_buffer.append(result)
        return result

    # ─── Session Persistence ──────────────────────────────────────────────────

    def save_session(self) -> bool:
        """
        Persist session state to memory_file (JSONL format).

        The file contains a single JSON object (last line wins for loading).

        Returns:
            True on success, False on error.
        """
        try:
            state = {
                "session_id": self.session_id,
                "name": self.name,
                "model": self.model,
                "created_at": self.created_at.isoformat(),
                "saved_at": datetime.now(timezone.utc).isoformat(),
                "turn_count": self.turn_count,
                "history": self.history,
                "total_prompt_tokens": self.total_prompt_tokens,
                "total_completion_tokens": self.total_completion_tokens,
                "estimated_context_tokens": self.estimated_context_tokens,
            }
            self.memory_file.parent.mkdir(parents=True, exist_ok=True)
            with open(self.memory_file, "w", encoding="utf-8") as fh:
                fh.write(json.dumps(state) + "\n")
            logger.debug("GeminiAgent[%s] session saved to %s", self.name, self.memory_file)
            return True
        except Exception as exc:
            logger.error("GeminiAgent[%s] save_session error: %s", self.name, exc)
            return False

    def load_session(self) -> bool:
        """
        Resume from last saved session state.

        Reads the most recent JSON object in memory_file and restores
        history, turn_count, and token counters.

        Returns:
            True if session was loaded, False if file does not exist or parse error.
        """
        if not self.memory_file.exists():
            logger.info(
                "GeminiAgent[%s] no saved session at %s — starting fresh",
                self.name,
                self.memory_file,
            )
            return False
        try:
            # Read last non-empty line (append mode could have multiple; overwrite has one)
            last_line = ""
            with open(self.memory_file, "r", encoding="utf-8") as fh:
                for line in fh:
                    line = line.strip()
                    if line:
                        last_line = line

            if not last_line:
                return False

            state = json.loads(last_line)
            self.session_id = state.get("session_id", self.session_id)
            self.turn_count = state.get("turn_count", 0)
            self.history = state.get("history", [])
            self.total_prompt_tokens = state.get("total_prompt_tokens", 0)
            self.total_completion_tokens = state.get("total_completion_tokens", 0)
            self.estimated_context_tokens = state.get("estimated_context_tokens", 0)

            logger.info(
                "GeminiAgent[%s] session loaded — %d turns, %d history messages",
                self.name,
                self.turn_count,
                len(self.history),
            )
            return True
        except (json.JSONDecodeError, KeyError) as exc:
            logger.warning(
                "GeminiAgent[%s] load_session parse error: %s — starting fresh",
                self.name,
                exc,
            )
            return False

    # ─── Stats ────────────────────────────────────────────────────────────────

    def get_context_usage(self) -> dict:
        """
        Return token usage stats for this agent.

        Returns:
            dict with token counts and context utilisation percentage.
        """
        max_tokens = 1_048_576  # All gcc models
        utilisation_pct = (
            (self.estimated_context_tokens / max_tokens) * 100
            if max_tokens > 0
            else 0.0
        )
        return {
            "agent": self.name,
            "model": self.model,
            "session_id": self.session_id,
            "turn_count": self.turn_count,
            "history_messages": len(self.history),
            "total_prompt_tokens": self.total_prompt_tokens,
            "total_completion_tokens": self.total_completion_tokens,
            "estimated_context_tokens": self.estimated_context_tokens,
            "max_context_tokens": max_tokens,
            "utilisation_pct": round(utilisation_pct, 2),
        }

    def __repr__(self) -> str:
        return (
            f"GeminiAgent(name={self.name!r}, model={self.model!r}, "
            f"turns={self.turn_count}, history={len(self.history)})"
        )
