#!/usr/bin/env python3
"""
Genesis Multi-Agent Observability Query Tool (v3 - Correlation Tracing)
========================================================================
Query the observability event stream for insights into agent behavior,
lifecycle events, errors, session patterns, and cross-agent traces.

Usage:
    python3 tools/observability_query.py --summary       # Session summary
    python3 tools/observability_query.py --recent N      # Last N events
    python3 tools/observability_query.py --by-tool       # Breakdown by tool
    python3 tools/observability_query.py --by-category   # Breakdown by category
    python3 tools/observability_query.py --by-agent      # Breakdown by agent
    python3 tools/observability_query.py --orchestration # Only orchestration events
    python3 tools/observability_query.py --lifecycle     # Agent spawn/stop lifecycle
    python3 tools/observability_query.py --errors        # Failed tool calls only
    python3 tools/observability_query.py --sessions      # Session start/stop events
    python3 tools/observability_query.py --timeline      # Chronological timeline
    python3 tools/observability_query.py --metrics       # Current metrics
    python3 tools/observability_query.py --trace ID      # Trace correlation ID across agents

Source: Alpha Evolve Cycles 3-7 - Observability Pillar + Correlation Tracing
Axiom MAO-005: "You can't improve what you can't see"
"""

import json
import sys
import argparse
from datetime import datetime, timezone
from pathlib import Path
from collections import Counter, defaultdict

EVENTS_FILE = Path("/mnt/e/genesis-system/data/observability/events.jsonl")
SESSION_FILE_TOP = Path("/mnt/e/genesis-system/data/observability/sessions.jsonl")
METRICS_FILE = Path("/mnt/e/genesis-system/data/observability/metrics.json")


def load_events() -> list:
    """Load all events from JSONL log."""
    events = []
    if not EVENTS_FILE.exists():
        return events
    with open(EVENTS_FILE, "r") as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    events.append(json.loads(line))
                except json.JSONDecodeError:
                    continue
    return events


def cmd_summary(events: list):
    """Print session summary."""
    if not events:
        print("No events recorded yet. The observability hook is active and will")
        print("capture events from the next tool call onwards.")
        return

    total = len(events)
    first = events[0].get("timestamp", "?")[:19]
    last = events[-1].get("timestamp", "?")[:19]

    # Separate event types
    tool_calls = [e for e in events if e.get("event_type") == "tool_call"]
    tool_failures = [e for e in events if e.get("event_type") == "tool_failure"]
    agent_spawns = [e for e in events if e.get("event_type") == "agent_spawn"]
    agent_stops = [e for e in events if e.get("event_type") == "agent_stop"]
    session_stops = [e for e in events if e.get("event_type") == "session_stop"]
    notifications = [e for e in events if e.get("event_type") == "notification"]

    tools = Counter(e.get("tool", "?") for e in tool_calls)
    categories = Counter(e.get("category", "?") for e in events)
    agents = Counter(e.get("agent", {}).get("agent_name", "?") for e in events)
    success = sum(1 for e in tool_calls if e.get("success", True))
    failure = len(tool_calls) - success + len(tool_failures)

    print(f"{'=' * 60}")
    print(f"  GENESIS OBSERVABILITY SUMMARY (v2 - Full Lifecycle)")
    print(f"{'=' * 60}")
    print(f"  Total events:      {total}")
    print(f"  Tool calls:        {len(tool_calls)} ({success} ok / {failure} fail)")
    print(f"  Agent spawns:      {len(agent_spawns)}")
    print(f"  Agent stops:       {len(agent_stops)}")
    print(f"  Session stops:     {len(session_stops)}")
    print(f"  Notifications:     {len(notifications)}")
    print(f"  Tool failures:     {len(tool_failures)}")
    print(f"  Time range:        {first} -> {last}")
    print(f"  Unique tools:      {len(tools)}")
    print(f"  Unique agents:     {len(agents)}")
    print()
    print(f"  Top 5 tools:")
    tc_total = max(len(tool_calls), 1)
    for tool, count in tools.most_common(5):
        pct = count / tc_total * 100
        bar = "#" * int(pct / 2)
        print(f"    {tool:20s} {count:4d} ({pct:5.1f}%) {bar}")
    print()
    print(f"  By category:")
    for cat, count in categories.most_common():
        print(f"    {cat:20s} {count:4d}")
    print()
    print(f"  By agent:")
    for agent, count in agents.most_common():
        print(f"    {agent:20s} {count:4d}")
    print(f"{'=' * 60}")


