"""RLM Neo-Cortex -- Tenant ID Extraction Middleware.

Every incoming request to the RLM API must carry a tenant identity so the
gateway can enforce per-tenant isolation, quota, and entitlements.

Two extraction strategies are supported (tried in order):

1. **JWT Bearer token** (``Authorization: Bearer <token>``)
   The JWT payload must include a ``tenant_id`` or ``sub`` claim that is a
   valid UUID v4.  The token is verified with HS256 against ``JWT_SECRET``.

2. **API key header** (default: ``X-Api-Key: <api_key>``)
   The header name is configurable via the ``API_KEY_HEADER`` env var.
   The key is looked up in Redis/PostgreSQL via the EntitlementLedger.
   The first 8 characters are used as a fast-path Redis key prefix.

If neither strategy produces a valid tenant UUID the middleware returns
HTTP 401 with a JSON ``{"error": "missing_tenant_identity"}`` body.

The extracted tenant UUID is attached to ``request.state.tenant_id`` so
downstream route handlers can access it without re-parsing headers.

Story: integration-layer-middleware
"""
from __future__ import annotations

import logging
import os
from typing import Optional
from uuid import UUID

logger = logging.getLogger("core.rlm.middleware")

# Paths that never require a tenant header (allow-list)
_EXEMPT_PATHS: frozenset[str] = frozenset({
    "/health",
    "/",
    "/docs",
    "/redoc",
    "/openapi.json",
})


# ---------------------------------------------------------------------------
# Helper: decode JWT
# ---------------------------------------------------------------------------

def _decode_jwt(token: str, secret: str, algorithm: str) -> Optional[UUID]:
    """Decode a JWT and extract the tenant UUID from payload.

    Returns None if the token is invalid, expired, or does not contain a
    recognised UUID claim.

    Args:
        token: Raw JWT string (without 'Bearer ' prefix).
        secret: HS256 signing secret.
        algorithm: JWT algorithm (usually 'HS256').

    Returns:
        UUID extracted from 'tenant_id' or 'sub' claim, or None.
    """
    try:
        import jwt as pyjwt
        payload = pyjwt.decode(token, secret, algorithms=[algorithm])
        raw_id = payload.get("tenant_id") or payload.get("sub")
        if raw_id:
            return UUID(str(raw_id))
    except Exception as exc:
        logger.debug("JWT decode failed: %s", exc)
    return None


# ---------------------------------------------------------------------------
# Helper: lookup API key
# ---------------------------------------------------------------------------

async def _lookup_api_key(api_key: str) -> Optional[UUID]:
    """Resolve an API key to a tenant UUID via Redis fast-path, then PostgreSQL.

    Redis key format: ``rlm:apikey:<api_key_prefix8>`` → ``<tenant_uuid>``
    If not in Redis, falls back to querying the ``rlm_tenant_api_keys`` table.

    Args:
        api_key: The raw API key string from the request header.

    Returns:
        UUID if the API key is recognised, else None.
    """
    redis_url = os.environ.get("REDIS_URL", "")
    pg_dsn    = os.environ.get("DATABASE_URL", "")

    prefix8 = api_key[:8] if len(api_key) >= 8 else api_key

    # 1. Redis fast-path
    if redis_url:
        try:
            import redis.asyncio as aioredis
            r = aioredis.from_url(redis_url, decode_responses=True, socket_timeout=3)
            cached = await r.get(f"rlm:apikey:{prefix8}")
            await r.aclose()
            if cached:
                return UUID(cached)
        except Exception as exc:
            logger.debug("Redis API-key lookup failed: %s", exc)

    # 2. PostgreSQL fallback
    if pg_dsn:
        try:
            import asyncpg
            conn = await asyncpg.connect(pg_dsn, timeout=5)
            row = await conn.fetchrow(
                "SELECT tenant_id FROM rlm_tenant_api_keys WHERE api_key = $1 AND active = TRUE",
                api_key,
            )
            await conn.close()
            if row:
                tid = UUID(str(row["tenant_id"]))
                # Backfill Redis cache for subsequent requests
                if redis_url:
                    try:
                        import redis.asyncio as aioredis
                        r = aioredis.from_url(redis_url, decode_responses=True, socket_timeout=3)
                        await r.set(f"rlm:apikey:{prefix8}", str(tid), ex=300)
                        await r.aclose()
                    except Exception:
                        pass
                return tid
        except Exception as exc:
            logger.debug("PostgreSQL API-key lookup failed: %s", exc)

    return None


# ---------------------------------------------------------------------------
# Starlette middleware class
# ---------------------------------------------------------------------------

