import logging
import logging.handlers
import time
import json
import os
from typing import Dict, Any
import psycopg2
from psycopg2 import sql
from urllib.parse import urlparse

# Configure logging
log_file_path = os.path.join(os.getcwd(), 'logs', 'api.log')
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# File handler with rotation
log_rotation_days = 7
rotating_handler = logging.handlers.TimedRotatingFileHandler(
    log_file_path,
    when="D",
    interval=1,
    backupCount=log_rotation_days
)
log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
rotating_handler.setFormatter(log_format)
logger.addHandler(rotating_handler)


class APILogger:
    """
    A class for logging API calls to both a file and a PostgreSQL database.
    """

    def __init__(self, db_url: str):
        """
        Initializes the APILogger with database connection details.

        Args:
            db_url: The PostgreSQL database URL.
        """
        self.db_url = db_url
        self.conn = None
        self.cursor = None
        self._connect()

    def _connect(self):
        """
        Establishes a connection to the PostgreSQL database.
        """
        try:
            url = urlparse(self.db_url)
            self.conn = psycopg2.connect(
                host=url.hostname,
                port=url.port or 5432,
                database=url.path[1:],
                user=url.username,
                password=url.password
            )
            self.cursor = self.conn.cursor()
            self._create_table()
        except Exception as e:
            logger.error(f"Database connection error: {e}")
            raise

    def _create_table(self):
        """
        Creates the api_logs table if it doesn't exist.
        """
        try:
            create_table_query = sql.SQL("""
                CREATE TABLE IF NOT EXISTS api_logs (
                    id SERIAL PRIMARY KEY,
                    endpoint VARCHAR(255) NOT NULL,
                    model VARCHAR(255),
                    prompt_tokens INTEGER,
                    completion_tokens INTEGER,
                    latency_ms INTEGER,
                    timestamp TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'utc'),
                    context JSONB
                )
            """)
            self.cursor.execute(create_table_query)
            self.conn.commit()
        except Exception as e:
            logger.error(f"Error creating table: {e}")
            self.conn.rollback()
            raise

    def log_api_call(self, endpoint: str, model: str, prompt_tokens: int, completion_tokens: int, latency_ms: int, context: Dict[str, Any]):
        """
        Logs an API call to both the file and the PostgreSQL database.

        Args:
            endpoint: The API endpoint.
            model: The model used for the API call.
            prompt_tokens: The number of prompt tokens.
            completion_tokens: The number of completion tokens.
            latency_ms: The latency of the API call in milliseconds.
            context: A dictionary containing additional context information.
        """
        try:
            log_data = {
                "endpoint": endpoint,
                "model": model,
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "latency_ms": latency_ms,
                "context": context
            }

            # Log to file
            logger.info(json.dumps(log_data))

            # Log to database
            self._log_to_db(endpoint, model, prompt_tokens, completion_tokens, latency_ms, context)

        except Exception as e:
            logger.error(f"Error logging API call: {e}")

    def _log_to_db(self, endpoint: str, model: str, prompt_tokens: int, completion_tokens: int, latency_ms: int, context: Dict[str, Any]):
        """
        Logs the API call details to the PostgreSQL database.

        Args:
            endpoint: The API endpoint.
            model: The model used for the API call.
            prompt_tokens: The number of prompt tokens.
            completion_tokens: The number of completion tokens.
            latency_ms: The latency of the API call in milliseconds.
            context: A dictionary containing additional context information.
        """
        try:
            insert_query = sql.SQL("""
                INSERT INTO api_logs (endpoint, model, prompt_tokens, completion_tokens, latency_ms, context)
                VALUES (%s, %s, %s, %s, %s, %s)
            """)
            self.cursor.execute(insert_query, (endpoint, model, prompt_tokens, completion_tokens, latency_ms, psycopg2.extras.Json(context)))
            self.conn.commit()
        except Exception as e:
            logger.error(f"Error inserting into database: {e}")
            self.conn.rollback()

    def close(self):
        """
        Closes the database connection.
        """
        if self.cursor:
            self.cursor.close()
        if self.conn:
            self.conn.close()


if __name__ == '__main__':
    # Example usage
    db_url = "postgresql://user:password@postgresql-genesis-u50607.vm.elestio.app:25432/database"  # Replace with your actual database URL
    api_logger = APILogger(db_url)

    try:
        api_logger.log_api_call(
            endpoint="/api/v1/generate",
            model="gpt-3.5-turbo",
            prompt_tokens=100,
            completion_tokens=50,
            latency_ms=250,
            context={"user_id": "user123", "request_id": "req456"}
        )
        print("API call logged successfully.")
    except Exception as e:
        print(f"Error: {e}")
    finally:
        api_logger.close()
