#!/usr/bin/env python3
"""
Memory Monitor for Multi-Agent Operations
Prevents OOM crashes during intensive agent sprints
"""

import psutil
import time
import json
from datetime import datetime
from pathlib import Path

class MemoryMonitor:
    def __init__(self, 
                 warning_threshold=70,  # Warn at 70% RAM
                 critical_threshold=85,  # Stop at 85% RAM
                 log_file="memory_monitor.log"):
        self.warning_threshold = warning_threshold
        self.critical_threshold = critical_threshold
        self.log_file = Path(log_file)
        self.alerts = []
        
    def get_memory_status(self):
        """Get current memory status"""
        mem = psutil.virtual_memory()
        swap = psutil.swap_memory()
        
        return {
            'timestamp': datetime.now().isoformat(),
            'ram_total_gb': round(mem.total / (1024**3), 2),
            'ram_used_gb': round(mem.used / (1024**3), 2),
            'ram_available_gb': round(mem.available / (1024**3), 2),
            'ram_percent': mem.percent,
            'swap_used_gb': round(swap.used / (1024**3), 2),
            'swap_percent': swap.percent,
            'status': self._determine_status(mem.percent, swap.percent)
        }
    
    def _determine_status(self, ram_percent, swap_percent):
        """Determine system status based on thresholds"""
        if ram_percent >= self.critical_threshold:
            return 'CRITICAL'
        elif swap_percent > 50:
            return 'WARNING_SWAP'
        elif ram_percent >= self.warning_threshold:
            return 'WARNING'
        else:
            return 'OK'
    
    def check_safe_to_spawn(self, agents_to_spawn=1, mb_per_agent=500):
        """Check if safe to spawn more agents"""
        status = self.get_memory_status()
        required_gb = (agents_to_spawn * mb_per_agent) / 1024
        
        safe = status['ram_available_gb'] > required_gb * 1.5  # 50% safety margin
        
        return {
            'safe': safe,
            'status': status,
            'required_gb': round(required_gb, 2),
            'available_gb': status['ram_available_gb'],
            'recommendation': self._get_recommendation(safe, status)
        }
    
    def _get_recommendation(self, safe, status):
        """Get recommendation based on current state"""
        if status['status'] == 'CRITICAL':
            return "STOP: Memory critical. Kill some processes or wait."
        elif status['status'] == 'WARNING_SWAP':
            return "CAUTION: Heavy swap usage. Consider pausing."
        elif status['status'] == 'WARNING':
            return "WARNING: Approaching memory limits. Proceed carefully."
        elif safe:
            return "OK: Safe to spawn agents."
        else:
            return "WAIT: Insufficient memory. Wait for GC or reduce agents."
    
    def log_status(self, status=None):
        """Log current status to file"""
        if status is None:
            status = self.get_memory_status()
        
        with open(self.log_file, 'a') as f:
            f.write(json.dumps(status) + '\n')
        
        # Print to console with color coding
        self._print_status(status)
    
    def _print_status(self, status):
        """Print formatted status"""
        color = {
            'OK': '\033[92m',       # Green
            'WARNING': '\033[93m',   # Yellow
            'WARNING_SWAP': '\033[93m',
            'CRITICAL': '\033[91m'   # Red
        }.get(status['status'], '')
        reset = '\033[0m'
        
        print(f"{color}[{status['timestamp']}] {status['status']}{reset}")
        print(f"  RAM: {status['ram_used_gb']}/{status['ram_total_gb']} GB "
              f"({status['ram_percent']}%)")
        print(f"  Available: {status['ram_available_gb']} GB")
        if status['swap_percent'] > 0:
            print(f"  Swap: {status['swap_used_gb']} GB ({status['swap_percent']}%)")
        print()
    
    def monitor_loop(self, interval=15):
        """Continuous monitoring loop"""
        print(f"Starting memory monitor (checking every {interval}s)")
        print(f"Warning threshold: {self.warning_threshold}%")
        print(f"Critical threshold: {self.critical_threshold}%")
        print("-" * 50)
        
        try:
            while True:
                self.log_status()
                time.sleep(interval)
        except KeyboardInterrupt:
            print("\nMonitoring stopped.")


def quick_check():
    """Quick memory check - returns True if safe to proceed"""
    monitor = MemoryMonitor()
    result = monitor.check_safe_to_spawn(agents_to_spawn=5)
    
    print(f"Status: {result['status']['status']}")
    print(f"Available: {result['available_gb']} GB")
    print(f"Recommendation: {result['recommendation']}")
    print()
    
    return result['safe']


if __name__ == '__main__':
    import sys
    
    if len(sys.argv) > 1 and sys.argv[1] == 'monitor':
        # Continuous monitoring mode
        monitor = MemoryMonitor()
        monitor.monitor_loop()
    elif len(sys.argv) > 1 and sys.argv[1] == 'check':
        # Quick check mode
        agents = int(sys.argv[2]) if len(sys.argv) > 2 else 5
        monitor = MemoryMonitor()
        result = monitor.check_safe_to_spawn(agents_to_spawn=agents)
        print(json.dumps(result, indent=2))
        sys.exit(0 if result['safe'] else 1)
    else:
        # Single status check
        quick_check()
