"""
OpenWork Health Monitor
=======================
Monitors OpenWork connection health and reports to Redis/Genesis health system.

Features:
- WebSocket connection health monitoring
- Latency tracking
- Reconnection statistics
- Redis health key updates
- Integration with Genesis health aggregator

Author: Genesis System
Version: 1.0.0
"""

import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

logger = logging.getLogger(__name__)


class HealthStatus(Enum):
    """Health status levels."""
    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNHEALTHY = "unhealthy"
    UNKNOWN = "unknown"


class ComponentType(Enum):
    """Types of OpenWork components to monitor."""
    BRIDGE = "bridge"
    VOICE_CHANNEL = "voice_channel"
    TTS_SERVICE = "tts_service"
    STT_SERVICE = "stt_service"
    APPROVAL_WORKFLOW = "approval_workflow"
    TASK_DISPATCHER = "task_dispatcher"


@dataclass
class HealthMetrics:
    """Health metrics for a component."""
    component: ComponentType
    status: HealthStatus
    latency_ms: Optional[float] = None
    uptime_seconds: float = 0
    last_activity: Optional[datetime] = None
    error_count: int = 0
    warning_count: int = 0
    reconnect_count: int = 0
    messages_sent: int = 0
    messages_received: int = 0
    actions_processed: int = 0
    actions_failed: int = 0
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        """Serialize to dictionary."""
        return {
            "component": self.component.value,
            "status": self.status.value,
            "latency_ms": self.latency_ms,
            "uptime_seconds": self.uptime_seconds,
            "last_activity": self.last_activity.isoformat() if self.last_activity else None,
            "error_count": self.error_count,
            "warning_count": self.warning_count,
            "reconnect_count": self.reconnect_count,
            "messages_sent": self.messages_sent,
            "messages_received": self.messages_received,
            "actions_processed": self.actions_processed,
            "actions_failed": self.actions_failed,
            "metadata": self.metadata
        }


@dataclass
class HealthCheckResult:
    """Result of a health check."""
    timestamp: datetime
    overall_status: HealthStatus
    components: Dict[str, HealthMetrics]
    alerts: List[str]
    recommendations: List[str]

    def to_dict(self) -> Dict[str, Any]:
        """Serialize to dictionary."""
        return {
            "timestamp": self.timestamp.isoformat(),
            "overall_status": self.overall_status.value,
            "components": {k: v.to_dict() for k, v in self.components.items()},
            "alerts": self.alerts,
            "recommendations": self.recommendations
        }


class HealthChecker:
    """
    Performs health checks on individual components.
    """

    def __init__(self, timeout_seconds: float = 5.0):
        self.timeout = timeout_seconds
        self._check_functions: Dict[ComponentType, Callable] = {}

    def register_check(
        self,
        component: ComponentType,
        check_func: Callable[[], HealthMetrics]
    ):
        """Register a health check function for a component."""
        self._check_functions[component] = check_func

    async def check_component(self, component: ComponentType) -> HealthMetrics:
        """Run health check for a single component."""
        if component not in self._check_functions:
            return HealthMetrics(
                component=component,
                status=HealthStatus.UNKNOWN,
                metadata={"error": "No health check registered"}
            )

        try:
            check_func = self._check_functions[component]

            if asyncio.iscoroutinefunction(check_func):
                result = await asyncio.wait_for(check_func(), timeout=self.timeout)
            else:
                result = check_func()

            return result

        except asyncio.TimeoutError:
            return HealthMetrics(
                component=component,
                status=HealthStatus.UNHEALTHY,
                metadata={"error": "Health check timed out"}
            )
        except Exception as e:
            logger.error(f"Health check failed for {component.value}: {e}")
            return HealthMetrics(
                component=component,
                status=HealthStatus.UNHEALTHY,
                metadata={"error": str(e)}
            )

    async def check_all(self) -> Dict[ComponentType, HealthMetrics]:
        """Run health checks on all registered components."""
        results = {}

        for component in self._check_functions:
            results[component] = await self.check_component(component)

        return results


