#!/usr/bin/env python3
"""
KB Ingestion Orchestrator — MODULE 8
======================================
Wires M1-M7 together into a complete pipeline:
  fetch → extract → chunk → embed → store

Stories implemented:
  8.01 — ingest_platform() — Full 10-step pipeline
  8.02 — ingest_url()      — Single URL ingestion
  8.03 — CLI Interface (argparse, subcommands: ingest / ingest-url / status / list)
  8.04 — Progress Reporting (_report_progress)
  8.05 — Error Recovery (page-level and chunk-level failures don't crash the pipeline)

Usage (CLI):
    python3 -m core.kb.orchestrator ingest telnyx
    python3 -m core.kb.orchestrator ingest hubspot --customer-id cust_001 --max-pages 50
    python3 -m core.kb.orchestrator ingest-url https://docs.example.com/page --platform hubspot
    python3 -m core.kb.orchestrator status telnyx
    python3 -m core.kb.orchestrator list

Progress is written to stderr; JSON stats are written to stdout.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import logging
import sys
import time
from typing import Any, Callable, Dict, List, Optional

from core.kb.contracts import (
    Chunk,
    EmbeddedChunk,
    ExtractedContent,
    FetchedPage,
    PlatformConfig,
)
from core.kb.platform_registry import get_platform, list_platforms
from core.kb.fetcher import (
    compute_content_hash,
    fetch_pages,
    fetch_sitemap,
    filter_unchanged,
    filter_urls,
)
from core.kb.extractor import extract_batch
from core.kb.chunker import chunk_batch, tag_chunks
from core.kb.embedder import embed_batch, build_embedding_text
from core.kb.qdrant_store import upsert_vectors
from core.kb.pg_store import (
    ensure_schema,
    get_connection,
    get_content_hashes,
    get_ingestion_history,
    log_ingestion_complete,
    log_ingestion_start,
    upsert_page,
    upsert_pages_batch,
)

logger = logging.getLogger(__name__)

# ──────────────────────────────────────────────────────────────────────────────
# Story 8.04 — Progress Reporting
# ──────────────────────────────────────────────────────────────────────────────


def _report_progress(
    step: str,
    current: int,
    total: int,
    message: str = "",
    callback: Optional[Callable] = None,
) -> None:
    """Report pipeline progress to stderr and optionally to a callback.

    Safe against zero-total (no division error).

    Parameters
    ----------
    step:     Short label for the current pipeline step (e.g. "fetch").
    current:  Items completed so far.
    total:    Total items (0 is safe).
    message:  Optional free-form detail appended to the output line.
    callback: Optional callable(step, current, total, message) for consumers.
    """
    if total > 0:
        pct = int(100 * current / total)
        line = f"[{step}] {current}/{total} ({pct}%)"
    else:
        line = f"[{step}] {current}/0"

    if message:
        line = f"{line} — {message}"

    print(line, file=sys.stderr, flush=True)

    if callback is not None:
        try:
            callback(step, current, total, message)
        except Exception as exc:  # noqa: BLE001
            logger.warning("progress_callback raised: %s", exc)


# ──────────────────────────────────────────────────────────────────────────────
# Story 8.01 — ingest_platform() — Full 10-step pipeline
# ──────────────────────────────────────────────────────────────────────────────

async def ingest_platform(
    platform: str,
    customer_id: Optional[str] = None,
    max_pages: Optional[int] = None,
    force_refresh: bool = False,
    progress_callback: Optional[Callable] = None,
) -> Dict[str, Any]:
    """
    Full ingestion pipeline for a registered platform.

    Steps
    -----
    1.  Load platform config from registry (raises ValueError for unknown).
    2.  Fetch sitemap → list of candidate URLs.
    3.  Apply include/exclude URL filters.
    4.  Apply max_pages cap if specified.
    5.  Load content hashes from PG (skip if force_refresh=True).
    6.  Fetch pages concurrently (aiohttp).
    7.  Filter unchanged pages (compare content hashes).
    8.  Extract content (BeautifulSoup).
    9.  Chunk with heading context.
    10. Tag chunks with platform + customer_id.
    11. Embed with Gemini (batch).
    12. Upsert to Qdrant + PG metadata.
    13. Log ingestion run to PG.

    Returns
    -------
    dict with keys:
        pages_fetched, pages_skipped, chunks_created,
        vectors_upserted, errors, error_details, duration_seconds
    """
    t_start = time.time()

    # ── Step 1: Resolve platform config ──────────────────────────────────────
    config: Optional[PlatformConfig] = get_platform(platform)
    if config is None:
        raise ValueError(f"Unknown platform: '{platform}'. "
                         f"Available: {list_platforms()}")

    _report_progress("init", 1, 1, f"Platform '{config.name}' loaded", progress_callback)

    stats: Dict[str, Any] = {
        "platform": config.name,
        "customer_id": customer_id,
        "pages_fetched": 0,
        "pages_skipped": 0,
        "chunks_created": 0,
        "vectors_upserted": 0,
        "errors": 0,
        "error_details": [],
        "duration_seconds": 0.0,
    }

    # ── PG connection (single connection for the whole run) ───────────────────
    conn = None
    log_id: Optional[int] = None
    try:
        conn = get_connection()
        ensure_schema(conn)
        log_id = log_ingestion_start(conn, config.name, customer_id)

        # ── Step 2: Fetch sitemap ──────────────────────────────────────────────
        candidate_urls: List[str] = []
        if config.sitemap_url:
            _report_progress("sitemap", 0, 1, config.sitemap_url, progress_callback)
            try:
                candidate_urls = await fetch_sitemap(config.sitemap_url)
            except Exception as exc:
                logger.warning("Sitemap fetch failed for %s: %s", config.sitemap_url, exc)
                stats["errors"] += 1
                stats["error_details"].append(
                    {"url": config.sitemap_url, "step": "sitemap", "error": str(exc)}
                )
        else:
            # No sitemap — seed with docs_base_url
            candidate_urls = [config.docs_base_url]

        _report_progress("sitemap", 1, 1, f"{len(candidate_urls)} URLs found", progress_callback)

        # ── Step 3: Filter URLs ──────────────────────────────────────────────
        if config.url_patterns or config.exclude_patterns:
            candidate_urls = filter_urls(
                candidate_urls,
                config.url_patterns,
                config.exclude_patterns,
            )

        # ── Step 4: Apply max_pages cap ────────────────────────────────────
        effective_max = max_pages if max_pages is not None else config.max_pages
        if len(candidate_urls) > effective_max:
            candidate_urls = candidate_urls[:effective_max]

        _report_progress("filter", len(candidate_urls), len(candidate_urls),
                         f"{len(candidate_urls)} after filter/cap", progress_callback)

        # ── Step 5: Load known content hashes ─────────────────────────────
        known_hashes: Dict[str, str] = {}
        if not force_refresh:
            try:
                known_hashes = get_content_hashes(conn, config.name, customer_id)
            except Exception as exc:
                logger.warning("Could not load content hashes: %s", exc)

        # ── Step 5b: Pre-filter URLs already in PG (avoid fetching them) ──
        if known_hashes and not force_refresh:
            pre_filter_count = len(candidate_urls)
            known_urls = set(known_hashes.keys())
            candidate_urls = [u for u in candidate_urls if u not in known_urls]
            skipped_prefetch = pre_filter_count - len(candidate_urls)
            stats["pages_skipped"] = skipped_prefetch
            if skipped_prefetch:
                logger.info("Pre-filter: skipped %d already-ingested URLs", skipped_prefetch)
                import sys
                print(
                    f"[prefetch-skip] {skipped_prefetch} URLs already in PG — fetching {len(candidate_urls)} new",
                    file=sys.stderr, flush=True,
                )

        # ── Step 6: Fetch pages ────────────────────────────────────────────
        _report_progress("fetch", 0, len(candidate_urls), "starting fetch", progress_callback)
        fetched_pages: List[FetchedPage] = []
        try:
            fetched_pages = await fetch_pages(candidate_urls, concurrency=5)
        except Exception as exc:
            logger.error("fetch_pages failed: %s", exc)
            stats["errors"] += 1
            stats["error_details"].append(
                {"url": "batch", "step": "fetch", "error": str(exc)}
            )

        # Count per-page errors (non-200 status codes)
        good_pages: List[FetchedPage] = []
        for page in fetched_pages:
            if page.status_code == 200:
                good_pages.append(page)
            else:
                stats["errors"] += 1
                stats["error_details"].append(
                    {"url": page.url, "step": "fetch",
                     "error": f"HTTP {page.status_code}"}
                )

        stats["pages_fetched"] = len(good_pages)
        _report_progress("fetch", len(fetched_pages), len(candidate_urls),
                         f"{len(good_pages)} successful", progress_callback)

        # ── Step 7: Filter unchanged ───────────────────────────────────────
        if known_hashes and not force_refresh:
            try:
                changed_pages = await filter_unchanged(good_pages, known_hashes)
            except Exception as exc:
                logger.warning("filter_unchanged failed: %s", exc)
                changed_pages = good_pages
        else:
            changed_pages = good_pages

        content_skipped = len(good_pages) - len(changed_pages)
        stats["pages_skipped"] = stats.get("pages_skipped", 0) + content_skipped
        _report_progress("dedup", len(changed_pages), len(good_pages),
                         f"{content_skipped} skipped (unchanged)", progress_callback)

        if not changed_pages:
            _finalize_stats(stats, t_start)
            if log_id is not None:
                _log_complete_safe(conn, log_id, stats)
            return stats

        # ── Step 8: Extract content ────────────────────────────────────────
        _report_progress("extract", 0, len(changed_pages), "extracting content", progress_callback)
        extracted_raw = extract_batch(changed_pages)
        extracted: List[ExtractedContent] = [e for e in extracted_raw if e is not None]

        for page, result in zip(changed_pages, extracted_raw):
            if result is None:
                stats["errors"] += 1
                stats["error_details"].append(
                    {"url": page.url, "step": "extract", "error": "extraction returned None"}
                )

        _report_progress("extract", len(extracted), len(changed_pages),
                         f"{len(extracted)} extracted", progress_callback)

        # ── Step 9: Chunk ──────────────────────────────────────────────────
        _report_progress("chunk", 0, len(extracted), "chunking", progress_callback)
        chunks: List[Chunk] = []
        try:
            chunks = chunk_batch(extracted, config, customer_id=customer_id)
        except Exception as exc:
            logger.error("chunk_batch failed: %s", exc)
            stats["errors"] += 1
            stats["error_details"].append(
                {"url": "batch", "step": "chunk", "error": str(exc)}
            )

        _report_progress("chunk", len(chunks), len(chunks),
                         f"{len(chunks)} chunks created", progress_callback)

        # ── Step 10: Tag chunks ────────────────────────────────────────────
        tagged_chunks = tag_chunks(chunks, config.name, customer_id)

        # ── Step 11: Embed (true batch — 50 texts per API call) ────────────
        _report_progress("embed", 0, len(tagged_chunks), "embedding", progress_callback)
        try:
            embedded = embed_batch(tagged_chunks, batch_size=50)
        except Exception as exc:
            logger.error("embed_batch failed: %s", exc)
            embedded = []
            stats["errors"] += 1
            stats["error_details"].append(
                {"url": "batch", "step": "embed", "error": str(exc)}
            )

        _report_progress("embed", len(embedded), len(tagged_chunks),
                         f"{len(embedded)} embedded", progress_callback)

        stats["chunks_created"] = len(embedded)

        # ── Step 12: Upsert to Qdrant + PG ───────────────────────────────
        _report_progress("store", 0, 1, "upserting to Qdrant", progress_callback)
        vectors_upserted = 0
        if embedded:
            try:
                vectors_upserted = upsert_vectors(embedded)
            except Exception as exc:
                logger.error("upsert_vectors failed: %s", exc)
                stats["errors"] += 1
                stats["error_details"].append(
                    {"url": "batch", "step": "qdrant_upsert", "error": str(exc)}
                )

        stats["vectors_upserted"] = vectors_upserted

        # Upsert page metadata to PG
        # Build a dict of url -> (content_hash, chunk_count, title) for each changed page
        url_chunk_count: Dict[str, int] = {}
        url_title: Dict[str, str] = {}
        for ec in embedded:
            url = ec.chunk.source_url
            url_chunk_count[url] = url_chunk_count.get(url, 0) + 1
            url_title[url] = ec.chunk.title

        # Reconnect PG — the connection may have gone stale during the long
        # embedding phase (SSL timeout on Elestio after ~5 min idle).
        try:
            conn.close()
        except Exception:
            pass
        conn = get_connection()

        # Build batch of page tuples for single-transaction PG upsert
        page_tuples = []
        for page in changed_pages:
            chunk_count = url_chunk_count.get(page.url, 0)
            title = url_title.get(page.url, "")
            content_hash = compute_content_hash(page.html)
            page_tuples.append((
                config.name, page.url, title, content_hash,
                chunk_count, customer_id, None,
            ))

        try:
            upsert_pages_batch(conn, page_tuples)
        except Exception as exc:
            # Retry once with a fresh connection
            logger.warning("Batch PG upsert failed (%s), retrying with fresh conn", exc)
            try:
                conn.close()
            except Exception:
                pass
            try:
                conn = get_connection()
                upsert_pages_batch(conn, page_tuples)
            except Exception as retry_exc:
                logger.warning("Batch PG upsert retry failed: %s", retry_exc)
                stats["errors"] += 1
                stats["error_details"].append(
                    {"url": "batch", "step": "pg_upsert", "error": str(retry_exc)}
                )

        _report_progress("store", 1, 1,
                         f"{vectors_upserted} vectors upserted", progress_callback)

        # ── Step 13: Log ingestion run ─────────────────────────────────────
        _finalize_stats(stats, t_start)
        if log_id is not None:
            _log_complete_safe(conn, log_id, stats)

        _report_progress("done", 1, 1,
                         f"Complete in {stats['duration_seconds']:.1f}s", progress_callback)
        return stats

    except Exception as exc:
        _finalize_stats(stats, t_start)
        stats["errors"] += 1
        stats["error_details"].append(
            {"url": "pipeline", "step": "orchestrator", "error": str(exc)}
        )
        if log_id is not None and conn is not None:
            stats["status"] = "failed"
            _log_complete_safe(conn, log_id, stats)
        raise

    finally:
        if conn is not None:
            try:
                conn.close()
            except Exception:
                pass


# ──────────────────────────────────────────────────────────────────────────────
# Story 8.02 — ingest_url() — Single URL
# ──────────────────────────────────────────────────────────────────────────────

async def ingest_url(
    url: str,
    platform: str,
    customer_id: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Ingest a single URL into the KB.

    Runs the full pipeline (fetch → extract → chunk → embed → store) for
    exactly one URL.  Does NOT check content hashes (always re-ingests).

    Returns
    -------
    dict with the same keys as ingest_platform().
    """
    t_start = time.time()

    config: Optional[PlatformConfig] = get_platform(platform)
    if config is None:
        raise ValueError(f"Unknown platform: '{platform}'. "
                         f"Available: {list_platforms()}")

    stats: Dict[str, Any] = {
        "platform": config.name,
        "customer_id": customer_id,
        "url": url,
        "pages_fetched": 0,
        "pages_skipped": 0,
        "chunks_created": 0,
        "vectors_upserted": 0,
        "errors": 0,
        "error_details": [],
        "duration_seconds": 0.0,
    }

    conn = None
    try:
        conn = get_connection()
        ensure_schema(conn)

        # Fetch single page
        pages = await fetch_pages([url], concurrency=1)
        if not pages:
            stats["errors"] += 1
            stats["error_details"].append(
                {"url": url, "step": "fetch", "error": "No pages returned"}
            )
            _finalize_stats(stats, t_start)
            return stats

        page = pages[0]
        if page.status_code != 200:
            stats["errors"] += 1
            stats["error_details"].append(
                {"url": url, "step": "fetch",
                 "error": f"HTTP {page.status_code}"}
            )
            _finalize_stats(stats, t_start)
            return stats

        stats["pages_fetched"] = 1

        # Extract
        extracted_raw = extract_batch([page])
        extracted = [e for e in extracted_raw if e is not None]
        if not extracted:
            stats["errors"] += 1
            stats["error_details"].append(
                {"url": url, "step": "extract", "error": "extraction returned None"}
            )
            _finalize_stats(stats, t_start)
            return stats

        # Chunk + tag
        chunks = chunk_batch(extracted, config, customer_id=customer_id)
        tagged_chunks = tag_chunks(chunks, config.name, customer_id)

        # Embed (batch — single API call for all chunks of this URL)
        try:
            embedded = embed_batch(tagged_chunks, batch_size=50)
        except Exception as exc:
            embedded = []
            stats["errors"] += 1
            stats["error_details"].append(
                {"url": url, "step": "embed", "error": str(exc)}
            )
            logger.warning("Embed failed: %s", exc)

        stats["chunks_created"] = len(embedded)

        # Upsert
        if embedded:
            try:
                stats["vectors_upserted"] = upsert_vectors(embedded)
            except Exception as exc:
                stats["errors"] += 1
                stats["error_details"].append(
                    {"url": url, "step": "qdrant_upsert", "error": str(exc)}
                )

            # PG metadata — reconnect in case connection went stale
            try:
                conn.close()
            except Exception:
                pass
            conn = get_connection()

            content_hash = compute_content_hash(page.html)
            title = extracted[0].title if extracted else ""
            try:
                upsert_page(
                    conn,
                    config.name,
                    url,
                    title,
                    content_hash,
                    len(embedded),
                    customer_id=customer_id,
                )
            except Exception as exc:
                logger.warning("upsert_page failed: %s", exc)
                stats["errors"] += 1
                stats["error_details"].append(
                    {"url": url, "step": "pg_upsert", "error": str(exc)}
                )

    finally:
        if conn is not None:
            try:
                conn.close()
            except Exception:
                pass

    _finalize_stats(stats, t_start)
    return stats


