"""
OpenWork WebSocket Bridge
=========================
Core bridge for Genesis-OpenWork bidirectional communication.

Features:
- WebSocket client to connect to OpenWork desktop app
- WebSocket server for OpenWork to connect to Genesis
- Action routing and response handling
- Heartbeat and reconnection logic
- Event emission for RWL integration

Usage:
    from core.integrations.openwork_bridge import OpenWorkBridge

    bridge = OpenWorkBridge()
    await bridge.start()
    await bridge.send_action("file_operation", {"action": "create", "path": "/tmp/test.txt"})

Author: Genesis System
Version: 1.0.0
"""

import os
import sys
import json
import asyncio
import logging
import uuid
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Dict, Any, List, Callable, Awaitable, Union
from dataclasses import dataclass, field
from enum import Enum, auto
from collections import defaultdict

# Add genesis path
GENESIS_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(GENESIS_ROOT))

# Conditional imports
try:
    import websockets
    from websockets.client import connect as ws_connect
    from websockets.server import serve as ws_serve
    WEBSOCKETS_AVAILABLE = True
except ImportError:
    WEBSOCKETS_AVAILABLE = False

try:
    import aiohttp
    AIOHTTP_AVAILABLE = True
except ImportError:
    AIOHTTP_AVAILABLE = False

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class ActionType(Enum):
    """Types of actions that can be sent to OpenWork."""
    FILE_OPERATION = "file_operation"
    BROWSER_TASK = "browser_task"
    DOCUMENT_AUTOMATION = "document_automation"
    SYSTEM_COMMAND = "system_command"
    CLIPBOARD = "clipboard"
    NOTIFICATION = "notification"
    VOICE_COMMAND = "voice_command"
    APPROVAL_REQUEST = "approval_request"


class ActionStatus(Enum):
    """Status of an action."""
    PENDING = "pending"
    SENT = "sent"
    ACKNOWLEDGED = "acknowledged"
    IN_PROGRESS = "in_progress"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"
    APPROVAL_REQUIRED = "approval_required"
    APPROVED = "approved"
    REJECTED = "rejected"


@dataclass
class OpenWorkAction:
    """An action to be executed by OpenWork."""
    action_id: str
    action_type: ActionType
    payload: Dict[str, Any]
    status: ActionStatus = ActionStatus.PENDING
    requires_approval: bool = False
    priority: int = 5  # 1-10, 1 is highest
    created_at: datetime = field(default_factory=datetime.utcnow)
    sent_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    result: Optional[Dict[str, Any]] = None
    error: Optional[str] = None
    retries: int = 0
    max_retries: int = 3

    def to_dict(self) -> Dict[str, Any]:
        return {
            "action_id": self.action_id,
            "action_type": self.action_type.value,
            "payload": self.payload,
            "status": self.status.value,
            "requires_approval": self.requires_approval,
            "priority": self.priority,
            "created_at": self.created_at.isoformat(),
            "sent_at": self.sent_at.isoformat() if self.sent_at else None,
            "completed_at": self.completed_at.isoformat() if self.completed_at else None,
            "result": self.result,
            "error": self.error,
            "retries": self.retries
        }

    def to_wire_format(self) -> Dict[str, Any]:
        """Convert to format for WebSocket transmission."""
        return {
            "type": "action_request",
            "action_id": self.action_id,
            "action_type": self.action_type.value,
            "payload": self.payload,
            "requires_approval": self.requires_approval,
            "priority": self.priority,
            "timestamp": datetime.utcnow().isoformat()
        }


@dataclass
class OpenWorkEvent:
    """An event received from OpenWork."""
    event_id: str
    event_type: str
    data: Dict[str, Any]
    timestamp: datetime = field(default_factory=datetime.utcnow)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "event_id": self.event_id,
            "event_type": self.event_type,
            "data": self.data,
            "timestamp": self.timestamp.isoformat()
        }


class ConnectionState(Enum):
    """Connection state to OpenWork."""
    DISCONNECTED = auto()
    CONNECTING = auto()
    CONNECTED = auto()
    RECONNECTING = auto()
    ERROR = auto()


