import datetime
import logging
import os
from typing import List, Dict, Tuple

import psycopg2
import redis
from qdrant_client import QdrantClient

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class CostOptimizer:
    """
    Analyzes cost data and provides recommendations for cost optimization.
    """

    def __init__(self, db_host: str, db_port: int, db_name: str, db_user: str, db_pass: str,
                 redis_host: str, redis_port: int,
                 qdrant_host: str, qdrant_port: int,
                 ollama_host: str, ollama_port: int):
        """
        Initializes the CostOptimizer with database and service connection details.

        Args:
            db_host: PostgreSQL host.
            db_port: PostgreSQL port.
            db_name: PostgreSQL database name.
            db_user: PostgreSQL username.
            db_pass: PostgreSQL password.
            redis_host: Redis host.
            redis_port: Redis port.
            qdrant_host: Qdrant host.
            qdrant_port: Qdrant port.
            ollama_host: Ollama host.
            ollama_port: Ollama port.
        """
        self.db_host = db_host
        self.db_port = db_port
        self.db_name = db_name
        self.db_user = db_user
        self.db_pass = db_pass
        self.redis_host = redis_host
        self.redis_port = redis_port
        self.qdrant_host = qdrant_host
        self.qdrant_port = qdrant_port
        self.ollama_host = ollama_host
        self.ollama_port = ollama_port
        self.db_connection = None
        self.redis_client = None
        self.qdrant_client = None

    def connect_to_database(self) -> None:
        """
        Establishes a connection to the PostgreSQL database.
        """
        try:
            self.db_connection = psycopg2.connect(
                host=self.db_host,
                port=self.db_port,
                database=self.db_name,
                user=self.db_user,
                password=self.db_pass
            )
            logging.info("Connected to PostgreSQL database.")
        except psycopg2.Error as e:
            logging.error(f"Error connecting to PostgreSQL: {e}")
            raise

    def connect_to_redis(self) -> None:
        """
        Establishes a connection to the Redis server.
        """
        try:
            self.redis_client = redis.Redis(host=self.redis_host, port=self.redis_port, decode_responses=True)
            self.redis_client.ping()  # Check the connection
            logging.info("Connected to Redis.")
        except redis.exceptions.ConnectionError as e:
            logging.error(f"Error connecting to Redis: {e}")
            raise

    def connect_to_qdrant(self) -> None:
        """
        Establishes a connection to the Qdrant vector database.
        """
        try:
            self.qdrant_client = QdrantClient(host=self.qdrant_host, port=self.qdrant_port)
            self.qdrant_client.get_telemetry_data() # Check the connection
            logging.info("Connected to Qdrant.")
        except Exception as e:
            logging.error(f"Error connecting to Qdrant: {e}")
            raise

    def get_cost_data(self, days: int = 7) -> List[Dict]:
        """
        Retrieves cost data from the database for the specified number of days.

        Args:
            days: The number of days to retrieve cost data for.

        Returns:
            A list of dictionaries, where each dictionary represents a cost record.
        """
        if not self.db_connection:
            raise Exception("Database connection not established.")

        end_date = datetime.datetime.now()
        start_date = end_date - datetime.timedelta(days=days)

        try:
            cursor = self.db_connection.cursor()
            query = """
                SELECT timestamp, model_name, prompt_tokens, completion_tokens, cost
                FROM usage_data
                WHERE timestamp BETWEEN %s AND %s
            """
            cursor.execute(query, (start_date, end_date))
            results = cursor.fetchall()

            cost_data = []
            for row in results:
                cost_data.append({
                    'timestamp': row[0],
                    'model_name': row[1],
                    'prompt_tokens': row[2],
                    'completion_tokens': row[3],
                    'cost': row[4]
                })

            logging.info(f"Retrieved {len(cost_data)} cost records from the database.")
            return cost_data
        except psycopg2.Error as e:
            logging.error(f"Error fetching cost data: {e}")
            raise

    def analyze_cost_data(self, cost_data: List[Dict]) -> Dict:
        """
        Analyzes cost data to identify expensive patterns, underused models, and inefficient prompts.

        Args:
            cost_data: A list of dictionaries representing cost records.

        Returns:
            A dictionary containing analysis results.
        """
        model_costs = {}
        total_cost = 0

        for record in cost_data:
            model_name = record['model_name']
            cost = record['cost']
            total_cost += cost

            if model_name not in model_costs:
                model_costs[model_name] = 0
            model_costs[model_name] += cost

        expensive_models = sorted(model_costs.items(), key=lambda item: item[1], reverse=True)[:3]

        # Simple analysis for underused models (can be improved with more sophisticated logic)
        underused_threshold = total_cost / len(model_costs) if model_costs else 0
        underused_models = [model for model, cost in model_costs.items() if cost < underused_threshold]

        # Placeholder for prompt efficiency analysis (requires more complex logic)
        inefficient_prompts = [] # Implement prompt analysis here

        analysis_results = {
            'expensive_models': expensive_models,
            'underused_models': underused_models,
            'inefficient_prompts': inefficient_prompts,
            'total_cost': total_cost
        }

        logging.info("Cost data analysis completed.")
        return analysis_results

    def generate_recommendations(self, analysis_results: Dict) -> List[str]:
        """
        Generates cost optimization recommendations based on the analysis results.

        Args:
            analysis_results: A dictionary containing the analysis results.

        Returns:
            A list of strings, where each string is a cost optimization recommendation.
        """
        recommendations = []

        for model, cost in analysis_results['expensive_models']:
            recommendations.append(f"Consider reducing usage of model '{model}' which cost ${cost:.2f}.")

        for model in analysis_results['underused_models']:
            recommendations.append(f"Evaluate the necessity of model '{model}' as it is underused.")

        if analysis_results['inefficient_prompts']:
            recommendations.append("Review and optimize prompts for better efficiency.")

        logging.info("Cost optimization recommendations generated.")
        return recommendations

    def calculate_potential_savings(self, analysis_results: Dict) -> float:
        """
        Calculates potential savings based on the analysis results.

        Args:
            analysis_results: A dictionary containing the analysis results.

        Returns:
            The potential savings amount.
        """
        # This is a placeholder, implement actual savings calculation logic based on recommendations
        # For example, estimate savings from reducing usage of expensive models
        potential_savings = 0.0
        for model, cost in analysis_results['expensive_models']:
            potential_savings += cost * 0.1 # Assume 10% savings by optimizing usage

        logging.info(f"Potential savings calculated: ${potential_savings:.2f}")
        return potential_savings

    def generate_markdown_report(self, analysis_results: Dict, recommendations: List[str], potential_savings: float) -> str:
        """
        Generates a markdown report summarizing the cost optimization analysis and recommendations.

        Args:
            analysis_results: A dictionary containing the analysis results.
            recommendations: A list of cost optimization recommendations.
            potential_savings: The potential savings amount.

        Returns:
            A string containing the markdown report.
        """
        report = f"""
        # Cost Optimization Report

        ## Analysis Results

        - Total Cost: ${analysis_results['total_cost']:.2f}

        ### Expensive Models:
        """
        for model, cost in analysis_results['expensive_models']:
            report += f"  - {model}: ${cost:.2f}\n"

        report += """
        ### Underused Models:
        """
        for model in analysis_results['underused_models']:
            report += f"  - {model}\n"

        report += """
        ### Inefficient Prompts:
        """
        if analysis_results['inefficient_prompts']:
            for prompt in analysis_results['inefficient_prompts']:
                report += f"  - {prompt}\n"
        else:
            report += "  - No inefficient prompts identified.\n"

        report += """
        ## Recommendations

        """
        for recommendation in recommendations:
            report += f"- {recommendation}\n"

        report += f"""
        ## Potential Savings

        Estimated Potential Savings: ${potential_savings:.2f}
        """

        logging.info("Markdown report generated.")
        return report

    def run_optimization(self, days: int = 7) -> str:
        """
        Runs the cost optimization process.

        Args:
            days: The number of days to analyze cost data for.

        Returns:
            A string containing the markdown report.
        """
        try:
            self.connect_to_database()
            self.connect_to_redis()
            self.connect_to_qdrant()

            cost_data = self.get_cost_data(days)
            analysis_results = self.analyze_cost_data(cost_data)
            recommendations = self.generate_recommendations(analysis_results)
            potential_savings = self.calculate_potential_savings(analysis_results)
            report = self.generate_markdown_report(analysis_results, recommendations, potential_savings)

            return report

        except Exception as e:
            logging.error(f"Error running cost optimization: {e}")
            return f"Error running cost optimization: {e}"
        finally:
            if self.db_connection:
                self.db_connection.close()
                logging.info("Database connection closed.")


if __name__ == '__main__':
    # Example Usage (replace with your actual configuration)
    db_host = os.environ.get("DB_HOST", "postgresql-genesis-u50607.vm.elestio.app")
    db_port = int(os.environ.get("DB_PORT", "25432"))
    db_name = os.environ.get("DB_NAME", "your_db_name")
    db_user = os.environ.get("DB_USER", "your_db_user")
    db_pass = os.environ.get("DB_PASS", "your_db_pass")

    redis_host = os.environ.get("REDIS_HOST", "redis-genesis-u50607.vm.elestio.app")
    redis_port = int(os.environ.get("REDIS_PORT", "26379"))

    qdrant_host = os.environ.get("QDRANT_HOST", "qdrant-b3knu-u50607.vm.elestio.app")
    qdrant_port = int(os.environ.get("QDRANT_PORT", "6333"))

    ollama_host = os.environ.get("OLLAMA_HOST", "localhost")
    ollama_port = int(os.environ.get("OLLAMA_PORT", "23405"))


    optimizer = CostOptimizer(db_host, db_port, db_name, db_user, db_pass,
                                redis_host, redis_port,
                                qdrant_host, qdrant_port,
                                ollama_host, ollama_port)
    report = optimizer.run_optimization()
    print(report)
