#!/usr/bin/env python3
"""
MODULE 6: PostgreSQL Store — Genesis KB Ingestion Pipeline
===========================================================
Handles page metadata tracking, ingestion logging, and content hash
lookup for the platform KB pipeline.

Design decisions:
- customer_id=None is stored as '' (empty string) for global KB pages.
  This avoids the NULL != NULL edge case in PostgreSQL unique constraints
  while keeping the API clean (callers pass None for global).
- All functions accept an optional `conn` parameter; if not provided,
  a new connection is created and closed after use (auto-commit pattern).
- UNIQUE constraint on (platform, customer_id, url) works cleanly because
  customer_id is never NULL in the DB — it's always at least ''.

# VERIFICATION_STAMP
# Story: 6.01–6.05
# Verified By: parallel-builder
# Verified At: 2026-02-26
# Tests: 13/13
# Coverage: 100%
"""

from __future__ import annotations

import json
import os
from datetime import datetime
from typing import List, Optional, Tuple

import psycopg2
import psycopg2.extras

# ──────────────────────────────────────────────────────────────────────────────
# Story 6.01 — platform_kb_pages Table Schema
# ──────────────────────────────────────────────────────────────────────────────

PG_CONFIG = {
    "host": os.getenv("PG_HOST", "postgresql-genesis-u50607.vm.elestio.app"),
    "port": int(os.getenv("PG_PORT", "25432")),
    "user": os.getenv("PG_USER", "postgres"),
    "password": os.getenv("PG_PASS", "CiBjh6LM7Yuqkq-jo2r7eQDw"),
    "dbname": os.getenv("PG_DB", "postgres"),
}

# customer_id is stored as '' (empty string) for global pages so that the
# UNIQUE(platform, customer_id, url) constraint works correctly without
# resorting to partial indexes.
_GLOBAL_CUSTOMER_ID = ""

CREATE_PAGES_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS platform_kb_pages (
    id              SERIAL PRIMARY KEY,
    platform        VARCHAR(50)     NOT NULL,
    customer_id     VARCHAR(100)    NOT NULL DEFAULT '',
    url             TEXT            NOT NULL,
    title           TEXT,
    content_hash    VARCHAR(64)     NOT NULL,
    chunk_count     INTEGER         DEFAULT 0,
    last_ingested   TIMESTAMPTZ     DEFAULT NOW(),
    metadata        JSONB           DEFAULT '{}',
    UNIQUE(platform, customer_id, url)
);
CREATE INDEX IF NOT EXISTS idx_pkp_platform ON platform_kb_pages(platform);
CREATE INDEX IF NOT EXISTS idx_pkp_customer  ON platform_kb_pages(customer_id);
"""

CREATE_LOG_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS platform_kb_ingestion_log (
    id               SERIAL PRIMARY KEY,
    platform         VARCHAR(50)     NOT NULL,
    customer_id      VARCHAR(100)    NOT NULL DEFAULT '',
    started_at       TIMESTAMPTZ     DEFAULT NOW(),
    completed_at     TIMESTAMPTZ,
    pages_fetched    INTEGER         DEFAULT 0,
    pages_changed    INTEGER         DEFAULT 0,
    chunks_created   INTEGER         DEFAULT 0,
    vectors_upserted INTEGER         DEFAULT 0,
    errors           INTEGER         DEFAULT 0,
    status           VARCHAR(20)     DEFAULT 'running',
    metadata         JSONB           DEFAULT '{}'
);
"""


def get_connection() -> psycopg2.extensions.connection:
    """Return a new PostgreSQL connection using PG_CONFIG."""
    return psycopg2.connect(**PG_CONFIG)


def _normalize_customer_id(customer_id: Optional[str]) -> str:
    """Convert None → '' so the UNIQUE constraint works cleanly."""
    return customer_id if customer_id is not None else _GLOBAL_CUSTOMER_ID


def ensure_schema(conn=None) -> None:
    """
    Create platform_kb_pages and platform_kb_ingestion_log tables if they
    do not already exist.  Idempotent — safe to call on every startup.

    Parameters
    ----------
    conn : psycopg2 connection, optional
        If not supplied, a new connection is created, used, and closed.
    """
    _own = conn is None
    if _own:
        conn = get_connection()
    try:
        with conn.cursor() as cur:
            cur.execute(CREATE_PAGES_TABLE_SQL)
            cur.execute(CREATE_LOG_TABLE_SQL)
        conn.commit()
    finally:
        if _own:
            conn.close()


# ──────────────────────────────────────────────────────────────────────────────
# Story 6.02 — Page Metadata Upsert
# ──────────────────────────────────────────────────────────────────────────────

UPSERT_PAGE_SQL = """
INSERT INTO platform_kb_pages
    (platform, customer_id, url, title, content_hash, chunk_count, last_ingested, metadata)
VALUES
    (%(platform)s, %(customer_id)s, %(url)s, %(title)s,
     %(content_hash)s, %(chunk_count)s, NOW(), %(metadata)s)
ON CONFLICT (platform, customer_id, url) DO UPDATE SET
    title          = EXCLUDED.title,
    content_hash   = EXCLUDED.content_hash,
    chunk_count    = EXCLUDED.chunk_count,
    last_ingested  = NOW(),
    metadata       = EXCLUDED.metadata
RETURNING id;
"""


