import time
import tracemalloc
import logging
from typing import Callable, Dict, Any, List, Optional
import statistics
import os
import json

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class PerformanceProfiler:
    """
    A class for profiling the performance of a given function or code block.
    It tracks metrics like latency, throughput, accuracy, and cost.
    It also includes memory usage profiling, token usage analytics, cost attribution,
    bottleneck detection, optimization recommendations, and historical trends.
    """

    def __init__(self, function_name: str = "Unnamed Function", profile_dir: str = "profiles"):
        """
        Initializes the PerformanceProfiler.

        Args:
            function_name (str, optional): A name for the function being profiled. Defaults to "Unnamed Function".
            profile_dir (str, optional): Directory to store historical profiles. Defaults to "profiles".
        """
        self.function_name = function_name
        self.metrics: Dict[str, List[float]] = {
            "query_processing_time": [],
            "retrieval_time": [],
            "generation_time": [],
            "queries_per_second": [],
            "concurrent_capacity": [],  # Placeholder - requires external monitoring
            "response_quality": [],      # Placeholder - requires evaluation function
            "validation_pass_rate": [],   # Placeholder - requires validation function
            "token_usage": [],            # Placeholder - requires API integration
            "api_costs": [],              # Placeholder - requires API integration
            "storage_costs": [],           # Placeholder - requires storage monitoring
            "memory_usage": [],          # Memory usage in MB
        }
        self.start_time = None
        self.end_time = None
        self.query_count = 0
        self.errors = 0  # Track errors during execution
        self.peak_memory = 0 # Track peak memory usage
        self.profile_dir = profile_dir
        os.makedirs(self.profile_dir, exist_ok=True)  # Create directory if it doesn't exist
        self.historical_profiles: List[Dict[str, Any]] = []
        self.load_historical_profiles()  # Load existing profiles on initialization


    def start(self):
        """Starts the profiling process."""
        self.start_time = time.time()
        tracemalloc.start()  # Start memory tracking
        logging.info(f"Profiling started for: {self.function_name}")

    def stop(self):
        """Stops the profiling process."""
        self.end_time = time.time()
        current, peak = tracemalloc.get_traced_memory()
        self.peak_memory = peak / (1024 * 1024)  # Convert to MB
        self.metrics["memory_usage"].append(self.peak_memory)
        tracemalloc.stop()   # Stop memory tracking
        logging.info(f"Profiling stopped for: {self.function_name}")

    def track_latency(self, query_processing_time: float = None, retrieval_time: float = None, generation_time: float = None):
        """Tracks latency metrics."""
        if query_processing_time is not None:
            self.metrics["query_processing_time"].append(query_processing_time)
        if retrieval_time is not None:
            self.metrics["retrieval_time"].append(retrieval_time)
        if generation_time is not None:
            self.metrics["generation_time"].append(generation_time)

    def track_throughput(self):
        """Tracks throughput (queries per second).  Call this AFTER processing a batch of queries."""
        if self.start_time and self.end_time:
            elapsed_time = self.end_time - self.start_time
            if elapsed_time > 0:  # Avoid division by zero
                qps = self.query_count / elapsed_time
                self.metrics["queries_per_second"].append(qps)
            else:
                self.metrics["queries_per_second"].append(0) # Add 0 if time is zero.

    def track_accuracy(self, response_quality: float = None, validation_pass_rate: float = None):
        """Tracks accuracy metrics."""
        if response_quality is not None:
            self.metrics["response_quality"].append(response_quality)
        if validation_pass_rate is not None:
            self.metrics["validation_pass_rate"].append(validation_pass_rate)

    def track_cost(self, token_usage: int = None, api_costs: float = None, storage_costs: float = None):
        """Tracks cost metrics."""
        if token_usage is not None:
            self.metrics["token_usage"].append(token_usage)
        if api_costs is not None:
            self.metrics["api_costs"].append(api_costs)
        if storage_costs is not None:
            self.metrics["storage_costs"].append(storage_costs)

    def increment_query_count(self, count: int = 1):
        """Increments the query count."""
        self.query_count += count

    def increment_error_count(self, count: int = 1):
        """Increments the error count."""
        self.errors += count

    def profile(self, func: Callable, *args, **kwargs) -> Any:
        """
        Profiles the execution of a given function.

        Args:
            func (Callable): The function to be profiled.
            *args: Positional arguments to be passed to the function.
            **kwargs: Keyword arguments to be passed to the function.

        Returns:
            Any: The return value of the function being profiled.
        """
        self.start()
        try:
            result = func(*args, **kwargs)
            self.stop()
            return result
        except Exception as e:
            self.stop()
            self.increment_error_count()
            logging.error(f"Error during profiling of {self.function_name}: {e}")
            raise  # Re-raise the exception after logging

    def get_average_metrics(self) -> Dict[str, float]:
        """Calculates and returns the average of each metric."""
        average_metrics: Dict[str, float] = {}
        for metric, values in self.metrics.items():
            if values:
                average_metrics[metric] = statistics.mean(values)
            else:
                average_metrics[metric] = 0  # Or None, depending on your preference.
        return average_metrics

    def get_summary(self) -> Dict[str, Any]:
        """
        Returns a summary of the profiling results.

        Returns:
            Dict[str, Any]: A dictionary containing the profiling summary.
        """
        summary = {
            "function_name": self.function_name,
            "start_time": self.start_time,
            "end_time": self.end_time,
            "duration": self.end_time - self.start_time if self.start_time and self.end_time else None,
            "query_count": self.query_count,
            "error_count": self.errors,
            "peak_memory_usage_mb": self.peak_memory,
            "average_metrics": self.get_average_metrics(),
        }
        return summary

    def generate_optimization_recommendations(self) -> List[str]:
        """
        Generates optimization recommendations based on the profiling results.

        Returns:
            List[str]: A list of optimization recommendations.
        """
        recommendations: List[str] = []
        average_metrics = self.get_average_metrics()

        # Latency Optimization
        if average_metrics.get("query_processing_time", 0) > 1.0:
            recommendations.append("Optimize query logic and database queries for faster processing.")
        if average_metrics.get("retrieval_time", 0) > 0.5:
            recommendations.append("Improve data retrieval efficiency, consider caching or indexing.")
        if average_metrics.get("generation_time", 0) > 2.0:
            recommendations.append("Optimize the generation algorithm, consider using more efficient data structures or algorithms.")

        # Throughput Optimization
        if average_metrics.get("queries_per_second", 0) < 10:
            recommendations.append("Improve concurrency and parallelism to increase queries per second.")
            recommendations.append("Consider load balancing to distribute requests across multiple servers.")

        # Accuracy Optimization (Assuming you have a baseline to compare against)
        if average_metrics.get("response_quality", 1) < 0.9:  # Assuming 1 is perfect quality
            recommendations.append("Review and improve the response generation logic to increase response quality.")
        if average_metrics.get("validation_pass_rate", 1) < 0.95: # Assuming 1 is perfect validation
            recommendations.append("Improve data validation logic to increase the validation pass rate.")

        # Cost Optimization
        if average_metrics.get("token_usage", 0) > 1000:  # Example threshold
            recommendations.append("Reduce token usage by optimizing prompts, shortening responses, or using more efficient models.")
        if average_metrics.get("api_costs", 0) > 1.0:  # Example threshold
            recommendations.append("Optimize API usage, consider caching responses or using cheaper API tiers.")
        if average_metrics.get("storage_costs", 0) > 0.5:  # Example threshold
            recommendations.append("Optimize storage usage by compressing data, deleting unnecessary data, or using cheaper storage options.")

        # Memory Optimization
        if average_metrics.get("memory_usage", 0) > 500:  # Example threshold (500MB)
            recommendations.append("Optimize memory usage by using more efficient data structures, reducing object allocations, or using garbage collection more effectively.")
            recommendations.append("Consider using memory profiling tools to identify memory leaks or excessive memory consumption.")


        if not recommendations:
            recommendations.append("No specific optimization recommendations at this time. Performance appears to be within acceptable limits.")

        return recommendations

    def detect_bottlenecks(self) -> List[str]:
        """Detects potential bottlenecks based on profiling data."""
        bottlenecks: List[str] = []
        average_metrics = self.get_average_metrics()

        if average_metrics.get("query_processing_time", 0) > max(average_metrics.get("retrieval_time", 0), average_metrics.get("generation_time", 0)):
            bottlenecks.append("Query processing is a potential bottleneck.")
        elif average_metrics.get("retrieval_time", 0) > max(average_metrics.get("query_processing_time", 0), average_metrics.get("generation_time", 0)):
            bottlenecks.append("Data retrieval is a potential bottleneck.")
        elif average_metrics.get("generation_time", 0) > max(average_metrics.get("query_processing_time", 0), average_metrics.get("retrieval_time", 0)):
            bottlenecks.append("Response generation is a potential bottleneck.")

        if average_metrics.get("memory_usage", 0) > 800:  #Example threshold
            bottlenecks.append("High memory usage is a potential bottleneck.")

        if average_metrics.get("errors", 0) > 0:
            bottlenecks.append("Errors are impacting performance; investigate error logs.")

        return bottlenecks

    def print_summary(self):
        """Prints the profiling summary to the console."""
        summary = self.get_summary()
        print("Profiling Summary:")
        for key, value in summary.items():
            print(f"- {key}: {value}")

        recommendations = self.generate_optimization_recommendations()
        print("\nOptimization Recommendations:")
        for recommendation in recommendations:
            print(f"- {recommendation}")

        bottlenecks = self.detect_bottlenecks()
        print("\nPotential Bottlenecks:")
        if bottlenecks:
            for bottleneck in bottlenecks:
                print(f"- {bottleneck}")
        else:
            print("- No bottlenecks detected.")

    def save_profile(self):
        """Saves the current profile to the historical profiles and to disk."""
        summary = self.get_summary()
        self.historical_profiles.append(summary)
        self.save_historical_profiles()
        logging.info(f"Profile saved for: {self.function_name}")

    def load_historical_profiles(self):
        """Loads historical profiles from disk."""
        file_path = os.path.join(self.profile_dir, f"{self.function_name}_profiles.json")
        try:
            with open(file_path, "r") as f:
                self.historical_profiles = json.load(f)
            logging.info(f"Historical profiles loaded for: {self.function_name}")
        except FileNotFoundError:
            logging.info(f"No historical profiles found for: {self.function_name}")
            self.historical_profiles = []
        except json.JSONDecodeError:
            logging.warning(f"Error decoding JSON in {file_path}.  Starting with empty history.")
            self.historical_profiles = []


    def save_historical_profiles(self):
        """Saves historical profiles to disk."""
        file_path = os.path.join(self.profile_dir, f"{self.function_name}_profiles.json")
        with open(file_path, "w") as f:
            json.dump(self.historical_profiles, f, indent=4)

    def get_historical_trends(self, metric: str) -> Dict[str, List[Any]]:
        """
        Retrieves historical trends for a given metric.

        Args:
            metric (str): The metric to retrieve trends for.

        Returns:
            Dict[str, List[Any]]: A dictionary containing the historical trends.
        """
        trends: List[Any] = []
        timestamps: List[float] = []  # Store timestamps for x-axis plotting

        for profile in self.historical_profiles:
            if metric in profile["average_metrics"]:
                trends.append(profile["average_metrics"][metric])
                timestamps.append(profile["start_time"]) # Use start_time as timestamp
            else:
                trends.append(None)  # Handle missing data

        return {"timestamps": timestamps, "values": trends}

    def print_historical_trends(self, metric: str):
        """Prints historical trends for a given metric."""
        trends = self.get_historical_trends(metric)
        print(f"\nHistorical Trends for {metric}:")
        for i in range(len(trends["values"])):
            timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(trends["timestamps"][i]))
            print(f"- {timestamp}: {trends['values'][i]}")


