#!/usr/bin/env python3
"""
embed_repair.py — embed all bloodstream_knowledge rows where embedding_id IS NULL.
Uses gemini-embedding-001 (3072d) → Qdrant genesis_memories collection.

Usage:
    python3 core/embed_repair.py [--limit N] [--batch-size N]
"""
import os
import sys
import uuid
import time
import argparse
import logging
from pathlib import Path

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger("embed_repair")

GENESIS_ROOT = Path("/mnt/e/genesis-system")
GENESIS_NS = uuid.UUID("e2a6d7f8-b3c4-4d5e-8f90-a1b2c3d4e5f6")
COLLECTION = "genesis_memories"
VECTOR_DIM = 3072


def load_api_key() -> str:
    for var in ("GEMINI_API_KEY_NEW", "GEMINI_API_KEY", "GOOGLE_API_KEY"):
        v = os.environ.get(var, "")
        if v and v != "YOUR_GEMINI_API_KEY":
            return v
    secrets = GENESIS_ROOT / "config" / "secrets.env"
    if secrets.exists():
        for line in secrets.read_text().splitlines():
            for prefix in ("GEMINI_API_KEY_NEW=", "GEMINI_API_KEY="):
                if line.startswith(prefix):
                    val = line.split("=", 1)[1].strip().strip('"').strip("'")
                    if val and val != "YOUR_GEMINI_API_KEY":
                        return val
    raise RuntimeError("No Gemini API key found")


def get_pg_conn():
    import psycopg2
    conn = psycopg2.connect(
        host="postgresql-genesis-u50607.vm.elestio.app",
        port=25432,
        user="postgres",
        password="CiBjh6LM7Yuqkq-jo2r7eQDw",
        dbname="postgres",
        sslmode="require",
    )
    conn.autocommit = True
    return conn


def get_qdrant():
    from qdrant_client import QdrantClient
    return QdrantClient(
        url="https://qdrant-b3knu-u50607.vm.elestio.app:6333",
        api_key="7b74e6621bd0e6650789f6662bca4cbf4143d3d1d710a0002b3b563973ca6876",
        timeout=30,
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--limit", type=int, default=0, help="Max rows to embed (0=all)")
    parser.add_argument("--batch-size", type=int, default=50, help="Rows per Qdrant upsert batch")
    args = parser.parse_args()

    api_key = load_api_key()
    from google import genai
    gclient = genai.Client(api_key=api_key)
    log.info("Gemini client ready (gemini-embedding-001, 3072d)")

    pg = get_pg_conn()
    qdrant = get_qdrant()
    log.info("PG + Qdrant connected")

    cur = pg.cursor()
    limit_clause = f"LIMIT {args.limit}" if args.limit else ""
    cur.execute(f"""
        SELECT id, embedding_text, source, type, title
        FROM bloodstream_knowledge
        WHERE embedding_id IS NULL AND embedding_text IS NOT NULL AND embedding_text != ''
        ORDER BY id
        {limit_clause}
    """)
    rows = cur.fetchall()
    log.info("Found %d rows with embedding_id IS NULL", len(rows))

    if not rows:
        log.info("Nothing to embed. Done.")
        return

    embedded = 0
    failed = 0
    batch_points = []

    for i, (db_id, emb_text, source, etype, title) in enumerate(rows):
        try:
            result = gclient.models.embed_content(model="gemini-embedding-001", contents=emb_text)
            vector = list(result.embeddings[0].values)
            point_id = str(uuid.uuid5(GENESIS_NS, emb_text))

            from qdrant_client.models import PointStruct
            batch_points.append(PointStruct(
                id=point_id,
                vector=vector,
                payload={"source": source or "", "type": etype or "", "title": title or "", "db_id": db_id},
            ))

            cur.execute(
                "UPDATE bloodstream_knowledge SET embedding_id = %s WHERE id = %s",
                (point_id, db_id),
            )

            if len(batch_points) >= args.batch_size:
                qdrant.upsert(collection_name=COLLECTION, points=batch_points)
                embedded += len(batch_points)
                batch_points = []
                log.info("  Embedded %d/%d (%.1f%%)", embedded, len(rows), 100 * embedded / len(rows))

        except Exception as exc:
            failed += 1
            log.warning("  Failed row %d (%s): %s", db_id, title[:40] if title else "?", exc)
            time.sleep(2)  # back off on rate limit

    if batch_points:
        qdrant.upsert(collection_name=COLLECTION, points=batch_points)
        embedded += len(batch_points)

    log.info("Done. Embedded: %d, Failed: %d, Total: %d", embedded, failed, len(rows))
    info = qdrant.get_collection(COLLECTION)
    log.info("Qdrant genesis_memories: %d total points", info.points_count)


if __name__ == "__main__":
    main()