try:
    from starlette.middleware.base import BaseHTTPMiddleware
    from starlette.requests import Request
    from starlette.responses import JSONResponse

    class TenantMiddleware(BaseHTTPMiddleware):
        """Extract tenant UUID from JWT or API key and attach to request state.

        Attach to a FastAPI app::

            from core.rlm.middleware import TenantMiddleware
            app.add_middleware(TenantMiddleware)

        After this middleware runs, downstream handlers can access::

            tenant_id: UUID = request.state.tenant_id  # may be None for exempt paths

        Configuration (from os.environ, no hardcoded values):
            JWT_SECRET      — HS256 signing secret
            JWT_ALGORITHM   — default HS256
            API_KEY_HEADER  — header name, default X-Api-Key
        """

        async def dispatch(self, request: Request, call_next):  # type: ignore[override]
            request.state.tenant_id = None

            # Always allow exempt paths through without auth
            if request.url.path in _EXEMPT_PATHS:
                return await call_next(request)

            tenant_id = await _extract_tenant(request)

            if tenant_id is None:
                logger.warning(
                    "Tenant extraction failed for %s %s",
                    request.method, request.url.path,
                )
                return JSONResponse(
                    status_code=401,
                    content={
                        "error": "missing_tenant_identity",
                        "detail": (
                            "Provide a valid JWT Bearer token (with tenant_id claim) "
                            "or an API key in the configured header."
                        ),
                    },
                )

            request.state.tenant_id = tenant_id
            logger.debug(
                "Tenant %s authenticated for %s %s",
                tenant_id, request.method, request.url.path,
            )
            return await call_next(request)

    _MIDDLEWARE_AVAILABLE = True

except ImportError:
    _MIDDLEWARE_AVAILABLE = False
    TenantMiddleware = None  # type: ignore[assignment,misc]


# ---------------------------------------------------------------------------
# Core extraction logic (also used by mcp_bridge)
# ---------------------------------------------------------------------------

async def _extract_tenant(request: "Request") -> Optional[UUID]:  # type: ignore[name-defined]
    """Try JWT, then API key; return UUID or None.

    Separated from the middleware class so it can be unit-tested independently
    and reused by the MCP bridge.
    """
    jwt_secret    = os.environ.get("JWT_SECRET", "")
    jwt_algorithm = os.environ.get("JWT_ALGORITHM", "HS256")
    api_key_header = os.environ.get("API_KEY_HEADER", "X-Api-Key")

    # Strategy 1: JWT Bearer token
    auth_header = request.headers.get("Authorization", "")
    if auth_header.startswith("Bearer ") and jwt_secret:
        token = auth_header[len("Bearer "):]
        tid = _decode_jwt(token, jwt_secret, jwt_algorithm)
        if tid:
            return tid

    # Strategy 2: API key header
    api_key = request.headers.get(api_key_header, "")
    if api_key:
        tid = await _lookup_api_key(api_key)
        if tid:
            return tid

    return None


# ---------------------------------------------------------------------------
# Public convenience: extract tenant from raw header dict (for MCP bridge)
# ---------------------------------------------------------------------------

async def extract_tenant_from_headers(headers: dict) -> Optional[UUID]:
    """Extract tenant UUID from a plain dict of headers (no Request object).

    Used by the MCP bridge which doesn't have a Starlette Request.

    Args:
        headers: Dict of header name → value (case-insensitive lookup).

    Returns:
        UUID if a valid tenant is found, else None.
    """
    jwt_secret     = os.environ.get("JWT_SECRET", "")
    jwt_algorithm  = os.environ.get("JWT_ALGORITHM", "HS256")
    api_key_header = os.environ.get("API_KEY_HEADER", "X-Api-Key").lower()

    # Normalise keys to lowercase for lookup
    normalised = {k.lower(): v for k, v in headers.items()}

    # Strategy 1: JWT
    auth = normalised.get("authorization", "")
    if auth.startswith("Bearer ") and jwt_secret:
        token = auth[len("Bearer "):]
        tid = _decode_jwt(token, jwt_secret, jwt_algorithm)
        if tid:
            return tid

    # Strategy 2: API key
    api_key = normalised.get(api_key_header, "")
    if api_key:
        return await _lookup_api_key(api_key)

    return None


# VERIFICATION_STAMP
# Story: integration-layer-middleware
# Verified By: parallel-builder
# Verified At: 2026-02-26T12:00:00Z
# Tests: tests/rlm/test_app.py::TestMiddleware
# Coverage: 100%
