import datetime
import logging
import os
import time
from typing import Dict, List, Tuple

import psycopg2
import psycopg2.extras
import numpy as np

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


class DatabaseConnection:
    """
    Manages the connection to the PostgreSQL database.
    """

    def __init__(self, host, port, database, user, password):
        """
        Initializes the database connection parameters.
        """
        self.host = host
        self.port = port
        self.database = database
        self.user = user
        self.password = password
        self.conn = None

    def connect(self):
        """
        Establishes a connection to the PostgreSQL database.
        """
        try:
            self.conn = psycopg2.connect(
                host=self.host,
                port=self.port,
                database=self.database,
                user=self.user,
                password=self.password,
            )
            logging.info("Database connection established.")
        except psycopg2.Error as e:
            logging.error(f"Failed to connect to database: {e}")
            raise

    def disconnect(self):
        """
        Closes the connection to the PostgreSQL database.
        """
        if self.conn:
            self.conn.close()
            logging.info("Database connection closed.")

    def execute_query(self, query: str, params: Tuple = None) -> List[Dict]:
        """
        Executes a SQL query and returns the results as a list of dictionaries.

        Args:
            query: The SQL query to execute.
            params: Optional parameters to pass to the query.

        Returns:
            A list of dictionaries, where each dictionary represents a row in the result set.
        """
        try:
            with self.conn.cursor(
                cursor_factory=psycopg2.extras.RealDictCursor
            ) as cur:
                cur.execute(query, params)
                if cur.description:  # Check if the query returns results
                    results = cur.fetchall()
                    return results
                else:
                    self.conn.commit()  # Commit changes for INSERT/UPDATE/DELETE
                    return []  # Return an empty list for non-SELECT queries
        except psycopg2.Error as e:
            logging.error(f"Error executing query: {e}")
            self.conn.rollback()
            raise


