#!/usr/bin/env python3
"""
Genesis Bloodstream Query Interface
=====================================
Query the RLM Bloodstream knowledge base from the command line.

Queries PostgreSQL (Elestio) bloodstream_knowledge table with full-text
search, type filtering, tag filtering, and source filtering.

Usage:
    python3 bloodstream_query.py --type PRODUCT_SPEC
    python3 bloodstream_query.py --search "radar audit scoring"
    python3 bloodstream_query.py --tag revenue
    python3 bloodstream_query.py --source DT3
    python3 bloodstream_query.py --type UNIT_ECONOMICS --tag pricing
    python3 bloodstream_query.py --stats
    python3 bloodstream_query.py --hot                         # Redis hot cache
    python3 bloodstream_query.py --top 20 --min-confidence 0.9
    python3 bloodstream_query.py --export results.jsonl --search "widget"
"""

import json
import os
import sys
import argparse
from datetime import datetime

# Add elestio config to path
sys.path.insert(0, "/mnt/e/genesis-system/data/genesis-memory")


def get_pg_connection():
    """Get PostgreSQL connection."""
    try:
        import psycopg2
        import psycopg2.extras
        from elestio_config import PostgresConfig
        conn = psycopg2.connect(**PostgresConfig.get_connection_params())
        return conn
    except ImportError:
        print("ERROR: psycopg2 not installed. Run: pip install psycopg2-binary")
        sys.exit(1)
    except Exception as e:
        print(f"ERROR: Cannot connect to PostgreSQL: {e}")
        sys.exit(1)


def get_redis_connection():
    """Get Redis connection."""
    try:
        import redis as redis_lib
        from elestio_config import RedisConfig
        r = redis_lib.Redis(**RedisConfig.get_connection_params())
        r.ping()
        return r
    except Exception as e:
        print(f"ERROR: Cannot connect to Redis: {e}")
        return None


def format_item(row: dict, verbose: bool = False) -> str:
    """Format a knowledge item for display."""
    confidence_bar = "#" * int(row.get("confidence", 0) * 10)
    confidence_bar = confidence_bar.ljust(10, ".")

    tags = row.get("tags", [])
    if isinstance(tags, str):
        # Handle case where tags come back as PG array string
        tags = tags.strip("{}").split(",") if tags.strip("{}") else []
    tag_str = ", ".join(tags[:5])

    lines = []
    lines.append(f"  [{row.get('type', 'UNKNOWN'):30s}] {row.get('title', 'Untitled')}")
    lines.append(f"  Source: {row.get('source', '?'):30s}  Confidence: [{confidence_bar}] {row.get('confidence', 0):.2f}")

    if verbose:
        content = row.get("content", "")
        # Wrap content at ~90 chars
        wrapped = []
        while content:
            if len(content) <= 90:
                wrapped.append(content)
                break
            # Find last space before 90
            idx = content.rfind(' ', 0, 90)
            if idx == -1:
                idx = 90
            wrapped.append(content[:idx])
            content = content[idx:].lstrip()
        for wline in wrapped:
            lines.append(f"  {wline}")

    if tag_str:
        lines.append(f"  Tags: {tag_str}")

    return "\n".join(lines)


def query_by_type(conn, knowledge_type: str, limit: int, min_conf: float) -> list[dict]:
    """Query items by type."""
    cur = conn.cursor()
    cur.execute(
        """SELECT id, source, type, title, content, tags, confidence, created_at
           FROM bloodstream_knowledge
           WHERE UPPER(type) = UPPER(%s) AND confidence >= %s
           ORDER BY confidence DESC
           LIMIT %s""",
        (knowledge_type, min_conf, limit),
    )
    columns = [desc[0] for desc in cur.description]
    results = [dict(zip(columns, row)) for row in cur.fetchall()]
    cur.close()
    return results


def query_by_search(conn, search_text: str, limit: int, min_conf: float) -> list[dict]:
    """Full-text search across title and content.

    Uses PostgreSQL FTS first, falls back to ILIKE if FTS returns nothing.
    """
    cur = conn.cursor()

    # Try full-text search first (best ranking)
    cur.execute(
        """SELECT id, source, type, title, content, tags, confidence, created_at,
                  ts_rank(to_tsvector('english', title || ' ' || content), plainto_tsquery('english', %s)) AS rank
           FROM bloodstream_knowledge
           WHERE to_tsvector('english', title || ' ' || content) @@ plainto_tsquery('english', %s)
                 AND confidence >= %s
           ORDER BY rank DESC, confidence DESC
           LIMIT %s""",
        (search_text, search_text, min_conf, limit),
    )
    columns = [desc[0] for desc in cur.description]
    results = [dict(zip(columns, row)) for row in cur.fetchall()]

    # Fallback to ILIKE if FTS found nothing (handles partial words, phrases)
    if not results:
        pattern = f"%{search_text}%"
        cur.execute(
            """SELECT id, source, type, title, content, tags, confidence, created_at,
                      0.0 AS rank
               FROM bloodstream_knowledge
               WHERE (LOWER(title) LIKE LOWER(%s) OR LOWER(content) LIKE LOWER(%s))
                     AND confidence >= %s
               ORDER BY confidence DESC
               LIMIT %s""",
            (pattern, pattern, min_conf, limit),
        )
        columns = [desc[0] for desc in cur.description]
        results = [dict(zip(columns, row)) for row in cur.fetchall()]

    cur.close()
    return results


