import os
import logging
import datetime
import json
from typing import Dict, Any

import redis
import psycopg2
from psycopg2 import extras

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Database and Redis configuration (using environment variables for security)
DATABASE_HOST = os.environ.get("DATABASE_HOST", "postgresql-genesis-u50607.vm.elestio.app")
DATABASE_PORT = os.environ.get("DATABASE_PORT", "25432")
DATABASE_NAME = os.environ.get("DATABASE_NAME", "postgres")  # Replace with your actual database name if different
DATABASE_USER = os.environ.get("DATABASE_USER", "postgres")  # Replace with your actual database user
DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD")  # Ensure this is set in the environment!
REDIS_HOST = os.environ.get("REDIS_HOST", "redis-genesis-u50607.vm.elestio.app")
REDIS_PORT = os.environ.get("REDIS_PORT", "26379")

# Default threshold (can be overridden by config)
DEFAULT_DAILY_THRESHOLD = 1000.0

class CostAggregator:
    """
    Aggregates costs in real-time, calculates daily, weekly, and monthly spend,
    and triggers alerts if daily spend exceeds a threshold.  Exposes an API endpoint
    to retrieve cost data.
    """

    def __init__(self, daily_threshold: float = DEFAULT_DAILY_THRESHOLD):
        """
        Initializes the CostAggregator with database and Redis connections.

        Args:
            daily_threshold: The threshold for daily spend, above which an alert is triggered.
        """
        self.daily_threshold = daily_threshold
        self.redis_client = self._connect_redis()
        self.db_conn = self._connect_db()
        self.db_cursor = self.db_conn.cursor(cursor_factory=extras.RealDictCursor)
        self._create_table_if_not_exists()

    def _connect_redis(self) -> redis.Redis:
        """
        Establishes a connection to Redis.

        Returns:
            A Redis client instance.
        """
        try:
            redis_client = redis.Redis(host=REDIS_HOST, port=int(REDIS_PORT), db=0)
            redis_client.ping()  # Check connection
            logging.info("Connected to Redis successfully.")
            return redis_client
        except redis.exceptions.ConnectionError as e:
            logging.error(f"Failed to connect to Redis: {e}")
            raise

    def _connect_db(self) -> psycopg2.extensions.connection:
        """
        Establishes a connection to the PostgreSQL database.

        Returns:
            A PostgreSQL connection instance.
        """
        try:
            conn = psycopg2.connect(
                host=DATABASE_HOST,
                port=DATABASE_PORT,
                database=DATABASE_NAME,
                user=DATABASE_USER,
                password=DATABASE_PASSWORD
            )
            logging.info("Connected to PostgreSQL successfully.")
            return conn
        except psycopg2.Error as e:
            logging.error(f"Failed to connect to PostgreSQL: {e}")
            raise

    def _create_table_if_not_exists(self) -> None:
        """
        Creates the 'costs' table in the database if it doesn't already exist.
        """
        try:
            self.db_cursor.execute("""
                CREATE TABLE IF NOT EXISTS costs (
                    id SERIAL PRIMARY KEY,
                    timestamp TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() at time zone 'utc'),
                    amount DECIMAL NOT NULL
                )
            """)
            self.db_conn.commit()
            logging.info("Ensured 'costs' table exists.")
        except psycopg2.Error as e:
            logging.error(f"Error creating table: {e}")
            self.db_conn.rollback()
            raise

    def record_cost(self, amount: float) -> None:
        """
        Records a cost in the database and updates the Redis cache.

        Args:
            amount: The cost amount to record.
        """
        try:
            self.db_cursor.execute("INSERT INTO costs (amount) VALUES (%s)", (amount,))
            self.db_conn.commit()
            logging.info(f"Recorded cost: {amount}")
            self._update_cache()
        except psycopg2.Error as e:
            logging.error(f"Error recording cost: {e}")
            self.db_conn.rollback()
            raise

    def _update_cache(self) -> None:
        """
        Updates the Redis cache with the latest cost aggregations.
        """
        now = datetime.datetime.utcnow()
        today = now.date()
        start_of_week = today - datetime.timedelta(days=today.weekday())
        start_of_month = today.replace(day=1)

        try:
            self.db_cursor.execute("""
                SELECT
                    SUM(CASE WHEN timestamp >= %s THEN amount ELSE 0 END) as daily_spend,
                    SUM(CASE WHEN timestamp >= %s THEN amount ELSE 0 END) as weekly_spend,
                    SUM(CASE WHEN timestamp >= %s THEN amount ELSE 0 END) as monthly_spend
                FROM costs
            """, (today, start_of_week, start_of_month))
            result = self.db_cursor.fetchone()

            if result:
                daily_spend = float(result['daily_spend'] or 0)  # Handle potential None values
                weekly_spend = float(result['weekly_spend'] or 0)
                monthly_spend = float(result['monthly_spend'] or 0)

                self.redis_client.set("daily_spend", str(daily_spend))
                self.redis_client.set("weekly_spend", str(weekly_spend))
                self.redis_client.set("monthly_spend", str(monthly_spend))

                logging.info(f"Cache updated: daily={daily_spend}, weekly={weekly_spend}, monthly={monthly_spend}")

                if daily_spend > self.daily_threshold:
                    self._trigger_alert(daily_spend)

        except psycopg2.Error as e:
            logging.error(f"Error updating cache: {e}")
            raise

    def _trigger_alert(self, daily_spend: float) -> None:
        """
        Triggers an alert if the daily spend exceeds the threshold.

        Args:
            daily_spend: The current daily spend.
        """
        logging.warning(f"ALERT: Daily spend of {daily_spend} exceeds threshold of {self.daily_threshold}")
        # In a real system, this would send an email, SMS, or other notification.
        # Implement alert logic here (e.g., send an email, push notification, etc.)
        # For demonstration purposes, just logging the alert.

    def get_costs(self) -> Dict[str, float]:
        """
        Retrieves the current cost aggregations from the Redis cache.

        Returns:
            A dictionary containing the daily, weekly, and monthly spend.
        """
        try:
            daily_spend = float(self.redis_client.get("daily_spend") or 0)
            weekly_spend = float(self.redis_client.get("weekly_spend") or 0)
            monthly_spend = float(self.redis_client.get("monthly_spend") or 0)

            return {
                "daily_spend": daily_spend,
                "weekly_spend": weekly_spend,
                "monthly_spend": monthly_spend
            }
        except Exception as e:
            logging.error(f"Error retrieving costs from cache: {e}")
            return {
                "daily_spend": 0.0,
                "weekly_spend": 0.0,
                "monthly_spend": 0.0
            }

    def close(self) -> None:
        """
        Closes the database connection.
        """
        if self.db_cursor:
            self.db_cursor.close()
        if self.db_conn:
            self.db_conn.close()
        logging.info("Database connection closed.")

if __name__ == '__main__':
    # Example usage (for testing purposes)
    try:
        aggregator = CostAggregator(daily_threshold=500)  # Set a lower threshold for testing

        # Simulate some costs
        aggregator.record_cost(100.0)
        aggregator.record_cost(250.0)
        aggregator.record_cost(300.0)
        aggregator.record_cost(600.0)  # This should trigger an alert

        # Retrieve and print the costs
        costs = aggregator.get_costs()
        print(f"Current Costs: {costs}")

    except Exception as e:
        logging.error(f"An error occurred: {e}")
    finally:
        if 'aggregator' in locals() and isinstance(aggregator, CostAggregator):
            aggregator.close() # Close the connection in the finally block.
