"""
GHL OAuth Token Vault

Handles encrypted storage, retrieval, and lifecycle management of OAuth tokens.
Uses PostgreSQL for persistence and Redis for distributed locking during refresh.

Token types:
- 'agency': The main OAuth token from app installation. Has refresh_token.
- 'location': Per-sub-account tokens exchanged via /oauth/locationToken. No refresh_token.
"""

import sys
import logging
from datetime import datetime, timezone, timedelta
from typing import Optional, Dict, Any
from dataclasses import dataclass

import psycopg2
import psycopg2.extras
import requests

sys.path.append("/mnt/e/genesis-system/data/genesis-memory")
from elestio_config import PostgresConfig, RedisConfig
from GHL.oauth.config import GHLOAuthConfig

logger = logging.getLogger(__name__)

# Keepalive params for Elestio PG (per failure_010 — prevents 30-min flapping)
PG_KEEPALIVE_PARAMS = {
    "keepalives": 1,
    "keepalives_idle": 30,
    "keepalives_interval": 10,
    "keepalives_count": 5,
}


@dataclass
class TokenRecord:
    """A single OAuth token record."""
    id: str
    token_type: str
    location_id: Optional[str]
    location_name: Optional[str]
    access_token: str
    refresh_token: Optional[str]
    scopes: list
    expires_at: datetime
    last_used_at: Optional[datetime]
    revoked_at: Optional[datetime]

    @property
    def is_expired(self) -> bool:
        return datetime.now(timezone.utc) >= self.expires_at

    @property
    def needs_refresh(self) -> bool:
        buffer = timedelta(seconds=GHLOAuthConfig.refresh_buffer_seconds)
        return datetime.now(timezone.utc) >= (self.expires_at - buffer)

    @property
    def is_valid(self) -> bool:
        return not self.is_expired and self.revoked_at is None