def cmd_recent(events: list, n: int):
    """Print last N events."""
    recent = events[-n:]
    for e in recent:
        ts = e.get("timestamp", "?")[:19]
        tool = e.get("tool", "?")
        cat = e.get("category", "?")
        ok = "OK" if e.get("success", True) else "FAIL"
        agent = e.get("agent", {}).get("agent_name", "?")

        detail = ""
        if tool == "Bash":
            detail = e.get("command_preview", "")[:60]
        elif tool in ("Read", "Write", "Edit"):
            fp = e.get("file_path", "")
            detail = fp.split("/")[-1] if fp else ""
        elif tool == "Task":
            detail = e.get("description", "")[:40]
        elif tool in ("TaskCreate", "TaskUpdate"):
            detail = e.get("task_subject", "") or e.get("task_status", "")
        elif tool == "SendMessage":
            detail = f"→ {e.get('recipient', '?')} ({e.get('message_type', '?')})"

        print(f"[{ts}] [{ok:4s}] {agent:12s} | {tool:15s} | {detail}")


def cmd_by_tool(events: list):
    """Breakdown by tool."""
    tools = Counter(e.get("tool", "?") for e in events)
    total = len(events)
    print(f"\n{'Tool':<20s} {'Count':>6s} {'%':>7s}  {'Bar'}")
    print("-" * 55)
    for tool, count in tools.most_common():
        pct = count / total * 100
        bar = "#" * int(pct / 2)
        print(f"{tool:<20s} {count:>6d} {pct:>6.1f}%  {bar}")


def cmd_by_category(events: list):
    """Breakdown by category."""
    cats = Counter(e.get("category", "?") for e in events)
    total = len(events)
    print(f"\n{'Category':<20s} {'Count':>6s} {'%':>7s}")
    print("-" * 35)
    for cat, count in cats.most_common():
        pct = count / total * 100
        print(f"{cat:<20s} {count:>6d} {pct:>6.1f}%")


def cmd_by_agent(events: list):
    """Breakdown by agent."""
    agents = defaultdict(lambda: {"total": 0, "tools": Counter()})
    for e in events:
        name = e.get("agent", {}).get("agent_name", "primary")
        agents[name]["total"] += 1
        agents[name]["tools"][e.get("tool", "?")] += 1

    for name, data in sorted(agents.items(), key=lambda x: -x[1]["total"]):
        print(f"\n  Agent: {name} ({data['total']} tool calls)")
        for tool, count in data["tools"].most_common(5):
            print(f"    {tool}: {count}")


def cmd_orchestration(events: list):
    """Show only orchestration events (team/task/message)."""
    orch_events = [e for e in events if e.get("category") == "orchestration"]
    if not orch_events:
        print("No orchestration events recorded yet.")
        return

    print(f"\nOrchestration Events: {len(orch_events)}")
    print("-" * 60)
    for e in orch_events:
        ts = e.get("timestamp", "?")[:19]
        tool = e.get("tool", "?")
        agent = e.get("agent", {}).get("agent_name", "?")

        if tool == "TeamCreate":
            detail = f"Team: {e.get('team_name', '?')}"
        elif tool in ("TaskCreate", "TaskUpdate"):
            detail = e.get("task_subject", "") or f"status→{e.get('task_status', '?')}"
        elif tool == "SendMessage":
            detail = f"→ {e.get('recipient', '?')} ({e.get('message_type', '?')})"
        elif tool == "Task":
            detail = f"Spawn: {e.get('subagent_type', '?')} - {e.get('description', '?')}"
        else:
            detail = ""

        print(f"[{ts}] {agent:12s} | {tool:15s} | {detail}")


def cmd_lifecycle(events: list):
    """Show agent spawn/stop lifecycle events."""
    lifecycle_events = [e for e in events if e.get("event_type") in ("agent_spawn", "agent_stop")]
    if not lifecycle_events:
        print("No agent lifecycle events recorded yet.")
        print("These appear when sub-agents are spawned/stopped via the Task tool.")
        return

    print(f"\nAgent Lifecycle Events: {len(lifecycle_events)}")
    print("-" * 70)

    spawns = [e for e in lifecycle_events if e.get("event_type") == "agent_spawn"]
    stops = [e for e in lifecycle_events if e.get("event_type") == "agent_stop"]

    print(f"  Total spawns: {len(spawns)}")
    print(f"  Total stops:  {len(stops)}")
    print(f"  Active now:   {len(spawns) - len(stops)}")
    print()

    for e in lifecycle_events:
        ts = e.get("timestamp", "?")[:19]
        etype = e.get("event_type", "?")

        if etype == "agent_spawn":
            name = e.get("spawned_agent_name", "?")
            atype = e.get("spawned_agent_type", "?")
            desc = e.get("description", "")[:50]
            print(f"  [{ts}] SPAWN  {name:15s} ({atype}) {desc}")
        elif etype == "agent_stop":
            name = e.get("stopped_agent_name", "?")
            reason = e.get("reason", "")[:50]
            print(f"  [{ts}] STOP   {name:15s} {reason}")


