import datetime
import logging
import os
import time
from typing import Dict, Optional

import redis
import psycopg2
import psycopg2.extras

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Database configuration
DB_HOST = os.getenv("DB_HOST", "postgresql-genesis-u50607.vm.elestio.app")
DB_PORT = os.getenv("DB_PORT", "25432")
DB_NAME = os.getenv("DB_NAME", "genesis")
DB_USER = os.getenv("DB_USER", "genesis")
DB_PASSWORD = os.getenv("DB_PASSWORD", "genesis")

# Redis configuration
REDIS_HOST = os.getenv("REDIS_HOST", "redis-genesis-u50607.vm.elestio.app")
REDIS_PORT = os.getenv("REDIS_PORT", "26379")

ERROR_RATE_THRESHOLD = 0.05  # 5% threshold

class ErrorTracker:
    """
    Tracks and analyzes error rates across systems.
    """

    def __init__(self):
        """
        Initializes the ErrorTracker with database and Redis connections.
        """
        self.db_conn = self._connect_db()
        self.redis_client = self._connect_redis()
        self.error_categories = ["api_errors", "timeout_errors", "validation_errors"]

    def _connect_db(self):
        """
        Establishes a connection to the PostgreSQL database.

        Returns:
            psycopg2.extensions.connection: The database connection object.
        """
        try:
            conn = psycopg2.connect(
                host=DB_HOST,
                port=DB_PORT,
                database=DB_NAME,
                user=DB_USER,
                password=DB_PASSWORD
            )
            logging.info("Database connection established.")
            return conn
        except psycopg2.Error as e:
            logging.error(f"Error connecting to database: {e}")
            raise

    def _connect_redis(self):
        """
        Establishes a connection to the Redis server.

        Returns:
            redis.Redis: The Redis client object.
        """
        try:
            redis_client = redis.Redis(host=REDIS_HOST, port=int(REDIS_PORT), decode_responses=True)
            redis_client.ping()  # Check connection
            logging.info("Redis connection established.")
            return redis_client
        except redis.exceptions.ConnectionError as e:
            logging.error(f"Error connecting to Redis: {e}")
            raise

    def track_error(self, category: str) -> None:
        """
        Tracks an error by incrementing the error count in Redis.

        Args:
            category (str): The category of the error (e.g., "api_errors").
        """
        if category not in self.error_categories:
            logging.warning(f"Invalid error category: {category}. Using 'api_errors' instead.")
            category = "api_errors" # default if invalid category

        try:
            now = datetime.datetime.utcnow()
            timestamp = now.strftime("%Y-%m-%d %H:%M:%S")  # Store timestamp for later analysis

            self.redis_client.incr(f"error:{category}:count")
            self.redis_client.lpush(f"error:{category}:timestamps", timestamp)  # Store timestamp
            self.redis_client.expire(f"error:{category}:timestamps", 60 * 60 * 24) # expire after 24h
            logging.info(f"Error tracked in category '{category}'.")

        except redis.exceptions.RedisError as e:
            logging.error(f"Error incrementing error count in Redis: {e}")

    def calculate_error_rate(self, category: str, time_window_hours: int) -> float:
        """
        Calculates the error rate for a given category and time window.

        Args:
            category (str): The category of the error.
            time_window_hours (int): The time window in hours.

        Returns:
            float: The error rate (between 0 and 1). Returns 0 if no data is available.
        """
        try:
            error_count = int(self.redis_client.get(f"error:{category}:count") or 0)
            if error_count == 0:
                return 0.0

            # Fetch timestamps from Redis
            timestamps = self.redis_client.lrange(f"error:{category}:timestamps", 0, -1)
            if not timestamps:
                return 0.0

            # Calculate the number of errors within the time window
            now = datetime.datetime.utcnow()
            time_window_start = now - datetime.timedelta(hours=time_window_hours)
            errors_in_window = 0
            for timestamp_str in timestamps:
                try:
                    timestamp = datetime.datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S")
                    if timestamp >= time_window_start:
                        errors_in_window += 1
                except ValueError as e:
                    logging.warning(f"Invalid timestamp format: {timestamp_str}. Skipping. Error: {e}")
                    continue

            # In a real system, you'd ideally fetch total request count from a monitoring system.
            # For this example, we'll simulate a constant request rate.
            total_requests = 100 * time_window_hours  # Example: 100 requests per hour

            if total_requests == 0:
                return 0.0

            error_rate = errors_in_window / total_requests
            logging.info(f"Error rate for '{category}' in the last {time_window_hours} hours: {error_rate:.4f}")
            return error_rate

        except redis.exceptions.RedisError as e:
            logging.error(f"Error calculating error rate: {e}")
            return 0.0

    def calculate_error_trend(self, category: str) -> Optional[float]:
        """
        Calculates the error trend by comparing the error rate in the last hour
        to the error rate in the previous 24 hours.

        Args:
            category (str): The category of the error.

        Returns:
            Optional[float]: The error trend (positive or negative).
                           Returns None if there is not enough data.
        """
        try:
            error_rate_1h = self.calculate_error_rate(category, 1)
            error_rate_24h = self.calculate_error_rate(category, 24)

            if error_rate_24h == 0:
                return None  # Avoid division by zero and indicate no trend

            error_trend = (error_rate_1h - error_rate_24h) / error_rate_24h
            logging.info(f"Error trend for '{category}': {error_trend:.4f}")
            return error_trend

        except Exception as e:
            logging.error(f"Error calculating error trend: {e}")
            return None

    def check_and_alert(self, category: str) -> None:
        """
        Checks if the error rate exceeds the threshold and triggers an alert if necessary.

        Args:
            category (str): The category of the error.
        """
        try:
            error_rate_1h = self.calculate_error_rate(category, 1)

            if error_rate_1h > ERROR_RATE_THRESHOLD:
                self._trigger_alert(category, error_rate_1h)

        except Exception as e:
            logging.error(f"Error checking and alerting: {e}")

    def _trigger_alert(self, category: str, error_rate: float) -> None:
        """
        Triggers an alert (e.g., sends a notification).

        Args:
            category (str): The category of the error.
            error_rate (float): The current error rate.
        """
        # In a real system, this would send a notification to an alerting system.
        logging.warning(f"ALERT: Error rate for '{category}' is above threshold ({error_rate:.4f} > {ERROR_RATE_THRESHOLD}).")

    def close(self):
        """
        Closes the database connection.
        """
        if self.db_conn:
            self.db_conn.close()
            logging.info("Database connection closed.")

if __name__ == '__main__':
    error_tracker = ErrorTracker()

    try:
        # Example Usage
        error_tracker.track_error("api_errors")
        error_tracker.track_error("timeout_errors")
        error_tracker.track_error("api_errors")
        error_tracker.track_error("validation_errors")

        time.sleep(1) # allow time for the timestamps to accumulate

        error_rate_1h = error_tracker.calculate_error_rate("api_errors", 1)
        print(f"API Error Rate (1h): {error_rate_1h}")

        error_rate_24h = error_tracker.calculate_error_rate("api_errors", 24)
        print(f"API Error Rate (24h): {error_rate_24h}")

        error_trend = error_tracker.calculate_error_trend("api_errors")
        print(f"API Error Trend: {error_trend}")

        error_tracker.check_and_alert("api_errors")

    except Exception as e:
        logging.error(f"An error occurred during example usage: {e}")
    finally:
        error_tracker.close()