def query_by_tag(conn, tag: str, limit: int, min_conf: float) -> list[dict]:
    """Query items containing a specific tag."""
    cur = conn.cursor()
    cur.execute(
        """SELECT id, source, type, title, content, tags, confidence, created_at
           FROM bloodstream_knowledge
           WHERE %s = ANY(tags) AND confidence >= %s
           ORDER BY confidence DESC
           LIMIT %s""",
        (tag.lower(), min_conf, limit),
    )
    columns = [desc[0] for desc in cur.description]
    results = [dict(zip(columns, row)) for row in cur.fetchall()]
    cur.close()
    return results


def query_by_source(conn, source: str, limit: int, min_conf: float) -> list[dict]:
    """Query items from a specific source (partial match)."""
    cur = conn.cursor()
    cur.execute(
        """SELECT id, source, type, title, content, tags, confidence, created_at
           FROM bloodstream_knowledge
           WHERE LOWER(source) LIKE LOWER(%s) AND confidence >= %s
           ORDER BY confidence DESC
           LIMIT %s""",
        (f"%{source}%", min_conf, limit),
    )
    columns = [desc[0] for desc in cur.description]
    results = [dict(zip(columns, row)) for row in cur.fetchall()]
    cur.close()
    return results


def show_stats(conn):
    """Show comprehensive statistics about the bloodstream."""
    cur = conn.cursor()

    # Total count
    cur.execute("SELECT COUNT(*) FROM bloodstream_knowledge")
    total = cur.fetchone()[0]

    # By type
    cur.execute(
        """SELECT type, COUNT(*), AVG(confidence)::numeric(4,2)
           FROM bloodstream_knowledge
           GROUP BY type ORDER BY COUNT(*) DESC"""
    )
    type_rows = cur.fetchall()

    # By source (grouped)
    cur.execute(
        """SELECT
             CASE
               WHEN source LIKE 'session_%' THEN 'Session Axioms'
               WHEN source LIKE 'DT%' THEN 'Deep Think Reports'
               ELSE source
             END AS source_group,
             COUNT(*)
           FROM bloodstream_knowledge
           GROUP BY source_group ORDER BY COUNT(*) DESC"""
    )
    source_rows = cur.fetchall()

    # Top tags
    cur.execute(
        """SELECT unnest(tags) AS tag, COUNT(*) AS cnt
           FROM bloodstream_knowledge
           GROUP BY tag ORDER BY cnt DESC LIMIT 20"""
    )
    tag_rows = cur.fetchall()

    # Confidence distribution
    cur.execute(
        """SELECT
             CASE
               WHEN confidence >= 0.9 THEN '0.9+'
               WHEN confidence >= 0.7 THEN '0.7-0.9'
               WHEN confidence >= 0.5 THEN '0.5-0.7'
               ELSE '<0.5'
             END AS bucket,
             COUNT(*)
           FROM bloodstream_knowledge
           GROUP BY bucket ORDER BY bucket DESC"""
    )
    conf_rows = cur.fetchall()

    # Latest load
    cur.execute(
        "SELECT MAX(created_at) FROM bloodstream_knowledge"
    )
    latest = cur.fetchone()[0]

    cur.close()

    print(f"\n{'=' * 60}")
    print(f"  BLOODSTREAM KNOWLEDGE BASE STATISTICS")
    print(f"{'=' * 60}")
    print(f"  Total items  : {total}")
    print(f"  Last loaded  : {latest}")

    print(f"\n  By Type:")
    print(f"  {'Type':35s} {'Count':>6s} {'Avg Conf':>9s}")
    print(f"  {'-'*35} {'-'*6} {'-'*9}")
    for t, c, avg_c in type_rows:
        print(f"  {t:35s} {c:6d} {float(avg_c):9.2f}")

    print(f"\n  By Source Group:")
    for s, c in source_rows:
        print(f"  {s:35s} : {c:6d}")

    print(f"\n  By Confidence:")
    for bucket, c in conf_rows:
        print(f"  {bucket:15s} : {c:6d}")

    print(f"\n  Top 20 Tags:")
    for tag, c in tag_rows:
        print(f"  {tag:20s} : {c:6d}")

    print(f"{'=' * 60}")


