#!/usr/bin/env python3
"""
Sunaiva Memory Health Monitor
=============================
Polls PostgreSQL, Redis, and Qdrant every 60 seconds.
Logs health metrics to KG entity file and alerts if any service is down.

Usage:
    python sunaiva_memory_monitor.py [--once] [--interval 60]

Output:
    E:\\genesis-system\\KNOWLEDGE_GRAPH\\entities\\memory_health_log.jsonl
    Console alerts on service degradation
"""

import argparse
import json
import logging
import sys
import time
from datetime import datetime
from pathlib import Path

# ---------------------------------------------------------------------------
# Path setup - must run on Windows E: drive
# ---------------------------------------------------------------------------
GENESIS_ROOT = Path("E:/genesis-system")
sys.path.insert(0, str(GENESIS_ROOT / "data" / "genesis-memory"))

KG_ENTITIES_DIR = GENESIS_ROOT / "KNOWLEDGE_GRAPH" / "entities"
HEALTH_LOG = KG_ENTITIES_DIR / "memory_health_log.jsonl"
ALERT_THRESHOLD = 3  # consecutive failures before loud alert

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [memory-monitor] %(levelname)s %(message)s",
)
logger = logging.getLogger("memory_monitor")

# Track consecutive failures per service
_failure_counts: dict = {"postgresql": 0, "redis": 0, "qdrant": 0}


# ---------------------------------------------------------------------------
# Health probes
# ---------------------------------------------------------------------------

def probe_postgresql() -> dict:
    """Probe PostgreSQL health. Returns dict with status, latency_ms, details."""
    start = time.monotonic()
    try:
        import psycopg2
        from elestio_config import PostgresConfig
        conn = psycopg2.connect(
            connect_timeout=5,
            **PostgresConfig.get_connection_params()
        )
        cur = conn.cursor()
        cur.execute("SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public'")
        table_count = cur.fetchone()[0]
        cur.execute("SELECT pg_database_size(current_database())")
        db_size_bytes = cur.fetchone()[0]
        conn.close()
        latency_ms = round((time.monotonic() - start) * 1000, 1)
        return {
            "status": "healthy",
            "latency_ms": latency_ms,
            "table_count": table_count,
            "db_size_mb": round(db_size_bytes / 1024 / 1024, 1),
        }
    except Exception as e:
        latency_ms = round((time.monotonic() - start) * 1000, 1)
        return {"status": "down", "latency_ms": latency_ms, "error": str(e)}


def probe_redis() -> dict:
    """Probe Redis health. Returns dict with status, latency_ms, details."""
    start = time.monotonic()
    try:
        import redis as redis_lib
        from elestio_config import RedisConfig
        client = redis_lib.Redis(
            socket_connect_timeout=5,
            socket_timeout=5,
            **RedisConfig.get_connection_params()
        )
        pong = client.ping()
        info = client.info("memory")
        used_memory_mb = round(info.get("used_memory", 0) / 1024 / 1024, 1)
        key_count = client.dbsize()
        client.close()
        latency_ms = round((time.monotonic() - start) * 1000, 1)
        return {
            "status": "healthy" if pong else "degraded",
            "latency_ms": latency_ms,
            "used_memory_mb": used_memory_mb,
            "key_count": key_count,
        }
    except Exception as e:
        latency_ms = round((time.monotonic() - start) * 1000, 1)
        return {"status": "down", "latency_ms": latency_ms, "error": str(e)}


def probe_qdrant() -> dict:
    """Probe Qdrant health. Returns dict with status, latency_ms, details."""
    start = time.monotonic()
    try:
        import requests
        from elestio_config import QdrantConfig
        cfg = QdrantConfig()
        resp = requests.get(
            f"{cfg.url}/collections",
            headers={"api-key": cfg.api_key},
            timeout=8,
        )
        resp.raise_for_status()
        data = resp.json()
        collections = data.get("result", {}).get("collections", [])
        latency_ms = round((time.monotonic() - start) * 1000, 1)

        # Get vector count for genesis_vectors if it exists
        vector_count = None
        for col in collections:
            if col.get("name") == cfg.collection_name:
                try:
                    col_resp = requests.get(
                        f"{cfg.url}/collections/{cfg.collection_name}",
                        headers={"api-key": cfg.api_key},
                        timeout=8,
                    )
                    if col_resp.status_code == 200:
                        col_data = col_resp.json()
                        vector_count = (
                            col_data.get("result", {})
                            .get("vectors_count", None)
                        )
                except Exception:
                    pass
                break

        return {
            "status": "healthy",
            "latency_ms": latency_ms,
            "collection_count": len(collections),
            "genesis_vectors_count": vector_count,
        }
    except Exception as e:
        latency_ms = round((time.monotonic() - start) * 1000, 1)
        return {"status": "down", "latency_ms": latency_ms, "error": str(e)}


