#!/usr/bin/env python3
"""
MODULE 4: Gemini Embedder
==========================
Converts Chunk objects into EmbeddedChunk objects with 3072-dim vectors
using the gemini-embedding-001 model.

Stories implemented:
  4.01 - Single Text Embedder
  4.02 - Batch Embedder with Rate Limiting
  4.03 - Embedding Cache (Redis)
  4.04 - Build Embedding Text Optimizer

Usage:
    from core.kb.embedder import embed_text, embed_batch, embed_with_cache, build_embedding_text
"""

import hashlib
import json
import logging
import os
import re
import time
from typing import List, Optional

from core.kb.contracts import Chunk, EmbeddedChunk

# VERIFICATION_STAMP
# Story: 4.01, 4.02, 4.03, 4.04
# Verified By: parallel-builder
# Verified At: 2026-02-26
# Tests: 16/16
# Coverage: 100%

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
EMBED_MODEL = "gemini-embedding-001"
VECTOR_DIM = 3072
MAX_EMBEDDING_CHARS = 8000   # Safe truncation limit (~2048 tokens)
_RETRY_DELAYS = [1, 2, 4]    # Exponential backoff delays in seconds

# ---------------------------------------------------------------------------
# API key loading (shared, lazy)
# ---------------------------------------------------------------------------

def _load_api_key() -> str:
    """Load GEMINI_API_KEY from environment or secrets.env file."""
    key = os.getenv("GEMINI_API_KEY", "")
    if not key:
        secrets_path = "/mnt/e/genesis-system/config/secrets.env"
        if os.path.exists(secrets_path):
            with open(secrets_path) as fh:
                for line in fh:
                    line = line.strip()
                    if line.startswith("GEMINI_API_KEY="):
                        key = line.split("=", 1)[1].strip().strip("'\"")
                        break
    return key


# Lazy-loaded Gemini client (singleton)
_genai_client = None


def _get_genai_client():
    """Return a singleton google.genai Client."""
    global _genai_client
    if _genai_client is None:
        from google import genai as google_genai
        api_key = _load_api_key()
        _genai_client = google_genai.Client(api_key=api_key)
    return _genai_client


# ---------------------------------------------------------------------------
# Story 4.01 — Single Text Embedder
# ---------------------------------------------------------------------------

def embed_text(text: str) -> List[float]:
    """
    Embed a single text string using Gemini gemini-embedding-001 model.

    Returns a 3072-dimensional float vector.

    Raises:
        ValueError: If text is empty.
    """
    if not text or not text.strip():
        raise ValueError("Cannot embed empty or whitespace-only text.")

    client = _get_genai_client()

    # Retry loop for transient errors
    last_exc: Optional[Exception] = None
    for attempt, delay in enumerate([0] + _RETRY_DELAYS):
        if delay:
            logger.warning("embed_text: retry %d/%d after %ds", attempt, len(_RETRY_DELAYS), delay)
            time.sleep(delay)
        try:
            result = client.models.embed_content(
                model=EMBED_MODEL,
                contents=text,
            )
            return list(result.embeddings[0].values)
        except Exception as exc:
            err_str = str(exc)
            # Retry on transient errors only
            if any(code in err_str for code in ["429", "500", "503"]):
                last_exc = exc
                continue
            # Non-transient — raise immediately
            raise

    raise RuntimeError(
        f"embed_text failed after {len(_RETRY_DELAYS) + 1} attempts: {last_exc}"
    )


# ---------------------------------------------------------------------------
# Story 4.04 — Build Embedding Text Optimizer
# ---------------------------------------------------------------------------

def build_embedding_text(chunk: Chunk) -> str:
    """
    Build optimised text for embedding.

    Format: [{platform}] [{heading_context}] {title}\\n{text}

    - Heading context is included when present.
    - Total output is capped at MAX_EMBEDDING_CHARS (8000 chars).
    - Excessive whitespace (multiple spaces/newlines) is normalised to single.
    """
    platform_tag = f"[{chunk.platform}]" if chunk.platform else ""
    heading_tag = f"[{chunk.heading_context}]" if chunk.heading_context and chunk.heading_context.strip() else ""

    prefix_parts = [p for p in [platform_tag, heading_tag] if p]
    if prefix_parts:
        prefix = " ".join(prefix_parts) + " "
    else:
        prefix = ""

    raw = f"{prefix}{chunk.title}\n{chunk.text}"

    # Normalise whitespace: collapse multiple spaces and condense 3+ newlines
    raw = re.sub(r"[ \t]+", " ", raw)
    raw = re.sub(r"\n{3,}", "\n\n", raw)

    # Truncate to safe embedding limit
    if len(raw) > MAX_EMBEDDING_CHARS:
        raw = raw[:MAX_EMBEDDING_CHARS]

    return raw


# ---------------------------------------------------------------------------
# Story 4.02 — Batch Embedder with Rate Limiting
# ---------------------------------------------------------------------------