if __name__ == '__main__':
    # Example usage:

    def my_function(n: int) -> int:
        """A simple function to profile."""
        time.sleep(0.1)  # Simulate some work
        result = sum(i for i in range(n))
        return result

    profiler = PerformanceProfiler(function_name="my_function")

    # Profile a single call
    result = profiler.profile(my_function, 1000)
    print(f"Result of my_function: {result}")
    profiler.print_summary()
    profiler.save_profile()

    # Profile multiple calls to simulate throughput
    num_calls = 10
    profiler.start()
    for _ in range(num_calls):
        my_function(500) # Call the function
        profiler.increment_query_count()  # Track each execution as a query
    profiler.stop()
    profiler.track_throughput()

    profiler.print_summary()  # Print the summary after all calls
    profiler.save_profile() # Save the profile after the calls

    # Example of tracking latency components individually
    def another_function():
        start_time = time.time()
        time.sleep(0.05)
        retrieval_time = time.time() - start_time

        start_time = time.time()
        time.sleep(0.03)
        generation_time = time.time() - start_time

        profiler.track_latency(retrieval_time=retrieval_time, generation_time=generation_time)
        return "Done"

    profiler.profile(another_function)
    profiler.print_summary()
    profiler.save_profile()

    # Example with errors
    def function_with_error():
        raise ValueError("Simulated error")

    try:
        profiler.profile(function_with_error)
    except ValueError:
        pass # Handle the exception

    profiler.print_summary()
    profiler.save_profile()

    # Example of printing historical trends
    profiler.print_historical_trends("queries_per_second")
    profiler.print_historical_trends("memory_usage")
