import psycopg2
import psycopg2.extras
import datetime
import os
import logging
from typing import Optional, List, Tuple

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class CostTrackerDB:
    """
    Handles database interactions for tracking API costs and AIVA fixed costs.
    """

    def __init__(self, dbname: str, user: str, password: str, host: str, port: int):
        """
        Initializes the CostTrackerDB connection.

        Args:
            dbname (str): The name of the database.
            user (str): The database user.
            password (str): The database password.
            host (str): The database host.
            port (int): The database port.
        """
        self.dbname = dbname
        self.user = user
        self.password = password
        self.host = host
        self.port = port
        self.conn = None  # Initialize connection to None

    def connect(self):
        """Connects to the PostgreSQL database."""
        try:
            self.conn = psycopg2.connect(
                dbname=self.dbname,
                user=self.user,
                password=self.password,
                host=self.host,
                port=self.port
            )
            self.conn.autocommit = True  # Enable autocommit
            logging.info("Successfully connected to the database.")
        except psycopg2.Error as e:
            logging.error(f"Error connecting to the database: {e}")
            raise

    def close(self):
        """Closes the database connection."""
        if self.conn:
            self.conn.close()
            logging.info("Database connection closed.")

    def insert_api_cost(self, provider: str, model: str, tokens_in: int, tokens_out: int, cost_usd: float):
        """
        Inserts API cost data into the api_costs table.

        Args:
            provider (str): The API provider (e.g., OpenAI).
            model (str): The model used (e.g., gpt-4).
            tokens_in (int): Number of input tokens.
            tokens_out (int): Number of output tokens.
            cost_usd (float): Cost in USD.
        """
        try:
            if not self.conn:
                self.connect()  # Ensure connection before executing queries
            cur = self.conn.cursor()
            sql = """
                INSERT INTO api_costs (provider, model, tokens_in, tokens_out, cost_usd)
                VALUES (%s, %s, %s, %s, %s);
            """
            cur.execute(sql, (provider, model, tokens_in, tokens_out, cost_usd))
            self.conn.commit()
            logging.info(f"Inserted API cost: Provider={provider}, Model={model}, Cost={cost_usd}")
        except psycopg2.Error as e:
            logging.error(f"Error inserting API cost data: {e}")
            raise
        finally:
            if cur:
                cur.close()


    def insert_aiva_fixed_cost(self, cost_usd: float):
        """
        Inserts AIVA fixed cost data into the aiva_fixed_costs table.

        Args:
            cost_usd (float): The fixed cost in USD.
        """
        try:
            if not self.conn:
                self.connect()
            cur = self.conn.cursor()
            sql = """
                INSERT INTO aiva_fixed_costs (cost_usd)
                VALUES (%s);
            """
            cur.execute(sql, (cost_usd,))
            self.conn.commit()
            logging.info(f"Inserted AIVA fixed cost: Cost={cost_usd}")
        except psycopg2.Error as e:
            logging.error(f"Error inserting AIVA fixed cost data: {e}")
            raise
        finally:
            if cur:
                cur.close()

    def query_api_costs(self, provider: Optional[str] = None, model: Optional[str] = None,
                        start_date: Optional[datetime.datetime] = None, end_date: Optional[datetime.datetime] = None) -> List[Tuple]:
        """
        Queries the api_costs table based on specified filters.

        Args:
            provider (Optional[str]): Filter by API provider.
            model (Optional[str]): Filter by model.
            start_date (Optional[datetime.datetime]): Filter by start date.
            end_date (Optional[datetime.datetime]): Filter by end date.

        Returns:
            List[Tuple]: A list of tuples, where each tuple represents a row from the api_costs table.
        """
        try:
            if not self.conn:
                self.connect()
            cur = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
            sql = "SELECT * FROM api_costs WHERE 1=1"  # Start with a base query
            params = []

            if provider:
                sql += " AND provider = %s"
                params.append(provider)
            if model:
                sql += " AND model = %s"
                params.append(model)
            if start_date:
                sql += " AND timestamp >= %s"
                params.append(start_date)
            if end_date:
                sql += " AND timestamp <= %s"
                params.append(end_date)

            sql += " ORDER BY timestamp"  # Add ordering

            cur.execute(sql, params)
            results = cur.fetchall()
            logging.info(f"Query executed successfully.  Returned {len(results)} results")
            return results
        except psycopg2.Error as e:
            logging.error(f"Error querying API costs: {e}")
            raise
        finally:
            if cur:
                cur.close()



    def table_exists(self, table_name: str) -> bool:
        """
        Checks if a table exists in the database.

        Args:
            table_name (str): The name of the table to check.

        Returns:
            bool: True if the table exists, False otherwise.
        """
        try:
            if not self.conn:
                self.connect()
            cur = self.conn.cursor()
            sql = "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = %s);"
            cur.execute(sql, (table_name,))
            exists = cur.fetchone()[0]
            return exists
        except psycopg2.Error as e:
            logging.error(f"Error checking if table exists: {e}")
            return False
        finally:
            if cur:
                cur.close()


def initialize_database(db: CostTrackerDB):
    """
    Initializes the database by checking for the existence of the api_costs and aiva_fixed_costs tables.
    If they do not exist, it will log an error (assuming the SQL schema should have created them).
    """
    api_costs_exists = db.table_exists("api_costs")
    aiva_fixed_costs_exists = db.table_exists("aiva_fixed_costs")

    if not api_costs_exists:
        logging.error("The api_costs table does not exist. Ensure the schema has been created.")
    else:
        logging.info("The api_costs table exists.")

    if not aiva_fixed_costs_exists:
        logging.error("The aiva_fixed_costs table does not exist. Ensure the schema has been created.")
    else:
        logging.info("The aiva_fixed_costs table exists.")


if __name__ == '__main__':
    # Example usage
    dbname = os.environ.get("DB_NAME", "genesis")  # Replace with your database name
    user = os.environ.get("DB_USER", "genesis")  # Replace with your database user
    password = os.environ.get("DB_PASSWORD", "genesis")  # Replace with your database password
    host = os.environ.get("DB_HOST", "postgresql-genesis-u50607.vm.elestio.app")  # Replace with your database host
    port = int(os.environ.get("DB_PORT", 25432))  # Replace with your database port

    db = CostTrackerDB(dbname, user, password, host, port)

    try:
        db.connect()  # Connect to the database

        # Initialize the database (check for table existence)
        initialize_database(db)

        # Example data
        provider = "OpenAI"
        model = "gpt-4"
        tokens_in = 1000
        tokens_out = 500
        cost_usd = 0.01

        # Insert API cost
        db.insert_api_cost(provider, model, tokens_in, tokens_out, cost_usd)

        # Insert AIVA fixed cost (daily allocation)
        aiva_daily_cost = 119.00 / 30
        db.insert_aiva_fixed_cost(aiva_daily_cost)

        # Query API costs
        start_date = datetime.datetime(2023, 1, 1)
        end_date = datetime.datetime(2024, 1, 1)
        results = db.query_api_costs(provider=provider, start_date=start_date, end_date=end_date)
        print("Query Results:", results)  # Print the query results


    except Exception as e:
        logging.error(f"An error occurred: {e}")
    finally:
        db.close()  # Close the connection in the finally block
