#!/usr/bin/env python3
"""
Genesis Elestio Resilience Layer — Production Zero-Downtime Infrastructure
==========================================================================
Automated health monitoring, backup snapshots, connection pooling, and
auto-recovery for all Elestio managed services.

Usage:
    python3 scripts/elestio_resilience.py check       # Health check all services
    python3 scripts/elestio_resilience.py backup       # Create snapshots/backups
    python3 scripts/elestio_resilience.py fix          # Auto-fix known issues
    python3 scripts/elestio_resilience.py full         # All of the above
    python3 scripts/elestio_resilience.py report       # Generate full report

Cron (recommended):
    */5 * * * * python3 /mnt/e/genesis-system/scripts/elestio_resilience.py check >> /tmp/elestio_resilience.log 2>&1
    0 */6 * * * python3 /mnt/e/genesis-system/scripts/elestio_resilience.py backup >> /tmp/elestio_resilience.log 2>&1
"""

import json
import os
import sys
import time
import logging
from datetime import datetime, timezone
from pathlib import Path

# ─── CONFIGURATION ──────────────────────────────────────────────────────
GENESIS_DIR = Path("/mnt/e/genesis-system")
SECRETS_FILE = GENESIS_DIR / "config" / "secrets.env"
HEALTH_LOG = Path("/tmp/elestio_resilience.log")
STATE_FILE = Path("/tmp/elestio_resilience_state.json")
ALERT_FILE = GENESIS_DIR / "data" / "alerts" / "elestio_health.json"

# Load secrets
def load_secrets():
    secrets = {}
    if SECRETS_FILE.exists():
        for line in SECRETS_FILE.read_text().splitlines():
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            if '=' in line:
                key, _, val = line.partition('=')
                secrets[key.strip()] = val.strip().strip("'\"")
    return secrets

SECRETS = load_secrets()

# Service configs
SERVICES = {
    "postgresql": {
        "host": "postgresql-genesis-u50607.vm.elestio.app",
        "port": int(SECRETS.get("SUNAIVA_DB_PORT", "25432")),
        "user": "postgres",
        "password": SECRETS.get("SUNAIVA_DB_PASSWORD", ""),
        "database": "postgres",
    },
    "qdrant": {
        "url": "https://qdrant-b3knu-u50607.vm.elestio.app:6333",
        "api_key": SECRETS.get("GENESIS_QDRANT_API_KEY", SECRETS.get("QDRANT_API_KEY", "")),
    },
    "redis": {
        "host": SECRETS.get("GENESIS_REDIS_HOST", "redis-genesis-u50607.vm.elestio.app"),
        "port": int(SECRETS.get("GENESIS_REDIS_PORT", "6379")),
        "password": SECRETS.get("GENESIS_REDIS_PASSWORD", ""),
    },
    "n8n": {
        "url": SECRETS.get("N8N_BASE_URL", "https://n8n-genesis-u50607.vm.elestio.app"),
        "api_key": SECRETS.get("N8N_API_KEY", ""),
    },
    "browserless": {
        "url": SECRETS.get("BROWSERLESS_URL", "https://browserless-genesis-u50607.vm.elestio.app"),
        "token": SECRETS.get("BROWSERLESS_TOKEN", ""),
    },
}

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] %(levelname)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("elestio_resilience")


# ─── HEALTH CHECKS ─────────────────────────────────────────────────────