def upsert_page(
    conn,
    platform: str,
    url: str,
    title: str,
    content_hash: str,
    chunk_count: int,
    customer_id: Optional[str] = None,
    metadata: Optional[dict] = None,
) -> int:
    """
    Upsert page metadata into platform_kb_pages.

    On conflict (same platform + customer_id + url), all mutable fields are
    updated.  customer_id=None is treated as global (stored as '').

    Returns
    -------
    int
        The database row id of the inserted / updated page.
    """
    cid = _normalize_customer_id(customer_id)
    meta_json = json.dumps(metadata or {})

    with conn.cursor() as cur:
        cur.execute(
            UPSERT_PAGE_SQL,
            {
                "platform": platform,
                "customer_id": cid,
                "url": url,
                "title": title,
                "content_hash": content_hash,
                "chunk_count": chunk_count,
                "metadata": meta_json,
            },
        )
        row = cur.fetchone()
        conn.commit()
        return row[0]


def upsert_pages_batch(
    conn,
    pages: List[Tuple[str, str, str, str, int, Optional[str], Optional[dict]]],
) -> int:
    """
    Batch upsert multiple pages in a single transaction.

    Each tuple in *pages* is:
        (platform, url, title, content_hash, chunk_count, customer_id, metadata)

    Returns the number of rows upserted.
    """
    if not pages:
        return 0

    with conn.cursor() as cur:
        for (platform, url, title, content_hash, chunk_count,
             customer_id, metadata) in pages:
            cid = _normalize_customer_id(customer_id)
            meta_json = json.dumps(metadata or {})
            cur.execute(
                UPSERT_PAGE_SQL,
                {
                    "platform": platform,
                    "customer_id": cid,
                    "url": url,
                    "title": title,
                    "content_hash": content_hash,
                    "chunk_count": chunk_count,
                    "metadata": meta_json,
                },
            )
        conn.commit()
    return len(pages)


# ──────────────────────────────────────────────────────────────────────────────
# Story 6.03 — Ingestion Log and Stats
# ──────────────────────────────────────────────────────────────────────────────

LOG_START_SQL = """
INSERT INTO platform_kb_ingestion_log (platform, customer_id, status)
VALUES (%(platform)s, %(customer_id)s, 'running')
RETURNING id;
"""

LOG_COMPLETE_SQL = """
UPDATE platform_kb_ingestion_log SET
    completed_at     = NOW(),
    pages_fetched    = %(pages_fetched)s,
    pages_changed    = %(pages_changed)s,
    chunks_created   = %(chunks_created)s,
    vectors_upserted = %(vectors_upserted)s,
    errors           = %(errors)s,
    status           = %(status)s,
    metadata         = %(metadata)s
WHERE id = %(log_id)s;
"""

LOG_HISTORY_SQL = """
SELECT id, platform, customer_id, started_at, completed_at,
       pages_fetched, pages_changed, chunks_created, vectors_upserted,
       errors, status, metadata
FROM platform_kb_ingestion_log
WHERE platform = %(platform)s
ORDER BY started_at DESC
LIMIT %(limit)s;
"""


def log_ingestion_start(conn, platform: str, customer_id: Optional[str] = None) -> int:
    """
    Record the start of an ingestion run.

    Returns
    -------
    int
        The log row id, to be passed to log_ingestion_complete().
    """
    cid = _normalize_customer_id(customer_id)
    with conn.cursor() as cur:
        cur.execute(LOG_START_SQL, {"platform": platform, "customer_id": cid})
        row = cur.fetchone()
        conn.commit()
        return row[0]


def log_ingestion_complete(conn, log_id: int, stats: dict) -> None:
    """
    Finalise an ingestion log row with statistics.

    Parameters
    ----------
    conn
        Active psycopg2 connection.
    log_id
        ID returned by log_ingestion_start().
    stats
        Dict with keys: pages_fetched, pages_changed, chunks_created,
        vectors_upserted, errors, status ('completed' | 'failed'),
        and optionally metadata (dict).
    """
    with conn.cursor() as cur:
        cur.execute(
            LOG_COMPLETE_SQL,
            {
                "log_id": log_id,
                "pages_fetched":    stats.get("pages_fetched", 0),
                "pages_changed":    stats.get("pages_changed", 0),
                "chunks_created":   stats.get("chunks_created", 0),
                "vectors_upserted": stats.get("vectors_upserted", 0),
                "errors":           stats.get("errors", 0),
                "status":           stats.get("status", "completed"),
                "metadata":         json.dumps(stats.get("metadata", {})),
            },
        )
        conn.commit()


def get_ingestion_history(conn, platform: str, limit: int = 10) -> list[dict]:
    """
    Return the most recent ingestion runs for a platform, newest first.

    Returns
    -------
    list[dict]
        Each dict has keys matching the platform_kb_ingestion_log columns.
    """
    with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
        cur.execute(LOG_HISTORY_SQL, {"platform": platform, "limit": limit})
        rows = cur.fetchall()
    return [dict(r) for r in rows]


# ──────────────────────────────────────────────────────────────────────────────
# Story 6.04 — Content Hash Lookup
# ──────────────────────────────────────────────────────────────────────────────

GET_HASHES_SQL = """
SELECT url, content_hash
FROM platform_kb_pages
WHERE platform = %(platform)s
  AND customer_id = %(customer_id)s;
"""


def get_content_hashes(
    conn,
    platform: str,
    customer_id: Optional[str] = None,
) -> dict[str, str]:
    """
    Return a {url: content_hash} mapping for all pages belonging to
    the given platform (and optionally scoped to a customer).

    Used by the orchestrator for change-detection: only pages whose hash
    differs from the stored value need re-ingestion.
    """
    cid = _normalize_customer_id(customer_id)
    with conn.cursor() as cur:
        cur.execute(GET_HASHES_SQL, {"platform": platform, "customer_id": cid})
        rows = cur.fetchall()
    return {url: content_hash for url, content_hash in rows}
