#!/usr/bin/env python3
"""
Qdrant Store — KB Ingestion Pipeline Module 5
==============================================
Upserts, deletes, searches, and reports stats for embedded KB chunks
in the genesis_memories Qdrant collection.

Stories implemented:
  5.01 — upsert_vectors (batch upsert, deterministic UUID5 IDs)
  5.02 — delete_platform (delete by platform, optionally customer-scoped)
  5.03 — search_platform (scoped semantic search)
  5.04 — get_platform_stats (collection + per-platform counts)
  5.05 — Integration test coverage in tests/kb/test_m5_qdrant_integration.py

Usage:
    from core.kb.qdrant_store import upsert_vectors, delete_platform, search_platform, get_platform_stats
"""

import os
import uuid
from typing import Optional

from qdrant_client import QdrantClient
from qdrant_client.models import (
    Distance,
    FieldCondition,
    Filter,
    MatchValue,
    PointStruct,
    VectorParams,
)

from core.kb.contracts import Chunk, EmbeddedChunk

# ──────────────────────────────────────────────────────────────────────────────
# Config
# ──────────────────────────────────────────────────────────────────────────────

QDRANT_URL = os.getenv(
    "QDRANT_URL",
    "https://qdrant-b3knu-u50607.vm.elestio.app:6333",
)
QDRANT_API_KEY = os.getenv(
    "QDRANT_API_KEY",
    "7b74e6621bd0e6650789f6662bca4cbf4143d3d1d710a0002b3b563973ca6876",
)
COLLECTION = "genesis_memories"

# Known platforms for stats enumeration (extend as new platforms are added)
_KNOWN_PLATFORMS = [
    "hubspot",
    "ghl",
    "xero",
    "telnyx",
    "stripe",
    "notion",
    "airtable",
    "zapier",
    "monday",
    "salesforce",
]

_BATCH_SIZE = 100  # points per upsert batch
_UUID5_NAMESPACE = uuid.UUID("12345678-1234-5678-1234-567812345678")

# ──────────────────────────────────────────────────────────────────────────────
# Story 5.01 — Client singleton
# ──────────────────────────────────────────────────────────────────────────────

_qdrant_client: Optional[QdrantClient] = None


def _get_client() -> QdrantClient:
    """Get or create Qdrant client singleton."""
    global _qdrant_client
    if _qdrant_client is None:
        _qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=60)
    return _qdrant_client


def _chunk_id_to_uuid(chunk_id: str) -> str:
    """Convert a chunk_id string to a deterministic UUID5 hex string."""
    return str(uuid.uuid5(_UUID5_NAMESPACE, chunk_id))


def _build_payload(ec: EmbeddedChunk) -> dict:
    """Build the Qdrant point payload from an EmbeddedChunk."""
    chunk = ec.chunk
    return {
        "platform": chunk.platform,
        "customer_id": chunk.customer_id,
        "title": chunk.title,
        "source_url": chunk.source_url,
        "text": chunk.text,
        "heading_context": chunk.heading_context,
        "chunk_index": chunk.chunk_index,
        "embedding_model": ec.embedding_model,
        "type": "PLATFORM_KB",
        # Extra convenience fields
        "chunk_id": chunk.chunk_id,
        "total_chunks": chunk.total_chunks,
    }


# ──────────────────────────────────────────────────────────────────────────────
# Story 5.01 — upsert_vectors
# ──────────────────────────────────────────────────────────────────────────────

def upsert_vectors(
    embedded_chunks: list[EmbeddedChunk],
    collection: str = COLLECTION,
) -> int:
    """
    Upsert embedded chunks into Qdrant.

    - Point IDs are deterministic UUID5 derived from chunk_id.
    - Payload includes all KB metadata fields.
    - Batch upsert in groups of _BATCH_SIZE for efficiency.

    Returns:
        Count of upserted points (0 for empty input).
    """
    if not embedded_chunks:
        return 0

    client = _get_client()

    # Build PointStructs
    points = [
        PointStruct(
            id=_chunk_id_to_uuid(ec.chunk.chunk_id),
            vector=ec.vector,
            payload=_build_payload(ec),
        )
        for ec in embedded_chunks
    ]

    # Batch upsert
    total_upserted = 0
    for i in range(0, len(points), _BATCH_SIZE):
        batch = points[i : i + _BATCH_SIZE]
        client.upsert(collection_name=collection, points=batch, wait=True)
        total_upserted += len(batch)

    return total_upserted


# ──────────────────────────────────────────────────────────────────────────────
# Story 5.02 — delete_platform
# ──────────────────────────────────────────────────────────────────────────────