def show_hot_cache():
    """Show items from the Redis hot cache."""
    r = get_redis_connection()
    if not r:
        return

    # Get stats
    stats = r.hgetall("bloodstream:stats")
    if stats:
        print(f"\n  Bloodstream Stats (Redis):")
        for k, v in stats.items():
            print(f"    {k}: {v}")

    # Get hot items
    keys = r.lrange("bloodstream:hot_index", 0, -1)
    if not keys:
        print("  No hot items cached in Redis.")
        return

    print(f"\n  Hot Cache ({len(keys)} items):")
    print(f"  {'-' * 60}")
    for key in keys:
        item = r.hgetall(key)
        if item:
            print(f"  [{item.get('type', '?'):30s}] {item.get('title', '?')}")
            print(f"    Source: {item.get('source', '?')}  Confidence: {item.get('confidence', '?')}")


def export_results(results: list[dict], path: str):
    """Export results to JSONL file."""
    with open(path, 'w', encoding='utf-8') as f:
        for item in results:
            # Convert datetime objects to strings
            clean = {}
            for k, v in item.items():
                if hasattr(v, 'isoformat'):
                    clean[k] = v.isoformat()
                elif isinstance(v, list):
                    clean[k] = v
                else:
                    clean[k] = v
            f.write(json.dumps(clean, ensure_ascii=False, default=str) + '\n')
    print(f"\n  Exported {len(results)} items to {path}")


def main():
    parser = argparse.ArgumentParser(
        description="Genesis Bloodstream Query Interface",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s --type PRODUCT_SPEC
  %(prog)s --search "radar audit scoring"
  %(prog)s --tag revenue
  %(prog)s --source DT3
  %(prog)s --type UNIT_ECONOMICS --tag pricing
  %(prog)s --stats
  %(prog)s --hot
  %(prog)s --top 20 --min-confidence 0.9
  %(prog)s --export results.jsonl --search "widget"
        """,
    )

    # Query filters (can combine)
    parser.add_argument("--type", help="Filter by knowledge type (e.g., PRODUCT_SPEC)")
    parser.add_argument("--search", help="Full-text search across title and content")
    parser.add_argument("--tag", help="Filter by tag")
    parser.add_argument("--source", help="Filter by source (partial match, e.g., 'DT3')")

    # Output options
    parser.add_argument("--top", type=int, default=25, help="Max results (default: 25)")
    parser.add_argument("--min-confidence", type=float, default=0.0, help="Min confidence threshold (default: 0.0)")
    parser.add_argument("-v", "--verbose", action="store_true", help="Show full content")
    parser.add_argument("--export", help="Export results to JSONL file")

    # Special modes
    parser.add_argument("--stats", action="store_true", help="Show comprehensive statistics")
    parser.add_argument("--hot", action="store_true", help="Show Redis hot cache")
    parser.add_argument("--json", action="store_true", help="Output raw JSON instead of formatted text")

    args = parser.parse_args()

    # Handle special modes first
    if args.hot:
        show_hot_cache()
        return

    # Everything else needs PG
    conn = get_pg_connection()

    if args.stats:
        show_stats(conn)
        conn.close()
        return

    # Need at least one query filter
    if not any([args.type, args.search, args.tag, args.source]):
        parser.print_help()
        print("\nERROR: Provide at least one query filter (--type, --search, --tag, --source) or use --stats/--hot.")
        conn.close()
        sys.exit(1)

    # Execute queries — if multiple filters given, use the most specific one
    # Priority: search > type > tag > source (search uses FTS ranking)
    results = []

    if args.search:
        results = query_by_search(conn, args.search, args.top, args.min_confidence)
    elif args.type:
        results = query_by_type(conn, args.type, args.top, args.min_confidence)
    elif args.tag:
        results = query_by_tag(conn, args.tag, args.top, args.min_confidence)
    elif args.source:
        results = query_by_source(conn, args.source, args.top, args.min_confidence)

    conn.close()

    # Display results
    if not results:
        print("\n  No results found.")
        if args.search:
            print(f"  Search: '{args.search}'")
        if args.type:
            print(f"  Type: '{args.type}'")
        if args.tag:
            print(f"  Tag: '{args.tag}'")
        if args.source:
            print(f"  Source: '{args.source}'")
        return

    if args.json:
        for item in results:
            print(json.dumps(item, ensure_ascii=False, default=str))
        return

    # Pretty print
    query_desc = []
    if args.type:
        query_desc.append(f"type={args.type}")
    if args.search:
        query_desc.append(f"search='{args.search}'")
    if args.tag:
        query_desc.append(f"tag={args.tag}")
    if args.source:
        query_desc.append(f"source={args.source}")

    print(f"\n{'=' * 60}")
    print(f"  BLOODSTREAM QUERY: {', '.join(query_desc)}")
    print(f"  Results: {len(results)} (limit {args.top}, min confidence {args.min_confidence})")
    print(f"{'=' * 60}")

    for i, row in enumerate(results, 1):
        print(f"\n  --- {i} ---")
        print(format_item(row, verbose=args.verbose))

    print(f"\n{'=' * 60}")

    # Export if requested
    if args.export:
        export_results(results, args.export)


if __name__ == "__main__":
    main()