class OpenWorkBridge:
    """
    Main bridge for Genesis-OpenWork communication.

    Can operate in two modes:
    1. Client mode: Genesis connects to OpenWork server
    2. Server mode: Genesis hosts server, OpenWork connects
    """

    def __init__(
        self,
        mode: str = "server",  # "server" or "client"
        server_host: str = "0.0.0.0",
        server_port: int = 8767,
        client_url: Optional[str] = None,
        heartbeat_interval: int = 30,
        reconnect_delay: int = 5,
        max_reconnect_attempts: int = 10
    ):
        if not WEBSOCKETS_AVAILABLE:
            raise ImportError("websockets library required. Install with: pip install websockets")

        self.mode = mode
        self.server_host = server_host
        self.server_port = server_port
        self.client_url = client_url or f"ws://localhost:{server_port}"
        self.heartbeat_interval = heartbeat_interval
        self.reconnect_delay = reconnect_delay
        self.max_reconnect_attempts = max_reconnect_attempts

        # Connection state
        self.state = ConnectionState.DISCONNECTED
        self._server = None
        self._client_ws = None
        self._connected_clients: List[Any] = []

        # Action management
        self._pending_actions: Dict[str, OpenWorkAction] = {}
        self._completed_actions: Dict[str, OpenWorkAction] = {}
        self._action_queue: asyncio.Queue = asyncio.Queue()

        # Event handling
        self._event_handlers: Dict[str, List[Callable[[OpenWorkEvent], Awaitable[None]]]] = defaultdict(list)
        self._action_callbacks: Dict[str, Callable[[OpenWorkAction], Awaitable[None]]] = {}

        # Background tasks
        self._running = False
        self._heartbeat_task = None
        self._queue_processor_task = None
        self._reconnect_task = None
        self._reconnect_attempts = 0

        # Metrics
        self._metrics = {
            "actions_sent": 0,
            "actions_completed": 0,
            "actions_failed": 0,
            "events_received": 0,
            "reconnections": 0
        }

        logger.info(f"OpenWorkBridge initialized in {mode} mode")

    async def start(self):
        """Start the bridge."""
        if self._running:
            return

        self._running = True
        self.state = ConnectionState.CONNECTING

        if self.mode == "server":
            await self._start_server()
        else:
            await self._start_client()

        # Start background tasks
        self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
        self._queue_processor_task = asyncio.create_task(self._process_action_queue())

        logger.info(f"OpenWorkBridge started ({self.mode} mode)")

    async def stop(self):
        """Stop the bridge."""
        self._running = False
        self.state = ConnectionState.DISCONNECTED

        # Cancel background tasks
        for task in [self._heartbeat_task, self._queue_processor_task, self._reconnect_task]:
            if task:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass

        # Close connections
        if self._server:
            self._server.close()
            await self._server.wait_closed()

        if self._client_ws:
            await self._client_ws.close()

        for client in self._connected_clients:
            await client.close()

        logger.info("OpenWorkBridge stopped")

    async def _start_server(self):
        """Start WebSocket server."""
        self._server = await ws_serve(
            self._handle_client_connection,
            self.server_host,
            self.server_port
        )
        self.state = ConnectionState.CONNECTED
        logger.info(f"OpenWork bridge server listening on ws://{self.server_host}:{self.server_port}")

    async def _start_client(self):
        """Start WebSocket client connection to OpenWork."""
        try:
            self._client_ws = await ws_connect(self.client_url)
            self.state = ConnectionState.CONNECTED
            self._reconnect_attempts = 0

            # Start message receiver
            asyncio.create_task(self._receive_messages(self._client_ws))

            logger.info(f"Connected to OpenWork at {self.client_url}")

        except Exception as e:
            logger.error(f"Failed to connect to OpenWork: {e}")
            self.state = ConnectionState.ERROR
            await self._schedule_reconnect()

    async def _handle_client_connection(self, websocket, path):
        """Handle incoming WebSocket connection from OpenWork."""
        client_id = id(websocket)
        self._connected_clients.append(websocket)
        logger.info(f"OpenWork client connected: {client_id}")

        try:
            async for message in websocket:
                await self._handle_message(websocket, message)
        except websockets.exceptions.ConnectionClosed:
            pass
        finally:
            self._connected_clients.remove(websocket)
            logger.info(f"OpenWork client disconnected: {client_id}")

    async def _receive_messages(self, websocket):
        """Receive messages from WebSocket."""
        try:
            async for message in websocket:
                await self._handle_message(websocket, message)
        except websockets.exceptions.ConnectionClosed:
            logger.warning("Connection to OpenWork closed")
            self.state = ConnectionState.DISCONNECTED
            await self._schedule_reconnect()

    async def _handle_message(self, websocket, message: str):
        """Handle incoming WebSocket message."""
        try:
            data = json.loads(message)
            msg_type = data.get("type")

            if msg_type == "pong":
                # Heartbeat response
                pass

            elif msg_type == "action_response":
                # Response to an action we sent
                await self._handle_action_response(data)

            elif msg_type == "action_progress":
                # Progress update for an action
                await self._handle_action_progress(data)

            elif msg_type == "event":
                # Event from OpenWork
                await self._handle_event(data)

            elif msg_type == "approval_response":
                # Response to approval request
                await self._handle_approval_response(data)

            elif msg_type == "voice_input":
                # Voice input from OpenWork
                await self._handle_voice_input(data)

            else:
                logger.warning(f"Unknown message type: {msg_type}")

        except json.JSONDecodeError:
            logger.warning(f"Invalid JSON message: {message[:100]}")
        except Exception as e:
            logger.error(f"Message handling error: {e}")

    async def _handle_action_response(self, data: Dict[str, Any]):
        """Handle action response from OpenWork."""
        action_id = data.get("action_id")
        if not action_id or action_id not in self._pending_actions:
            return

        action = self._pending_actions[action_id]
        success = data.get("success", False)

        if success:
            action.status = ActionStatus.COMPLETED
            action.result = data.get("result")
            self._metrics["actions_completed"] += 1
        else:
            action.status = ActionStatus.FAILED
            action.error = data.get("error", "Unknown error")
            self._metrics["actions_failed"] += 1

        action.completed_at = datetime.utcnow()

        # Move to completed
        del self._pending_actions[action_id]
        self._completed_actions[action_id] = action

        # Call callback if registered
        if action_id in self._action_callbacks:
            callback = self._action_callbacks.pop(action_id)
            await callback(action)

        logger.info(f"Action {action_id} {action.status.value}: {action.result or action.error}")

    async def _handle_action_progress(self, data: Dict[str, Any]):
        """Handle action progress update."""
        action_id = data.get("action_id")
        if not action_id or action_id not in self._pending_actions:
            return

        action = self._pending_actions[action_id]
        action.status = ActionStatus.IN_PROGRESS

        # Emit progress event
        event = OpenWorkEvent(
            event_id=str(uuid.uuid4()),
            event_type="action_progress",
            data={
                "action_id": action_id,
                "progress": data.get("progress", 0),
                "message": data.get("message", "")
            }
        )
        await self._emit_event(event)

    async def _handle_event(self, data: Dict[str, Any]):
        """Handle event from OpenWork."""
        event = OpenWorkEvent(
            event_id=data.get("event_id", str(uuid.uuid4())),
            event_type=data.get("event_type", "unknown"),
            data=data.get("data", {})
        )

        self._metrics["events_received"] += 1
        await self._emit_event(event)

    async def _handle_approval_response(self, data: Dict[str, Any]):
        """Handle approval response from OpenWork/Kinan."""
        action_id = data.get("action_id")
        approved = data.get("approved", False)

        if action_id and action_id in self._pending_actions:
            action = self._pending_actions[action_id]

            if approved:
                action.status = ActionStatus.APPROVED
                # Re-queue for execution
                await self._action_queue.put(action)
            else:
                action.status = ActionStatus.REJECTED
                action.error = data.get("reason", "Rejected by user")
                action.completed_at = datetime.utcnow()
                del self._pending_actions[action_id]
                self._completed_actions[action_id] = action

            logger.info(f"Action {action_id} {'approved' if approved else 'rejected'}")

    async def _handle_voice_input(self, data: Dict[str, Any]):
        """Handle voice input from OpenWork."""
        event = OpenWorkEvent(
            event_id=str(uuid.uuid4()),
            event_type="voice_input",
            data={
                "text": data.get("text", ""),
                "audio": data.get("audio"),  # Base64 encoded
                "source": "openwork"
            }
        )
        await self._emit_event(event)

    async def _emit_event(self, event: OpenWorkEvent):
        """Emit event to registered handlers."""
        handlers = self._event_handlers.get(event.event_type, [])
        handlers.extend(self._event_handlers.get("*", []))  # Wildcard handlers

        for handler in handlers:
            try:
                await handler(event)
            except Exception as e:
                logger.error(f"Event handler error: {e}")

    async def _heartbeat_loop(self):
        """Send periodic heartbeats."""
        while self._running:
            try:
                await asyncio.sleep(self.heartbeat_interval)

                if self.state == ConnectionState.CONNECTED:
                    await self._send_heartbeat()

            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"Heartbeat error: {e}")

    async def _send_heartbeat(self):
        """Send heartbeat to connected clients/server."""
        ping_msg = json.dumps({
            "type": "ping",
            "timestamp": datetime.utcnow().isoformat()
        })

        if self.mode == "server":
            for client in self._connected_clients:
                try:
                    await client.send(ping_msg)
                except Exception:
                    pass
        elif self._client_ws:
            try:
                await self._client_ws.send(ping_msg)
            except Exception:
                pass

    async def _schedule_reconnect(self):
        """Schedule reconnection attempt."""
        if self._reconnect_attempts >= self.max_reconnect_attempts:
            logger.error("Max reconnection attempts reached")
            self.state = ConnectionState.ERROR
            return

        self.state = ConnectionState.RECONNECTING
        self._reconnect_attempts += 1
        self._metrics["reconnections"] += 1

        delay = self.reconnect_delay * (2 ** min(self._reconnect_attempts - 1, 5))
        logger.info(f"Reconnecting in {delay}s (attempt {self._reconnect_attempts})")

        await asyncio.sleep(delay)

        if self._running and self.mode == "client":
            await self._start_client()

    async def _process_action_queue(self):
        """Process queued actions."""
        while self._running:
            try:
                action = await asyncio.wait_for(
                    self._action_queue.get(),
                    timeout=1.0
                )

                if self.state != ConnectionState.CONNECTED:
                    # Re-queue if not connected
                    await self._action_queue.put(action)
                    await asyncio.sleep(1)
                    continue

                await self._send_action(action)

            except asyncio.TimeoutError:
                continue
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"Action queue processing error: {e}")

    async def _send_action(self, action: OpenWorkAction):
        """Send action to OpenWork."""
        message = json.dumps(action.to_wire_format())
        action.status = ActionStatus.SENT
        action.sent_at = datetime.utcnow()
        self._pending_actions[action.action_id] = action
        self._metrics["actions_sent"] += 1

        try:
            if self.mode == "server":
                # Send to all connected clients
                for client in self._connected_clients:
                    await client.send(message)
            elif self._client_ws:
                await self._client_ws.send(message)

            logger.info(f"Action sent: {action.action_id} ({action.action_type.value})")

        except Exception as e:
            logger.error(f"Failed to send action: {e}")
            action.status = ActionStatus.FAILED
            action.error = str(e)

    # Public API

    async def send_action(
        self,
        action_type: Union[ActionType, str],
        payload: Dict[str, Any],
        requires_approval: bool = False,
        priority: int = 5,
        callback: Optional[Callable[[OpenWorkAction], Awaitable[None]]] = None
    ) -> OpenWorkAction:
        """
        Send an action to OpenWork.

        Args:
            action_type: Type of action
            payload: Action payload
            requires_approval: Whether action requires user approval
            priority: Priority (1-10, 1 is highest)
            callback: Callback when action completes

        Returns:
            OpenWorkAction object
        """
        if isinstance(action_type, str):
            action_type = ActionType(action_type)

        action = OpenWorkAction(
            action_id=str(uuid.uuid4()),
            action_type=action_type,
            payload=payload,
            requires_approval=requires_approval,
            priority=priority
        )

        if callback:
            self._action_callbacks[action.action_id] = callback

        await self._action_queue.put(action)
        return action

    async def send_file_operation(
        self,
        operation: str,
        path: str,
        content: Optional[str] = None,
        **kwargs
    ) -> OpenWorkAction:
        """Send a file operation to OpenWork."""
        return await self.send_action(
            ActionType.FILE_OPERATION,
            {
                "operation": operation,
                "path": path,
                "content": content,
                **kwargs
            }
        )

    async def send_browser_task(
        self,
        task: str,
        url: Optional[str] = None,
        **kwargs
    ) -> OpenWorkAction:
        """Send a browser task to OpenWork."""
        return await self.send_action(
            ActionType.BROWSER_TASK,
            {
                "task": task,
                "url": url,
                **kwargs
            }
        )

    async def send_notification(
        self,
        title: str,
        message: str,
        **kwargs
    ) -> OpenWorkAction:
        """Send a notification through OpenWork."""
        return await self.send_action(
            ActionType.NOTIFICATION,
            {
                "title": title,
                "message": message,
                **kwargs
            }
        )

    async def request_approval(
        self,
        action_description: str,
        action_payload: Dict[str, Any],
        timeout_seconds: int = 300
    ) -> OpenWorkAction:
        """Request user approval for an action."""
        return await self.send_action(
            ActionType.APPROVAL_REQUEST,
            {
                "description": action_description,
                "payload": action_payload,
                "timeout": timeout_seconds
            },
            requires_approval=True
        )

    def on_event(
        self,
        event_type: str,
        handler: Callable[[OpenWorkEvent], Awaitable[None]]
    ):
        """Register event handler."""
        self._event_handlers[event_type].append(handler)

    def get_pending_actions(self) -> List[Dict[str, Any]]:
        """Get pending actions."""
        return [a.to_dict() for a in self._pending_actions.values()]

    def get_completed_actions(self, limit: int = 50) -> List[Dict[str, Any]]:
        """Get recently completed actions."""
        actions = list(self._completed_actions.values())
        actions.sort(key=lambda a: a.completed_at or datetime.min, reverse=True)
        return [a.to_dict() for a in actions[:limit]]

    def get_metrics(self) -> Dict[str, Any]:
        """Get bridge metrics."""
        return {
            **self._metrics,
            "state": self.state.name,
            "mode": self.mode,
            "connected_clients": len(self._connected_clients),
            "pending_actions": len(self._pending_actions),
            "queue_size": self._action_queue.qsize()
        }

    @property
    def is_connected(self) -> bool:
        """Check if bridge is connected."""
        return self.state == ConnectionState.CONNECTED