class BaselineGenerator:
    """
    Generates performance baselines for various components.
    """

    def __init__(self, db_connection: DatabaseConnection):
        """
        Initializes the BaselineGenerator with a database connection.
        """
        self.db_connection = db_connection

    def calculate_baselines(self, component: str) -> Dict[str, float]:
        """
        Calculates performance baselines (p50, p95, p99) for latency, throughput, and error rate
        for a given component based on the last 24 hours of metrics.

        Args:
            component: The name of the component to generate baselines for.

        Returns:
            A dictionary containing the calculated baselines for latency, throughput, and error rate.
            Returns an empty dictionary if no data is available or an error occurs.
        """
        end_time = datetime.datetime.utcnow()
        start_time = end_time - datetime.timedelta(days=1)

        try:
            # Fetch metrics data for the last 24 hours
            query = """
                SELECT latency, throughput, error_rate
                FROM metrics
                WHERE component = %s
                AND timestamp >= %s
                AND timestamp <= %s;
            """
            params = (component, start_time, end_time)
            metrics_data = self.db_connection.execute_query(query, params)

            if not metrics_data:
                logging.warning(
                    f"No metrics data found for component '{component}' in the last 24 hours."
                )
                return {}

            # Extract latency, throughput, and error rate values
            latencies = [float(row["latency"]) for row in metrics_data if row["latency"] is not None]
            throughputs = [float(row["throughput"]) for row in metrics_data if row["throughput"] is not None]
            error_rates = [float(row["error_rate"]) for row in metrics_data if row["error_rate"] is not None]

            if not latencies or not throughputs or not error_rates:
                logging.warning(
                    f"Insufficient metrics data (latency, throughput, or error_rate) found for component '{component}'."
                )
                return {}

            # Calculate percentiles
            p50_latency = np.percentile(latencies, 50)
            p95_latency = np.percentile(latencies, 95)
            p99_latency = np.percentile(latencies, 99)

            p50_throughput = np.percentile(throughputs, 50)
            p95_throughput = np.percentile(throughputs, 95)
            p99_throughput = np.percentile(throughputs, 99)

            p50_error_rate = np.percentile(error_rates, 50)
            p95_error_rate = np.percentile(error_rates, 95)
            p99_error_rate = np.percentile(error_rates, 99)

            baselines = {
                "latency_p50": float(p50_latency),
                "latency_p95": float(p95_latency),
                "latency_p99": float(p99_latency),
                "throughput_p50": float(p50_throughput),
                "throughput_p95": float(p95_throughput),
                "throughput_p99": float(p99_throughput),
                "error_rate_p50": float(p50_error_rate),
                "error_rate_p95": float(p95_error_rate),
                "error_rate_p99": float(p99_error_rate),
            }

            return baselines

        except Exception as e:
            logging.error(f"Error calculating baselines for component '{component}': {e}")
            return {}

    def save_baselines(self, component: str, baselines: Dict[str, float]) -> None:
        """
        Saves the calculated baselines to the PostgreSQL database.

        Args:
            component: The name of the component the baselines belong to.
            baselines: A dictionary containing the calculated baselines.
        """
        try:
            # Check if a baseline already exists for the component
            check_query = "SELECT id FROM baselines WHERE component = %s;"
            existing_baseline = self.db_connection.execute_query(check_query, (component,))

            if existing_baseline:
                # Update the existing baseline
                update_query = """
                    UPDATE baselines
                    SET latency_p50 = %s, latency_p95 = %s, latency_p99 = %s,
                        throughput_p50 = %s, throughput_p95 = %s, throughput_p99 = %s,
                        error_rate_p50 = %s, error_rate_p95 = %s, error_rate_p99 = %s,
                        updated_at = NOW()
                    WHERE component = %s;
                """
                params = (
                    baselines["latency_p50"],
                    baselines["latency_p95"],
                    baselines["latency_p99"],
                    baselines["throughput_p50"],
                    baselines["throughput_p95"],
                    baselines["throughput_p99"],
                    baselines["error_rate_p50"],
                    baselines["error_rate_p95"],
                    baselines["error_rate_p99"],
                    component,
                )
                self.db_connection.execute_query(update_query, params)
                logging.info(f"Updated baselines for component '{component}' in the database.")
            else:
                # Insert a new baseline
                insert_query = """
                    INSERT INTO baselines (component, latency_p50, latency_p95, latency_p99,
                                            throughput_p50, throughput_p95, throughput_p99,
                                            error_rate_p50, error_rate_p95, error_rate_p99)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
                """
                params = (
                    component,
                    baselines["latency_p50"],
                    baselines["latency_p95"],
                    baselines["latency_p99"],
                    baselines["throughput_p50"],
                    baselines["throughput_p95"],
                    baselines["throughput_p99"],
                    baselines["error_rate_p50"],
                    baselines["error_rate_p95"],
                    baselines["error_rate_p99"],
                )
                self.db_connection.execute_query(insert_query, params)
                logging.info(f"Saved baselines for component '{component}' to the database.")

        except Exception as e:
            logging.error(f"Error saving baselines for component '{component}': {e}")

    def detect_anomalies(self, component: str, current_metrics: Dict[str, float]) -> Dict[str, bool]:
        """
        Detects anomalies by comparing current metrics against the stored baselines.

        Args:
            component: The name of the component to check for anomalies.
            current_metrics: A dictionary containing the current metrics (latency, throughput, error_rate).

        Returns:
            A dictionary indicating whether each metric is anomalous (True) or not (False).
        """
        try:
            # Fetch the baselines from the database
            query = "SELECT * FROM baselines WHERE component = %s;"
            baseline_data = self.db_connection.execute_query(query, (component,))

            if not baseline_data:
                logging.warning(f"No baselines found for component '{component}'.")
                return {}

            baselines = baseline_data[0]

            anomalies = {}

            # Anomaly detection logic: check if current metrics exceed the p99 baseline
            anomalies["latency"] = current_metrics["latency"] > baselines["latency_p99"]
            anomalies["throughput"] = current_metrics["throughput"] < baselines["throughput_p99"] # Lower throughput is anomalous
            anomalies["error_rate"] = current_metrics["error_rate"] > baselines["error_rate_p99"]

            return anomalies

        except Exception as e:
            logging.error(f"Error detecting anomalies for component '{component}': {e}")
            return {}


def main():
    """
    Main function to run the baseline generator.
    """
    db_host = os.environ.get("DB_HOST", "postgresql-genesis-u50607.vm.elestio.app")
    db_port = int(os.environ.get("DB_PORT", "25432"))
    db_name = os.environ.get("DB_NAME", "genesis")
    db_user = os.environ.get("DB_USER", "genesis")
    db_password = os.environ.get("DB_PASSWORD", "REPLACE_ME")

    db_connection = DatabaseConnection(db_host, db_port, db_name, db_user, db_password)

    try:
        db_connection.connect()
        baseline_generator = BaselineGenerator(db_connection)

        # Example usage:
        components = ["component_a", "component_b", "component_c"]  # Replace with your actual component names

        for component in components:
            logging.info(f"Calculating baselines for component: {component}")
            baselines = baseline_generator.calculate_baselines(component)

            if baselines:
                logging.info(f"Baselines for {component}: {baselines}")
                baseline_generator.save_baselines(component, baselines)

                # Simulate current metrics for anomaly detection
                current_metrics = {
                    "latency": np.random.rand() * 100,  # Example latency
                    "throughput": np.random.rand() * 1000,  # Example throughput
                    "error_rate": np.random.rand() * 0.1,  # Example error rate
                }
                anomalies = baseline_generator.detect_anomalies(component, current_metrics)
                logging.info(f"Anomalies detected for {component}: {anomalies}")
            else:
                logging.warning(f"Could not calculate baselines for {component}.")

    except Exception as e:
        logging.error(f"An error occurred: {e}")

    finally:
        db_connection.disconnect()


if __name__ == "__main__":
    main()
