#!/usr/bin/env python3
"""
Genesis Cost Tracker - Real-time API usage and cost monitoring.

Tracks costs across all Genesis features and API providers.
"""

"""
RULE 7 COMPLIANT: Uses Elestio PostgreSQL via genesis_db module.
"""
import json
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field, asdict
import threading
import logging

# RULE 7: Use PostgreSQL via genesis_db (no sqlite3)
from core.genesis_db import connection, ensure_table
from psycopg2.extras import RealDictCursor

logger = logging.getLogger(__name__)

# Cost data (per 1M tokens)
API_PRICING = {
    "claude": {
        "opus-4-5": {"input": 5.00, "output": 25.00},
        "sonnet-4-5": {"input": 3.00, "output": 15.00},
        "haiku-4-5": {"input": 1.00, "output": 5.00},
        "sonnet-4": {"input": 3.00, "output": 15.00},
        "opus-4": {"input": 15.00, "output": 75.00},
    },
    "gemini": {
        "2.0-flash": {"input": 0.10, "output": 0.40},
        "2.0-flash-thinking": {"input": 0.70, "output": 3.00},
        "2.5-pro": {"input": 1.25, "output": 10.00},
        "2.5-pro-preview": {"input": 2.50, "output": 15.00},
        "3-pro": {"input": 2.00, "output": 12.00},
        "3-pro-preview": {"input": 4.00, "output": 18.00},
    },
    "supadata": {
        "youtube-transcript": {"per_call": 0.009},  # $9/1000 calls
    },
    "openai": {
        "gpt-4o": {"input": 2.50, "output": 10.00},
        "gpt-4o-mini": {"input": 0.15, "output": 0.60},
        "whisper": {"per_minute": 0.006},  # $0.006/minute
    }
}

# Feature to API mapping
FEATURE_API_MAP = {
    "claude_code_session": {
        "provider": "claude",
        "model": "opus-4-5",
        "description": "Claude Code terminal sessions",
        "typical_tokens": {"input": 50000, "output": 10000}
    },
    "aiva_chat": {
        "provider": "claude",
        "model": "sonnet-4-5",
        "description": "AIVA chat interface",
        "typical_tokens": {"input": 5000, "output": 2000}
    },
    "youtube_transcript": {
        "provider": "supadata",
        "model": "youtube-transcript",
        "description": "YouTube transcript extraction via Supadata",
        "typical_calls": 1
    },
    "knowledge_ingestion": {
        "provider": "claude",
        "model": "haiku-4-5",
        "description": "Knowledge graph processing",
        "typical_tokens": {"input": 20000, "output": 5000}
    },
    "gemini_research": {
        "provider": "gemini",
        "model": "2.5-pro",
        "description": "Deep research via Gemini",
        "typical_tokens": {"input": 100000, "output": 20000}
    },
    "multi_agent_orchestration": {
        "provider": "claude",
        "model": "opus-4-5",
        "description": "Multi-agent task coordination",
        "typical_tokens": {"input": 30000, "output": 15000}
    },
    "voice_transcription": {
        "provider": "openai",
        "model": "whisper",
        "description": "Voice-to-text via Whisper API",
        "typical_minutes": 5
    },
    "patent_analysis": {
        "provider": "claude",
        "model": "opus-4-5",
        "description": "Patent document analysis",
        "typical_tokens": {"input": 80000, "output": 20000}
    },
    "code_review": {
        "provider": "claude",
        "model": "sonnet-4-5",
        "description": "Code review and suggestions",
        "typical_tokens": {"input": 15000, "output": 5000}
    },
    "web_search": {
        "provider": "claude",
        "model": "sonnet-4-5",
        "description": "Web search and analysis",
        "typical_tokens": {"input": 10000, "output": 3000}
    }
}


@dataclass
class UsageRecord:
    """Single API usage record."""
    timestamp: str
    feature: str
    provider: str
    model: str
    input_tokens: int = 0
    output_tokens: int = 0
    api_calls: int = 0
    minutes: float = 0.0
    cost_usd: float = 0.0
    metadata: Dict[str, Any] = field(default_factory=dict)


