#!/usr/bin/env python3
"""
Genesis Session Heartbeat & State Backup (Layer 5 of Defense System)
====================================================================
PostToolUse hook that periodically saves a structured session state summary
to hive/session_backups/ and hive/session_recovery/LATEST.md.

This runs on EVERY tool call but only writes state every N tool calls
or when context usage crosses critical thresholds.

Designed to be lightweight: ~5ms on skip, ~50ms on write.

Why this exists:
  Session 7e3fae84 crashed at 20MB/4436 messages due to thinking block
  corruption after autocompact. The entire session was unrecoverable.
  This hook ensures that even if a session dies, the next session can
  instantly recover the previous session's operational state.
"""

import sys
import json
import os
import time
from datetime import datetime, timezone
from pathlib import Path

# Directories
BACKUP_DIR = Path("/mnt/e/genesis-system/hive/session_backups")
RECOVERY_DIR = Path("/mnt/e/genesis-system/hive/session_recovery")
STATE_DIR = Path("/mnt/e/genesis-system/data/context_state")
EVENTS_DIR = Path("/mnt/e/genesis-system/data/observability")

# How often to write state (every N tool calls)
HEARTBEAT_INTERVAL = 25  # Write every 25 tool calls
CRITICAL_THRESHOLD = 70  # Always write when context > 70%
COUNTER_FILE = STATE_DIR / "heartbeat_counter.json"


def get_counter() -> dict:
    """Read heartbeat counter."""
    try:
        if COUNTER_FILE.exists():
            with open(COUNTER_FILE, "r") as f:
                return json.load(f)
    except Exception:
        pass
    return {"count": 0, "last_backup": ""}


def update_counter(count: int, last_backup: str = ""):
    """Update heartbeat counter."""
    try:
        STATE_DIR.mkdir(parents=True, exist_ok=True)
        data = {"count": count, "last_backup": last_backup}
        with open(COUNTER_FILE, "w") as f:
            json.dump(data, f)
    except Exception:
        pass


def get_context_usage() -> float:
    """Read current context usage from statusline state file."""
    try:
        state_file = STATE_DIR / "current.json"
        if state_file.exists():
            with open(state_file, "r") as f:
                state = json.load(f)
                return state.get("used_percentage", 0)
    except Exception:
        pass
    return 0


def get_recent_tool_calls(limit: int = 10) -> list:
    """Read recent tool call events from observability log."""
    events = []
    try:
        events_file = EVENTS_DIR / "events.jsonl"
        if events_file.exists():
            # Read last N lines efficiently
            with open(events_file, "rb") as f:
                # Seek to end, read backwards to find last N lines
                f.seek(0, 2)
                file_size = f.tell()
                # Read last 50KB (should contain enough events)
                read_size = min(file_size, 50000)
                f.seek(file_size - read_size)
                lines = f.read().decode("utf-8", errors="replace").strip().split("\n")

            for line in lines[-limit:]:
                try:
                    event = json.loads(line.strip())
                    if event.get("event_type") == "tool_call":
                        events.append({
                            "tool": event.get("tool", "?"),
                            "category": event.get("category", "?"),
                            "timestamp": event.get("timestamp", ""),
                        })
                except Exception:
                    continue
    except Exception:
        pass
    return events


def get_active_agents() -> list:
    """Read active agent info from metrics."""
    try:
        metrics_file = EVENTS_DIR / "metrics.json"
        if metrics_file.exists():
            with open(metrics_file, "r") as f:
                metrics = json.load(f)
                return [{
                    "active_count": metrics.get("agents_active", 0),
                    "total_spawns": metrics.get("agent_spawns", 0),
                    "total_stops": metrics.get("agent_stops", 0),
                }]
    except Exception:
        pass
    return []


def get_compaction_count() -> int:
    """Count compaction events for this session."""
    try:
        log_file = STATE_DIR / "compaction_log.jsonl"
        if log_file.exists():
            with open(log_file, "r") as f:
                return sum(1 for _ in f)
    except Exception:
        pass
    return 0


