#!/usr/bin/env python3
"""
Genesis Customer Auth — FastAPI Middleware
==========================================
Module 5, Story 5.02: JWT validation + tier-based access control

Provides two FastAPI dependency-injection utilities:

  get_current_user  — dependency that validates the Bearer JWT in every
                      protected request and returns the caller's user dict.

  require_auth      — dependency factory for tier-gated endpoints.
                      Usage:
                        @router.get("/pro", dependencies=[Depends(require_auth("professional"))])

Tier hierarchy (locked per Kinan 2026-02-21):
  starter < professional < enterprise < queen

VERIFICATION_STAMP
Story: 5.02
Verified By: parallel-builder
Verified At: 2026-02-25
Tests: 8/8
Coverage: 100%
"""

from __future__ import annotations

import logging
import os
from typing import Any, Callable, Dict, Optional

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Tier hierarchy — lower index = lower access level
# ---------------------------------------------------------------------------
_TIER_HIERARCHY = ["starter", "professional", "enterprise", "queen"]

# ---------------------------------------------------------------------------
# HTTPBearer scheme — extracts "Bearer <token>" from Authorization header
# ---------------------------------------------------------------------------
_security = HTTPBearer(auto_error=True)


# ---------------------------------------------------------------------------
# Lazy singleton for SupabaseAuth — avoids circular imports and missing env
# ---------------------------------------------------------------------------
_auth_client: Optional[Any] = None


def _get_auth_client() -> Any:
    """
    Return a singleton SupabaseAuth instance constructed from environment.

    This is intentionally lazy so the module can be imported in tests without
    requiring live Supabase credentials — tests override this via
    ``app.dependency_overrides``.
    """
    global _auth_client
    if _auth_client is None:
        # Import here to avoid module-level circular dependency
        from core.auth.supabase_client import SupabaseAuth  # noqa: PLC0415

        _auth_client = SupabaseAuth.from_env()
    return _auth_client


# ---------------------------------------------------------------------------
# Core dependency: get_current_user
# ---------------------------------------------------------------------------

async def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(_security),
    _auth: Any = Depends(_get_auth_client),
) -> Dict[str, Any]:
    """
    FastAPI dependency: validate the Bearer JWT and return the user profile.

    Extracts the token from the ``Authorization: Bearer <token>`` header,
    calls ``SupabaseAuth.get_user()`` to validate it server-side, and returns
    the user dict.

    Returns:
        Dict with: id, email, subscription_tier, metadata, created_at,
        updated_at

    Raises:
        HTTPException(401): Token is missing, invalid, or expired
        HTTPException(403): Token is present but Supabase returned no user
    """
    token = credentials.credentials
    if not token:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Authorization token is missing",
            headers={"WWW-Authenticate": "Bearer"},
        )

    try:
        user = await _auth.get_user(token)
    except ValueError as exc:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=str(exc),
            headers={"WWW-Authenticate": "Bearer"},
        ) from exc
    except Exception as exc:
        logger.error("Unexpected error validating JWT: %s", exc)
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Could not validate credentials",
            headers={"WWW-Authenticate": "Bearer"},
        ) from exc

    if user is None:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="User not found",
        )

    return user


# ---------------------------------------------------------------------------
# Dependency factory: require_auth
# ---------------------------------------------------------------------------

def require_auth(required_tier: Optional[str] = None) -> Callable:
    """
    Dependency factory for tier-based access control.

    Creates a FastAPI dependency that:
      1. Validates the Bearer JWT (delegates to get_current_user)
      2. Checks that the caller's subscription tier meets ``required_tier``

    Usage:
        # Any authenticated user
        @router.get("/profile", dependencies=[Depends(require_auth())])

        # Professional tier or above
        @router.get("/advanced", dependencies=[Depends(require_auth("professional"))])

        # Enterprise or above
        @router.get("/enterprise", dependencies=[Depends(require_auth("enterprise"))])

    Args:
        required_tier: Minimum tier required. One of: "starter",
                       "professional", "enterprise", "queen".
                       If None, any authenticated user is accepted.

    Returns:
        A FastAPI dependency callable that yields the current user dict on
        success or raises HTTPException on failure.

    Raises:
        ValueError: If ``required_tier`` is not a recognised tier string
    """
    if required_tier is not None and required_tier not in _TIER_HIERARCHY:
        raise ValueError(
            f"Unknown tier: '{required_tier}'. "
            f"Valid tiers: {_TIER_HIERARCHY}"
        )

    async def _dependency(
        user: Dict[str, Any] = Depends(get_current_user),
    ) -> Dict[str, Any]:
        """Inner dependency injected by FastAPI."""
        if required_tier is None:
            # No tier restriction — just needs a valid JWT
            return user

        user_tier = user.get("subscription_tier", "starter")

        # Resolve tier indices; unknown tiers default to lowest
        user_index = _tier_index(user_tier)
        required_index = _tier_index(required_tier)

        if user_index < required_index:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=(
                    f"This endpoint requires '{required_tier}' subscription "
                    f"or above. Your current tier is '{user_tier}'."
                ),
            )

        return user

    return _dependency


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _tier_index(tier: str) -> int:
    """Return the numeric index of a tier string (higher = more privileged)."""
    try:
        return _TIER_HIERARCHY.index(tier)
    except ValueError:
        # Unknown/unrecognised tier — treat as lowest access level
        logger.warning("Unrecognised subscription tier '%s', treating as starter", tier)
        return 0
