
import asyncio
import json
import logging
import base64
import time
from datetime import datetime
from typing import Optional, Dict, Any, List
import websockets
from websockets.server import serve as ws_serve, WebSocketServerProtocol

from core.voice.kinan_aiva_voice_channel import KinanAivaVoiceChannel, VoiceMessage
from RECEPTIONISTAI.widget.swarm_generated.server.W-K03_block7 import TelnyxService

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# RLM Gateway — shadow-mode integration
# On call.hangup (WebSocket disconnect), the full session transcript is
# fed into the RLM pipeline for AIVA's learning loop.  Import failure is
# non-fatal; the call simply continues without RLM scoring.
# ---------------------------------------------------------------------------
_rlm_gateway = None
try:
    import sys as _sys
    _sys.path.append('/mnt/e/genesis-system')
    from AIVA.rlm_gateway import get_gateway as _get_rlm_gateway
    _rlm_gateway = _get_rlm_gateway()
    logger.info("RLM Gateway loaded (shadow mode) in TelnyxWebRTCWebSocketHandler")
except Exception as _rlm_err:
    logger.warning(f"RLM Gateway unavailable (non-fatal): {_rlm_err}")

class TelnyxWebRTCWebSocketHandler:
    """
    WebSocket handler for real-time voice streaming between browser widget and Telnyx,
    bridging to KinanAivaVoiceChannel for STT/TTS.
    """
    def __init__(
        self,
        kinan_aiva_channel: KinanAivaVoiceChannel,
        telnyx_service: TelnyxService,
        websocket_host: str = "0.0.0.0",
        websocket_port: int = 8767,
    ):
        self.kinan_aiva_channel = kinan_aiva_channel
        self.telnyx_service = telnyx_service
        self.ws_host = websocket_host
        self.ws_port = websocket_port
        self.ws_clients: List[WebSocketServerProtocol] = []
        self._server = None
        self._running = False

        # Per-session call metadata for RLM scoring on hangup
        # Key: session_id, Value: {"start_time": float, "transcripts": List[str], "caller": str}
        self._call_sessions: Dict[str, Dict[str, Any]] = {}

        logger.info(f"TelnyxWebRTCWebSocketHandler initialized (ws://{websocket_host}:{websocket_port})")

    async def start(self):
        """Starts the WebSocket server."""
        if self._running:
            return

        self._running = True
        self._server = await ws_serve(
            self._handle_websocket,
            self.ws_host,
            self.ws_port
        )
        logger.info(f"Telnyx WebRTC WebSocket server started on ws://{self.ws_host}:{self.ws_port}")

        await self._server.wait_closed() # Keep server running

    async def stop(self):
        """Stops the WebSocket server."""
        self._running = False
        if self._server:
            self._server.close()
            await self._server.wait_closed()
        logger.info("Telnyx WebRTC WebSocket server stopped")

    async def _handle_websocket(self, websocket: WebSocketServerProtocol, path: str):
        """Handles new WebSocket connections."""
        import uuid
        session_id = str(uuid.uuid4()) # Unique session ID for Telnyx token

        self.ws_clients.append(websocket)
        client_id = id(websocket)
        logger.info(f"Telnyx WebRTC client connected: {client_id}, Session ID: {session_id}")

        # Record session start for RLM duration calculation
        self._call_sessions[session_id] = {
            "start_time": time.monotonic(),
            "transcripts": [],
            "caller": "",
        }

        aiva_response_task = None
        try:
            # 1. Generate Telnyx WebRTC token
            telnyx_token = await self.telnyx_service.generate_client_token(session_id)
            await self._send_message_to_client(websocket, {
                "type": "telnyx_token",
                "token": telnyx_token,
                "session_id": session_id
            })
            logger.info(f"Sent Telnyx token for session {session_id}")

            # Start processing AIVA's responses for this session
            aiva_response_task = asyncio.create_task(self._process_aiva_responses(session_id, websocket))

            async for message in websocket:
                await self._handle_ws_message(websocket, message, session_id)
        except websockets.exceptions.ConnectionClosed:
            pass
        except Exception as e:
            logger.error(f"Error during WebSocket connection for session {session_id}: {e}")
        finally:
            if aiva_response_task is not None:
                aiva_response_task.cancel()
                await asyncio.gather(aiva_response_task, return_exceptions=True)
            if websocket in self.ws_clients:
                self.ws_clients.remove(websocket)
            logger.info(f"Telnyx WebRTC client disconnected: {client_id}, Session ID: {session_id}")

            # === CALL.HANGUP → RLM GATEWAY (shadow mode) ===
            # Equivalent of Telnyx call.hangup webhook for WebRTC sessions.
            # Feed session transcript into AIVA's learning loop.
            await self._on_call_hangup(session_id)

    async def _handle_ws_message(self, websocket: WebSocketServerProtocol, message: str, session_id: str):
        """Processes incoming WebSocket messages (SDP, ICE, audio, transcription)."""
        try:
            data = json.loads(message)
            msg_type = data.get("type")

            if msg_type == "sdp_offer":
                logger.info(f"Session {session_id}: Received SDP offer from client.")
                await self._send_message_to_client(websocket, {"type": "ack", "message": "SDP offer received"})

            elif msg_type == "ice_candidate":
                logger.info(f"Session {session_id}: Received ICE candidate from client.")
                await self._send_message_to_client(websocket, {"type": "ack", "message": "ICE candidate received"})

            elif msg_type == "transcription":
                # Transcription segments accumulate for RLM scoring at hangup
                text   = data.get("text", "").strip()
                caller = data.get("caller", "")
                if text:
                    self._accumulate_transcript(session_id, text, caller)
                    logger.debug(f"Session {session_id}: Transcript segment ({len(text)} chars) accumulated for RLM")

            elif msg_type == "audio":
                audio_b64 = data.get("audio")
                if audio_b64:
                    audio_bytes = base64.b64decode(audio_b64)
                    await self.kinan_aiva_channel.receive_audio_chunk(session_id, audio_bytes)
                    logger.debug(f"Session {session_id}: Received {len(audio_bytes)} bytes of audio and sent to AIVA channel.")

            elif msg_type == "ping":
                await self._send_message_to_client(websocket, {"type": "pong"})

            else:
                logger.warning(f"Session {session_id}: Unknown message type from client: {msg_type}")

        except json.JSONDecodeError:
            logger.warning(f"Session {session_id}: Invalid JSON from WebSocket: {message[:100]}")
        except Exception as e:
            logger.error(f"Session {session_id}: WebSocket message handling error: {e}")

    async def _process_aiva_responses(self, session_id: str, websocket: WebSocketServerProtocol):
        """
        Polls KinanAivaVoiceChannel's outbound_audio_streams for AIVA's responses
        and sends them back to the client.
        """
        while True:
            try:
                # Wait for AIVA's audio response for this session
                # This queue will be populated by kinan_aiva_voice_channel.speak_to_kinan
                if session_id not in self.kinan_aiva_channel.outbound_audio_streams:
                    self.kinan_aiva_channel.outbound_audio_streams[session_id] = asyncio.Queue()
                
                audio_data, text = await self.kinan_aiva_channel.outbound_audio_streams[session_id].get()

                # Send AIVA's audio back to the client
                await self._send_message_to_client(websocket, {
                    "type": "aiva_audio",
                    "audio": base64.b64encode(audio_data).decode(),
                    "text": text,
                    "timestamp": datetime.utcnow().isoformat()
                })
                logger.debug(f"Sent AIVA's audio response to session {session_id}.")

            except asyncio.CancelledError:
                logger.info(f"AIVA response processing for session {session_id} cancelled.")
                break
            except KeyError: # Session queue might be removed if session ended
                logger.warning(f"Outbound audio stream for session {session_id} not found, stopping processing.")
                break
            except Exception as e:
                logger.error(f"Error processing AIVA responses for session {session_id}: {e}")
            finally:
                if session_id in self.kinan_aiva_channel.outbound_audio_streams and \
                   self.kinan_aiva_channel.outbound_audio_streams[session_id].empty():
                    # Clean up empty queue if no more messages are expected
                    # This might need more sophisticated lifecycle management
                    pass # del self.kinan_aiva_channel.outbound_audio_streams[session_id]

    async def _handle_ws_message(self, websocket: WebSocketServerProtocol, message: str, session_id: str):
        """Processes incoming WebSocket messages (SDP, ICE, audio data)."""
        try:
            data = json.loads(message)
            msg_type = data.get("type")

            if msg_type == "sdp_offer":
                logger.info(f"Session {session_id}: Received SDP offer from client.")
                # For Telnyx WebRTC, the client SDK handles the SDP exchange directly with Telnyx.
                # Server-side usually acts as a signaling server for other peers or logs.
                # In this PoC, we'll assume the client is establishing a direct media connection with Telnyx.
                # We could potentially pass this to Telnyx API if needed for advanced control.
                # For now, just acknowledge.
                await self._send_message_to_client(websocket, {"type": "ack", "message": "SDP offer received"})

            elif msg_type == "ice_candidate":
                logger.info(f"Session {session_id}: Received ICE candidate from client.")
                # Similar to SDP, client-side Telnyx SDK handles ICE.
                await self._send_message_to_client(websocket, {"type": "ack", "message": "ICE candidate received"})
                
            elif msg_type == "audio":
                audio_b64 = data.get("audio")
                if audio_b64:
                    audio_bytes = base64.b64decode(audio_b64)
                    # TODO: Feed audio to KinanAivaVoiceChannel for STT
                    logger.debug(f"Session {session_id}: Received {len(audio_bytes)} bytes of audio.")
                    # Example of feeding to KinanAivaVoiceChannel:
                    # voice_msg = await self.kinan_aiva_channel.receive_from_kinan_stream(audio_bytes, session_id)
                    # If AIVA responds, send TTS back:
                    # if voice_msg.audio_data:
                    #     await self._send_message_to_client(websocket, {
                    #         "type": "aiva_audio",
                    #         "audio": base64.b64encode(voice_msg.audio_data).decode()
                    #     })

            elif msg_type == "ping":
                await self._send_message_to_client(websocket, {"type": "pong"})

            else:
                logger.warning(f"Session {session_id}: Unknown message type from client: {msg_type}")

        except json.JSONDecodeError:
            logger.warning(f"Session {session_id}: Invalid JSON from WebSocket: {message[:100]}")
        except Exception as e:
            logger.error(f"Session {session_id}: WebSocket message handling error: {e}")

    async def _on_call_hangup(self, session_id: str):
        """
        Triggered when a WebRTC session ends (equivalent to Telnyx call.hangup).

        Collects the session transcript and feeds it into the RLM Gateway
        in shadow mode so every completed call trains AIVA's learning loop.
        """
        session_data = self._call_sessions.pop(session_id, {})
        if not session_data:
            return

        duration_seconds = int(time.monotonic() - session_data.get("start_time", time.monotonic()))
        transcripts      = session_data.get("transcripts", [])
        caller_number    = session_data.get("caller", "")
        full_transcript  = " ".join(transcripts).strip()

        logger.info(
            f"Session {session_id} ended: duration={duration_seconds}s, "
            f"transcript_segments={len(transcripts)}, rlm_available={_rlm_gateway is not None}"
        )

        if _rlm_gateway is None:
            logger.debug(f"RLM Gateway not available — skipping shadow scoring for session {session_id}")
            return

        if not full_transcript:
            logger.debug(f"No transcript collected for session {session_id} — skipping RLM")
            return

        try:
            rlm_result = await _rlm_gateway.process_interaction(
                call_id=session_id,
                transcript=full_transcript,
                caller_number=caller_number,
                call_duration_seconds=duration_seconds,
                outcome="completed",
            )
            logger.info(
                f"RLM shadow scored session {session_id}: "
                f"reward={rlm_result.get('reward_score')}, "
                f"outcome={rlm_result.get('outcome_label')}, "
                f"cai_violations={len(rlm_result.get('cai_violations', []))}"
            )
        except Exception as e:
            logger.warning(f"RLM shadow scoring failed for session {session_id} (non-fatal): {e}")

    def _accumulate_transcript(self, session_id: str, text: str, caller: str = ""):
        """
        Add a transcript segment to the session for RLM processing at hangup.
        Called from _handle_ws_message when a transcription event arrives.
        """
        if session_id in self._call_sessions:
            if text.strip():
                self._call_sessions[session_id]["transcripts"].append(text.strip())
            if caller and not self._call_sessions[session_id]["caller"]:
                self._call_sessions[session_id]["caller"] = caller

    async def _process_inbound_audio(self):
        """Asynchronous loop to process incoming audio."""
        # This will be where audio is fed to KinanAivaVoiceChannel
        pass

    async def _send_audio_to_telnyx(self, audio_data: bytes):
        """Sends processed audio back to Telnyx (conceptually)."""
        # This is where TTS from AIVA would be sent to Telnyx for playback
        pass

    async def _send_message_to_client(self, websocket: WebSocketServerProtocol, message: Dict[str, Any]):
        """Sends control messages or TTS audio to the client."""
        try:
            await websocket.send(json.dumps(message))
        except Exception as e:
            logger.error(f"Error sending message to client: {e}")