def check_postgresql():
    """Check PostgreSQL connectivity and basic health."""
    try:
        import psycopg2
        cfg = SERVICES["postgresql"]
        conn = psycopg2.connect(
            host=cfg["host"], port=cfg["port"],
            user=cfg["user"], password=cfg["password"],
            database=cfg["database"], connect_timeout=10,
        )
        cur = conn.cursor()

        # Basic connectivity
        cur.execute("SELECT 1")

        # Connection count
        cur.execute("SELECT count(*) FROM pg_stat_activity")
        active_conns = cur.fetchone()[0]

        # Cache hit ratio
        cur.execute("""
            SELECT CASE WHEN blks_hit + blks_read = 0 THEN 1.0
                   ELSE blks_hit::float / (blks_hit + blks_read) END
            FROM pg_stat_database WHERE datname = 'postgres'
        """)
        cache_hit = cur.fetchone()[0]

        # Dead tuples (bloat indicator)
        cur.execute("""
            SELECT sum(n_dead_tup) FROM pg_stat_user_tables
        """)
        dead_tuples = cur.fetchone()[0] or 0

        # Long-running queries
        cur.execute("""
            SELECT count(*) FROM pg_stat_activity
            WHERE state = 'active' AND query_start < now() - interval '5 minutes'
        """)
        long_queries = cur.fetchone()[0]

        conn.close()

        status = "healthy"
        issues = []
        if active_conns > 80:
            issues.append(f"High connection count: {active_conns}/100")
        if cache_hit < 0.95:
            issues.append(f"Low cache hit ratio: {cache_hit:.2%}")
        if dead_tuples > 10000:
            issues.append(f"High dead tuples: {dead_tuples:,} (needs VACUUM)")
        if long_queries > 0:
            issues.append(f"{long_queries} queries running >5 minutes")
        if issues:
            status = "degraded"

        return {
            "service": "postgresql",
            "status": status,
            "connections": active_conns,
            "cache_hit_ratio": round(cache_hit, 4),
            "dead_tuples": dead_tuples,
            "long_running_queries": long_queries,
            "issues": issues,
        }
    except Exception as e:
        return {"service": "postgresql", "status": "unreachable", "error": str(e), "issues": [str(e)]}


def check_qdrant():
    """Check Qdrant connectivity, collection health, and indexing status."""
    try:
        import urllib.request
        import ssl
        ctx = ssl.create_default_context()
        ctx.check_hostname = False
        ctx.verify_mode = ssl.CERT_NONE

        cfg = SERVICES["qdrant"]
        url = cfg["url"]
        headers = {"Api-Key": cfg["api_key"]}

        # Health check
        req = urllib.request.Request(f"{url}/", headers=headers)
        with urllib.request.urlopen(req, context=ctx, timeout=15) as resp:
            health = json.loads(resp.read())

        # Collections
        req = urllib.request.Request(f"{url}/collections", headers=headers)
        with urllib.request.urlopen(req, context=ctx, timeout=15) as resp:
            colls = json.loads(resp.read())["result"]["collections"]

        # Check main collection
        req = urllib.request.Request(f"{url}/collections/genesis_memories", headers=headers)
        with urllib.request.urlopen(req, context=ctx, timeout=15) as resp:
            main_coll = json.loads(resp.read())["result"]

        points = main_coll.get("points_count", 0)
        indexed = main_coll.get("indexed_vectors_count", 0)
        index_ratio = indexed / points if points > 0 else 1.0

        # Check snapshots
        req = urllib.request.Request(f"{url}/snapshots", headers=headers)
        with urllib.request.urlopen(req, context=ctx, timeout=15) as resp:
            snaps = json.loads(resp.read())["result"]

        status = "healthy"
        issues = []
        if index_ratio < 0.9:
            issues.append(f"Indexing gap: {indexed}/{points} vectors indexed ({index_ratio:.0%})")
        if not snaps:
            issues.append("No snapshots exist — data not backed up")
        elif snaps:
            newest = max(snaps, key=lambda s: s.get("creation_time", ""))
            snap_age_str = newest.get("creation_time", "")
            if snap_age_str:
                try:
                    # Handle both naive and aware datetime strings
                    clean = snap_age_str.replace("Z", "+00:00")
                    if "+" not in clean and "-" not in clean[10:]:
                        clean += "+00:00"
                    snap_time = datetime.fromisoformat(clean)
                    if snap_time.tzinfo is None:
                        snap_time = snap_time.replace(tzinfo=timezone.utc)
                    age_hours = (datetime.now(timezone.utc) - snap_time).total_seconds() / 3600
                    if age_hours > 24:
                        issues.append(f"Latest snapshot is {age_hours:.0f}h old (>24h)")
                except Exception:
                    pass  # Skip age check if parsing fails

        empty_colls = sum(1 for c in colls if True)  # would need per-coll check
        if issues:
            status = "degraded"

        return {
            "service": "qdrant",
            "status": status,
            "version": health.get("version", "?"),
            "collections": len(colls),
            "main_collection_points": points,
            "main_collection_indexed": indexed,
            "snapshots": len(snaps),
            "issues": issues,
        }
    except Exception as e:
        return {"service": "qdrant", "status": "unreachable", "error": str(e), "issues": [str(e)]}