def cmd_errors(events: list):
    """Show only failed tool calls."""
    error_events = [e for e in events if e.get("event_type") == "tool_failure"
                    or (e.get("event_type") == "tool_call" and not e.get("success", True))]
    if not error_events:
        print("No tool failures recorded. All tools executing successfully.")
        return

    print(f"\nTool Failures: {len(error_events)}")
    print("-" * 70)

    for e in error_events:
        ts = e.get("timestamp", "?")[:19]
        tool = e.get("tool", "?")
        agent = e.get("agent", {}).get("agent_name", "?")
        error_preview = e.get("error_preview", "")[:80]

        print(f"  [{ts}] {agent:12s} | {tool:15s}")
        if error_preview:
            print(f"    Error: {error_preview}")

    # Error frequency
    tool_errors = Counter(e.get("tool", "?") for e in error_events)
    print(f"\n  Error frequency by tool:")
    for tool, count in tool_errors.most_common():
        print(f"    {tool:20s} {count}")


def cmd_sessions(events: list):
    """Show session lifecycle (starts and stops)."""
    session_events = [e for e in events if e.get("event_type") in ("session_stop",)]

    # Also load sessions.jsonl for richer data
    sessions = []
    if SESSION_FILE_TOP.exists():
        with open(SESSION_FILE, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    try:
                        sessions.append(json.loads(line))
                    except json.JSONDecodeError:
                        continue

    total = len(session_events) + len(sessions)
    if total == 0:
        print("No session events recorded yet.")
        print("Session tracking activates on session Stop events.")
        return

    print(f"\nSession Events: {total}")
    print("-" * 60)

    # Combine and sort by timestamp
    all_session_events = []
    for e in session_events:
        all_session_events.append({
            "timestamp": e.get("timestamp", "?")[:19],
            "type": "stop",
            "session_id": e.get("agent", {}).get("session_id", "?")[:12],
            "agent": e.get("agent", {}).get("agent_name", "?"),
            "reason": e.get("reason", ""),
        })
    for s in sessions:
        all_session_events.append({
            "timestamp": s.get("timestamp", "?")[:19],
            "type": s.get("type", "?"),
            "session_id": s.get("session_id", "?")[:12],
            "agent": s.get("agent_name", "?"),
        })

    all_session_events.sort(key=lambda x: x.get("timestamp", ""))

    for e in all_session_events:
        ts = e.get("timestamp", "?")
        etype = e.get("type", "?").upper()
        sid = e.get("session_id", "?")
        agent = e.get("agent", "?")
        reason = e.get("reason", "")[:40]
        extra = f" ({reason})" if reason else ""
        print(f"  [{ts}] {etype:12s} session={sid} agent={agent}{extra}")


def cmd_trace(events: list, trace_id: str):
    """Trace a correlation ID across all agents and events.

    This is the P0 gap fix: distributed tracing across agent boundaries.
    Shows the full request flow from parent agent through all child agents.
    """
    # Find events matching the correlation ID (partial match supported)
    traced = [e for e in events if trace_id in str(e.get("correlation_id", ""))]

    if not traced:
        # Try matching against session_id in agent info
        traced = [e for e in events
                  if trace_id in str(e.get("agent", {}).get("session_id", ""))]

    if not traced:
        print(f"No events found for trace ID: {trace_id}")
        print(f"Available correlation IDs:")
        corr_ids = set(e.get("correlation_id", "") for e in events if e.get("correlation_id"))
        for cid in sorted(corr_ids)[:20]:
            count = sum(1 for e in events if e.get("correlation_id") == cid)
            print(f"  {cid}  ({count} events)")
        return

    print(f"\n{'=' * 70}")
    print(f"  TRACE: {trace_id}")
    print(f"  Events: {len(traced)}")
    print(f"{'=' * 70}")

    # Group by agent
    by_agent = defaultdict(list)
    for e in traced:
        agent_name = e.get("agent", {}).get("agent_name", "primary")
        by_agent[agent_name].append(e)

    print(f"\n  Agents involved: {len(by_agent)}")
    for agent_name, agent_events in sorted(by_agent.items()):
        print(f"    {agent_name}: {len(agent_events)} events")

    # Show parent-child relationships
    spawn_events = [e for e in traced if e.get("event_type") == "agent_spawn"]
    if spawn_events:
        print(f"\n  Agent Call Tree:")
        for e in spawn_events:
            pc = e.get("parent_child", {})
            parent = pc.get("parent_agent", "?")
            child = pc.get("child_agent", "?")
            depth = pc.get("child_depth", 0)
            indent = "    " + "  " * depth
            print(f"{indent}{parent} -> {child} (depth={depth})")

    # Timeline
    print(f"\n  Timeline:")
    print(f"  {'-' * 66}")
    traced.sort(key=lambda x: x.get("timestamp", ""))
    for e in traced:
        ts = e.get("timestamp", "?")[:19]
        etype = e.get("event_type", "?")
        agent = e.get("agent", {}).get("agent_name", "?")
        tool = e.get("tool", "")

        detail = ""
        if etype == "tool_call":
            detail = f"{tool}"
            if tool == "Task":
                detail += f" -> {e.get('description', '')[:30]}"
            elif tool in ("Read", "Write", "Edit"):
                fp = e.get("file_path", "")
                detail += f" {fp.split('/')[-1]}" if fp else ""
        elif etype == "agent_spawn":
            detail = f"SPAWN {e.get('spawned_agent_name', '?')} ({e.get('spawned_agent_type', '?')})"
        elif etype == "agent_stop":
            detail = f"STOP {e.get('stopped_agent_name', '?')}"
        elif etype == "tool_failure":
            detail = f"FAIL {tool}: {e.get('error_preview', '')[:40]}"
        elif etype == "notification":
            detail = f"NOTIFY {e.get('notification_type', '?')}"

        print(f"  [{ts}] {agent:12s} | {etype:14s} | {detail}")

    print(f"  {'-' * 66}")


def cmd_metrics():
    """Show current metrics."""
    if not METRICS_FILE.exists():
        print("No metrics file yet. Metrics accumulate after tool calls.")
        return

    with open(METRICS_FILE, "r") as f:
        metrics = json.load(f)

    print(json.dumps(metrics, indent=2))


def main():
    parser = argparse.ArgumentParser(description="Genesis Observability Query Tool (v2)")
    parser.add_argument("--summary", action="store_true", help="Session summary")
    parser.add_argument("--recent", type=int, metavar="N", help="Last N events")
    parser.add_argument("--by-tool", action="store_true", help="Breakdown by tool")
    parser.add_argument("--by-category", action="store_true", help="Breakdown by category")
    parser.add_argument("--by-agent", action="store_true", help="Breakdown by agent")
    parser.add_argument("--orchestration", action="store_true", help="Orchestration events only")
    parser.add_argument("--lifecycle", action="store_true", help="Agent spawn/stop lifecycle")
    parser.add_argument("--errors", action="store_true", help="Failed tool calls only")
    parser.add_argument("--sessions", action="store_true", help="Session start/stop events")
    parser.add_argument("--timeline", action="store_true", help="Full timeline")
    parser.add_argument("--metrics", action="store_true", help="Current metrics")
    parser.add_argument("--trace", type=str, metavar="ID", help="Trace correlation ID across agents")
    args = parser.parse_args()

    events = load_events()

    if args.trace:
        cmd_trace(events, args.trace)
    elif args.summary or not any(vars(args).values()):
        cmd_summary(events)
    elif args.recent:
        cmd_recent(events, args.recent)
    elif args.by_tool:
        cmd_by_tool(events)
    elif args.by_category:
        cmd_by_category(events)
    elif args.by_agent:
        cmd_by_agent(events)
    elif args.orchestration:
        cmd_orchestration(events)
    elif args.lifecycle:
        cmd_lifecycle(events)
    elif args.errors:
        cmd_errors(events)
    elif args.sessions:
        cmd_sessions(events)
    elif args.timeline:
        cmd_recent(events, len(events))
    elif args.metrics:
        cmd_metrics()


if __name__ == "__main__":
    main()