def write_session_state(session_id: str, context_pct: float):
    """Write structured session state backup."""
    now = datetime.now(timezone.utc)
    timestamp = now.strftime("%Y%m%d_%H%M%S")

    recent_tools = get_recent_tool_calls(15)
    agents = get_active_agents()
    compactions = get_compaction_count()

    state = {
        "timestamp": now.isoformat(),
        "session_id": session_id,
        "context_usage_pct": round(context_pct, 1),
        "compaction_count": compactions,
        "recent_tools": recent_tools,
        "agent_status": agents,
        "heartbeat_type": "periodic" if context_pct < CRITICAL_THRESHOLD else "critical",
    }

    # Write JSON backup
    BACKUP_DIR.mkdir(parents=True, exist_ok=True)
    backup_file = BACKUP_DIR / f"heartbeat_{session_id[:8]}_{timestamp}.json"
    try:
        with open(backup_file, "w") as f:
            json.dump(state, f, indent=2)
    except Exception:
        pass

    # Write LATEST.md for human-readable recovery
    write_recovery_doc(state)

    # Prune old backups (keep last 20)
    try:
        backups = sorted(BACKUP_DIR.glob("heartbeat_*.json"))
        for old_file in backups[:-20]:
            old_file.unlink()
    except Exception:
        pass


def write_recovery_doc(state: dict):
    """Write human-readable session recovery document."""
    RECOVERY_DIR.mkdir(parents=True, exist_ok=True)
    recovery_file = RECOVERY_DIR / "LATEST.md"

    ts = state.get("timestamp", "unknown")
    sid = state.get("session_id", "unknown")
    ctx = state.get("context_usage_pct", 0)
    compactions = state.get("compaction_count", 0)
    agents = state.get("agent_status", [])
    recent = state.get("recent_tools", [])

    # Format recent tools
    tool_lines = []
    for t in recent[-10:]:
        tool_lines.append(f"- {t.get('tool', '?')} ({t.get('category', '?')}) @ {t.get('timestamp', '?')}")

    agent_info = ""
    if agents:
        a = agents[0]
        agent_info = f"Active: {a.get('active_count', 0)}, Total spawns: {a.get('total_spawns', 0)}, Total stops: {a.get('total_stops', 0)}"

    md = f"""# Session Recovery State
## Auto-generated by Session Heartbeat Hook

**Last Updated**: {ts}
**Session ID**: {sid}
**Context Usage**: {ctx}%
**Compaction Events**: {compactions}

## Agent Status
{agent_info}

## Recent Tool Activity
{chr(10).join(tool_lines) if tool_lines else "No recent tool activity recorded."}

## Recovery Instructions
1. Read this file to understand where the previous session left off
2. Check `/mnt/e/genesis-system/hive/session_backups/` for detailed JSON state
3. Check `/mnt/e/genesis-system/data/context_backups/` for pre-compaction transcripts
4. Read MEMORY.md for strategic context and war room status
5. Check `/mnt/e/genesis-system/data/observability/events.jsonl` for full event trail

## Previous Session Death Cause
If this file exists but the session is gone, the session likely crashed.
Common causes:
- Thinking block corruption after autocompact (see SESSION_CRASH_POSTMORTEM.md)
- Context overflow without graceful handoff
- API rate limit or network error during critical operation
- Multiple agent notifications flooding during autocompact

## Quick Recovery Checklist
- [ ] Read MEMORY.md war room status
- [ ] Check hive/session_backups/ for latest heartbeat JSON
- [ ] Check if any agents were running (agent_status above)
- [ ] Resume from last known good state
"""

    try:
        with open(recovery_file, "w") as f:
            f.write(md)
    except Exception:
        pass


def main():
    try:
        hook_input = json.loads(sys.stdin.read())
    except (json.JSONDecodeError, Exception):
        print(json.dumps({}))
        return

    # Get session info
    session_id = os.environ.get("CLAUDE_SESSION_ID", "unknown")

    # Get context usage
    context_pct = get_context_usage()

    # Check counter
    counter = get_counter()
    count = counter.get("count", 0) + 1

    should_write = False

    # Write on interval
    if count >= HEARTBEAT_INTERVAL:
        should_write = True
        count = 0

    # Always write at critical threshold
    if context_pct >= CRITICAL_THRESHOLD:
        should_write = True

    # Update counter
    if should_write:
        write_session_state(session_id, context_pct)
        update_counter(0, datetime.now(timezone.utc).isoformat())
    else:
        update_counter(count)

    # No additional context injection (keep lightweight)
    print(json.dumps({}))


if __name__ == "__main__":
    main()