# Example usage (for testing purposes)
async def main():
    logging.basicConfig(level=logging.INFO)

    # Initialize dependencies
    # Placeholder for actual API key
    gemini_api_key = os.environ.get("GEMINI_API_KEY", "YOUR_GEMINI_API_KEY_HERE")
    kinan_aiva_channel = KinanAivaVoiceChannel(gemini_api_key=gemini_api_key)

    # Need to import Settings and get_settings for TelnyxService
    # For PoC, we will create a dummy settings object if actual one is not available
    try:
        # Assuming W-K03_block7.py is directly importable or in path
        from RECEPTIONISTAI.widget.swarm_generated.server.W-K03_block7 import TelnyxService
        telnyx_service = TelnyxService()
    except ImportError:
        logger.warning("Could not import TelnyxService directly from generated code. Using dummy. Ensure generated files are accessible.")
        class DummySettings:
            telnyx_api_key = os.environ.get("TELNYX_API_KEY", "YOUR_TELNYX_API_KEY_HERE")
        class DummyTelnyxService:
            def __init__(self, settings=None):
                self.settings = settings or DummySettings()
            async def generate_client_token(self, session_id: str) -> str:
                logger.warning("Using dummy TelnyxService.generate_client_token")
                await asyncio.sleep(0.1)
                return f"dummy_telnyx_token_{session_id}"
        telnyx_service = DummyTelnyxService()

    handler = TelnyxWebRTCWebSocketHandler(
        kinan_aiva_channel=kinan_aiva_channel,
        telnyx_service=telnyx_service
    )
    
    # Start the AIVA channel in the background
    asyncio.create_task(kinan_aiva_channel.start())

    print(f"Starting Telnyx WebRTC WebSocket Handler on ws://{handler.ws_host}:{handler.ws_port}")
    await handler.start()

if __name__ == "__main__":
    asyncio.run(main())
