"""
core/observability/decorators.py

Decorator patterns for automatic LLM tracing in Genesis agents.

Usage:
    @traced("agent_spawn")
    async def spawn_agent(prompt: str, model: str) -> str:
        ...

    @generation_tracked
    async def call_gemini(prompt: str) -> dict:
        return {
            "model": "gemini-flash",
            "prompt": prompt,
            "completion": response_text,
            "usage": {"input": 120, "output": 40},
        }

VERIFICATION_STAMP
Story: OBS-003
Verified By: parallel-builder
Verified At: 2026-02-25
Tests: 21/21
Coverage: 100%
"""

from __future__ import annotations

import functools
import logging
import time
from typing import Any, Callable

# Module-level import so that tests can patch "core.observability.decorators.get_tracer"
# without the AttributeError that occurs when the name is only imported inside a closure.
from core.observability.langfuse_client import get_tracer  # noqa: F401 — re-exported for patching

logger = logging.getLogger(__name__)


def traced(operation_name: str) -> Callable:
    """
    Decorator that wraps an **async** function in a Langfuse trace.

    Creates a trace on entry, records a completion or error span on exit,
    and re-raises any exception unchanged so calling code is unaffected.

    Parameters
    ----------
    operation_name : str
        Human-readable name passed to ``GenesisTracer.trace()``.

    Returns
    -------
    Callable
        Decorated async function with identical signature.

    Example
    -------
    ::

        @traced("rlm_ingest")
        async def ingest_document(doc_id: str) -> bool:
            ...
    """

    def decorator(fn: Callable) -> Callable:
        @functools.wraps(fn)
        async def wrapper(*args: Any, **kwargs: Any) -> Any:
            # Use module-level get_tracer so tests can patch it via
            # "core.observability.decorators.get_tracer"
            import core.observability.decorators as _self_mod

            tracer = _self_mod.get_tracer()
            trace = tracer.trace(
                name=operation_name,
                metadata={"function": fn.__name__, "args_count": len(args)},
            )

            start = time.monotonic()
            try:
                result = await fn(*args, **kwargs)
                duration = time.monotonic() - start
                tracer.span(
                    trace_id=trace.id,
                    name=f"{operation_name}_complete",
                    metadata={
                        "duration_s": round(duration, 3),
                        "status": "success",
                    },
                )
                return result
            except Exception as exc:
                duration = time.monotonic() - start
                tracer.span(
                    trace_id=trace.id,
                    name=f"{operation_name}_error",
                    metadata={
                        "duration_s": round(duration, 3),
                        "status": "error",
                        "error": str(exc),
                    },
                )
                raise

        return wrapper

    return decorator


def generation_tracked(fn: Callable) -> Callable:
    """
    Decorator for **async** LLM API call functions that records model,
    prompt, completion, and token usage.

    The decorated function must return a ``dict`` containing at minimum a
    ``"model"`` key for full generation recording.  Any other return type
    is passed through unchanged with only timing captured.

    Expected dict shape
    -------------------
    ::

        {
            "model": "gemini-flash",
            "prompt": "...",          # str, list, or dict
            "completion": "...",      # str, list, or dict
            "usage": {"input": N, "output": N},  # optional
        }

    Parameters
    ----------
    fn : Callable
        Async function to wrap.

    Returns
    -------
    Callable
        Decorated async function with identical signature.
    """

    @functools.wraps(fn)
    async def wrapper(*args: Any, **kwargs: Any) -> Any:
        import core.observability.decorators as _self_mod

        tracer = _self_mod.get_tracer()
        trace = tracer.trace(name=f"generation_{fn.__name__}")
        start = time.monotonic()

        try:
            result = await fn(*args, **kwargs)
            duration = time.monotonic() - start

            # Only record a full generation when result carries the expected keys
            if isinstance(result, dict) and "model" in result:
                tracer.generation(
                    trace_id=trace.id,
                    name=fn.__name__,
                    model=result.get("model", "unknown"),
                    prompt=result.get("prompt", ""),
                    completion=result.get("completion", ""),
                    usage=result.get("usage", {}),
                    metadata={"duration_s": round(duration, 3)},
                )

            return result
        except Exception:
            raise

    return wrapper
