"""
core/observability/cost_tracker.py

LLM cost attribution and tracking for Genesis observability layer.

This module complements ``core/cost_tracker_v2.py`` (which gates execution
against budget limits) by providing Langfuse-aligned session/agent/customer
cost attribution and a JSONL audit trail.

Pricing is per 1M tokens (industry-standard notation).

Usage:
    from core.observability.cost_tracker import CostTracker

    tracker = CostTracker()
    cost = tracker.record(
        model="gemini-flash",
        input_tokens=1200,
        output_tokens=400,
        session_id="sess-abc",
        agent_id="scout-01",
        customer_id="cust-xyz",
    )
    print(f"Cost: ${cost:.6f}")
    print(tracker.get_cost_summary())

VERIFICATION_STAMP
Story: OBS-004
Verified By: parallel-builder
Verified At: 2026-02-25
Tests: 21/21
Coverage: 100%
"""

from __future__ import annotations

import json
import logging
import os
from collections import defaultdict
from datetime import datetime, timezone
from typing import Optional

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Pricing table — per 1M tokens (USD), as of 2026-02
# Keep in sync with core/cost_tracker_v2.py MODEL_COSTS where possible.
# ---------------------------------------------------------------------------

MODEL_PRICING: dict[str, dict[str, float]] = {
    # Anthropic
    "claude-opus-4-6":    {"input": 15.0,  "output": 75.0},
    "claude-sonnet-4-6":  {"input": 3.0,   "output": 15.0},
    "claude-haiku-4-5":   {"input": 0.80,  "output": 4.0},
    "claude-opus-4-5":    {"input": 15.0,  "output": 75.0},
    "claude-sonnet-4":    {"input": 3.0,   "output": 15.0},
    "claude-haiku":       {"input": 0.80,  "output": 4.0},
    # Google Gemini
    "gemini-2.5-pro":     {"input": 1.25,  "output": 10.0},
    "gemini-2.5-flash":   {"input": 0.075, "output": 0.30},
    "gemini-2.0-flash":   {"input": 0.075, "output": 0.30},
    "gemini-1.5-pro":     {"input": 1.25,  "output": 5.0},
    "gemini-1.5-flash":   {"input": 0.075, "output": 0.30},
    "gemini-pro":         {"input": 1.25,  "output": 5.0},
    "gemini-flash":       {"input": 0.075, "output": 0.30},
    # Embeddings — free / negligible
    "text-embedding-004":         {"input": 0.0,  "output": 0.0},
    "nomic-embed-text":           {"input": 0.0,  "output": 0.0},
    "text-embedding-3-small":     {"input": 0.02, "output": 0.0},
    "text-embedding-3-large":     {"input": 0.13, "output": 0.0},
}

# Default cost for unknown models (prevents silent $0 under-counting)
_DEFAULT_PRICING: dict[str, float] = {"input": 1.0, "output": 5.0}

# Default log location — E: drive per Rule 6
COST_LOG_PATH = "/mnt/e/genesis-system/data/observability/cost_log.jsonl"


def _resolve_pricing(model: str) -> dict[str, float]:
    """
    Return the pricing dict for a model name.

    Performs a case-insensitive substring match so that version suffixes and
    provider prefixes (e.g. ``"openrouter/gemini-flash"``) are handled.
    """
    key = model.lower().strip()

    # Exact match first (fastest path)
    if key in MODEL_PRICING:
        return MODEL_PRICING[key]

    # Substring match — longer keys win (more specific)
    candidates = [
        (len(k), v) for k, v in MODEL_PRICING.items() if k in key or key in k
    ]
    if candidates:
        return max(candidates, key=lambda x: x[0])[1]

    logger.warning(
        "CostTracker: unknown model '%s' — using default pricing %s",
        model,
        _DEFAULT_PRICING,
    )
    return _DEFAULT_PRICING