def _build_platform_filter(
    platform: str,
    customer_id: Optional[str] = None,
) -> Filter:
    """Build a Qdrant filter for platform (and optionally customer_id)."""
    conditions = [
        FieldCondition(key="platform", match=MatchValue(value=platform))
    ]
    if customer_id is not None:
        conditions.append(
            FieldCondition(key="customer_id", match=MatchValue(value=customer_id))
        )
    return Filter(must=conditions)


def delete_platform(
    platform: str,
    customer_id: Optional[str] = None,
    collection: str = COLLECTION,
) -> int:
    """
    Delete all vectors for a platform (optionally scoped to a customer).

    Counts matching points before deletion to return the actual deleted count.

    Returns:
        Count of deleted points.
    """
    client = _get_client()
    flt = _build_platform_filter(platform, customer_id)

    # Count before deletion via scroll (count_only)
    count_result = client.count(collection_name=collection, count_filter=flt, exact=True)
    pre_count = count_result.count

    if pre_count == 0:
        return 0

    # Delete by filter
    client.delete(
        collection_name=collection,
        points_selector=flt,
        wait=True,
    )

    return pre_count


# ──────────────────────────────────────────────────────────────────────────────
# Story 5.03 — search_platform
# ──────────────────────────────────────────────────────────────────────────────

def search_platform(
    query_vector: list[float],
    platform: str,
    customer_id: Optional[str] = None,
    top_k: int = 5,
    score_threshold: float = 0.3,
    collection: str = COLLECTION,
) -> list[dict]:
    """
    Semantic search scoped to a specific platform's KB vectors.

    Optionally narrows scope to a specific customer_id for multi-tenant isolation.

    Returns:
        List of result dicts with keys:
            id, score, title, text, source_url, platform, heading_context
    """
    client = _get_client()
    flt = _build_platform_filter(platform, customer_id)

    results = client.query_points(
        collection_name=collection,
        query=query_vector,
        query_filter=flt,
        limit=top_k,
        score_threshold=score_threshold,
        with_payload=True,
    )

    hits = []
    for point in results.points:
        payload = point.payload or {}
        hits.append(
            {
                "id": str(point.id),
                "score": round(point.score, 4),
                "title": payload.get("title", ""),
                "text": payload.get("text", ""),
                "source_url": payload.get("source_url", ""),
                "platform": payload.get("platform", ""),
                "heading_context": payload.get("heading_context", ""),
            }
        )

    return hits


# ──────────────────────────────────────────────────────────────────────────────
# Story 5.04 — get_platform_stats
# ──────────────────────────────────────────────────────────────────────────────

def get_platform_stats(collection: str = COLLECTION) -> dict:
    """
    Return vector counts per known platform and overall collection stats.

    Returns:
        {
            "total": int,
            "collection": str,
            "dimension": 3072,
            "platforms": {"hubspot": N, "ghl": N, ...}
        }
    """
    client = _get_client()

    # Total count from collection info
    info = client.get_collection(collection_name=collection)
    total = info.points_count or 0

    # Per-platform counts
    platform_counts: dict[str, int] = {}
    for platform in _KNOWN_PLATFORMS:
        flt = Filter(
            must=[FieldCondition(key="platform", match=MatchValue(value=platform))]
        )
        result = client.count(
            collection_name=collection, count_filter=flt, exact=True
        )
        count = result.count
        if count > 0:
            platform_counts[platform] = count

    return {
        "total": total,
        "collection": collection,
        "dimension": 3072,
        "platforms": platform_counts,
    }


# ──────────────────────────────────────────────────────────────────────────────
# Helpers (for tests and external use)
# ──────────────────────────────────────────────────────────────────────────────

def create_test_collection(collection: str, client: Optional[QdrantClient] = None) -> None:
    """Create a test collection with 3072-dim cosine config."""
    c = client or _get_client()
    c.create_collection(
        collection_name=collection,
        vectors_config=VectorParams(size=3072, distance=Distance.COSINE),
    )


def drop_test_collection(collection: str, client: Optional[QdrantClient] = None) -> None:
    """Delete a test collection."""
    c = client or _get_client()
    c.delete_collection(collection_name=collection)


# VERIFICATION_STAMP
# Story: M5 — Qdrant Store (Stories 5.01–5.04)
# Verified By: parallel-builder (claude-sonnet-4-6)
# Verified At: 2026-02-26
# Tests: see tests/kb/test_m5_qdrant_integration.py
# Coverage: 100% of stories implemented
