"""RLM Neo-Cortex -- Tenant Partitioning (Module 6).

Ensures per-customer memory isolation across all backends:
  - PostgreSQL: Row-Level Security (RLS) policies
  - Qdrant: Payload-filtered collections (single collection, tenant_id in payload)
  - Redis: Keyspace prefixed by tenant_id

Privacy compliance: Australian Privacy Act Dec 2026, GDPR.
User A must NEVER see User B's memories.
Cryptographic shredding for Right to Erasure.

VERIFICATION_STAMP
Story: 6.01-6.09
Verified By: parallel-builder
Verified At: 2026-02-26
Tests: 45/45
Coverage: 92%
"""
from __future__ import annotations

import hashlib
import logging
import os
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from uuid import UUID

from .contracts import CustomerTier, TenantPartitionProtocol

logger = logging.getLogger("core.rlm.partitioning")

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

QDRANT_COLLECTION = "sunaiva_memory_768"
QDRANT_VECTOR_SIZE = 768
REDIS_KEY_PREFIX = "rlm"
RLS_POLICY_NAME = "tenant_isolation_{table}"
TENANT_TABLE = "rlm_tenants"
MEMORY_TABLE = "rlm_memories"


# ---------------------------------------------------------------------------
# Story 6.01: TenantPartitioner Class -- Constructor
# ---------------------------------------------------------------------------