class GenesisCostTracker:
    """Track and analyze Genesis system costs."""

    def __init__(self, db_path: Optional[str] = None):
        # RULE 7: db_path is ignored - uses PostgreSQL via genesis_db
        self.pricing = API_PRICING
        self.feature_map = FEATURE_API_MAP
        self._init_db()
        self._lock = threading.Lock()

    def _init_db(self):
        """Initialize PostgreSQL database for cost tracking (RULE 7)."""
        ensure_table('cost_usage', '''
            id SERIAL PRIMARY KEY,
            timestamp TIMESTAMPTZ NOT NULL,
            feature TEXT NOT NULL,
            provider TEXT NOT NULL,
            model TEXT NOT NULL,
            input_tokens INTEGER DEFAULT 0,
            output_tokens INTEGER DEFAULT 0,
            api_calls INTEGER DEFAULT 0,
            minutes REAL DEFAULT 0.0,
            cost_usd REAL NOT NULL,
            metadata JSONB
        ''')
        # Create indexes
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("CREATE INDEX IF NOT EXISTS idx_cost_timestamp ON cost_usage(timestamp)")
                cursor.execute("CREATE INDEX IF NOT EXISTS idx_cost_feature ON cost_usage(feature)")
                cursor.execute("CREATE INDEX IF NOT EXISTS idx_cost_provider ON cost_usage(provider)")
        except Exception as e:
            logger.warning(f"Index creation warning: {e}")

    def calculate_cost(
        self,
        provider: str,
        model: str,
        input_tokens: int = 0,
        output_tokens: int = 0,
        api_calls: int = 0,
        minutes: float = 0.0
    ) -> float:
        """Calculate cost for a given API usage."""
        if provider not in self.pricing:
            return 0.0

        model_pricing = self.pricing[provider].get(model, {})

        cost = 0.0

        # Token-based pricing
        if "input" in model_pricing and input_tokens > 0:
            cost += (input_tokens / 1_000_000) * model_pricing["input"]
        if "output" in model_pricing and output_tokens > 0:
            cost += (output_tokens / 1_000_000) * model_pricing["output"]

        # Per-call pricing
        if "per_call" in model_pricing and api_calls > 0:
            cost += api_calls * model_pricing["per_call"]

        # Per-minute pricing
        if "per_minute" in model_pricing and minutes > 0:
            cost += minutes * model_pricing["per_minute"]

        return round(cost, 6)

    def record_usage(
        self,
        feature: str,
        input_tokens: int = 0,
        output_tokens: int = 0,
        api_calls: int = 0,
        minutes: float = 0.0,
        metadata: Optional[Dict] = None
    ) -> UsageRecord:
        """Record a usage event."""
        feature_info = self.feature_map.get(feature, {
            "provider": "unknown",
            "model": "unknown"
        })

        provider = feature_info.get("provider", "unknown")
        model = feature_info.get("model", "unknown")

        cost = self.calculate_cost(
            provider, model,
            input_tokens, output_tokens,
            api_calls, minutes
        )

        record = UsageRecord(
            timestamp=datetime.now().isoformat(),
            feature=feature,
            provider=provider,
            model=model,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            api_calls=api_calls,
            minutes=minutes,
            cost_usd=cost,
            metadata=metadata or {}
        )

        with self._lock:
            try:
                with connection() as conn:
                    cursor = conn.cursor()
                    cursor.execute("""
                        INSERT INTO cost_usage
                        (timestamp, feature, provider, model, input_tokens, output_tokens,
                         api_calls, minutes, cost_usd, metadata)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    """, (
                        record.timestamp, record.feature, record.provider, record.model,
                        record.input_tokens, record.output_tokens, record.api_calls,
                        record.minutes, record.cost_usd, json.dumps(record.metadata)
                    ))
            except Exception as e:
                logger.warning(f"Failed to record usage: {e}")

        return record

    def get_feature_costs(self, days: int = 30) -> Dict[str, float]:
        """Get costs broken down by feature (RULE 7: PostgreSQL)."""
        cutoff = (datetime.now() - timedelta(days=days)).isoformat()

        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    SELECT feature, SUM(cost_usd) as total
                    FROM cost_usage
                    WHERE timestamp >= %s
                    GROUP BY feature
                    ORDER BY total DESC
                """, (cutoff,))
                results = {row[0]: round(row[1], 4) for row in cursor.fetchall()}
                return results
        except Exception as e:
            logger.warning(f"Failed to get feature costs: {e}")
            return {}

    def get_provider_costs(self, days: int = 30) -> Dict[str, float]:
        """Get costs broken down by provider (RULE 7: PostgreSQL)."""
        cutoff = (datetime.now() - timedelta(days=days)).isoformat()

        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    SELECT provider, SUM(cost_usd) as total
                    FROM cost_usage
                    WHERE timestamp >= %s
                    GROUP BY provider
                    ORDER BY total DESC
                """, (cutoff,))
                results = {row[0]: round(row[1], 4) for row in cursor.fetchall()}
                return results
        except Exception as e:
            logger.warning(f"Failed to get provider costs: {e}")
            return {}

    def get_daily_costs(self, days: int = 30) -> Dict[str, float]:
        """Get daily cost totals (RULE 7: PostgreSQL)."""
        cutoff = (datetime.now() - timedelta(days=days)).isoformat()

        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    SELECT DATE(timestamp) as day, SUM(cost_usd) as total
                    FROM cost_usage
                    WHERE timestamp >= %s
                    GROUP BY DATE(timestamp)
                    ORDER BY day DESC
                """, (cutoff,))
                results = {str(row[0]): round(row[1], 4) for row in cursor.fetchall()}
                return results
        except Exception as e:
            logger.warning(f"Failed to get daily costs: {e}")
            return {}

    def get_total_cost(self, days: int = 30) -> float:
        """Get total cost for time period (RULE 7: PostgreSQL)."""
        cutoff = (datetime.now() - timedelta(days=days)).isoformat()

        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    SELECT SUM(cost_usd) FROM cost_usage WHERE timestamp >= %s
                """, (cutoff,))
                result = cursor.fetchone()[0]
                return round(result or 0, 4)
        except Exception as e:
            logger.warning(f"Failed to get total cost: {e}")
            return 0.0

    def estimate_feature_cost(self, feature: str, count: int = 1) -> Dict[str, Any]:
        """Estimate cost for running a feature N times."""
        if feature not in self.feature_map:
            return {"error": f"Unknown feature: {feature}"}

        info = self.feature_map[feature]
        provider = info["provider"]
        model = info["model"]

        # Calculate single use cost
        if "typical_tokens" in info:
            single_cost = self.calculate_cost(
                provider, model,
                input_tokens=info["typical_tokens"]["input"],
                output_tokens=info["typical_tokens"]["output"]
            )
        elif "typical_calls" in info:
            single_cost = self.calculate_cost(
                provider, model,
                api_calls=info["typical_calls"]
            )
        elif "typical_minutes" in info:
            single_cost = self.calculate_cost(
                provider, model,
                minutes=info["typical_minutes"]
            )
        else:
            single_cost = 0.0

        return {
            "feature": feature,
            "description": info.get("description", ""),
            "provider": provider,
            "model": model,
            "single_use_cost": round(single_cost, 4),
            "total_cost": round(single_cost * count, 4),
            "count": count
        }

    def get_all_feature_estimates(self) -> Dict[str, Dict[str, Any]]:
        """Get cost estimates for all features."""
        return {
            feature: self.estimate_feature_cost(feature)
            for feature in self.feature_map
        }

    def generate_cost_report(self, days: int = 30) -> str:
        """Generate a comprehensive cost report."""
        report = []
        report.append("=" * 60)
        report.append("GENESIS COST REPORT")
        report.append(f"Period: Last {days} days")
        report.append(f"Generated: {datetime.now().isoformat()}")
        report.append("=" * 60)

        # Total
        total = self.get_total_cost(days)
        report.append(f"\n💰 TOTAL COST: ${total:.4f}")

        # By Provider
        report.append("\n📊 BY PROVIDER:")
        for provider, cost in self.get_provider_costs(days).items():
            report.append(f"  {provider}: ${cost:.4f}")

        # By Feature
        report.append("\n🔧 BY FEATURE:")
        for feature, cost in self.get_feature_costs(days).items():
            report.append(f"  {feature}: ${cost:.4f}")

        # Daily breakdown
        report.append("\n📅 DAILY COSTS (last 7 days):")
        daily = self.get_daily_costs(7)
        for day, cost in daily.items():
            report.append(f"  {day}: ${cost:.4f}")

        # Estimated costs
        report.append("\n📈 ESTIMATED PER-USE COSTS:")
        for feature, est in self.get_all_feature_estimates().items():
            report.append(f"  {feature}: ${est['single_use_cost']:.4f}")

        report.append("\n" + "=" * 60)
        return "\n".join(report)


