import time
import threading
from prometheus_client import start_http_server, Counter, Histogram, Gauge
from collections import deque

class RWLMetricsCollector:
    """
    Collects and exposes metrics related to RWL (Reasoning, Writing, Learning) executions.
    Uses Prometheus client library to expose an HTTP endpoint for metrics scraping.
    """

    def __init__(self, port=8000, moving_average_window=100):
        """
        Initializes the metrics collector and starts the HTTP server.

        Args:
            port (int): The port number to expose the metrics endpoint.
            moving_average_window (int): The number of data points to use for calculating moving averages.
        """

        self.total_stories = Counter('rwl_total_stories', 'Total number of RWL stories executed')
        self.success_stories = Counter('rwl_success_stories', 'Number of successfully executed RWL stories')
        self.failed_stories = Counter('rwl_failed_stories', 'Number of failed RWL stories')

        # Buckets for story duration (in seconds) - tailored for expected durations
        self.story_duration = Histogram('rwl_story_duration_seconds', 'Duration of RWL story execution',
                                        buckets=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0, float('inf')),
                                        labelnames=('model'))  # Adding model label

        self.current_cost = Gauge('rwl_current_cost_usd', 'Current cost of RWL executions', labelnames=('model')) # Adding model label

        self.story_durations = {} #Store story durations by model
        self.story_costs = {} # Store costs by model
        self.moving_average_window = moving_average_window
        self.lock = threading.Lock() #Protect moving average window

        print(f"Starting RWL Metrics Server on port {port}")
        start_http_server(port)  # Expose metrics on /metrics endpoint

    def record_story(self, duration, success, model, cost):
        """
        Records the metrics for a single RWL story execution.

        Args:
            duration (float): The duration of the story execution in seconds.
            success (bool): Whether the story execution was successful.
            model (str): The model used for the story execution (e.g., 'gpt-3.5-turbo', 'gpt-4').
            cost (float): The cost of the story execution in USD.
        """
        with self.lock: # Ensure thread safety when updating metrics
            self.total_stories.inc()
            if success:
                self.success_stories.inc()
            else:
                self.failed_stories.inc()

            self.story_duration.labels(model=model).observe(duration) # Pass the model label
            self.current_cost.labels(model=model).inc(cost) # Increment the cost for the model

            # Update moving averages for durations
            if model not in self.story_durations:
                self.story_durations[model] = deque(maxlen=self.moving_average_window)

            self.story_durations[model].append(duration)

            #Update moving averages for costs
            if model not in self.story_costs:
                self.story_costs[model] = deque(maxlen=self.moving_average_window)

            self.story_costs[model].append(cost)

    def get_moving_average_duration(self, model):
        """
        Calculates the moving average duration for a given model.

        Args:
            model (str): The model to calculate the moving average for.

        Returns:
            float: The moving average duration in seconds, or None if no data is available.
        """
        with self.lock:
            if model in self.story_durations and self.story_durations[model]:
                return sum(self.story_durations[model]) / len(self.story_durations[model])
            else:
                return None

    def get_moving_average_cost(self, model):
        """
        Calculates the moving average cost for a given model

        Args:
            model (str): The model to calculate the moving average for.

        Returns:
            float: The moving average cost in USD, or None if no data is available
        """
        with self.lock:
            if model in self.story_costs and self.story_costs[model]:
                return sum(self.story_costs[model]) / len(self.story_costs[model])
            else:
                return None

    def reset_cost(self, model):
        """
        Resets the current cost for a specific model.

        Args:
            model (str): The model to reset the cost for.
        """
        self.current_cost.labels(model=model).set(0)

if __name__ == '__main__':
    # Example Usage:
    metrics = RWLMetricsCollector(port=8001) # Start on a different port for testing
    import random

    # Simulate some story executions
    models = ['gpt-3.5-turbo', 'gpt-4']

    try:
        while True:
            model = random.choice(models)
            duration = random.uniform(0.2, 5)  # Duration between 0.2 and 5 seconds
            success = random.random() > 0.1  # 90% success rate
            cost = random.uniform(0.001, 0.01) # Cost between $0.001 and $0.01

            metrics.record_story(duration, success, model, cost)

            avg_duration = metrics.get_moving_average_duration(model)
            avg_cost = metrics.get_moving_average_cost(model)
            print(f"Recorded story: model={model}, duration={duration:.2f}s, success={success}, cost=${cost:.4f}")
            if avg_duration:
                print(f"  Moving average duration ({model}): {avg_duration:.2f}s")
            if avg_cost:
                print(f"  Moving average cost ({model}): ${avg_cost:.4f}")

            time.sleep(random.uniform(0.1, 1)) # Simulate variable load


    except KeyboardInterrupt:
        print("Exiting simulation...")