import logging
import os
import random
import subprocess
import time
from typing import Callable, Dict, List, Tuple

import psutil
import pytest
from psycopg2 import connect, OperationalError
import redis

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Genesis System Configuration (replace with actual values from environment or config file)
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "postgresql-genesis-u50607.vm.elestio.app")
POSTGRES_PORT = int(os.getenv("POSTGRES_PORT", "25432"))
POSTGRES_USER = os.getenv("POSTGRES_USER", "genesis")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "password")  # NEVER store passwords directly in code
POSTGRES_DB = os.getenv("POSTGRES_DB", "genesisdb")

REDIS_HOST = os.getenv("REDIS_HOST", "redis-genesis-u50607.vm.elestio.app")
REDIS_PORT = int(os.getenv("REDIS_PORT", "26379"))

QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant-b3knu-u50607.vm.elestio.app")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))

AIVA_OLLAMA_HOST = os.getenv("AIVA_OLLAMA_HOST", "localhost")
AIVA_OLLAMA_PORT = int(os.getenv("AIVA_OLLAMA_PORT", "23405"))


class ChaosTestFramework:
    """
    Framework for injecting chaos into the Genesis system and testing its resilience.
    """

    def __init__(self):
        """
        Initializes the ChaosTestFramework.
        """
        self.initial_state: Dict[str, any] = {}
        self.recovery_time: float = 0.0
        self.resilience_score: float = 0.0

    def record_initial_state(self) -> None:
        """
        Records the initial state of the system, including database counts, redis keys, etc.
        """
        logging.info("Recording initial system state...")
        self.initial_state["postgres_count"] = self._get_postgres_count("your_table")  # Replace with actual table
        self.initial_state["redis_keys"] = self._get_redis_keys()
        # Add more state capture as needed (Qdrant, AIVA Ollama)
        logging.info(f"Initial state recorded: {self.initial_state}")

    def _get_postgres_count(self, table_name: str) -> int:
        """
        Retrieves the count of rows in a specified PostgreSQL table.

        Args:
            table_name (str): The name of the table to query.

        Returns:
            int: The number of rows in the table.
        """
        try:
            with connect(
                host=POSTGRES_HOST,
                port=POSTGRES_PORT,
                user=POSTGRES_USER,
                password=POSTGRES_PASSWORD,
                dbname=POSTGRES_DB
            ) as conn:
                with conn.cursor() as cur:
                    cur.execute(f"SELECT COUNT(*) FROM {table_name};")
                    count = cur.fetchone()[0]
                    return count
        except OperationalError as e:
            logging.error(f"Error connecting to PostgreSQL: {e}")
            return -1  # Indicate an error

    def _get_redis_keys(self) -> int:
        """
        Retrieves the number of keys in Redis.

        Returns:
            int: The number of keys in Redis.
        """
        try:
            r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
            return len(r.keys("*"))
        except redis.exceptions.ConnectionError as e:
            logging.error(f"Error connecting to Redis: {e}")
            return -1  # Indicate an error

    def inject_chaos(self, chaos_type: str, process_name: str = None, delay_ms: int = 0, disk_fill_percentage: int = 0) -> None:
        """
        Injects chaos into the system based on the specified chaos type.

        Args:
            chaos_type (str): The type of chaos to inject (e.g., "kill_process", "network_delay", "disk_full").
            process_name (str, optional): The name of the process to kill (required for "kill_process"). Defaults to None.
            delay_ms (int, optional): The network delay in milliseconds (required for "network_delay"). Defaults to 0.
            disk_fill_percentage (int, optional): Percentage to fill the disk. Defaults to 0.
        """
        logging.info(f"Injecting chaos: {chaos_type}")
        if chaos_type == "kill_process":
            if not process_name:
                raise ValueError("Process name is required for kill_process chaos.")
            self._kill_process(process_name)
        elif chaos_type == "network_delay":
            self._add_network_delay(delay_ms)
        elif chaos_type == "disk_full":
            self._simulate_disk_full(disk_fill_percentage)
        else:
            raise ValueError(f"Unsupported chaos type: {chaos_type}")

    def _kill_process(self, process_name: str) -> None:
        """
        Kills a process by its name.

        Args:
            process_name (str): The name of the process to kill.
        """
        logging.info(f"Killing process: {process_name}")
        for proc in psutil.process_iter(['pid', 'name']):
            if process_name in proc.info['name']:
                try:
                    process = psutil.Process(proc.info['pid'])
                    process.kill()
                    logging.info(f"Process {process_name} (PID {proc.info['pid']}) killed.")
                    break  # Assuming only one process with that name
                except psutil.NoSuchProcess:
                    logging.warning(f"Process {process_name} not found.")
                except psutil.AccessDenied:
                    logging.error(f"Permission denied to kill process {process_name}.")

    def _add_network_delay(self, delay_ms: int) -> None:
        """
        Adds network delay using the `tc` command.  Requires sudo privileges and tc to be installed.

        Args:
            delay_ms (int): The delay in milliseconds.
        """
        logging.info(f"Adding network delay: {delay_ms}ms")
        try:
            # This command adds delay to the outgoing interface (eth0 - replace if needed)
            subprocess.run(
                ["sudo", "tc", "qdisc", "add", "dev", "eth0", "root", "netem", "delay", f"{delay_ms}ms"],
                check=True, capture_output=True, text=True
            )
            logging.info(f"Successfully added network delay of {delay_ms}ms.")
        except subprocess.CalledProcessError as e:
            logging.error(f"Error adding network delay: {e.stderr}")

    def _remove_network_delay(self) -> None:
        """
        Removes the network delay added by _add_network_delay.
        """
        logging.info("Removing network delay.")
        try:
            # This command removes the delay from the outgoing interface (eth0 - replace if needed)
            subprocess.run(
                ["sudo", "tc", "qdisc", "del", "dev", "eth0", "root"],
                check=True, capture_output=True, text=True
            )
            logging.info("Successfully removed network delay.")
        except subprocess.CalledProcessError as e:
            logging.error(f"Error removing network delay: {e.stderr}")

    def _simulate_disk_full(self, disk_fill_percentage: int) -> None:
        """
        Simulates a disk full condition by creating a large file.

        Args:
            disk_fill_percentage (int): The percentage of the disk to fill (0-100).
        """
        logging.info(f"Simulating disk full: {disk_fill_percentage}%")
        if not 0 <= disk_fill_percentage <= 100:
            raise ValueError("Disk fill percentage must be between 0 and 100.")

        # Determine the free space on the root partition
        statvfs = os.statvfs("/")
        free_bytes = statvfs.f_bavail * statvfs.f_frsize
        fill_bytes = int(free_bytes * (disk_fill_percentage / 100.0))

        # Create a large file to fill the disk
        fill_file = "/tmp/disk_fill_file"
        try:
            with open(fill_file, "wb") as f:
                f.seek(fill_bytes - 1)
                f.write(b"\0")  # Write a single byte to allocate the space
            logging.info(f"Created file {fill_file} to simulate disk full.")
        except OSError as e:
            logging.error(f"Error creating file to simulate disk full: {e}")

    def _remove_disk_full_simulation(self) -> None:
      """
      Removes the disk full simulation file.
      """
      fill_file = "/tmp/disk_fill_file"
      if os.path.exists(fill_file):
        try:
          os.remove(fill_file)
          logging.info(f"Removed disk full simulation file {fill_file}")
        except OSError as e:
          logging.error(f"Error removing disk full simulation file: {e}")


    def monitor_system(self, monitoring_duration: int = 10) -> None:
        """
        Monitors the system for recovery after chaos injection.

        Args:
            monitoring_duration (int): The duration in seconds to monitor the system.
        """
        logging.info(f"Monitoring system for {monitoring_duration} seconds...")
        start_time = time.time()
        while time.time() - start_time < monitoring_duration:
            if self.is_system_recovered():
                self.recovery_time = time.time() - start_time
                logging.info(f"System recovered in {self.recovery_time:.2f} seconds.")
                return
            time.sleep(1)
        logging.warning("System did not recover within the monitoring duration.")
        self.recovery_time = float('inf')  # Indicate no recovery

    def is_system_recovered(self) -> bool:
        """
        Checks if the system has recovered after chaos injection.

        Returns:
            bool: True if the system has recovered, False otherwise.
        """
        try:
            current_postgres_count = self._get_postgres_count("your_table")  # Replace with actual table
            current_redis_keys = self._get_redis_keys()

            # Define recovery conditions.  These should be tailored to your specific application.
            postgres_recovered = current_postgres_count == self.initial_state["postgres_count"]
            redis_recovered = current_redis_keys == self.initial_state["redis_keys"]

            if postgres_recovered and redis_recovered:
                logging.info("System recovery detected.")
                return True
            else:
                logging.debug(f"System not yet recovered. Postgres: {postgres_recovered}, Redis: {redis_recovered}")
                return False

        except Exception as e:
            logging.error(f"Error checking system recovery: {e}")
            return False

    def verify_data_integrity(self) -> bool:
        """
        Verifies data integrity after recovery.

        Returns:
            bool: True if data integrity is maintained, False otherwise.
        """
        # Implement data integrity checks based on your application's data model.
        # This might involve comparing data before and after the chaos test.
        logging.info("Performing data integrity checks...")
        try:
            # Example: Compare database counts before and after
            final_postgres_count = self._get_postgres_count("your_table")  # Replace with actual table
            if final_postgres_count != self.initial_state["postgres_count"]:
                logging.error("Data integrity check failed: PostgreSQL count mismatch.")
                return False

            # Example: Verify specific Redis keys
            final_redis_keys = self._get_redis_keys()
            if final_redis_keys != self.initial_state["redis_keys"]:
                logging.error("Data integrity check failed: Redis key count mismatch.")
                return False
            logging.info("Data integrity checks passed.")
            return True
        except Exception as e:
            logging.error(f"Error during data integrity check: {e}")
            return False

    def calculate_resilience_score(self) -> float:
        """
        Calculates a resilience score based on recovery time and data integrity.

        Returns:
            float: The resilience score (higher is better).
        """
        logging.info("Calculating resilience score...")
        # Define weights for recovery time and data integrity
        recovery_time_weight = 0.7
        data_integrity_weight = 0.3

        # Normalize recovery time (lower is better)
        if self.recovery_time == float('inf'):
            normalized_recovery_time = 0.0  # Penalize for no recovery
        else:
            # Assuming a maximum acceptable recovery time of 60 seconds
            normalized_recovery_time = max(0.0, 1.0 - (self.recovery_time / 60.0))

        # Data integrity score (1.0 if passed, 0.0 if failed)
        data_integrity_score = 1.0 if self.verify_data_integrity() else 0.0

        # Calculate the resilience score
        self.resilience_score = (recovery_time_weight * normalized_recovery_time) + \
                                 (data_integrity_weight * data_integrity_score)

        logging.info(f"Resilience score: {self.resilience_score:.2f}")
        return self.resilience_score

    def generate_report(self) -> Dict[str, any]:
        """
        Generates a report summarizing the chaos test results.

        Returns:
            Dict[str, any]: A dictionary containing the test results.
        """
        report = {
            "initial_state": self.initial_state,
            "recovery_time": self.recovery_time,
            "data_integrity": self.verify_data_integrity(),
            "resilience_score": self.resilience_score
        }
        logging.info(f"Chaos test report: {report}")
        return report

    def run_chaos_test(self, chaos_type: str, process_name: str = None, delay_ms: int = 0, disk_fill_percentage: int = 0, monitoring_duration: int = 10) -> Dict[str, any]:
        """
        Runs a complete chaos test.

        Args:
            chaos_type (str): The type of chaos to inject.
            process_name (str, optional): The name of the process to kill. Defaults to None.
            delay_ms (int, optional): The network delay in milliseconds. Defaults to 0.
            disk_fill_percentage (int, optional): The percentage of the disk to fill. Defaults to 0.
            monitoring_duration (int, optional): The duration to monitor for recovery. Defaults to 10.

        Returns:
            Dict[str, any]: A dictionary containing the test report.
        """
        try:
            self.record_initial_state()
            self.inject_chaos(chaos_type, process_name, delay_ms, disk_fill_percentage)
            self.monitor_system(monitoring_duration)
            report = self.generate_report()
            return report
        finally:
            # Ensure that network delay and disk full simulations are removed
            # even if the test fails.
            self._remove_network_delay()
            self._remove_disk_full_simulation()

if __name__ == '__main__':
    # Example usage
    framework = ChaosTestFramework()
    try:
      report = framework.run_chaos_test(chaos_type="kill_process", process_name="ollama", monitoring_duration=20)
      print(report)

      report = framework.run_chaos_test(chaos_type="network_delay", delay_ms=200, monitoring_duration=20)
      print(report)

      report = framework.run_chaos_test(chaos_type="disk_full", disk_fill_percentage=90, monitoring_duration=20)
      print(report)
    except Exception as e:
      logging.error(f"Test failed: {e}")