class OpenWorkHealthMonitor:
    """
    Main health monitoring system for OpenWork integration.

    Monitors all OpenWork components and reports health to Redis.
    """

    def __init__(
        self,
        config: Optional[Dict[str, Any]] = None,
        redis_client: Optional[Any] = None
    ):
        self.config = config or {}
        self.redis = redis_client

        # Configuration
        self._check_interval = config.get("health_check_interval", 60)
        self._redis_key = config.get("redis_health_key", "genesis:health:openwork")
        self._metrics_enabled = config.get("metrics_enabled", True)

        # Health checker
        self.checker = HealthChecker()

        # Component references (set externally)
        self._bridge = None
        self._voice_channel = None
        self._tts_service = None
        self._approval_workflow = None
        self._task_dispatcher = None

        # Metrics tracking
        self._start_time = datetime.utcnow()
        self._last_check: Optional[HealthCheckResult] = None
        self._health_history: List[HealthCheckResult] = []

        # Background task
        self._monitor_task: Optional[asyncio.Task] = None
        self._is_running = False

        # Alert callbacks
        self._alert_callbacks: List[Callable] = []

        logger.info("OpenWorkHealthMonitor initialized")

    def set_bridge(self, bridge):
        """Set reference to OpenWorkBridge."""
        self._bridge = bridge
        self.checker.register_check(ComponentType.BRIDGE, self._check_bridge)

    def set_voice_channel(self, channel):
        """Set reference to KinanAivaVoiceChannel."""
        self._voice_channel = channel
        self.checker.register_check(ComponentType.VOICE_CHANNEL, self._check_voice_channel)

    def set_tts_service(self, service):
        """Set reference to TTS service."""
        self._tts_service = service
        self.checker.register_check(ComponentType.TTS_SERVICE, self._check_tts_service)

    def set_approval_workflow(self, workflow):
        """Set reference to ApprovalWorkflow."""
        self._approval_workflow = workflow
        self.checker.register_check(ComponentType.APPROVAL_WORKFLOW, self._check_approval_workflow)

    def set_task_dispatcher(self, dispatcher):
        """Set reference to TaskDispatcher."""
        self._task_dispatcher = dispatcher
        self.checker.register_check(ComponentType.TASK_DISPATCHER, self._check_task_dispatcher)

    def add_alert_callback(self, callback: Callable):
        """Add callback for health alerts."""
        self._alert_callbacks.append(callback)

    async def start(self):
        """Start the health monitor."""
        if self._is_running:
            return

        self._is_running = True
        self._start_time = datetime.utcnow()
        self._monitor_task = asyncio.create_task(self._monitoring_loop())

        logger.info("OpenWorkHealthMonitor started")

    async def stop(self):
        """Stop the health monitor."""
        self._is_running = False

        if self._monitor_task:
            self._monitor_task.cancel()
            try:
                await self._monitor_task
            except asyncio.CancelledError:
                pass

        logger.info("OpenWorkHealthMonitor stopped")

    async def _monitoring_loop(self):
        """Background monitoring loop."""
        while self._is_running:
            try:
                result = await self.check_health()
                await self._process_health_result(result)
                await asyncio.sleep(self._check_interval)

            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"Monitoring loop error: {e}")
                await asyncio.sleep(self._check_interval)

    async def check_health(self) -> HealthCheckResult:
        """
        Perform a complete health check.

        Returns:
            HealthCheckResult with all component statuses
        """
        # Run all component checks
        component_results = await self.checker.check_all()

        # Determine overall status
        overall_status = self._calculate_overall_status(component_results)

        # Generate alerts and recommendations
        alerts, recommendations = self._generate_alerts_and_recommendations(component_results)

        result = HealthCheckResult(
            timestamp=datetime.utcnow(),
            overall_status=overall_status,
            components={c.value: m for c, m in component_results.items()},
            alerts=alerts,
            recommendations=recommendations
        )

        self._last_check = result
        self._health_history.append(result)

        # Keep history manageable
        if len(self._health_history) > 1440:  # 24 hours at 1-minute intervals
            self._health_history = self._health_history[-1440:]

        return result

    def _calculate_overall_status(
        self,
        components: Dict[ComponentType, HealthMetrics]
    ) -> HealthStatus:
        """Calculate overall health status from component statuses."""
        if not components:
            return HealthStatus.UNKNOWN

        statuses = [m.status for m in components.values()]

        if any(s == HealthStatus.UNHEALTHY for s in statuses):
            return HealthStatus.UNHEALTHY
        elif any(s == HealthStatus.DEGRADED for s in statuses):
            return HealthStatus.DEGRADED
        elif any(s == HealthStatus.UNKNOWN for s in statuses):
            return HealthStatus.DEGRADED
        else:
            return HealthStatus.HEALTHY

    def _generate_alerts_and_recommendations(
        self,
        components: Dict[ComponentType, HealthMetrics]
    ) -> tuple[List[str], List[str]]:
        """Generate alerts and recommendations based on health status."""
        alerts = []
        recommendations = []

        for component, metrics in components.items():
            # Status-based alerts
            if metrics.status == HealthStatus.UNHEALTHY:
                alerts.append(f"{component.value} is UNHEALTHY")
            elif metrics.status == HealthStatus.DEGRADED:
                alerts.append(f"{component.value} is DEGRADED")

            # Error count alerts
            if metrics.error_count > 10:
                alerts.append(f"{component.value} has {metrics.error_count} errors")
                recommendations.append(f"Review {component.value} error logs")

            # Reconnection alerts
            if metrics.reconnect_count > 5:
                alerts.append(f"{component.value} has reconnected {metrics.reconnect_count} times")
                recommendations.append(f"Check {component.value} connection stability")

            # Latency alerts
            if metrics.latency_ms and metrics.latency_ms > 1000:
                alerts.append(f"{component.value} latency is {metrics.latency_ms:.0f}ms")
                recommendations.append(f"Investigate {component.value} performance")

            # Inactivity alerts
            if metrics.last_activity:
                inactive_seconds = (datetime.utcnow() - metrics.last_activity).total_seconds()
                if inactive_seconds > 300:  # 5 minutes
                    alerts.append(f"{component.value} inactive for {inactive_seconds:.0f}s")

            # Action failure alerts
            if metrics.actions_processed > 0:
                failure_rate = metrics.actions_failed / metrics.actions_processed
                if failure_rate > 0.1:  # >10% failure rate
                    alerts.append(f"{component.value} action failure rate: {failure_rate:.1%}")
                    recommendations.append(f"Review {component.value} action errors")

        return alerts, recommendations

    async def _process_health_result(self, result: HealthCheckResult):
        """Process health check result - update Redis, trigger alerts."""

        # Update Redis
        if self.redis and self._metrics_enabled:
            try:
                await self._update_redis(result)
            except Exception as e:
                logger.error(f"Failed to update Redis health: {e}")

        # Trigger alert callbacks if unhealthy
        if result.overall_status in [HealthStatus.UNHEALTHY, HealthStatus.DEGRADED]:
            for callback in self._alert_callbacks:
                try:
                    if asyncio.iscoroutinefunction(callback):
                        await callback(result)
                    else:
                        callback(result)
                except Exception as e:
                    logger.error(f"Alert callback error: {e}")

    async def _update_redis(self, result: HealthCheckResult):
        """Update health status in Redis."""
        if not self.redis:
            return

        try:
            # Main health key
            health_data = result.to_dict()
            health_data["monitor_uptime"] = (datetime.utcnow() - self._start_time).total_seconds()

            # Try async Redis first, fall back to sync
            if hasattr(self.redis, 'set') and asyncio.iscoroutinefunction(self.redis.set):
                await self.redis.set(
                    self._redis_key,
                    json.dumps(health_data),
                    ex=self._check_interval * 3  # TTL = 3x check interval
                )
            elif hasattr(self.redis, 'set'):
                self.redis.set(
                    self._redis_key,
                    json.dumps(health_data),
                    ex=self._check_interval * 3
                )

            logger.debug(f"Updated Redis health key: {self._redis_key}")

        except Exception as e:
            logger.error(f"Redis update error: {e}")

    # Component-specific health checks

    async def _check_bridge(self) -> HealthMetrics:
        """Check OpenWorkBridge health."""
        if not self._bridge:
            return HealthMetrics(
                component=ComponentType.BRIDGE,
                status=HealthStatus.UNKNOWN,
                metadata={"error": "Bridge not configured"}
            )

        try:
            # Get bridge status
            is_connected = getattr(self._bridge, 'connected', False)
            stats = getattr(self._bridge, 'get_statistics', lambda: {})()

            status = HealthStatus.HEALTHY if is_connected else HealthStatus.UNHEALTHY

            return HealthMetrics(
                component=ComponentType.BRIDGE,
                status=status,
                latency_ms=stats.get("avg_latency_ms"),
                uptime_seconds=stats.get("uptime_seconds", 0),
                last_activity=stats.get("last_activity"),
                error_count=stats.get("error_count", 0),
                reconnect_count=stats.get("reconnect_count", 0),
                messages_sent=stats.get("messages_sent", 0),
                messages_received=stats.get("messages_received", 0),
                actions_processed=stats.get("actions_processed", 0),
                actions_failed=stats.get("actions_failed", 0),
                metadata={
                    "connected": is_connected,
                    "mode": getattr(self._bridge, 'mode', 'unknown'),
                    "pending_actions": stats.get("pending_actions", 0)
                }
            )

        except Exception as e:
            return HealthMetrics(
                component=ComponentType.BRIDGE,
                status=HealthStatus.UNHEALTHY,
                metadata={"error": str(e)}
            )

    async def _check_voice_channel(self) -> HealthMetrics:
        """Check KinanAivaVoiceChannel health."""
        if not self._voice_channel:
            return HealthMetrics(
                component=ComponentType.VOICE_CHANNEL,
                status=HealthStatus.UNKNOWN,
                metadata={"error": "Voice channel not configured"}
            )

        try:
            is_active = getattr(self._voice_channel, 'is_active', False)
            stats = getattr(self._voice_channel, 'get_statistics', lambda: {})()

            status = HealthStatus.HEALTHY if is_active else HealthStatus.DEGRADED

            return HealthMetrics(
                component=ComponentType.VOICE_CHANNEL,
                status=status,
                uptime_seconds=stats.get("uptime_seconds", 0),
                last_activity=stats.get("last_activity"),
                messages_sent=stats.get("messages_sent", 0),
                messages_received=stats.get("messages_received", 0),
                error_count=stats.get("error_count", 0),
                metadata={
                    "stt_model": stats.get("stt_model"),
                    "tts_backend": stats.get("tts_backend"),
                    "connected_clients": stats.get("connected_clients", 0)
                }
            )

        except Exception as e:
            return HealthMetrics(
                component=ComponentType.VOICE_CHANNEL,
                status=HealthStatus.UNHEALTHY,
                metadata={"error": str(e)}
            )

    async def _check_tts_service(self) -> HealthMetrics:
        """Check TTS service health."""
        if not self._tts_service:
            return HealthMetrics(
                component=ComponentType.TTS_SERVICE,
                status=HealthStatus.UNKNOWN,
                metadata={"error": "TTS service not configured"}
            )

        try:
            stats = getattr(self._tts_service, 'get_statistics', lambda: {})()

            # Check if primary backend is available
            primary_available = stats.get("primary_backend_available", True)
            status = HealthStatus.HEALTHY if primary_available else HealthStatus.DEGRADED

            return HealthMetrics(
                component=ComponentType.TTS_SERVICE,
                status=status,
                latency_ms=stats.get("avg_latency_ms"),
                actions_processed=stats.get("requests_total", 0),
                actions_failed=stats.get("requests_failed", 0),
                metadata={
                    "active_backend": stats.get("active_backend"),
                    "daily_cost": stats.get("daily_cost", 0),
                    "budget_remaining": stats.get("budget_remaining")
                }
            )

        except Exception as e:
            return HealthMetrics(
                component=ComponentType.TTS_SERVICE,
                status=HealthStatus.UNHEALTHY,
                metadata={"error": str(e)}
            )

    async def _check_approval_workflow(self) -> HealthMetrics:
        """Check ApprovalWorkflow health."""
        if not self._approval_workflow:
            return HealthMetrics(
                component=ComponentType.APPROVAL_WORKFLOW,
                status=HealthStatus.UNKNOWN,
                metadata={"error": "Approval workflow not configured"}
            )

        try:
            stats = getattr(self._approval_workflow, 'get_statistics', lambda: {})()
            pending = len(getattr(self._approval_workflow, 'get_pending_requests', lambda: [])())

            # Degraded if too many pending requests
            status = HealthStatus.HEALTHY
            if pending > 20:
                status = HealthStatus.DEGRADED

            return HealthMetrics(
                component=ComponentType.APPROVAL_WORKFLOW,
                status=status,
                actions_processed=stats.get("total_actions", 0),
                metadata={
                    "pending_requests": pending,
                    "auto_approve_enabled": stats.get("auto_approve_enabled"),
                    "by_decision": stats.get("by_decision", {})
                }
            )

        except Exception as e:
            return HealthMetrics(
                component=ComponentType.APPROVAL_WORKFLOW,
                status=HealthStatus.UNHEALTHY,
                metadata={"error": str(e)}
            )

    async def _check_task_dispatcher(self) -> HealthMetrics:
        """Check TaskDispatcher health."""
        if not self._task_dispatcher:
            return HealthMetrics(
                component=ComponentType.TASK_DISPATCHER,
                status=HealthStatus.UNKNOWN,
                metadata={"error": "Task dispatcher not configured"}
            )

        try:
            stats = getattr(self._task_dispatcher, 'get_statistics', lambda: {})()

            dispatched = stats.get("tasks_dispatched", 0)
            completed = stats.get("tasks_completed", 0)
            failed = stats.get("tasks_failed", 0)

            # Calculate health based on success rate
            if dispatched > 0:
                success_rate = completed / dispatched
                if success_rate >= 0.9:
                    status = HealthStatus.HEALTHY
                elif success_rate >= 0.7:
                    status = HealthStatus.DEGRADED
                else:
                    status = HealthStatus.UNHEALTHY
            else:
                status = HealthStatus.HEALTHY

            return HealthMetrics(
                component=ComponentType.TASK_DISPATCHER,
                status=status,
                actions_processed=dispatched,
                actions_failed=failed,
                metadata={
                    "tasks_pending": stats.get("tasks_pending", 0),
                    "tasks_completed": completed,
                    "success_rate": f"{(completed/dispatched*100):.1f}%" if dispatched else "N/A"
                }
            )

        except Exception as e:
            return HealthMetrics(
                component=ComponentType.TASK_DISPATCHER,
                status=HealthStatus.UNHEALTHY,
                metadata={"error": str(e)}
            )

    # Public API

    def get_current_health(self) -> Optional[HealthCheckResult]:
        """Get the most recent health check result."""
        return self._last_check

    def get_health_history(
        self,
        hours: int = 1,
        component: Optional[ComponentType] = None
    ) -> List[HealthCheckResult]:
        """
        Get health check history.

        Args:
            hours: Number of hours of history
            component: Optional filter by component

        Returns:
            List of health check results
        """
        cutoff = datetime.utcnow() - timedelta(hours=hours)

        history = [
            r for r in self._health_history
            if r.timestamp > cutoff
        ]

        return history

    def get_uptime(self) -> float:
        """Get monitor uptime in seconds."""
        return (datetime.utcnow() - self._start_time).total_seconds()

    def get_summary(self) -> Dict[str, Any]:
        """Get health summary for dashboard/API."""
        result = self._last_check

        if not result:
            return {
                "status": "unknown",
                "message": "No health check performed yet",
                "uptime": self.get_uptime()
            }

        return {
            "status": result.overall_status.value,
            "timestamp": result.timestamp.isoformat(),
            "components": {
                name: {
                    "status": metrics.status.value,
                    "latency_ms": metrics.latency_ms
                }
                for name, metrics in result.components.items()
            },
            "alerts": result.alerts,
            "recommendations": result.recommendations,
            "uptime": self.get_uptime()
        }