class TokenVault:
    """Manages GHL OAuth tokens in PostgreSQL with Redis-based distributed locking."""

    def __init__(self, config: Optional[GHLOAuthConfig] = None):
        self.config = config or GHLOAuthConfig()
        self._conn = None

    def _get_conn(self):
        """Get a PostgreSQL connection with keepalives."""
        if self._conn is None or self._conn.closed:
            params = PostgresConfig.get_connection_params()
            params.update(PG_KEEPALIVE_PARAMS)
            self._conn = psycopg2.connect(**params)
            self._conn.autocommit = True
        # Liveness check
        try:
            with self._conn.cursor() as cur:
                cur.execute("SELECT 1")
        except Exception:
            params = PostgresConfig.get_connection_params()
            params.update(PG_KEEPALIVE_PARAMS)
            self._conn = psycopg2.connect(**params)
            self._conn.autocommit = True
        return self._conn

    def _get_redis(self):
        """Get a Redis connection for distributed locking."""
        import redis
        return redis.Redis(**RedisConfig.get_connection_params())

    def init_schema(self):
        """Create tables if they don't exist."""
        schema_path = "/mnt/e/genesis-system/GHL/oauth/db_schema.sql"
        with open(schema_path) as f:
            sql = f.read()
        conn = self._get_conn()
        with conn.cursor() as cur:
            cur.execute(sql)
        logger.info("GHL OAuth schema initialized")

    def store_agency_token(
        self,
        access_token: str,
        refresh_token: str,
        expires_in: int,
        scopes: list,
    ) -> str:
        """Store the agency-level OAuth token from initial app installation."""
        expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
        conn = self._get_conn()
        with conn.cursor() as cur:
            # Upsert — only one agency token
            cur.execute("""
                INSERT INTO ghl_oauth_tokens
                    (token_type, location_id, access_token, refresh_token, scopes, expires_at)
                VALUES ('agency', NULL, %s, %s, %s, %s)
                ON CONFLICT (token_type, location_id)
                DO UPDATE SET
                    access_token = EXCLUDED.access_token,
                    refresh_token = EXCLUDED.refresh_token,
                    scopes = EXCLUDED.scopes,
                    expires_at = EXCLUDED.expires_at,
                    last_refreshed_at = NOW(),
                    revoked_at = NULL
                RETURNING id
            """, (access_token, refresh_token, scopes, expires_at))
            token_id = str(cur.fetchone()[0])

            # Audit
            cur.execute("""
                INSERT INTO ghl_oauth_audit (token_id, operation, details)
                VALUES (%s, 'created', %s)
            """, (token_id, psycopg2.extras.Json({
                "type": "agency",
                "expires_at": expires_at.isoformat(),
                "scope_count": len(scopes),
            })))

        logger.info(f"Agency token stored, expires {expires_at.isoformat()}")
        return token_id

    def store_location_token(
        self,
        location_id: str,
        location_name: str,
        access_token: str,
        expires_in: int,
        scopes: list,
    ) -> str:
        """Store a per-location token from /oauth/locationToken exchange."""
        expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
        conn = self._get_conn()
        with conn.cursor() as cur:
            cur.execute("""
                INSERT INTO ghl_oauth_tokens
                    (token_type, location_id, location_name, access_token, scopes, expires_at)
                VALUES ('location', %s, %s, %s, %s, %s)
                ON CONFLICT (token_type, location_id)
                DO UPDATE SET
                    access_token = EXCLUDED.access_token,
                    location_name = EXCLUDED.location_name,
                    scopes = EXCLUDED.scopes,
                    expires_at = EXCLUDED.expires_at,
                    last_refreshed_at = NOW(),
                    revoked_at = NULL
                RETURNING id
            """, (location_id, location_name, access_token, scopes, expires_at))
            token_id = str(cur.fetchone()[0])

            cur.execute("""
                INSERT INTO ghl_oauth_audit (token_id, operation, location_id, details)
                VALUES (%s, 'exchanged', %s, %s)
            """, (token_id, location_id, psycopg2.extras.Json({
                "location_name": location_name,
                "expires_at": expires_at.isoformat(),
            })))

        logger.info(f"Location token stored for {location_id} ({location_name})")
        return token_id

    def get_agency_token(self) -> Optional[TokenRecord]:
        """Get the current agency token."""
        conn = self._get_conn()
        with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
            cur.execute("""
                SELECT * FROM ghl_oauth_tokens
                WHERE token_type = 'agency' AND revoked_at IS NULL
                ORDER BY created_at DESC LIMIT 1
            """)
            row = cur.fetchone()
            if not row:
                return None
            return self._row_to_record(row)

    def get_location_token(self, location_id: str) -> Optional[TokenRecord]:
        """Get the token for a specific location (sub-account)."""
        conn = self._get_conn()
        with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
            cur.execute("""
                SELECT * FROM ghl_oauth_tokens
                WHERE token_type = 'location'
                  AND location_id = %s
                  AND revoked_at IS NULL
                ORDER BY created_at DESC LIMIT 1
            """, (location_id,))
            row = cur.fetchone()
            if not row:
                return None
            return self._row_to_record(row)

    def get_all_location_tokens(self) -> list:
        """Get all active location tokens."""
        conn = self._get_conn()
        with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
            cur.execute("""
                SELECT * FROM ghl_oauth_tokens
                WHERE token_type = 'location' AND revoked_at IS NULL
                ORDER BY location_name
            """)
            return [self._row_to_record(row) for row in cur.fetchall()]

    def refresh_agency_token(self) -> Optional[TokenRecord]:
        """Refresh the agency token using the refresh_token. Uses Redis lock to prevent races."""
        r = self._get_redis()
        lock_key = "ghl_oauth:refresh_lock:agency"

        # Try to acquire lock (30s TTL)
        if not r.set(lock_key, "1", nx=True, ex=30):
            logger.info("Another process is refreshing the agency token, skipping")
            return self.get_agency_token()

        try:
            agency = self.get_agency_token()
            if not agency or not agency.refresh_token:
                logger.error("No agency token or refresh_token available")
                return None

            # Re-check after acquiring lock (another process may have refreshed)
            if not agency.needs_refresh:
                return agency

            resp = requests.post(
                self.config.token_url,
                data={
                    "client_id": self.config.client_id,
                    "client_secret": self.config.client_secret,
                    "grant_type": "refresh_token",
                    "refresh_token": agency.refresh_token,
                    "user_type": "Company",
                },
                headers={"Content-Type": "application/x-www-form-urlencoded"},
                timeout=30,
            )

            if resp.status_code != 200:
                logger.error(f"Agency token refresh failed: {resp.status_code} {resp.text}")
                self._audit("failed", None, {"error": resp.text, "operation": "refresh_agency"})
                return None

            data = resp.json()
            # CRITICAL: GHL uses token rotation — must update refresh_token
            new_refresh = data.get("refresh_token", agency.refresh_token)
            self.store_agency_token(
                access_token=data["access_token"],
                refresh_token=new_refresh,
                expires_in=data.get("expires_in", self.config.access_token_ttl_seconds),
                scopes=data.get("scope", "").split(" ") if data.get("scope") else agency.scopes,
            )
            logger.info("Agency token refreshed successfully")
            return self.get_agency_token()

        finally:
            r.delete(lock_key)

    def exchange_location_token(self, location_id: str, location_name: str = "") -> Optional[TokenRecord]:
        """Exchange agency token for a location-specific access token."""
        agency = self.get_agency_token()
        if not agency:
            logger.error("No agency token available for location exchange")
            return None

        # Refresh agency token if needed
        if agency.needs_refresh:
            agency = self.refresh_agency_token()
            if not agency:
                return None

        resp = requests.post(
            self.config.location_token_url,
            json={
                "companyId": self.config.company_id,
                "locationId": location_id,
            },
            headers={
                "Authorization": f"Bearer {agency.access_token}",
                "Version": self.config.api_version,
                "Content-Type": "application/json",
            },
            timeout=30,
        )

        if resp.status_code != 200:
            logger.error(f"Location token exchange failed for {location_id}: {resp.status_code} {resp.text}")
            self._audit("failed", location_id, {"error": resp.text, "operation": "exchange_location"})
            return None

        data = resp.json()
        self.store_location_token(
            location_id=location_id,
            location_name=location_name,
            access_token=data["access_token"],
            expires_in=data.get("expires_in", self.config.access_token_ttl_seconds),
            scopes=data.get("scope", "").split(" ") if data.get("scope") else [],
        )

        logger.info(f"Location token exchanged for {location_id} ({location_name})")
        return self.get_location_token(location_id)

    def get_valid_token(self, location_id: str) -> Optional[str]:
        """Get a valid access token for a location. Auto-refreshes if needed."""
        token = self.get_location_token(location_id)

        if token and token.is_valid and not token.needs_refresh:
            self._mark_used(token.id)
            return token.access_token

        # Token expired or needs refresh — re-exchange from agency
        # Load location name from existing record or registry
        name = token.location_name if token else ""
        new_token = self.exchange_location_token(location_id, name)
        if new_token and new_token.is_valid:
            self._mark_used(new_token.id)
            return new_token.access_token

        return None

    def exchange_all_locations(self, locations: list) -> Dict[str, bool]:
        """Exchange tokens for all locations. Returns {location_id: success}."""
        results = {}
        for loc in locations:
            loc_id = loc.get("id") or loc.get("location_id")
            loc_name = loc.get("name") or loc.get("location_name", "")
            token = self.exchange_location_token(loc_id, loc_name)
            results[loc_id] = token is not None
        return results

    def status(self) -> Dict[str, Any]:
        """Get the current status of all tokens."""
        conn = self._get_conn()
        with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
            cur.execute("""
                SELECT token_type, location_id, location_name,
                       expires_at, last_used_at, revoked_at,
                       CASE WHEN expires_at > NOW() THEN 'valid'
                            ELSE 'expired' END as status
                FROM ghl_oauth_tokens
                WHERE revoked_at IS NULL
                ORDER BY token_type, location_name
            """)
            tokens = []
            for row in cur.fetchall():
                tokens.append({
                    "type": row["token_type"],
                    "location_id": row["location_id"],
                    "location_name": row["location_name"],
                    "expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
                    "last_used": row["last_used_at"].isoformat() if row["last_used_at"] else None,
                    "status": row["status"],
                })
            return {"tokens": tokens, "total": len(tokens)}

    def _mark_used(self, token_id: str):
        """Update last_used_at timestamp."""
        conn = self._get_conn()
        with conn.cursor() as cur:
            cur.execute(
                "UPDATE ghl_oauth_tokens SET last_used_at = NOW() WHERE id = %s",
                (token_id,)
            )

    def _audit(self, operation: str, location_id: Optional[str], details: dict):
        """Write an audit log entry."""
        conn = self._get_conn()
        with conn.cursor() as cur:
            cur.execute("""
                INSERT INTO ghl_oauth_audit (operation, location_id, details)
                VALUES (%s, %s, %s)
            """, (operation, location_id, psycopg2.extras.Json(details)))

    @staticmethod
    def _row_to_record(row) -> TokenRecord:
        return TokenRecord(
            id=str(row["id"]),
            token_type=row["token_type"],
            location_id=row["location_id"],
            location_name=row["location_name"],
            access_token=row["access_token"],
            refresh_token=row["refresh_token"],
            scopes=row["scopes"] or [],
            expires_at=row["expires_at"],
            last_used_at=row["last_used_at"],
            revoked_at=row["revoked_at"],
        )