def check_redis():
    """Check Redis connectivity, memory, and connection health."""
    try:
        import redis
        cfg = SERVICES["redis"]
        r = redis.Redis(
            host=cfg["host"], port=cfg["port"],
            password=cfg["password"], decode_responses=True,
            socket_timeout=10,
        )
        info = r.info()
        client_list = r.client_list()

        connected = info["connected_clients"]
        memory_used = info["used_memory"]
        memory_peak = info["used_memory_peak"]
        keyspace_hits = info.get("keyspace_hits", 0)
        keyspace_misses = info.get("keyspace_misses", 0)
        total_keys = r.dbsize()

        # Count zombie connections (idle > 300s)
        zombies = sum(1 for c in client_list if int(c.get("idle", 0)) > 300)

        # Fragmentation
        frag_ratio = info.get("mem_fragmentation_ratio", 1.0)

        status = "healthy"
        issues = []
        if zombies > 10:
            issues.append(f"{zombies} zombie connections (idle >5min)")
        if connected > 500:
            issues.append(f"High connection count: {connected}")
        if frag_ratio > 2.0:
            issues.append(f"High memory fragmentation: {frag_ratio:.1f}x")

        r.close()

        if issues:
            status = "degraded"

        return {
            "service": "redis",
            "status": status,
            "version": info["redis_version"],
            "connected_clients": connected,
            "zombie_connections": zombies,
            "memory_used_mb": round(memory_used / 1024 / 1024, 1),
            "total_keys": total_keys,
            "fragmentation_ratio": round(frag_ratio, 2),
            "issues": issues,
        }
    except Exception as e:
        return {"service": "redis", "status": "unreachable", "error": str(e), "issues": [str(e)]}


def check_n8n():
    """Check n8n API reachability and workflow health."""
    try:
        import urllib.request
        cfg = SERVICES["n8n"]
        url = cfg["url"]
        headers = {"X-N8N-API-KEY": cfg["api_key"]}

        req = urllib.request.Request(f"{url}/api/v1/workflows", headers=headers)
        with urllib.request.urlopen(req, timeout=15) as resp:
            data = json.loads(resp.read())

        workflows = data.get("data", [])
        active = sum(1 for w in workflows if w.get("active"))
        inactive = len(workflows) - active

        status = "healthy"
        issues = []
        if not workflows:
            issues.append("No workflows found")
        if active == 0:
            issues.append("No active workflows")
            status = "degraded"

        return {
            "service": "n8n",
            "status": status,
            "total_workflows": len(workflows),
            "active_workflows": active,
            "inactive_workflows": inactive,
            "issues": issues,
        }
    except Exception as e:
        return {"service": "n8n", "status": "unreachable", "error": str(e), "issues": [str(e)]}


def check_browserless():
    """Check Browserless availability."""
    try:
        import urllib.request
        cfg = SERVICES["browserless"]
        url = cfg["url"]
        token = cfg["token"]

        req = urllib.request.Request(f"{url}/pressure?token={token}")
        with urllib.request.urlopen(req, timeout=15) as resp:
            data = json.loads(resp.read())

        pressure = data.get("pressure", {})
        available = pressure.get("isAvailable", False)
        running = pressure.get("running", 0)
        queued = pressure.get("queued", 0)
        max_conc = pressure.get("maxConcurrent", 10)

        status = "healthy" if available else "degraded"
        issues = []
        if not available:
            issues.append("Service reports unavailable")
        if running >= max_conc:
            issues.append(f"At capacity: {running}/{max_conc}")

        return {
            "service": "browserless",
            "status": status,
            "available": available,
            "running_sessions": running,
            "queued": queued,
            "max_concurrent": max_conc,
            "issues": issues,
        }
    except Exception as e:
        return {"service": "browserless", "status": "unreachable", "error": str(e), "issues": [str(e)]}


# ─── BACKUP / SNAPSHOT ──────────────────────────────────────────────────

