from typing import List, Tuple, Dict

class CostOptimizer:
    def __init__(self, budget_threshold: float):
        self.budget_threshold = budget_threshold
        self.stories: List[Tuple[str, int, float]] = []
        self.total_cost = 0.0
        
    def record_story(self, story_id: str, agent_count: int, story_cost: float):
        """Records a story's cost and updates total cost"""
        self.stories.append((story_id, agent_count, story_cost))
        self.total_cost += story_cost
        
    def get_cost_per_story(self) -> float:
        """Returns average cost per story processed"""
        if not self.stories:
            return 0.0
        return self.total_cost / len(self.stories)
    
    def recommend_agent_count(self) -> int:
        """Recommends optimal agent count based on historical cost data"""
        if not self.stories:
            return 1
        
        agent_costs: Dict[int, List[float]] = {}
        for _, agent_count, cost in self.stories:
            if agent_count not in agent_costs:
                agent_costs[agent_count] = []
            agent_costs[agent_count].append(cost)
        
        # Calculate average cost per story for each agent count
        avg_costs = {}
        for agent_count, costs in agent_costs.items():
            avg_costs[agent_count] = sum(costs) / len(costs)
        
        # Return agent count with lowest average cost
        return min(avg_costs.items(), key=lambda x: x[1])[0]
    
    def should_pause(self) -> bool:
        """Checks if total cost exceeds budget threshold"""
        return self.total_cost > self.budget_threshold
    
    def get_current_metrics(self) -> Dict[str, float]:
        """Returns current cost metrics for monitoring"""
        return {
            "total_cost": self.total_cost,
            "cost_per_story": self.get_cost_per_story(),
            "recommended_agents": self.recommend_agent_count()
        }