class CostTracker:
    """
    Tracks LLM API costs with per-session, per-agent, and per-customer attribution.

    Instances are lightweight and can be created per-request or shared as a
    singleton via ``get_cost_tracker()``.

    Parameters
    ----------
    log_path : str
        Path to the JSONL audit file.  Parent directory is created on first write.
    """

    def __init__(self, log_path: str = COST_LOG_PATH) -> None:
        self.log_path = log_path
        # defaultdict avoids KeyError and lets callers query arbitrary ids
        self._session_costs: dict[str, float] = defaultdict(float)
        self._agent_costs: dict[str, float] = defaultdict(float)
        self._customer_costs: dict[str, float] = defaultdict(float)

    # ------------------------------------------------------------------
    # Core record method
    # ------------------------------------------------------------------

    def record(
        self,
        model: str,
        input_tokens: int,
        output_tokens: int,
        session_id: Optional[str] = None,
        agent_id: Optional[str] = None,
        customer_id: Optional[str] = None,
    ) -> float:
        """
        Record one generation's cost and append to the JSONL audit log.

        Parameters
        ----------
        model : str
            Model name (e.g. ``"gemini-flash"``).
        input_tokens : int
            Number of prompt/input tokens consumed.
        output_tokens : int
            Number of completion/output tokens produced.
        session_id : str, optional
            Session identifier (maps to Langfuse session).
        agent_id : str, optional
            Genesis agent identifier.
        customer_id : str, optional
            Customer / SubAIVA identifier for billing attribution.

        Returns
        -------
        float
            Computed cost in USD (rounded to 6 decimal places).
        """
        pricing = _resolve_pricing(model)
        cost = round(
            (input_tokens / 1_000_000) * pricing["input"]
            + (output_tokens / 1_000_000) * pricing["output"],
            6,
        )

        # Accumulate into attribution buckets
        if session_id:
            self._session_costs[session_id] += cost
        if agent_id:
            self._agent_costs[agent_id] += cost
        if customer_id:
            self._customer_costs[customer_id] += cost

        # Write JSONL entry
        entry: dict = {
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "model": model,
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "cost_usd": cost,
            "session_id": session_id,
            "agent_id": agent_id,
            "customer_id": customer_id,
        }
        self._append_log(entry)

        return cost

    # ------------------------------------------------------------------
    # Query helpers
    # ------------------------------------------------------------------

    def get_session_cost(self, session_id: str) -> float:
        """Return total cost accumulated for *session_id*."""
        return self._session_costs.get(session_id, 0.0)

    def get_agent_cost(self, agent_id: str) -> float:
        """Return total cost accumulated for *agent_id*."""
        return self._agent_costs.get(agent_id, 0.0)

    def get_customer_cost(self, customer_id: str) -> float:
        """Return total cost accumulated for *customer_id*."""
        return self._customer_costs.get(customer_id, 0.0)

    def get_daily_total(self) -> float:
        """
        Return total cost across all tracked sessions in this instance.

        Note: this reflects in-memory state only; it does not read from the
        JSONL file.  For persistent totals, read ``cost_log.jsonl`` directly.
        """
        return round(sum(self._session_costs.values()), 6)

    def get_cost_summary(self) -> dict:
        """
        Return a summary dict with all attribution breakdowns.

        Returns
        -------
        dict
            Keys: ``daily_total_usd``, ``sessions``, ``agents``, ``customers``.
        """
        return {
            "daily_total_usd": round(self.get_daily_total(), 4),
            "sessions": dict(self._session_costs),
            "agents": dict(self._agent_costs),
            "customers": dict(self._customer_costs),
        }

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _append_log(self, entry: dict) -> None:
        """Append a JSONL record to the audit file, creating dirs as needed."""
        try:
            os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
            with open(self.log_path, "a", encoding="utf-8") as fh:
                fh.write(json.dumps(entry) + "\n")
        except OSError:
            logger.exception("CostTracker: failed to write cost log at %s", self.log_path)


# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------

_cost_tracker: Optional[CostTracker] = None


def get_cost_tracker() -> CostTracker:
    """Return the module-level ``CostTracker`` singleton."""
    global _cost_tracker  # noqa: PLW0603
    if _cost_tracker is None:
        _cost_tracker = CostTracker()
    return _cost_tracker