# CLI for testing
async def main():
    """CLI for testing the bridge."""
    import argparse

    parser = argparse.ArgumentParser(description="OpenWork Bridge")
    parser.add_argument("--mode", type=str, default="server", choices=["server", "client"])
    parser.add_argument("--port", type=int, default=8767)
    parser.add_argument("--url", type=str, help="OpenWork URL (for client mode)")
    parser.add_argument("--test", action="store_true", help="Run test")
    args = parser.parse_args()

    if args.test:
        print("Testing OpenWork bridge...")

        bridge = OpenWorkBridge(mode="server", server_port=args.port)

        # Register event handler
        @bridge.on_event("*")
        async def log_event(event):
            print(f"Event: {event.event_type} - {event.data}")

        await bridge.start()
        print(f"Bridge started in server mode on port {args.port}")
        print("Connect with a WebSocket client to test.")
        print("Press Ctrl+C to stop.\n")

        try:
            await asyncio.Future()
        except KeyboardInterrupt:
            print("\nStopping...")
            await bridge.stop()

    else:
        bridge = OpenWorkBridge(
            mode=args.mode,
            server_port=args.port,
            client_url=args.url
        )

        await bridge.start()

        print(f"OpenWork Bridge running ({args.mode} mode)")
        print(f"Metrics: {json.dumps(bridge.get_metrics(), indent=2)}")
        print("Press Ctrl+C to stop.\n")

        try:
            await asyncio.Future()
        except KeyboardInterrupt:
            print("\nStopping...")
            await bridge.stop()


if __name__ == "__main__":
    asyncio.run(main())