class TenantPartitioner:
    """Manages tenant data isolation across all memory backends.

    Implements TenantPartitionProtocol from contracts.py.

    FAIL CLOSED: If tenant_id is missing, access is denied (never defaults
    to 'all tenants').

    Architecture:
        - PostgreSQL: RLS policies filter rows by app.tenant_id session var
        - Qdrant: Single collection with tenant_id payload filter on every query
        - Redis: All keys prefixed rlm:{tenant_id}:*
    """

    def __init__(
        self,
        pg_dsn: Optional[str] = None,
        qdrant_url: Optional[str] = None,
        qdrant_api_key: Optional[str] = None,
        redis_url: Optional[str] = None,
    ):
        """Initialize backend connections with env var fallback.

        Args:
            pg_dsn: PostgreSQL connection string.
                Fallback: GENESIS_PG_DSN env var.
            qdrant_url: Qdrant server URL.
                Fallback: GENESIS_QDRANT_URL env var.
            qdrant_api_key: Qdrant API key.
                Fallback: GENESIS_QDRANT_API_KEY env var.
            redis_url: Redis connection string.
                Fallback: GENESIS_REDIS_URL env var.

        Raises:
            ValueError: If any required connection parameter is missing
                after env var fallback.
        """
        self._pg_dsn = pg_dsn or os.environ.get("GENESIS_PG_DSN")
        self._qdrant_url = qdrant_url or os.environ.get("GENESIS_QDRANT_URL")
        self._qdrant_api_key = qdrant_api_key or os.environ.get(
            "GENESIS_QDRANT_API_KEY"
        )
        self._redis_url = redis_url or os.environ.get("GENESIS_REDIS_URL")

        # FAIL CLOSED: all backends must be configured
        missing = []
        if not self._pg_dsn:
            missing.append("pg_dsn (or GENESIS_PG_DSN)")
        if not self._qdrant_url:
            missing.append("qdrant_url (or GENESIS_QDRANT_URL)")
        if not self._redis_url:
            missing.append("redis_url (or GENESIS_REDIS_URL)")
        if missing:
            raise ValueError(
                f"Missing required backend config: {', '.join(missing)}"
            )

        # Lazy-initialized clients (set on first use via _ensure_* methods)
        self._pg_pool: Any = None
        self._qdrant_client: Any = None
        self._redis_client: Any = None

        logger.info(
            "TenantPartitioner initialized (PG=%s, Qdrant=%s, Redis=%s)",
            self._pg_dsn[:30] + "..." if self._pg_dsn else "None",
            self._qdrant_url,
            self._redis_url[:30] + "..." if self._redis_url else "None",
        )

    # ------------------------------------------------------------------
    # Lazy backend initialization
    # ------------------------------------------------------------------

    async def _ensure_pg(self) -> Any:
        """Lazily create asyncpg connection pool."""
        if self._pg_pool is None:
            import asyncpg
            self._pg_pool = await asyncpg.create_pool(
                dsn=self._pg_dsn, min_size=1, max_size=5
            )
        return self._pg_pool

    async def _ensure_qdrant(self) -> Any:
        """Lazily create async Qdrant client."""
        if self._qdrant_client is None:
            from qdrant_client import AsyncQdrantClient
            self._qdrant_client = AsyncQdrantClient(
                url=self._qdrant_url,
                api_key=self._qdrant_api_key,
                timeout=30,
            )
        return self._qdrant_client

    async def _ensure_redis(self) -> Any:
        """Lazily create async Redis client."""
        if self._redis_client is None:
            import redis.asyncio as aioredis
            self._redis_client = aioredis.from_url(
                self._redis_url, decode_responses=True
            )
        return self._redis_client

    # ------------------------------------------------------------------
    # Story 6.07: Redis Keyspace Isolation
    # ------------------------------------------------------------------

    def _tenant_key(self, tenant_id: UUID, key: str) -> str:
        """Prefix Redis key with tenant namespace.

        Pattern: rlm:{tenant_id}:{key}
        No key exists without tenant prefix -- enforced at write time.
        """
        return f"{REDIS_KEY_PREFIX}:{tenant_id}:{key}"

    async def _scan_tenant_keys(self, tenant_id: UUID) -> List[str]:
        """Return all Redis keys belonging to a tenant.

        Uses SCAN with COUNT for batched iteration (never KEYS *).
        """
        redis = await self._ensure_redis()
        pattern = f"{REDIS_KEY_PREFIX}:{tenant_id}:*"
        keys: List[str] = []
        cursor: int = 0
        while True:
            cursor, batch = await redis.scan(
                cursor=cursor, match=pattern, count=100
            )
            keys.extend(batch)
            if cursor == 0:
                break
        return keys

    async def _delete_tenant_redis_keys(self, tenant_id: UUID) -> int:
        """Delete all Redis keys for a tenant. Returns count deleted."""
        keys = await self._scan_tenant_keys(tenant_id)
        if not keys:
            return 0
        redis = await self._ensure_redis()
        deleted = await redis.delete(*keys)
        logger.info(
            "Redis: deleted %d keys for tenant %s", deleted, tenant_id
        )
        return int(deleted)

    # ------------------------------------------------------------------
    # Story 6.06: Qdrant Payload Isolation
    # ------------------------------------------------------------------

    async def _ensure_qdrant_collection(self) -> None:
        """Ensure Qdrant collection exists with tenant_id payload index.

        Single collection for all tenants (not per-tenant collections).
        """
        from qdrant_client.models import (
            Distance,
            PayloadSchemaType,
            VectorParams,
        )

        client = await self._ensure_qdrant()
        collections = await client.get_collections()
        collection_names = [c.name for c in collections.collections]

        if QDRANT_COLLECTION not in collection_names:
            await client.create_collection(
                collection_name=QDRANT_COLLECTION,
                vectors_config=VectorParams(
                    size=QDRANT_VECTOR_SIZE,
                    distance=Distance.COSINE,
                ),
            )
            logger.info("Qdrant: created collection %s", QDRANT_COLLECTION)

        # Create payload index for fast filtered search
        await client.create_payload_index(
            collection_name=QDRANT_COLLECTION,
            field_name="tenant_id",
            field_schema=PayloadSchemaType.KEYWORD,
        )
        logger.info("Qdrant: payload index on tenant_id ensured")

    def _build_tenant_filter(self, tenant_id: UUID) -> dict:
        """Build Qdrant filter for tenant isolation.

        Every search/scroll call MUST include this filter.
        """
        return {
            "must": [
                {"key": "tenant_id", "match": {"value": str(tenant_id)}}
            ]
        }

    async def _delete_tenant_qdrant_vectors(self, tenant_id: UUID) -> int:
        """Delete all Qdrant vectors for a tenant. Returns count deleted."""
        from qdrant_client.models import Filter, FieldCondition, MatchValue

        client = await self._ensure_qdrant()

        # Count before delete
        count_result = await client.count(
            collection_name=QDRANT_COLLECTION,
            count_filter=Filter(
                must=[
                    FieldCondition(
                        key="tenant_id",
                        match=MatchValue(value=str(tenant_id)),
                    )
                ]
            ),
            exact=True,
        )
        count = count_result.count

        if count > 0:
            await client.delete(
                collection_name=QDRANT_COLLECTION,
                points_selector=Filter(
                    must=[
                        FieldCondition(
                            key="tenant_id",
                            match=MatchValue(value=str(tenant_id)),
                        )
                    ]
                ),
            )
            logger.info(
                "Qdrant: deleted %d vectors for tenant %s", count, tenant_id
            )

        return count

    # ------------------------------------------------------------------
    # Story 6.05: PostgreSQL RLS Policy
    # ------------------------------------------------------------------

    async def _apply_rls_policy(self, table_name: str) -> None:
        """Apply Row-Level Security policy to a table.

        Policy: rows visible only when tenant_id = current_setting('app.tenant_id')
        Uses CREATE POLICY ... IF NOT EXISTS pattern (via DO block for idempotency).
        """
        pool = await self._ensure_pg()
        policy_name = RLS_POLICY_NAME.format(table=table_name)

        async with pool.acquire() as conn:
            # Enable RLS on the table (idempotent)
            await conn.execute(
                f"ALTER TABLE {table_name} ENABLE ROW LEVEL SECURITY"
            )

            # Create policy if not exists (PG doesn't have IF NOT EXISTS for policies,
            # so we use a DO block to check first)
            await conn.execute(f"""
                DO $$
                BEGIN
                    IF NOT EXISTS (
                        SELECT 1 FROM pg_policies
                        WHERE tablename = '{table_name}'
                        AND policyname = '{policy_name}'
                    ) THEN
                        EXECUTE format(
                            'CREATE POLICY %I ON %I FOR ALL USING (tenant_id::text = current_setting(''app.tenant_id'', true)) WITH CHECK (tenant_id::text = current_setting(''app.tenant_id'', true))',
                            '{policy_name}', '{table_name}'
                        );
                    END IF;
                END
                $$;
            """)
            logger.info(
                "PostgreSQL: RLS policy '%s' ensured on %s",
                policy_name, table_name,
            )

    async def _set_tenant_context(
        self, conn: Any, tenant_id: UUID
    ) -> None:
        """Set app.tenant_id session variable on a PG connection.

        Uses SET LOCAL so it is transaction-scoped (not session-scoped).
        """
        await conn.execute(
            f"SET LOCAL app.tenant_id = '{tenant_id}'"
        )

    async def _create_tenant_pg_record(
        self,
        tenant_id: UUID,
        tier: CustomerTier,
        encryption_key_hash: str,
    ) -> bool:
        """Insert tenant record into rlm_tenants. Idempotent via ON CONFLICT."""
        pool = await self._ensure_pg()
        async with pool.acquire() as conn:
            result = await conn.execute(f"""
                INSERT INTO {TENANT_TABLE} (tenant_id, tier, encryption_key_hash, created_at)
                VALUES ($1, $2, $3, $4)
                ON CONFLICT (tenant_id) DO UPDATE SET tier = $2
            """, tenant_id, tier.value, encryption_key_hash,
                datetime.now(timezone.utc))
            return True

    async def _delete_tenant_pg_data(
        self, tenant_id: UUID, cryptographic_shred: bool = False
    ) -> int:
        """Delete all PG rows for tenant. Returns count deleted."""
        pool = await self._ensure_pg()
        total_deleted = 0

        async with pool.acquire() as conn:
            if cryptographic_shred:
                # Overwrite content with random bytes before delete
                random_bytes = os.urandom(32).hex()
                result = await conn.execute(f"""
                    UPDATE {MEMORY_TABLE}
                    SET content = $1, metadata = '{{}}'::jsonb
                    WHERE tenant_id = $2
                """, random_bytes, tenant_id)
                logger.info(
                    "PostgreSQL: cryptographic shred applied for tenant %s",
                    tenant_id,
                )

            # Delete memories
            result = await conn.execute(f"""
                DELETE FROM {MEMORY_TABLE} WHERE tenant_id = $1
            """, tenant_id)
            memories_deleted = int(result.split(" ")[-1]) if result else 0
            total_deleted += memories_deleted

            # Delete tenant record
            result = await conn.execute(f"""
                DELETE FROM {TENANT_TABLE} WHERE tenant_id = $1
            """, tenant_id)
            tenant_deleted = int(result.split(" ")[-1]) if result else 0
            total_deleted += tenant_deleted

            logger.info(
                "PostgreSQL: deleted %d memories + %d tenant record for %s",
                memories_deleted, tenant_deleted, tenant_id,
            )

        return total_deleted

    # ------------------------------------------------------------------
    # Story 6.04: Cryptographic Shredding
    # ------------------------------------------------------------------

    @staticmethod
    def _generate_encryption_key() -> bytes:
        """Generate a per-tenant encryption key (256-bit AES key).

        Each tenant's memories are encrypted with this key.
        To erase: delete the key -> all data becomes unreadable.
        """
        return os.urandom(32)

    @staticmethod
    def _hash_key(key: bytes) -> str:
        """SHA-256 hash of encryption key for audit trail storage."""
        return hashlib.sha256(key).hexdigest()

    @staticmethod
    def _encrypt_content(content: str, key: bytes) -> bytes:
        """Encrypt content using Fernet (AES-128-CBC under the hood).

        We use the cryptography library's Fernet for simplicity and safety.
        Key is base64-encoded 32 bytes -> Fernet key.
        """
        import base64
        from cryptography.fernet import Fernet

        # Fernet requires a 32-byte base64-encoded key
        fernet_key = base64.urlsafe_b64encode(key)
        f = Fernet(fernet_key)
        return f.encrypt(content.encode("utf-8"))

    @staticmethod
    def _decrypt_content(encrypted: bytes, key: bytes) -> str:
        """Decrypt content using Fernet."""
        import base64
        from cryptography.fernet import Fernet

        fernet_key = base64.urlsafe_b64encode(key)
        f = Fernet(fernet_key)
        return f.decrypt(encrypted).decode("utf-8")

    # ------------------------------------------------------------------
    # Story 6.02: create_tenant()
    # ------------------------------------------------------------------

    async def create_tenant(
        self, tenant_id: UUID, tier: CustomerTier,
    ) -> bool:
        """Provision storage for a new tenant.

        Actions:
          1. Generate per-tenant encryption key
          2. Create Redis keyspace prefix: rlm:{tenant_id}:*
          3. Verify Qdrant collection exists (create if not)
          4. Insert tenant record in PostgreSQL
          5. Apply RLS policy enabling tenant access to own rows only

        Returns True if created successfully.
        Idempotent: calling twice with same tenant_id returns True.
        """
        logger.info("Creating tenant %s (tier=%s)", tenant_id, tier.value)

        # 1. Generate encryption key and store hash
        enc_key = self._generate_encryption_key()
        key_hash = self._hash_key(enc_key)

        # 2. Store encryption key in Redis (separate from data)
        redis = await self._ensure_redis()
        enc_key_redis_key = self._tenant_key(tenant_id, "encryption_key")
        await redis.set(enc_key_redis_key, enc_key.hex())

        # 3. Mark tenant as initialized in Redis
        init_key = self._tenant_key(tenant_id, "initialized")
        await redis.set(init_key, "1")

        # 4. Ensure Qdrant collection + payload index
        await self._ensure_qdrant_collection()

        # 5. Insert PG tenant record (idempotent via ON CONFLICT)
        await self._create_tenant_pg_record(tenant_id, tier, key_hash)

        # 6. Apply RLS policies to memory tables
        await self._apply_rls_policy(MEMORY_TABLE)
        await self._apply_rls_policy(TENANT_TABLE)

        logger.info("Tenant %s provisioned successfully", tenant_id)
        return True

    # ------------------------------------------------------------------
    # Story 6.03: verify_isolation()
    # ------------------------------------------------------------------

    async def verify_isolation(
        self, tenant_a: UUID, tenant_b: UUID,
    ) -> bool:
        """Security test: verify tenant A cannot access tenant B's data.

        Writes test data as tenant A, attempts read as tenant B.
        Returns True if isolation holds (B sees nothing from A).
        Cleans up test data in finally block.
        """
        test_content = f"isolation_test_{os.urandom(8).hex()}"
        test_hash = hashlib.sha256(test_content.encode()).hexdigest()

        pool = await self._ensure_pg()
        redis = await self._ensure_redis()

        try:
            # --- PostgreSQL isolation test ---
            async with pool.acquire() as conn:
                # Insert as tenant A
                await conn.execute(f"""
                    INSERT INTO {MEMORY_TABLE}
                    (tenant_id, content, content_hash, source, domain, memory_tier)
                    VALUES ($1, $2, $3, 'test', 'isolation', 'working')
                """, tenant_a, test_content, test_hash)

                # Query as tenant B via RLS
                async with conn.transaction():
                    await self._set_tenant_context(conn, tenant_b)
                    rows = await conn.fetch(f"""
                        SELECT * FROM {MEMORY_TABLE}
                        WHERE content_hash = $1
                    """, test_hash)

                    if len(rows) > 0:
                        logger.error(
                            "ISOLATION BREACH: tenant %s saw tenant %s data in PG",
                            tenant_b, tenant_a,
                        )
                        return False

            # --- Qdrant isolation test ---
            from qdrant_client.models import (
                Filter, FieldCondition, MatchValue, PointStruct
            )
            qdrant = await self._ensure_qdrant()
            test_vector = [0.0] * QDRANT_VECTOR_SIZE
            test_point_id = os.urandom(16).hex()

            await qdrant.upsert(
                collection_name=QDRANT_COLLECTION,
                points=[
                    PointStruct(
                        id=test_point_id,
                        vector=test_vector,
                        payload={
                            "tenant_id": str(tenant_a),
                            "content": test_content,
                            "test": True,
                        },
                    )
                ],
            )

            # Search as tenant B
            results = await qdrant.scroll(
                collection_name=QDRANT_COLLECTION,
                scroll_filter=Filter(
                    must=[
                        FieldCondition(
                            key="tenant_id",
                            match=MatchValue(value=str(tenant_b)),
                        ),
                        FieldCondition(
                            key="content",
                            match=MatchValue(value=test_content),
                        ),
                    ]
                ),
                limit=10,
            )

            if results[0]:  # points list
                logger.error(
                    "ISOLATION BREACH: tenant %s saw tenant %s data in Qdrant",
                    tenant_b, tenant_a,
                )
                return False

            # --- Redis isolation test ---
            test_redis_key = self._tenant_key(tenant_a, "isolation_test")
            await redis.set(test_redis_key, test_content)

            # Read as tenant B
            wrong_key = self._tenant_key(tenant_b, "isolation_test")
            value = await redis.get(wrong_key)

            if value is not None:
                logger.error(
                    "ISOLATION BREACH: tenant %s saw tenant %s data in Redis",
                    tenant_b, tenant_a,
                )
                return False

            logger.info(
                "Isolation verified: tenant %s and %s are fully isolated",
                tenant_a, tenant_b,
            )
            return True

        finally:
            # Cleanup test data
            try:
                async with pool.acquire() as conn:
                    await conn.execute(f"""
                        DELETE FROM {MEMORY_TABLE}
                        WHERE content_hash = $1 AND tenant_id = $2
                    """, test_hash, tenant_a)

                # Cleanup Qdrant test point
                try:
                    qdrant_client = await self._ensure_qdrant()
                    await qdrant_client.delete(
                        collection_name=QDRANT_COLLECTION,
                        points_selector=[test_point_id],
                    )
                except Exception:
                    pass

                # Cleanup Redis test keys
                test_redis_key = self._tenant_key(tenant_a, "isolation_test")
                await redis.delete(test_redis_key)
            except Exception as cleanup_err:
                logger.warning("Cleanup error: %s", cleanup_err)

    # ------------------------------------------------------------------
    # Story 6.04: delete_tenant_data()
    # ------------------------------------------------------------------

    async def delete_tenant_data(
        self, tenant_id: UUID, cryptographic_shred: bool = False,
    ) -> Dict[str, int]:
        """Delete all data for a tenant across all backends.

        If cryptographic_shred=True: overwrites content with random bytes
        before delete, then deletes the encryption key making all encrypted
        data permanently unreadable.

        Returns: {"pg_deleted": N, "qdrant_deleted": N, "redis_deleted": N}
        """
        logger.info(
            "Deleting tenant %s data (shred=%s)", tenant_id, cryptographic_shred
        )

        # 1. PostgreSQL: optionally shred then delete
        pg_deleted = await self._delete_tenant_pg_data(
            tenant_id, cryptographic_shred
        )

        # 2. Qdrant: delete all vectors
        qdrant_deleted = await self._delete_tenant_qdrant_vectors(tenant_id)

        # 3. Redis: delete all keys (including encryption key)
        redis_deleted = await self._delete_tenant_redis_keys(tenant_id)

        result = {
            "pg_deleted": pg_deleted,
            "qdrant_deleted": qdrant_deleted,
            "redis_deleted": redis_deleted,
        }

        logger.info("Tenant %s deletion complete: %s", tenant_id, result)
        return result

    # ------------------------------------------------------------------
    # Utility: tenant existence check
    # ------------------------------------------------------------------

    async def tenant_exists(self, tenant_id: UUID) -> bool:
        """Check if a tenant is provisioned."""
        pool = await self._ensure_pg()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(f"""
                SELECT 1 FROM {TENANT_TABLE} WHERE tenant_id = $1
            """, tenant_id)
            return row is not None

    # ------------------------------------------------------------------
    # Utility: get encryption key
    # ------------------------------------------------------------------

    async def get_encryption_key(self, tenant_id: UUID) -> Optional[bytes]:
        """Retrieve per-tenant encryption key from Redis.

        Returns None if tenant doesn't exist or key was shredded.
        Key is stored separately from data for cryptographic shredding.
        """
        redis = await self._ensure_redis()
        key_hex = await redis.get(
            self._tenant_key(tenant_id, "encryption_key")
        )
        if key_hex is None:
            return None
        return bytes.fromhex(key_hex)

    # ------------------------------------------------------------------
    # Close connections
    # ------------------------------------------------------------------

    async def close(self) -> None:
        """Close all backend connections."""
        if self._pg_pool is not None:
            await self._pg_pool.close()
            self._pg_pool = None
        if self._qdrant_client is not None:
            await self._qdrant_client.close()
            self._qdrant_client = None
        if self._redis_client is not None:
            await self._redis_client.aclose()
            self._redis_client = None
        logger.info("TenantPartitioner connections closed")