# Factory function
def create_health_monitor(
    config_path: Optional[str] = None,
    redis_client: Optional[Any] = None
) -> OpenWorkHealthMonitor:
    """
    Create a health monitor from configuration.

    Args:
        config_path: Path to openwork_config.json
        redis_client: Optional Redis client

    Returns:
        Configured OpenWorkHealthMonitor
    """
    config = {}

    if config_path:
        config_file = Path(config_path)
        if config_file.exists():
            with open(config_file) as f:
                full_config = json.load(f)
                config = full_config.get("monitoring", {})

    return OpenWorkHealthMonitor(config=config, redis_client=redis_client)


# Example usage
async def example_usage():
    """Demonstrate the health monitor."""

    # Create monitor
    monitor = OpenWorkHealthMonitor(
        config={
            "health_check_interval": 10,
            "redis_health_key": "genesis:health:openwork",
            "metrics_enabled": True
        }
    )

    # Add alert callback
    def on_alert(result: HealthCheckResult):
        print(f"ALERT: {result.overall_status.value}")
        for alert in result.alerts:
            print(f"  - {alert}")

    monitor.add_alert_callback(on_alert)

    # Start monitor
    await monitor.start()

    # Run a health check
    result = await monitor.check_health()
    print(f"\nHealth Check Result:")
    print(f"  Status: {result.overall_status.value}")
    print(f"  Components: {len(result.components)}")
    print(f"  Alerts: {len(result.alerts)}")

    # Get summary
    summary = monitor.get_summary()
    print(f"\nSummary:")
    print(json.dumps(summary, indent=2, default=str))

    # Wait a bit to see monitoring loop
    await asyncio.sleep(15)

    # Stop
    await monitor.stop()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    asyncio.run(example_usage())