def backup_qdrant():
    """Create Qdrant full snapshot."""
    try:
        import urllib.request
        import ssl
        ctx = ssl.create_default_context()
        ctx.check_hostname = False
        ctx.verify_mode = ssl.CERT_NONE

        cfg = SERVICES["qdrant"]
        url = cfg["url"]
        headers = {"Api-Key": cfg["api_key"]}

        req = urllib.request.Request(f"{url}/snapshots", method="POST", headers=headers)
        with urllib.request.urlopen(req, context=ctx, timeout=120) as resp:
            result = json.loads(resp.read())

        snap = result.get("result", {})
        log.info(f"Qdrant snapshot created: {snap.get('name')} ({snap.get('size', 0) / 1024 / 1024:.0f} MB)")
        return {"service": "qdrant", "action": "snapshot", "success": True, "snapshot": snap}
    except Exception as e:
        log.error(f"Qdrant snapshot FAILED: {e}")
        return {"service": "qdrant", "action": "snapshot", "success": False, "error": str(e)}


def backup_postgresql():
    """Trigger PostgreSQL VACUUM ANALYZE for maintenance."""
    try:
        import psycopg2
        cfg = SERVICES["postgresql"]
        conn = psycopg2.connect(
            host=cfg["host"], port=cfg["port"],
            user=cfg["user"], password=cfg["password"],
            database=cfg["database"], connect_timeout=10,
        )
        conn.autocommit = True
        cur = conn.cursor()
        cur.execute("VACUUM ANALYZE")
        conn.close()
        log.info("PostgreSQL VACUUM ANALYZE completed")
        return {"service": "postgresql", "action": "vacuum_analyze", "success": True}
    except Exception as e:
        log.error(f"PostgreSQL VACUUM FAILED: {e}")
        return {"service": "postgresql", "action": "vacuum_analyze", "success": False, "error": str(e)}


def cleanup_qdrant_snapshots():
    """Keep only the 3 most recent Qdrant snapshots."""
    try:
        import urllib.request
        import ssl
        ctx = ssl.create_default_context()
        ctx.check_hostname = False
        ctx.verify_mode = ssl.CERT_NONE

        cfg = SERVICES["qdrant"]
        url = cfg["url"]
        headers = {"Api-Key": cfg["api_key"]}

        req = urllib.request.Request(f"{url}/snapshots", headers=headers)
        with urllib.request.urlopen(req, context=ctx, timeout=15) as resp:
            snaps = json.loads(resp.read())["result"]

        if len(snaps) <= 3:
            return {"action": "cleanup_snapshots", "deleted": 0}

        # Sort by creation time, delete oldest
        snaps.sort(key=lambda s: s.get("creation_time", ""), reverse=True)
        deleted = 0
        for snap in snaps[3:]:
            name = snap["name"]
            try:
                req = urllib.request.Request(
                    f"{url}/snapshots/{name}", method="DELETE", headers=headers
                )
                urllib.request.urlopen(req, context=ctx, timeout=15)
                deleted += 1
                log.info(f"Deleted old Qdrant snapshot: {name}")
            except Exception:
                pass

        return {"action": "cleanup_snapshots", "deleted": deleted}
    except Exception as e:
        return {"action": "cleanup_snapshots", "error": str(e)}


# ─── AUTO-FIX ───────────────────────────────────────────────────────────

def fix_redis_zombies():
    """Kill zombie Redis connections and ensure timeout is set."""
    try:
        import redis
        cfg = SERVICES["redis"]
        r = redis.Redis(
            host=cfg["host"], port=cfg["port"],
            password=cfg["password"], decode_responses=True,
            socket_timeout=10,
        )
        # Ensure timeout is set
        current_timeout = r.config_get("timeout").get("timeout", "0")
        if current_timeout == "0":
            r.config_set("timeout", 300)
            log.info("Redis: Set timeout=300s")

        # Count and report zombies
        clients = r.client_list()
        zombies = [c for c in clients if int(c.get("idle", 0)) > 600]
        if zombies:
            log.info(f"Redis: {len(zombies)} zombie connections will be cleaned by timeout")

        r.close()
        return {"action": "fix_redis_zombies", "zombies_found": len(zombies), "timeout_set": True}
    except Exception as e:
        return {"action": "fix_redis_zombies", "error": str(e)}


# ─── MAIN COMMANDS ──────────────────────────────────────────────────────

