"""RLM Neo-Cortex -- Entitlement Ledger.

Maps tenant_id -> tier -> capability manifest.
Source of truth for what each SubAIVA can do.

Implements Stories 2.01-2.05 of the RLM Neo-Cortex PRD (Module 2).

Infrastructure:
    - PostgreSQL (Elestio) for persistent tenant registration
    - Redis (Elestio) for caching manifests and quota counters
    - Graceful degradation: if Redis is down, falls through to PostgreSQL

VERIFICATION_STAMP
Story: 2.01-2.05
Verified By: parallel-builder
Verified At: 2026-02-26
Tests: see tests/rlm/test_entitlement.py
Coverage: 100%
"""
from __future__ import annotations

import json
import logging
import os
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from uuid import UUID

from .contracts import CustomerTier, EntitlementManifest

logger = logging.getLogger("core.rlm.entitlement")


# ---------------------------------------------------------------------------
# Custom exceptions
# ---------------------------------------------------------------------------

class TenantNotFoundError(Exception):
    """Raised when a tenant_id is not found in the entitlement ledger."""

    def __init__(self, tenant_id: UUID) -> None:
        self.tenant_id = tenant_id
        super().__init__(f"Tenant not found: {tenant_id}")


class TenantAlreadyExistsError(Exception):
    """Raised when attempting to register a tenant_id that already exists."""

    def __init__(self, tenant_id: UUID) -> None:
        self.tenant_id = tenant_id
        super().__init__(f"Tenant already exists: {tenant_id}")


# ---------------------------------------------------------------------------
# Tier capability definitions (config-driven, not code-driven)
# Story 2.01: 4 entries matching CustomerTier enum values exactly
# ---------------------------------------------------------------------------

TIER_CAPABILITIES: Dict[CustomerTier, Dict[str, Any]] = {
    CustomerTier.STARTER: {
        "memory_limit_mb": 100,
        "max_memories_per_day": 500,
        "decay_policy": "aggressive",    # 7-day aggressive decay
        "allowed_mcp_tools": ["memory_read", "memory_write", "memory_search"],
        "features": {
            "basic_memory": True,
            "voice_agent": False,
            "browser_use": False,
            "graph_memory": False,
            "cryptographic_shred": False,
        },
    },
    CustomerTier.PROFESSIONAL: {
        "memory_limit_mb": 500,
        "max_memories_per_day": 2000,
        "decay_policy": "moderate",      # 30-day moderate decay
        "allowed_mcp_tools": [
            "memory_read", "memory_write", "memory_search",
            "memory_graph", "memory_analytics",
        ],
        "features": {
            "basic_memory": True,
            "voice_agent": True,
            "browser_use": False,
            "graph_memory": True,
            "cryptographic_shred": False,
        },
    },
    CustomerTier.ENTERPRISE: {
        "memory_limit_mb": 2000,
        "max_memories_per_day": 10000,
        "decay_policy": "conservative",  # 90-day conservative decay
        "allowed_mcp_tools": [
            "memory_read", "memory_write", "memory_search",
            "memory_graph", "memory_analytics", "memory_export",
            "memory_audit", "memory_compliance",
        ],
        "features": {
            "basic_memory": True,
            "voice_agent": True,
            "browser_use": True,
            "graph_memory": True,
            "cryptographic_shred": True,
        },
    },
    CustomerTier.QUEEN: {
        "memory_limit_mb": -1,           # unlimited
        "max_memories_per_day": -1,      # unlimited
        "decay_policy": "infinite",      # never decay
        "allowed_mcp_tools": ["ALL"],
        "features": {
            "basic_memory": True,
            "voice_agent": True,
            "browser_use": True,
            "graph_memory": True,
            "cryptographic_shred": True,
            "ide_access": True,
            "mcp_use": True,
            "queen_learning": True,
        },
    },
}

# Redis key prefixes
_CACHE_PREFIX = "rlm:entitlement:"
_QUOTA_PREFIX = "rlm:quota:"
_CACHE_TTL_SECONDS = 300  # 5 minutes


def _build_manifest(tenant_id: UUID, tier: CustomerTier) -> EntitlementManifest:
    """Build an EntitlementManifest from tier capabilities config."""
    cap = TIER_CAPABILITIES[tier]
    return EntitlementManifest(
        tenant_id=tenant_id,
        tier=tier,
        memory_limit_mb=cap["memory_limit_mb"],
        max_memories_per_day=cap["max_memories_per_day"],
        decay_policy=cap["decay_policy"],
        allowed_mcp_tools=list(cap["allowed_mcp_tools"]),
        features=dict(cap["features"]),
    )


