import redis
import time
import logging
import json
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from psycopg2 import pool
import os
import schedule
import threading
import requests

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load environment variables (adjust as needed for your environment)
REDIS_HOST = os.environ.get('REDIS_HOST', 'redis-genesis-u50607.vm.elestio.app')
REDIS_PORT = int(os.environ.get('REDIS_PORT', '26379'))
POSTGRES_HOST = os.environ.get('POSTGRES_HOST', 'postgresql-genesis-u50607.vm.elestio.app')
POSTGRES_PORT = int(os.environ.get('POSTGRES_PORT', '25432'))
POSTGRES_DB = os.environ.get('POSTGRES_DB', 'postgres')
POSTGRES_USER = os.environ.get('POSTGRES_USER', 'postgres')
POSTGRES_PASSWORD = os.environ.get('POSTGRES_PASSWORD', 'postgres')
QUEUE_NAME = os.environ.get('QUEUE_NAME', 'task_queue')
API_ENDPOINT = os.environ.get('API_ENDPOINT', 'http://localhost:8000/api/queue/status') # Example, adjust as needed

class QueueMonitor:
    """
    Monitors the Redis task queue depth, processing rate, average wait time, and identifies stuck tasks.
    Provides alerts if queue depth exceeds a threshold or if stuck tasks are detected.
    """

    def __init__(self, redis_host: str, redis_port: int, pg_host: str, pg_port: int, pg_db: str, pg_user: str, pg_password: str, queue_name: str):
        """
        Initializes the QueueMonitor with Redis and PostgreSQL connection details.

        Args:
            redis_host (str): Redis host address.
            redis_port (int): Redis port number.
            pg_host (str): PostgreSQL host address.
            pg_port (int): PostgreSQL port number.
            pg_db (str): PostgreSQL database name.
            pg_user (str): PostgreSQL user name.
            pg_password (str): PostgreSQL password.
            queue_name (str): The name of the Redis queue to monitor.
        """
        self.redis_host = redis_host
        self.redis_port = redis_port
        self.redis_client = redis.Redis(host=self.redis_host, port=self.redis_port, db=0)  # Use db=0 for default database
        self.pg_host = pg_host
        self.pg_port = pg_port
        self.pg_db = pg_db
        self.pg_user = pg_user
        self.pg_password = pg_password
        self.queue_name = queue_name
        self.db_pool = self._create_db_pool()
        self.queue_depth_threshold = 100
        self.stuck_task_age_threshold = timedelta(minutes=60) #Consider tasks older than 60 mins as stuck
        self.processing_history: list[tuple[datetime, int]] = []  # History of (timestamp, queue_depth)
        self.processing_history_length = 60  # Store history for the last 60 seconds
        self.lock = threading.Lock() # Protect access to shared resources

    def _create_db_pool(self) -> pool.SimpleConnectionPool:
        """Creates a connection pool to the PostgreSQL database."""
        try:
            db_pool = pool.SimpleConnectionPool(
                1,  # min connections
                10, # max connections
                host=self.pg_host,
                port=self.pg_port,
                database=self.pg_db,
                user=self.pg_user,
                password=self.pg_password
            )
            logging.info("PostgreSQL connection pool created successfully.")
            return db_pool
        except Exception as e:
            logging.error(f"Error creating PostgreSQL connection pool: {e}")
            raise

    def _get_db_connection(self):
        """Gets a connection from the connection pool."""
        try:
            conn = self.db_pool.getconn()
            return conn
        except Exception as e:
            logging.error(f"Error getting database connection: {e}")
            return None

    def _release_db_connection(self, conn):
        """Releases a connection back to the connection pool."""
        try:
            self.db_pool.putconn(conn)
        except Exception as e:
            logging.error(f"Error releasing database connection: {e}")

    def get_queue_depth(self) -> int:
        """
        Retrieves the current depth of the Redis queue.

        Returns:
            int: The number of items in the queue.
        """
        try:
            queue_depth = self.redis_client.llen(self.queue_name)
            return queue_depth
        except redis.exceptions.ConnectionError as e:
            logging.error(f"Error connecting to Redis: {e}")
            return -1  # Indicate an error

    def calculate_processing_rate(self) -> float:
        """
        Calculates the processing rate of the queue based on historical queue depth data.

        Returns:
            float: The processing rate (number of tasks processed per second).
                   Returns 0.0 if there is insufficient data.
        """
        with self.lock:  # Protect access to processing_history
            now = datetime.now()
            self.processing_history.append((now, self.get_queue_depth()))

            # Trim history to the desired length
            self.processing_history = self.processing_history[-self.processing_history_length:]

            if len(self.processing_history) < 2:
                return 0.0  # Not enough data to calculate rate

            oldest_time, oldest_depth = self.processing_history[0]
            newest_time, newest_depth = self.processing_history[-1]

            time_diff = (newest_time - oldest_time).total_seconds()
            depth_diff = oldest_depth - newest_depth  # Assuming depth decreases as tasks are processed

            if time_diff > 0:
                processing_rate = depth_diff / time_diff
            else:
                processing_rate = 0.0

            return processing_rate

    def analyze_stuck_tasks(self) -> list[Dict[str, Any]]:
        """
        Identifies tasks that have been in the queue for longer than the defined threshold.

        Returns:
            list[Dict[str, Any]]: A list of dictionaries, each representing a stuck task with details.
        """
        # This is a placeholder.  A real implementation would need to store
        # enqueue timestamps with each task in Redis.  This example assumes
        # that the task data itself contains an 'enqueue_time' field.

        stuck_tasks = []
        queue_depth = self.get_queue_depth()

        for i in range(queue_depth):
            try:
                # Peek at the task without removing it
                task_data_str = self.redis_client.lindex(self.queue_name, i)

                if task_data_str:
                    task_data = json.loads(task_data_str.decode('utf-8'))  # Assuming tasks are stored as JSON strings.
                    enqueue_time_str = task_data.get('enqueue_time')  # Assuming 'enqueue_time' is in the task data

                    if enqueue_time_str:
                        enqueue_time = datetime.fromisoformat(enqueue_time_str)
                        if datetime.now() - enqueue_time > self.stuck_task_age_threshold:
                            stuck_tasks.append({
                                'task_index': i,
                                'enqueue_time': enqueue_time_str,
                                'task_data': task_data
                            })
            except Exception as e:
                logging.error(f"Error analyzing task at index {i}: {e}")

        return stuck_tasks

    def get_average_wait_time(self) -> float:
        """
        Calculates the average wait time for tasks in the queue.

        Returns:
            float: The average wait time in seconds.
        """
        # This is a placeholder.  A real implementation would require tracking
        # enqueue and dequeue times for each task, potentially using a
        # separate data structure or database.  This returns a dummy value.
        return 0.0

    def check_alerts(self, queue_depth: int, stuck_tasks: list[Dict[str, Any]]):
        """
        Checks for alert conditions based on queue depth and the presence of stuck tasks.

        Args:
            queue_depth (int): The current depth of the queue.
            stuck_tasks (list[Dict[str, Any]]): A list of stuck tasks.
        """
        if queue_depth > self.queue_depth_threshold:
            logging.warning(f"ALERT: Queue depth ({queue_depth}) exceeds threshold ({self.queue_depth_threshold})")

        if stuck_tasks:
            logging.warning(f"ALERT: Found {len(stuck_tasks)} stuck tasks: {stuck_tasks}")

    def collect_metrics(self) -> Dict[str, Any]:
        """
        Collects queue metrics, checks for alerts, and returns the metrics.

        Returns:
            Dict[str, Any]: A dictionary containing the collected metrics.
        """
        queue_depth = self.get_queue_depth()
        processing_rate = self.calculate_processing_rate()
        avg_wait_time = self.get_average_wait_time()
        stuck_tasks = self.analyze_stuck_tasks()

        self.check_alerts(queue_depth, stuck_tasks)

        metrics = {
            'queue_depth': queue_depth,
            'processing_rate': processing_rate,
            'avg_wait_time': avg_wait_time,
            'stuck_tasks': stuck_tasks
        }

        return metrics

    def post_metrics(self, metrics: Dict[str, Any], api_endpoint: str):
        """
        Posts the collected metrics to the specified API endpoint.

        Args:
            metrics (Dict[str, Any]): The metrics to post.
            api_endpoint (str): The API endpoint to send the metrics to.
        """
        try:
            headers = {'Content-Type': 'application/json'}
            response = requests.post(api_endpoint, data=json.dumps(metrics), headers=headers)
            response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
            logging.info(f"Metrics posted successfully to {api_endpoint}. Status code: {response.status_code}")
        except requests.exceptions.RequestException as e:
            logging.error(f"Error posting metrics to {api_endpoint}: {e}")

    def run_monitor(self, api_endpoint: str):
        """
        Runs the queue monitor, collecting metrics and posting them to the API endpoint.
        """
        try:
            metrics = self.collect_metrics()
            self.post_metrics(metrics, api_endpoint)
        except Exception as e:
            logging.error(f"Error during monitor execution: {e}")

def main():
    """
    Main function to initialize and run the QueueMonitor.
    """
    try:
        queue_monitor = QueueMonitor(
            redis_host=REDIS_HOST,
            redis_port=REDIS_PORT,
            pg_host=POSTGRES_HOST,
            pg_port=POSTGRES_PORT,
            pg_db=POSTGRES_DB,
            pg_user=POSTGRES_USER,
            pg_password=POSTGRES_PASSWORD,
            queue_name=QUEUE_NAME
        )

        # Schedule the monitor to run every 10 seconds
        schedule.every(10).seconds.do(queue_monitor.run_monitor, api_endpoint=API_ENDPOINT)

        logging.info("Queue monitor started.  Press Ctrl+C to exit.")

        while True:
            schedule.run_pending()
            time.sleep(1)

    except Exception as e:
        logging.error(f"An error occurred: {e}")


if __name__ == "__main__":
    main()