# Singleton instance
_tracker_instance = None

def get_tracker() -> GenesisCostTracker:
    """Get the global cost tracker instance."""
    global _tracker_instance
    if _tracker_instance is None:
        _tracker_instance = GenesisCostTracker()
    return _tracker_instance


def record(feature: str, **kwargs) -> UsageRecord:
    """Convenience function to record usage."""
    return get_tracker().record_usage(feature, **kwargs)


if __name__ == "__main__":
    import sys

    tracker = GenesisCostTracker()

    if len(sys.argv) > 1:
        cmd = sys.argv[1]

        if cmd == "report":
            days = int(sys.argv[2]) if len(sys.argv) > 2 else 30
            print(tracker.generate_cost_report(days))

        elif cmd == "estimate":
            feature = sys.argv[2] if len(sys.argv) > 2 else None
            if feature:
                est = tracker.estimate_feature_cost(feature)
                print(json.dumps(est, indent=2))
            else:
                for f, e in tracker.get_all_feature_estimates().items():
                    print(f"{f}: ${e['single_use_cost']:.4f}")

        elif cmd == "total":
            days = int(sys.argv[2]) if len(sys.argv) > 2 else 30
            print(f"Total cost (last {days} days): ${tracker.get_total_cost(days):.4f}")

        elif cmd == "features":
            for feature, info in FEATURE_API_MAP.items():
                est = tracker.estimate_feature_cost(feature)
                print(f"\n{feature}:")
                print(f"  Description: {info.get('description', 'N/A')}")
                print(f"  Provider: {info['provider']}")
                print(f"  Model: {info['model']}")
                print(f"  Est. cost per use: ${est['single_use_cost']:.4f}")

    else:
        print("Genesis Cost Tracker")
        print("\nCommands:")
        print("  report [days]     - Generate cost report")
        print("  estimate [feature] - Estimate feature costs")
        print("  total [days]      - Get total cost")
        print("  features          - List all features with costs")
        print("\nEstimated per-use costs:")
        for f, e in tracker.get_all_feature_estimates().items():
            print(f"  {f}: ${e['single_use_cost']:.4f}")