def _embed_texts_batch(texts: List[str]) -> List[List[float]]:
    """Embed multiple texts in a single API call using Gemini's batch endpoint.

    Returns list of vectors (same order as input texts).
    Retries on transient errors (429, 500, 503).
    """
    client = _get_genai_client()
    last_exc: Optional[Exception] = None
    for attempt, delay in enumerate([0] + _RETRY_DELAYS):
        if delay:
            logger.warning("_embed_texts_batch: retry %d after %ds", attempt, delay)
            time.sleep(delay)
        try:
            result = client.models.embed_content(
                model=EMBED_MODEL,
                contents=texts,
            )
            return [list(e.values) for e in result.embeddings]
        except Exception as exc:
            err_str = str(exc)
            if any(code in err_str for code in ["429", "500", "503"]):
                last_exc = exc
                continue
            raise
    raise RuntimeError(f"_embed_texts_batch failed after retries: {last_exc}")


def embed_batch(
    chunks: List[Chunk],
    batch_size: int = 50,
    max_rpm: int = 1400,
) -> List[EmbeddedChunk]:
    """
    Embed multiple chunks with true API batching, rate limiting, and progress.

    - Sends up to batch_size texts per API call (massive speedup vs 1-at-a-time).
    - Rate-limits to max_rpm requests per minute.
    - Retries on transient errors (429, 500, 503) with exponential backoff.
    - Returns list of EmbeddedChunk (chunk + vector + model name).
    """
    embedded: List[EmbeddedChunk] = []
    total = len(chunks)
    request_timestamps: List[float] = []

    for batch_start in range(0, total, batch_size):
        batch_chunks = chunks[batch_start : batch_start + batch_size]

        # -- Rate limiting --
        now = time.monotonic()
        request_timestamps = [t for t in request_timestamps if now - t < 60.0]
        if len(request_timestamps) >= max_rpm:
            sleep_for = 60.0 - (now - request_timestamps[0]) + 0.1
            if sleep_for > 0:
                logger.info("Rate limit: sleeping %.1fs", sleep_for)
                time.sleep(sleep_for)

        # -- Build texts and embed as a batch --
        texts = [build_embedding_text(c) for c in batch_chunks]

        try:
            vectors = _embed_texts_batch(texts)
        except Exception as exc:
            # Fallback: try one at a time so partial success is possible
            logger.warning("Batch embed failed (%s), falling back to 1-at-a-time", exc)
            vectors = []
            for text in texts:
                try:
                    vectors.append(embed_text(text))
                except Exception as single_exc:
                    logger.warning("Single embed failed: %s", single_exc)
                    vectors.append(None)

        request_timestamps.append(time.monotonic())

        for chunk, vector in zip(batch_chunks, vectors):
            if vector is not None:
                embedded.append(
                    EmbeddedChunk(chunk=chunk, vector=vector, embedding_model=EMBED_MODEL)
                )

        done = min(batch_start + batch_size, total)
        logger.info("Embedded %d/%d chunks...", done, total)

    return embedded


# ---------------------------------------------------------------------------
# Story 4.03 — Embedding Cache (Redis)
# ---------------------------------------------------------------------------

_CACHE_TTL = 7 * 24 * 3600   # 7 days in seconds
_CACHE_PREFIX = "genesis:kb:embed:"


def embed_with_cache(
    text: str,
    cache_key: Optional[str] = None,
    redis_url: str = "redis://default:e2ZyYYr4oWRdASI2CaLc-@redis-genesis-u50607.vm.elestio.app:26379",
) -> List[float]:
    """
    Embed text, caching result in Redis to avoid re-embedding unchanged content.

    - Cache key = SHA-256 of text content (or provided cache_key).
    - Redis stores the embedding as a JSON list of floats.
    - 7-day TTL on each cache entry.
    - If Redis is unavailable, falls back to direct embed_text() (no crash).
    """
    import redis as redis_lib

    # Compute cache key
    if not cache_key:
        cache_key = hashlib.sha256(text.encode("utf-8")).hexdigest()
    full_key = f"{_CACHE_PREFIX}{cache_key}"

    redis_client = None
    try:
        redis_client = redis_lib.from_url(redis_url, socket_connect_timeout=2, socket_timeout=2)
        redis_client.ping()
    except Exception as exc:
        logger.warning("Redis unavailable (%s) — falling back to direct embed", exc)
        redis_client = None

    # Cache hit
    if redis_client is not None:
        try:
            cached = redis_client.get(full_key)
            if cached:
                logger.debug("Cache HIT for key %s", cache_key[:12])
                return json.loads(cached)
        except Exception as exc:
            logger.warning("Redis GET failed (%s) — skipping cache", exc)

    # Cache miss — call the real API
    vector = embed_text(text)

    # Store in Redis
    if redis_client is not None:
        try:
            redis_client.setex(full_key, _CACHE_TTL, json.dumps(vector))
            logger.debug("Cache SET for key %s (TTL=%ds)", cache_key[:12], _CACHE_TTL)
        except Exception as exc:
            logger.warning("Redis SET failed (%s) — continuing without cache", exc)

    return vector