def run_check():
    """Run health checks on all services."""
    results = []
    for checker in [check_postgresql, check_qdrant, check_redis, check_n8n, check_browserless]:
        result = checker()
        results.append(result)
        icon = {"healthy": "✓", "degraded": "⚠", "unreachable": "✗"}.get(result["status"], "?")
        log.info(f"  {icon} {result['service']}: {result['status']}")
        for issue in result.get("issues", []):
            log.info(f"    → {issue}")

    # Save state
    state = {
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "services": results,
        "overall": "healthy" if all(r["status"] == "healthy" for r in results) else
                   "degraded" if any(r["status"] == "unreachable" for r in results) else "degraded",
    }
    STATE_FILE.write_text(json.dumps(state, indent=2, default=str))

    # Save alert file for dashboards
    ALERT_FILE.parent.mkdir(parents=True, exist_ok=True)
    ALERT_FILE.write_text(json.dumps(state, indent=2, default=str))

    return state


def run_backup():
    """Run backups/maintenance on all services."""
    results = []
    results.append(backup_qdrant())
    results.append(backup_postgresql())
    results.append(cleanup_qdrant_snapshots())
    return results


def run_fix():
    """Auto-fix known issues."""
    results = []
    results.append(fix_redis_zombies())
    return results


def run_full():
    """Full cycle: check → fix → backup."""
    log.info("=== Elestio Resilience — Full Cycle ===")
    check_results = run_check()
    fix_results = run_fix()
    backup_results = run_backup()
    log.info("=== Full cycle complete ===")
    return {"check": check_results, "fix": fix_results, "backup": backup_results}


def run_report():
    """Generate comprehensive report."""
    log.info("=== GENESIS ELESTIO INFRASTRUCTURE REPORT ===")
    log.info(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    log.info("")

    state = run_check()

    log.info("")
    log.info("=== SERVICE SUMMARY ===")
    for svc in state["services"]:
        status = svc["status"].upper()
        name = svc["service"].upper()
        log.info(f"  {name}: {status}")

        # Service-specific details
        if svc["service"] == "postgresql":
            log.info(f"    Connections: {svc.get('connections', '?')}/100")
            log.info(f"    Cache hit: {svc.get('cache_hit_ratio', 0):.1%}")
            log.info(f"    Dead tuples: {svc.get('dead_tuples', 0):,}")
        elif svc["service"] == "qdrant":
            log.info(f"    Version: {svc.get('version', '?')}")
            log.info(f"    Collections: {svc.get('collections', 0)}")
            log.info(f"    Main collection: {svc.get('main_collection_points', 0):,} points")
            log.info(f"    Snapshots: {svc.get('snapshots', 0)}")
        elif svc["service"] == "redis":
            log.info(f"    Version: {svc.get('version', '?')}")
            log.info(f"    Memory: {svc.get('memory_used_mb', 0)} MB")
            log.info(f"    Keys: {svc.get('total_keys', 0)}")
            log.info(f"    Clients: {svc.get('connected_clients', 0)}")
            log.info(f"    Zombies: {svc.get('zombie_connections', 0)}")
        elif svc["service"] == "n8n":
            log.info(f"    Workflows: {svc.get('active_workflows', 0)} active / {svc.get('total_workflows', 0)} total")
        elif svc["service"] == "browserless":
            log.info(f"    Available: {svc.get('available', False)}")
            log.info(f"    Sessions: {svc.get('running_sessions', 0)}/{svc.get('max_concurrent', 10)}")

        for issue in svc.get("issues", []):
            log.info(f"    ⚠ {issue}")

    log.info("")
    log.info(f"Overall: {state['overall'].upper()}")
    return state


# ─── ENTRY POINT ────────────────────────────────────────────────────────

if __name__ == "__main__":
    cmd = sys.argv[1] if len(sys.argv) > 1 else "check"

    if cmd == "check":
        run_check()
    elif cmd == "backup":
        run_backup()
    elif cmd == "fix":
        run_fix()
    elif cmd == "full":
        run_full()
    elif cmd == "report":
        run_report()
    else:
        print(f"Usage: {sys.argv[0]} {{check|backup|fix|full|report}}")
        sys.exit(1)
