import psycopg2
import psycopg2.extras
import time
import logging
import json
from typing import Dict, Any, Optional
from urllib.parse import urlparse

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class PostgresHealthMonitor:
    """
    Monitors the health and performance of a PostgreSQL database.

    Tracks active connections, query latency, table sizes, and index usage.
    Automatically reconnects on connection loss and exposes health data.
    """

    def __init__(self, db_url: str):
        """
        Initializes the PostgresHealthMonitor.

        Args:
            db_url: The PostgreSQL database URL.  Example: postgresql://user:password@host:port/database
        """
        self.db_url = db_url
        self.connection = None
        self.cursor = None
        self.reconnect_interval = 5  # seconds
        self.is_running = False # Flag to control the monitoring loop

        # Parse the database URL
        parsed_url = urlparse(self.db_url)
        self.db_host = parsed_url.hostname
        self.db_port = parsed_url.port or 5432 # Default PostgreSQL port
        self.db_name = parsed_url.path[1:] if parsed_url.path else None  # Remove leading slash
        self.db_user = parsed_url.username
        self.db_password = parsed_url.password
        
        logging.info(f"PostgresHealthMonitor initialized for {self.db_host}:{self.db_port}/{self.db_name}")


    def connect(self) -> None:
        """
        Establishes a connection to the PostgreSQL database.
        """
        try:
            self.connection = psycopg2.connect(
                host=self.db_host,
                port=self.db_port,
                database=self.db_name,
                user=self.db_user,
                password=self.db_password
            )
            self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
            logging.info("Successfully connected to PostgreSQL.")
        except psycopg2.Error as e:
            logging.error(f"Failed to connect to PostgreSQL: {e}")
            self.connection = None
            self.cursor = None

    def disconnect(self) -> None:
        """
        Closes the connection to the PostgreSQL database.
        """
        if self.cursor:
            try:
                self.cursor.close()
                logging.info("Cursor closed.")
            except Exception as e:
                logging.error(f"Error closing cursor: {e}")
            finally:
                self.cursor = None
        
        if self.connection:
            try:
                self.connection.close()
                logging.info("Connection closed.")
            except Exception as e:
                logging.error(f"Error closing connection: {e}")
            finally:
                self.connection = None
    
    def check_connection(self) -> bool:
        """
        Checks if the connection to the PostgreSQL database is active.

        Returns:
            True if the connection is active, False otherwise.
        """
        if self.connection is None:
            return False
        try:
            with self.connection.cursor() as cur:  # Use a context manager for the cursor
                cur.execute("SELECT 1")
                cur.fetchone()
            return True
        except psycopg2.Error:
            return False

    def reconnect(self) -> None:
        """
        Attempts to reconnect to the PostgreSQL database.
        """
        logging.info("Attempting to reconnect to PostgreSQL...")
        self.disconnect()
        self.connect()

    def get_active_connections(self) -> Optional[int]:
        """
        Retrieves the number of active connections to the PostgreSQL database.

        Returns:
            The number of active connections, or None if an error occurs.
        """
        try:
            self.cursor.execute("SELECT count(*) FROM pg_stat_activity WHERE datname = %s;", (self.db_name,))
            result = self.cursor.fetchone()
            if result:
                return result['count']
            else:
                return 0
        except psycopg2.Error as e:
            logging.error(f"Failed to retrieve active connections: {e}")
            return None

    def get_query_latency(self) -> Optional[float]:
        """
        Retrieves the average query latency in milliseconds.

        Returns:
            The average query latency, or None if an error occurs.
        """
        try:
            self.cursor.execute("""
                SELECT 
                    CASE 
                        WHEN sum(num_queries) > 0 THEN sum(total_time) / sum(num_queries) * 1000 
                        ELSE 0 
                    END AS avg_query_latency_ms
                FROM pg_stat_statements;
            """)
            result = self.cursor.fetchone()
            if result and result['avg_query_latency_ms'] is not None:
                return float(result['avg_query_latency_ms'])
            else:
                return 0.0  # Return 0 if there's no data or the value is None
        except psycopg2.Error as e:
            logging.error(f"Failed to retrieve query latency: {e}")
            return None

    def get_table_sizes(self) -> Optional[Dict[str, str]]:
        """
        Retrieves the sizes of all tables in the database.

        Returns:
            A dictionary of table names and their sizes (in human-readable format),
            or None if an error occurs.
        """
        try:
            self.cursor.execute("""
                SELECT 
                    table_name, 
                    pg_size_pretty(pg_total_relation_size(table_name::regclass)) AS table_size
                FROM information_schema.tables
                WHERE table_schema = 'public'
                AND table_type = 'BASE TABLE';
            """)
            results = self.cursor.fetchall()
            if results:
                return {row['table_name']: row['table_size'] for row in results}
            else:
                return {}
        except psycopg2.Error as e:
            logging.error(f"Failed to retrieve table sizes: {e}")
            return None

    def get_index_usage(self) -> Optional[Dict[str, float]]:
        """
        Retrieves the index usage statistics for all indexes in the database.

        Returns:
            A dictionary of index names and their usage percentages,
            or None if an error occurs.
        """
        try:
            self.cursor.execute("""
                SELECT
                    indexrelname,
                    CASE WHEN reltuples = 0 THEN 0 ELSE (idx_scan * 100 / reltuples) END AS index_usage_percent
                FROM pg_stat_all_indexes
                JOIN pg_class ON pg_class.oid = pg_stat_all_indexes.relid
                WHERE schemaname = 'public'
                AND reltuples > 0;
            """)
            results = self.cursor.fetchall()
            if results:
                return {row['indexrelname']: float(row['index_usage_percent']) for row in results}
            else:
                return {}
        except psycopg2.Error as e:
            logging.error(f"Failed to retrieve index usage: {e}")
            return None

    def get_health_status(self) -> Dict[str, Any]:
        """
        Retrieves the overall health status of the PostgreSQL database.

        Returns:
            A dictionary containing the health status and performance metrics.
        """
        status: Dict[str, Any] = {}
        status['connection_status'] = self.check_connection()

        if status['connection_status']:
            status['active_connections'] = self.get_active_connections()
            status['query_latency'] = self.get_query_latency()
            status['table_sizes'] = self.get_table_sizes()
            status['index_usage'] = self.get_index_usage()
        else:
            status['active_connections'] = None
            status['query_latency'] = None
            status['table_sizes'] = None
            status['index_usage'] = None

        return status

    def start_monitoring(self) -> None:
        """
        Starts the monitoring loop.
        """
        self.is_running = True
        self.connect()
        while self.is_running:
            if not self.check_connection():
                self.reconnect()

            if self.check_connection():
                health_data = self.get_health_status()
                logging.info(f"PostgreSQL Health Data: {json.dumps(health_data, indent=2)}")
            else:
                logging.warning("Connection lost.  Waiting for reconnect...")
            
            time.sleep(30)  # Check every 30 seconds.  Adjust as needed.

    def stop_monitoring(self) -> None:
        """
        Stops the monitoring loop.
        """
        self.is_running = False
        self.disconnect()
        logging.info("PostgreSQL monitoring stopped.")

if __name__ == '__main__':
    # Example usage (replace with your actual database URL)
    db_url = "postgresql://postgres:password@postgresql-genesis-u50607.vm.elestio.app:25432/postgres"
    monitor = PostgresHealthMonitor(db_url)
    try:
        monitor.start_monitoring()
    except KeyboardInterrupt:
        print("Monitoring interrupted.")
    finally:
        monitor.stop_monitoring()