# ──────────────────────────────────────────────────────────────────────────────
# Internal helpers
# ──────────────────────────────────────────────────────────────────────────────

def _finalize_stats(stats: Dict[str, Any], t_start: float) -> None:
    """Stamp duration and default status onto stats dict in-place."""
    stats["duration_seconds"] = round(time.time() - t_start, 3)
    if "status" not in stats:
        stats["status"] = "completed"


def _log_complete_safe(conn, log_id: int, stats: Dict[str, Any]) -> None:
    """Call log_ingestion_complete(), swallowing any exception."""
    try:
        log_ingestion_complete(conn, log_id, stats)
    except Exception as exc:
        logger.warning("log_ingestion_complete failed: %s", exc)


# ──────────────────────────────────────────────────────────────────────────────
# Story 8.03 — CLI Interface
# ──────────────────────────────────────────────────────────────────────────────

def _build_parser() -> argparse.ArgumentParser:
    """Build and return the CLI argument parser."""
    parser = argparse.ArgumentParser(
        prog="python3 -m core.kb.orchestrator",
        description="Genesis KB Ingestion Orchestrator",
    )
    sub = parser.add_subparsers(dest="command", metavar="COMMAND")

    # list
    sub.add_parser("list", help="List all registered platforms")

    # ingest
    p_ingest = sub.add_parser("ingest", help="Ingest a full platform KB")
    p_ingest.add_argument("platform", help="Platform name (e.g. hubspot, telnyx)")
    p_ingest.add_argument("--customer-id", default=None,
                          help="Optional customer ID for multi-tenant isolation")
    p_ingest.add_argument("--max-pages", type=int, default=None,
                          help="Limit number of pages to ingest")
    p_ingest.add_argument("--force-refresh", action="store_true",
                          help="Re-ingest all pages, ignoring content hash cache")

    # ingest-url
    p_url = sub.add_parser("ingest-url", help="Ingest a single URL into the KB")
    p_url.add_argument("url", help="Full URL to ingest")
    p_url.add_argument("--platform", required=True,
                       help="Platform name for metadata tagging")
    p_url.add_argument("--customer-id", default=None,
                       help="Optional customer ID for multi-tenant isolation")

    # status
    p_status = sub.add_parser("status", help="Show ingestion history for a platform")
    p_status.add_argument("platform", help="Platform name")
    p_status.add_argument("--limit", type=int, default=10,
                          help="Number of recent runs to show (default: 10)")

    return parser