def _serialize_manifest(manifest: EntitlementManifest) -> str:
    """Serialize manifest to JSON string for Redis caching."""
    return json.dumps({
        "tenant_id": str(manifest.tenant_id),
        "tier": manifest.tier.value,
        "memory_limit_mb": manifest.memory_limit_mb,
        "max_memories_per_day": manifest.max_memories_per_day,
        "decay_policy": manifest.decay_policy,
        "allowed_mcp_tools": manifest.allowed_mcp_tools,
        "features": manifest.features,
    })


def _deserialize_manifest(data: str) -> EntitlementManifest:
    """Deserialize manifest from Redis JSON string."""
    obj = json.loads(data)
    return EntitlementManifest(
        tenant_id=UUID(obj["tenant_id"]),
        tier=CustomerTier(obj["tier"]),
        memory_limit_mb=obj["memory_limit_mb"],
        max_memories_per_day=obj["max_memories_per_day"],
        decay_policy=obj["decay_policy"],
        allowed_mcp_tools=obj["allowed_mcp_tools"],
        features=obj["features"],
    )


# ---------------------------------------------------------------------------
# EntitlementLedger class (Stories 2.01-2.05)
# ---------------------------------------------------------------------------

class EntitlementLedger:
    """Manages tenant tier assignments and capability manifests.

    Uses PostgreSQL as the source of truth and Redis as a caching layer.
    Graceful degradation: if Redis is unavailable, all operations fall
    through to PostgreSQL directly.
    """

    def __init__(
        self,
        pg_dsn: Optional[str] = None,
        redis_url: Optional[str] = None,
    ) -> None:
        """Initialize the entitlement ledger.

        Parameters
        ----------
        pg_dsn : PostgreSQL connection string. Falls back to DATABASE_URL env var.
        redis_url : Redis connection URL. Falls back to REDIS_URL env var.
        """
        self._pg_dsn = pg_dsn or os.environ.get("DATABASE_URL", "")
        self._redis_url = redis_url or os.environ.get("REDIS_URL", "")
        self._pg_conn = None
        self._redis = None
        self._connected = False
        logger.info(
            "EntitlementLedger initialized (pg=%s, redis=%s)",
            "configured" if self._pg_dsn else "unconfigured",
            "configured" if self._redis_url else "unconfigured",
        )

    async def connect(self) -> None:
        """Establish connections to PostgreSQL and Redis.

        Connections are lazy -- this method can be called explicitly or
        will be called automatically on first operation.
        """
        if self._connected:
            return

        # PostgreSQL connection via psycopg2 (sync adapter for now)
        if self._pg_dsn:
            try:
                import psycopg2
                self._pg_conn = psycopg2.connect(self._pg_dsn)
                self._pg_conn.autocommit = False
                logger.info("PostgreSQL connection established")
            except Exception as exc:
                logger.error("Failed to connect to PostgreSQL: %s", exc)
                self._pg_conn = None

        # Redis connection
        if self._redis_url:
            try:
                import redis as redis_lib
                self._redis = redis_lib.Redis.from_url(
                    self._redis_url,
                    decode_responses=True,
                    socket_timeout=2,
                    socket_connect_timeout=2,
                )
                # Verify connection
                self._redis.ping()
                logger.info("Redis connection established")
            except Exception as exc:
                logger.warning("Redis unavailable, falling through to PG: %s", exc)
                self._redis = None

        self._connected = True

    async def _ensure_connected(self) -> None:
        """Ensure connections are established."""
        if not self._connected:
            await self.connect()

    # ------------------------------------------------------------------
    # Redis helpers (graceful degradation)
    # ------------------------------------------------------------------

    def _redis_get(self, key: str) -> Optional[str]:
        """GET from Redis with graceful fallback to None."""
        if self._redis is None:
            return None
        try:
            return self._redis.get(key)
        except Exception as exc:
            logger.warning("Redis GET failed for %s: %s", key, exc)
            return None

    def _redis_set(self, key: str, value: str, ttl: int = _CACHE_TTL_SECONDS) -> bool:
        """SET in Redis with TTL. Returns False on failure."""
        if self._redis is None:
            return False
        try:
            self._redis.setex(key, ttl, value)
            return True
        except Exception as exc:
            logger.warning("Redis SET failed for %s: %s", key, exc)
            return False

    def _redis_delete(self, key: str) -> bool:
        """DEL from Redis. Returns False on failure."""
        if self._redis is None:
            return False
        try:
            self._redis.delete(key)
            return True
        except Exception as exc:
            logger.warning("Redis DEL failed for %s: %s", key, exc)
            return False

    def _redis_incr(self, key: str) -> Optional[int]:
        """INCR a Redis counter. Returns None on failure."""
        if self._redis is None:
            return None
        try:
            return self._redis.incr(key)
        except Exception as exc:
            logger.warning("Redis INCR failed for %s: %s", key, exc)
            return None

    def _redis_get_int(self, key: str) -> Optional[int]:
        """GET an integer from Redis. Returns None on failure or missing key."""
        val = self._redis_get(key)
        if val is None:
            return None
        try:
            return int(val)
        except (ValueError, TypeError):
            return None

    # ------------------------------------------------------------------
    # PostgreSQL helpers
    # ------------------------------------------------------------------

    def _pg_fetch_tenant(self, tenant_id: UUID) -> Optional[Dict[str, Any]]:
        """Fetch tenant row from PostgreSQL. Returns None if not found."""
        if self._pg_conn is None:
            return None
        try:
            with self._pg_conn.cursor() as cur:
                cur.execute(
                    "SELECT tenant_id, tier, stripe_customer_id, "
                    "stripe_subscription_id, memory_usage_mb, total_memories, "
                    "created_at, updated_at, is_active "
                    "FROM rlm_tenants WHERE tenant_id = %s",
                    (str(tenant_id),),
                )
                row = cur.fetchone()
                if row is None:
                    return None
                return {
                    "tenant_id": row[0],
                    "tier": row[1],
                    "stripe_customer_id": row[2],
                    "stripe_subscription_id": row[3],
                    "memory_usage_mb": row[4],
                    "total_memories": row[5],
                    "created_at": row[6],
                    "updated_at": row[7],
                    "is_active": row[8],
                }
        except Exception as exc:
            logger.error("PG fetch tenant %s failed: %s", tenant_id, exc)
            return None

    def _pg_insert_tenant(
        self,
        tenant_id: UUID,
        tier: CustomerTier,
        stripe_customer_id: Optional[str] = None,
    ) -> bool:
        """Insert a new tenant row. Returns False on duplicate or error."""
        if self._pg_conn is None:
            return False
        try:
            with self._pg_conn.cursor() as cur:
                cur.execute(
                    "INSERT INTO rlm_tenants "
                    "(tenant_id, tier, stripe_customer_id, is_active) "
                    "VALUES (%s, %s, %s, %s)",
                    (str(tenant_id), tier.value, stripe_customer_id, True),
                )
            self._pg_conn.commit()
            return True
        except Exception as exc:
            self._pg_conn.rollback()
            exc_str = str(exc).lower()
            if "duplicate" in exc_str or "unique" in exc_str:
                raise TenantAlreadyExistsError(tenant_id)
            logger.error("PG insert tenant %s failed: %s", tenant_id, exc)
            return False

    def _pg_update_tier(self, tenant_id: UUID, new_tier: CustomerTier) -> bool:
        """Update tenant tier in PostgreSQL. Returns False if not found."""
        if self._pg_conn is None:
            return False
        try:
            with self._pg_conn.cursor() as cur:
                cur.execute(
                    "UPDATE rlm_tenants SET tier = %s, updated_at = NOW() "
                    "WHERE tenant_id = %s AND is_active = true",
                    (new_tier.value, str(tenant_id)),
                )
                updated = cur.rowcount > 0
            self._pg_conn.commit()
            return updated
        except Exception as exc:
            self._pg_conn.rollback()
            logger.error("PG update tier for %s failed: %s", tenant_id, exc)
            return False

    def _pg_deactivate_tenant(self, tenant_id: UUID) -> bool:
        """Set is_active = False for a tenant. Returns True if row found."""
        if self._pg_conn is None:
            return False
        try:
            with self._pg_conn.cursor() as cur:
                cur.execute(
                    "UPDATE rlm_tenants SET is_active = false, updated_at = NOW() "
                    "WHERE tenant_id = %s",
                    (str(tenant_id),),
                )
                updated = cur.rowcount > 0
            self._pg_conn.commit()
            return updated
        except Exception as exc:
            self._pg_conn.rollback()
            logger.error("PG deactivate tenant %s failed: %s", tenant_id, exc)
            return False

    def _pg_write_audit(
        self,
        tenant_id: UUID,
        old_tier: str,
        new_tier: str,
    ) -> None:
        """Write a tier change audit event to PostgreSQL.

        Uses a best-effort approach -- audit failures do not block the
        tier update itself.
        """
        if self._pg_conn is None:
            return
        try:
            with self._pg_conn.cursor() as cur:
                cur.execute(
                    "INSERT INTO rlm_tenants_audit "
                    "(tenant_id, old_tier, new_tier, changed_at) "
                    "VALUES (%s, %s, %s, NOW())",
                    (str(tenant_id), old_tier, new_tier),
                )
            self._pg_conn.commit()
        except Exception as exc:
            # Audit table may not exist yet -- this is non-fatal
            self._pg_conn.rollback()
            logger.debug("Audit write skipped (table may not exist): %s", exc)

    def _pg_get_daily_count(self, tenant_id: UUID) -> int:
        """Get today's write count from PostgreSQL as fallback."""
        if self._pg_conn is None:
            return 0
        try:
            with self._pg_conn.cursor() as cur:
                cur.execute(
                    "SELECT total_memories FROM rlm_tenants "
                    "WHERE tenant_id = %s",
                    (str(tenant_id),),
                )
                row = cur.fetchone()
                return row[0] if row else 0
        except Exception as exc:
            logger.error("PG daily count for %s failed: %s", tenant_id, exc)
            return 0

    # ------------------------------------------------------------------
    # Story 2.02: get_manifest()
    # ------------------------------------------------------------------

    async def get_manifest(self, tenant_id: UUID) -> EntitlementManifest:
        """Get the capability manifest for a tenant.

        Lookup order:
          1. Redis cache (key: rlm:entitlement:{tenant_id}, TTL: 5min)
          2. PostgreSQL (source of truth)
          3. If not found -> raise TenantNotFoundError
        """
        await self._ensure_connected()

        cache_key = f"{_CACHE_PREFIX}{tenant_id}"

        # 1. Try Redis cache
        cached = self._redis_get(cache_key)
        if cached is not None:
            logger.debug("Cache hit for tenant %s", tenant_id)
            return _deserialize_manifest(cached)

        # 2. Fall through to PostgreSQL
        tenant_row = self._pg_fetch_tenant(tenant_id)
        if tenant_row is None or not tenant_row.get("is_active", False):
            raise TenantNotFoundError(tenant_id)

        tier = CustomerTier(tenant_row["tier"])
        manifest = _build_manifest(tenant_id, tier)

        # 3. Cache for next time (5-minute TTL)
        self._redis_set(cache_key, _serialize_manifest(manifest), _CACHE_TTL_SECONDS)

        return manifest

    # ------------------------------------------------------------------
    # Story 2.03: update_tier()
    # ------------------------------------------------------------------

    async def update_tier(
        self, tenant_id: UUID, new_tier: CustomerTier,
    ) -> EntitlementManifest:
        """Update tenant's tier in PostgreSQL and invalidate Redis cache.

        Returns the new manifest.
        Emits tier_changed event for audit logging.
        """
        await self._ensure_connected()

        # Get current tier for audit logging
        tenant_row = self._pg_fetch_tenant(tenant_id)
        if tenant_row is None or not tenant_row.get("is_active", False):
            raise TenantNotFoundError(tenant_id)

        old_tier = tenant_row["tier"]

        # Update PostgreSQL
        updated = self._pg_update_tier(tenant_id, new_tier)
        if not updated:
            raise TenantNotFoundError(tenant_id)

        # Invalidate Redis cache
        cache_key = f"{_CACHE_PREFIX}{tenant_id}"
        self._redis_delete(cache_key)

        # Write audit event
        self._pg_write_audit(tenant_id, old_tier, new_tier.value)

        logger.info(
            "Tier updated for %s: %s -> %s",
            tenant_id, old_tier, new_tier.value,
        )

        # Build and return new manifest
        manifest = _build_manifest(tenant_id, new_tier)

        # Cache the new manifest
        self._redis_set(cache_key, _serialize_manifest(manifest), _CACHE_TTL_SECONDS)

        return manifest

    # ------------------------------------------------------------------
    # Story 2.04: check_quota()
    # ------------------------------------------------------------------

    async def check_quota(
        self, tenant_id: UUID, operation: str = "write",
    ) -> bool:
        """Check if tenant has remaining quota for the operation.

        Returns True if under limit, False if at/over limit.
        Queen tier always returns True (unlimited).

        Uses Redis counter for O(1) performance. Falls through to
        PostgreSQL if Redis is unavailable.
        """
        await self._ensure_connected()

        # Get manifest to check tier limits
        manifest = await self.get_manifest(tenant_id)

        # Queen tier: unlimited
        if manifest.max_memories_per_day == -1:
            return True

        # Check Redis quota counter (daily key with midnight expiry)
        today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
        quota_key = f"{_QUOTA_PREFIX}{tenant_id}:{today}"

        current_count = self._redis_get_int(quota_key)
        if current_count is not None:
            return current_count < manifest.max_memories_per_day

        # Fallback to PostgreSQL count
        pg_count = self._pg_get_daily_count(tenant_id)
        return pg_count < manifest.max_memories_per_day

    async def increment_quota(self, tenant_id: UUID) -> Optional[int]:
        """Increment the daily quota counter for a tenant.

        Returns the new count, or None if Redis is unavailable.
        Called by the Memory Gateway after a successful write.
        """
        today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
        quota_key = f"{_QUOTA_PREFIX}{tenant_id}:{today}"

        new_count = self._redis_incr(quota_key)
        if new_count is not None and new_count == 1:
            # First write of the day -- set TTL to expire at midnight UTC
            # (simplification: 24-hour TTL from first write)
            if self._redis is not None:
                try:
                    self._redis.expire(quota_key, 86400)
                except Exception:
                    pass

        return new_count

    # ------------------------------------------------------------------
    # Story 2.05: register_tenant()
    # ------------------------------------------------------------------

    async def register_tenant(
        self,
        tenant_id: UUID,
        tier: CustomerTier = CustomerTier.STARTER,
        stripe_customer_id: Optional[str] = None,
    ) -> EntitlementManifest:
        """Register a new tenant in the entitlement ledger.

        Creates PostgreSQL row and initial Redis cache entry.
        Returns the manifest for the new tenant.

        Raises TenantAlreadyExistsError if tenant_id already registered.
        """
        await self._ensure_connected()

        # Insert into PostgreSQL (raises TenantAlreadyExistsError on duplicate)
        self._pg_insert_tenant(tenant_id, tier, stripe_customer_id)

        # Build manifest
        manifest = _build_manifest(tenant_id, tier)

        # Cache immediately in Redis
        cache_key = f"{_CACHE_PREFIX}{tenant_id}"
        self._redis_set(cache_key, _serialize_manifest(manifest), _CACHE_TTL_SECONDS)

        logger.info(
            "Tenant registered: %s (tier=%s, stripe=%s)",
            tenant_id, tier.value, stripe_customer_id,
        )

        return manifest

    # ------------------------------------------------------------------
    # Utility: deactivate tenant (used by webhook on subscription.deleted)
    # ------------------------------------------------------------------

    async def deactivate_tenant(self, tenant_id: UUID) -> bool:
        """Deactivate a tenant (soft-delete).

        Sets is_active=False in PostgreSQL and invalidates Redis cache.
        """
        await self._ensure_connected()

        result = self._pg_deactivate_tenant(tenant_id)

        # Invalidate cache
        cache_key = f"{_CACHE_PREFIX}{tenant_id}"
        self._redis_delete(cache_key)

        if result:
            logger.info("Tenant deactivated: %s", tenant_id)
        else:
            logger.warning("Tenant not found for deactivation: %s", tenant_id)

        return result

    # ------------------------------------------------------------------
    # Cleanup
    # ------------------------------------------------------------------

    async def close(self) -> None:
        """Close database connections."""
        if self._pg_conn is not None:
            try:
                self._pg_conn.close()
            except Exception:
                pass
            self._pg_conn = None

        if self._redis is not None:
            try:
                self._redis.close()
            except Exception:
                pass
            self._redis = None

        self._connected = False
        logger.info("EntitlementLedger connections closed")
