import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, List, Optional

import aioredis
import websockets
import websockets.exceptions
from websockets.server import WebSocketServerProtocol

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Redis configuration (read from environment variables or defaults)
REDIS_HOST = os.environ.get("REDIS_HOST", "redis-genesis-u50607.vm.elestio.app")
REDIS_PORT = int(os.environ.get("REDIS_PORT", "26379"))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", None)  # Optional password

# WebSocket configuration
HEARTBEAT_INTERVAL = 30  # Seconds
RECONNECTION_DELAY = 5  # Seconds


class WebSocketServer:
    """
    WebSocket server for real-time dashboard updates.

    Supports:
        - Multiple clients
        - Topic subscription
        - Heartbeat mechanism
        - Reconnection logic
        - Efficient JSON message format
    """

    def __init__(self, host: str, port: int):
        """
        Initializes the WebSocket server.

        Args:
            host: The host address to bind to.
            port: The port number to listen on.
        """
        self.host = host
        self.port = port
        self.connected_clients: Dict[WebSocketServerProtocol, List[str]] = {}  # client: [topics]
        self.redis_connection: Optional[aioredis.Redis] = None
        self.pubsub: Optional[aioredis.client.PubSub] = None
        self.server = None  # type: Optional[websockets.server.Serve]
        self.running = False

    async def start(self):
        """
        Starts the WebSocket server.
        """
        try:
            self.redis_connection = aioredis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
            await self.redis_connection.ping()
            logging.info(f"Connected to Redis at {REDIS_HOST}:{REDIS_PORT}")
        except aioredis.exceptions.ConnectionError as e:
            logging.error(f"Failed to connect to Redis: {e}")
            return

        try:
            self.server = websockets.serve(self.handler, self.host, self.port)
            logging.info(f"WebSocket server started at ws://{self.host}:{self.port}")
            self.running = True
            await self.server  # type: ignore
        except Exception as e:
            logging.error(f"Error starting WebSocket server: {e}")
            self.running = False
        finally:
            if self.redis_connection:
                await self.redis_connection.close()
                logging.info("Redis connection closed.")

    async def stop(self):
        """
        Stops the WebSocket server.
        """
        self.running = False
        if self.server:
            self.server.close()
            await self.server.wait_closed()
            logging.info("WebSocket server stopped.")

        if self.redis_connection:
            await self.redis_connection.close()
            logging.info("Redis connection closed.")

    async def handler(self, websocket: WebSocketServerProtocol):
        """
        Handles a new WebSocket connection.

        Args:
            websocket: The WebSocket connection object.
        """
        logging.info(f"Client connected: {websocket.remote_address}")
        self.connected_clients[websocket] = []  # Initialize with no subscriptions
        try:
            await self.send_heartbeat(websocket)
            await self.receive_messages(websocket)  # Keep connection alive and handle messages
        except websockets.exceptions.ConnectionClosedError as e:
            logging.info(f"Client disconnected (ConnectionClosedError): {websocket.remote_address} - {e}")
        except websockets.exceptions.ConnectionClosedOK as e:
            logging.info(f"Client disconnected (ConnectionClosedOK): {websocket.remote_address} - {e}")
        except Exception as e:
            logging.error(f"Error handling WebSocket connection: {e}")
        finally:
            await self.unsubscribe_all(websocket)
            if websocket in self.connected_clients:
                del self.connected_clients[websocket]
            logging.info(f"Client disconnected: {websocket.remote_address}")

    async def receive_messages(self, websocket: WebSocketServerProtocol):
        """
        Receives messages from the WebSocket client and processes them.

        Args:
            websocket: The WebSocket connection object.
        """
        async for message in websocket:
            try:
                data = json.loads(message)
                if "type" in data:
                    if data["type"] == "subscribe":
                        topic = data.get("topic")
                        if topic:
                            await self.subscribe(websocket, topic)
                        else:
                            logging.warning(f"Invalid subscribe message: {data}")
                    elif data["type"] == "unsubscribe":
                        topic = data.get("topic")
                        if topic:
                            await self.unsubscribe(websocket, topic)
                        else:
                            logging.warning(f"Invalid unsubscribe message: {data}")
                    else:
                        logging.warning(f"Unknown message type: {data['type']}")
                else:
                    logging.warning(f"Invalid message format: {data}")
            except json.JSONDecodeError:
                logging.warning(f"Invalid JSON received: {message}")
            except Exception as e:
                logging.error(f"Error processing message: {e}")


    async def subscribe(self, websocket: WebSocketServerProtocol, topic: str):
        """
        Subscribes a client to a specific topic.

        Args:
            websocket: The WebSocket connection object.
            topic: The topic to subscribe to.
        """
        if websocket not in self.connected_clients:
            logging.warning(f"Client not found: {websocket.remote_address}")
            return

        if topic not in self.connected_clients[websocket]:
            self.connected_clients[websocket].append(topic)
            logging.info(f"Client {websocket.remote_address} subscribed to topic: {topic}")
            await self.redis_subscribe(topic)  # Subscribe to redis channel if not already subscribed.
            await websocket.send(json.dumps({"type": "subscription_status", "topic": topic, "status": "subscribed"}))  # type: ignore
        else:
            logging.info(f"Client {websocket.remote_address} already subscribed to topic: {topic}")

    async def unsubscribe(self, websocket: WebSocketServerProtocol, topic: str):
        """
        Unsubscribes a client from a specific topic.

        Args:
            websocket: The WebSocket connection object.
            topic: The topic to unsubscribe from.
        """
        if websocket not in self.connected_clients:
            logging.warning(f"Client not found: {websocket.remote_address}")
            return

        if topic in self.connected_clients[websocket]:
            self.connected_clients[websocket].remove(topic)
            logging.info(f"Client {websocket.remote_address} unsubscribed from topic: {topic}")
            await websocket.send(json.dumps({"type": "subscription_status", "topic": topic, "status": "unsubscribed"}))  # type: ignore
            # Unsubscribe from redis only if no client is subscribed to this topic.
            topic_subscribed = False
            for client, topics in self.connected_clients.items():
                if topic in topics:
                    topic_subscribed = True
                    break
            if not topic_subscribed:
                await self.redis_unsubscribe(topic)
        else:
            logging.info(f"Client {websocket.remote_address} not subscribed to topic: {topic}")

    async def unsubscribe_all(self, websocket: WebSocketServerProtocol):
        """
        Unsubscribes a client from all topics.

        Args:
            websocket: The WebSocket connection object.
        """
        if websocket in self.connected_clients:
            topics = list(self.connected_clients[websocket])  # Iterate over a copy
            for topic in topics:
                await self.unsubscribe(websocket, topic)

    async def send_heartbeat(self, websocket: WebSocketServerProtocol):
        """
        Sends a heartbeat message to the client at regular intervals.

        Args:
            websocket: The WebSocket connection object.
        """
        try:
            while websocket.open and self.running:
                await asyncio.sleep(HEARTBEAT_INTERVAL)
                try:
                    await websocket.send(json.dumps({"type": "heartbeat"}))  # type: ignore
                    logging.debug(f"Heartbeat sent to {websocket.remote_address}")
                except websockets.exceptions.ConnectionClosed:
                    logging.info(f"Connection closed while sending heartbeat to {websocket.remote_address}")
                    break
                except Exception as e:
                    logging.error(f"Error sending heartbeat to {websocket.remote_address}: {e}")
                    break
        except Exception as e:
            logging.error(f"Error in heartbeat loop: {e}")


    async def redis_subscribe(self, channel: str):
        """
        Subscribes to a Redis channel and listens for messages.

        Args:
            channel: The Redis channel to subscribe to.
        """
        if not self.redis_connection:
            logging.error("Redis connection not initialized.")
            return

        if not self.pubsub or self.pubsub.closed:
            self.pubsub = self.redis_connection.pubsub()
            await self.pubsub.connect()

        await self.pubsub.subscribe(channel)
        logging.info(f"Subscribed to Redis channel: {channel}")

        async def reader(channel_name: str):
            try:
                async for message in self.pubsub.listen():
                    if message["type"] == "message":
                        data = message["data"]
                        await self.broadcast(channel_name, data)
            except aioredis.exceptions.ConnectionError as e:
                logging.error(f"Redis connection error in reader: {e}")
            except Exception as e:
                logging.error(f"Error in pubsub reader: {e}")

        asyncio.create_task(reader(channel))


    async def redis_unsubscribe(self, channel: str):
        """
        Unsubscribes from a Redis channel.

        Args:
            channel: The Redis channel to unsubscribe from.
        """
        if not self.redis_connection or not self.pubsub:
            logging.warning("Redis connection or pubsub not initialized.")
            return

        try:
            await self.pubsub.unsubscribe(channel)
            logging.info(f"Unsubscribed from Redis channel: {channel}")

            if not self.pubsub.channels:  # If no channels are subscribed, disconnect.
                await self.pubsub.close()
                self.pubsub = None

        except aioredis.exceptions.ConnectionError as e:
            logging.error(f"Redis connection error during unsubscribe: {e}")
        except Exception as e:
            logging.error(f"Error unsubscribing from Redis channel: {e}")


    async def broadcast(self, topic: str, message: str):
        """
        Broadcasts a message to all clients subscribed to a specific topic.

        Args:
            topic: The topic to broadcast to.
            message: The message to broadcast.
        """
        for websocket, topics in self.connected_clients.items():
            if websocket.open and topic in topics:
                try:
                    await websocket.send(message)
                    logging.debug(f"Sent message to {websocket.remote_address} on topic {topic}")
                except websockets.exceptions.ConnectionClosed:
                    logging.info(f"Connection closed while broadcasting to {websocket.remote_address}")
                except Exception as e:
                    logging.error(f"Error broadcasting to {websocket.remote_address}: {e}")


async def main():
    """
    Main function to start the WebSocket server.
    """
    host = "localhost"
    port = 8765
    server = WebSocketServer(host, port)
    try:
        await server.start()
    except Exception as e:
        logging.error(f"Error running server: {e}")
    finally:
        await server.stop()


if __name__ == "__main__":
    asyncio.run(main())