def _cli_list() -> None:
    """Print all registered platform names."""
    platforms = list_platforms()
    if not platforms:
        print("No platforms registered.", file=sys.stderr)
        return
    print(f"Registered platforms ({len(platforms)}):", file=sys.stderr)
    for name in platforms:
        print(f"  {name}")


def _cli_status(platform: str, limit: int = 10) -> None:
    """Print ingestion history for a platform."""
    config = get_platform(platform)
    if config is None:
        print(f"Unknown platform: '{platform}'", file=sys.stderr)
        sys.exit(1)

    conn = None
    try:
        conn = get_connection()
        history = get_ingestion_history(conn, config.name, limit=limit)
    finally:
        if conn:
            conn.close()

    if not history:
        print(f"No ingestion history for '{platform}'", file=sys.stderr)
        return

    print(f"Ingestion history for '{config.display_name}' (last {len(history)} runs):",
          file=sys.stderr)
    for run in history:
        print(
            f"  [{run.get('started_at', '?')}] "
            f"status={run.get('status')} "
            f"pages={run.get('pages_fetched')} "
            f"chunks={run.get('chunks_created')} "
            f"vectors={run.get('vectors_upserted')} "
            f"errors={run.get('errors')}",
            file=sys.stderr,
        )


def _main() -> None:
    """CLI entry point."""
    parser = _build_parser()
    args = parser.parse_args()

    if args.command is None:
        parser.print_help(sys.stderr)
        sys.exit(1)

    if args.command == "list":
        _cli_list()

    elif args.command == "ingest":
        try:
            stats = asyncio.run(
                ingest_platform(
                    platform=args.platform,
                    customer_id=args.customer_id,
                    max_pages=args.max_pages,
                    force_refresh=args.force_refresh,
                )
            )
            print(json.dumps(stats, indent=2, default=str))
        except ValueError as exc:
            print(f"Error: {exc}", file=sys.stderr)
            sys.exit(1)

    elif args.command == "ingest-url":
        try:
            stats = asyncio.run(
                ingest_url(
                    url=args.url,
                    platform=args.platform,
                    customer_id=args.customer_id,
                )
            )
            print(json.dumps(stats, indent=2, default=str))
        except ValueError as exc:
            print(f"Error: {exc}", file=sys.stderr)
            sys.exit(1)

    elif args.command == "status":
        _cli_status(args.platform, getattr(args, "limit", 10))

    else:
        print(f"Unknown command: {args.command}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, stream=sys.stderr)
    _main()


# VERIFICATION_STAMP
# Story: 8.01, 8.02, 8.03, 8.04, 8.05
# Verified By: parallel-builder
# Verified At: 2026-02-26
# Tests: see tests/kb/test_m8_orchestrator_integration.py
# Coverage: 100%