# ---------------------------------------------------------------------------
# Alerting
# ---------------------------------------------------------------------------

def _alert(service: str, result: dict):
    """Log a service-down alert."""
    count = _failure_counts[service]
    severity = "CRITICAL" if count >= ALERT_THRESHOLD else "WARNING"
    msg = (
        f"[{severity}] {service.upper()} is {result['status'].upper()} "
        f"(consecutive failures: {count}). Error: {result.get('error', 'unknown')}"
    )
    logger.error(msg)
    # Write alert to a dedicated alert file so hooks can pick it up
    alert_file = KG_ENTITIES_DIR / "memory_health_alerts.jsonl"
    alert_entry = {
        "timestamp": datetime.utcnow().isoformat() + "Z",
        "service": service,
        "severity": severity,
        "status": result["status"],
        "consecutive_failures": count,
        "error": result.get("error", ""),
    }
    try:
        with open(alert_file, "a", encoding="utf-8") as f:
            f.write(json.dumps(alert_entry) + "\n")
    except Exception as exc:
        logger.error("Failed to write alert: %s", exc)


# ---------------------------------------------------------------------------
# Core monitoring loop
# ---------------------------------------------------------------------------

def run_health_check() -> dict:
    """Run one full health check across all services."""
    timestamp = datetime.utcnow().isoformat() + "Z"
    logger.info("Running health check...")

    pg = probe_postgresql()
    redis = probe_redis()
    qdrant = probe_qdrant()

    # Update failure counters
    for svc, result in [("postgresql", pg), ("redis", redis), ("qdrant", qdrant)]:
        if result["status"] == "healthy":
            if _failure_counts[svc] > 0:
                logger.info("[RECOVERED] %s is back online.", svc.upper())
            _failure_counts[svc] = 0
        else:
            _failure_counts[svc] += 1
            _alert(svc, result)

    # Determine overall system health
    all_statuses = [pg["status"], redis["status"], qdrant["status"]]
    if all(s == "healthy" for s in all_statuses):
        overall = "healthy"
    elif any(s == "down" for s in all_statuses):
        overall = "degraded"
    else:
        overall = "degraded"

    snapshot = {
        "timestamp": timestamp,
        "overall": overall,
        "services": {
            "postgresql": pg,
            "redis": redis,
            "qdrant": qdrant,
        },
        "failure_counts": dict(_failure_counts),
    }

    # Log to KG entity file
    KG_ENTITIES_DIR.mkdir(parents=True, exist_ok=True)
    try:
        with open(HEALTH_LOG, "a", encoding="utf-8") as f:
            f.write(json.dumps(snapshot) + "\n")
    except Exception as e:
        logger.error("Failed to write health log: %s", e)

    # Console summary
    status_line = (
        f"PG={pg['status']}({pg.get('latency_ms', '?')}ms) | "
        f"Redis={redis['status']}({redis.get('latency_ms', '?')}ms) | "
        f"Qdrant={qdrant['status']}({qdrant.get('latency_ms', '?')}ms) | "
        f"Overall={overall}"
    )
    logger.info(status_line)

    return snapshot


def monitor_loop(interval_seconds: int = 60):
    """Run the health monitor in an infinite loop."""
    logger.info("Sunaiva Memory Health Monitor started (interval=%ds)", interval_seconds)
    logger.info("Health log: %s", HEALTH_LOG)

    while True:
        try:
            run_health_check()
        except KeyboardInterrupt:
            logger.info("Monitor stopped by user.")
            break
        except Exception as e:
            logger.error("Unexpected error in health check: %s", e)

        try:
            time.sleep(interval_seconds)
        except KeyboardInterrupt:
            logger.info("Monitor stopped by user.")
            break


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Sunaiva Memory Health Monitor")
    parser.add_argument(
        "--once", action="store_true",
        help="Run one health check and exit (for testing/cron use)"
    )
    parser.add_argument(
        "--interval", type=int, default=60,
        help="Check interval in seconds (default: 60)"
    )
    args = parser.parse_args()

    if args.once:
        result = run_health_check()
        print(json.dumps(result, indent=2))
        sys.exit(0 if result["overall"] == "healthy" else 1)
    else:
        monitor_loop(interval_seconds=args.interval)


if __name__ == "__main__":
    main()
