"""
GHL OAuth Flow Handler

Implements the Authorization Code flow:
1. Generate auth URL → redirect user to GHL
2. Receive callback with auth code
3. Exchange auth code for access_token + refresh_token
4. Store in TokenVault
5. Exchange for per-location tokens

This can run as a standalone FastAPI server or be mounted
into the existing Sunaiva API.
"""

import os
import sys
import json
import secrets
import logging
from datetime import datetime, timezone

import requests

sys.path.append("/mnt/e/genesis-system/data/genesis-memory")
sys.path.append("/mnt/e/genesis-system")

from GHL.oauth.config import GHLOAuthConfig
from GHL.oauth.token_vault import TokenVault

logger = logging.getLogger(__name__)


class OAuthFlowManager:
    """Manages the GHL OAuth authorization flow."""

    def __init__(self, config: GHLOAuthConfig = None, vault: TokenVault = None):
        self.config = config or GHLOAuthConfig()
        self.vault = vault or TokenVault(self.config)

    def initiate(self) -> dict:
        """Start the OAuth flow. Returns the auth URL and state token."""
        if not self.config.is_configured:
            return {
                "error": "OAuth not configured. Set GHL_OAUTH_CLIENT_ID and GHL_OAUTH_CLIENT_SECRET.",
                "status": "not_configured",
            }

        state = secrets.token_urlsafe(32)
        auth_url = self.config.get_auth_url(state)

        # Store state for CSRF validation (in Redis if available, file fallback)
        self._store_state(state)

        return {
            "auth_url": auth_url,
            "state": state,
            "instructions": "Open this URL in a browser. Log in to GHL and approve the app.",
        }

    def handle_callback(self, code: str, state: str) -> dict:
        """Handle the OAuth callback. Exchange code for tokens."""
        # Validate state
        if not self._validate_state(state):
            return {"error": "Invalid state parameter — possible CSRF attack", "status": "error"}

        # Exchange authorization code for tokens
        resp = requests.post(
            self.config.token_url,
            data={
                "client_id": self.config.client_id,
                "client_secret": self.config.client_secret,
                "grant_type": "authorization_code",
                "code": code,
                "redirect_uri": self.config.redirect_uri,
            },
            headers={"Content-Type": "application/x-www-form-urlencoded"},
            timeout=30,
        )

        if resp.status_code != 200:
            logger.error(f"Token exchange failed: {resp.status_code} {resp.text}")
            return {
                "error": f"Token exchange failed: {resp.status_code}",
                "details": resp.text,
                "status": "error",
            }

        data = resp.json()
        access_token = data.get("access_token")
        refresh_token = data.get("refresh_token")
        expires_in = data.get("expires_in", self.config.access_token_ttl_seconds)
        scopes = data.get("scope", "").split(" ") if data.get("scope") else []
        user_type = data.get("userType", "unknown")  # "Company" or "Location"
        location_id = data.get("locationId")
        company_id = data.get("companyId")

        if not access_token:
            return {"error": "No access_token in response", "status": "error"}

        # Store agency token
        token_id = self.vault.store_agency_token(
            access_token=access_token,
            refresh_token=refresh_token,
            expires_in=expires_in,
            scopes=scopes,
        )

        result = {
            "status": "success",
            "token_id": token_id,
            "user_type": user_type,
            "company_id": company_id,
            "location_id": location_id,
            "expires_in": expires_in,
            "scope_count": len(scopes),
        }

        logger.info(f"OAuth flow complete: {user_type} token stored (id={token_id})")
        return result

    def exchange_all_locations(self) -> dict:
        """After initial auth, exchange for all location tokens."""
        # First, get the list of locations using the agency token
        agency = self.vault.get_agency_token()
        if not agency:
            return {"error": "No agency token. Run OAuth flow first.", "status": "error"}

        # Fetch all locations
        resp = requests.get(
            f"{self.config.api_base_url}/locations/search",
            headers={
                "Authorization": f"Bearer {agency.access_token}",
                "Version": self.config.api_version,
            },
            params={"companyId": self.config.company_id},
            timeout=30,
        )

        if resp.status_code != 200:
            return {"error": f"Failed to list locations: {resp.status_code}", "status": "error"}

        locations = resp.json().get("locations", [])
        results = {}

        for loc in locations:
            loc_id = loc.get("id")
            loc_name = loc.get("name", "")
            token = self.vault.exchange_location_token(loc_id, loc_name)
            results[loc_id] = {
                "name": loc_name,
                "success": token is not None,
                "expires_at": token.expires_at.isoformat() if token else None,
            }

        return {
            "status": "success",
            "locations_total": len(locations),
            "locations_exchanged": sum(1 for r in results.values() if r["success"]),
            "results": results,
        }

    def _store_state(self, state: str):
        """Store OAuth state for CSRF validation."""
        state_file = "/mnt/e/genesis-system/GHL/oauth/.oauth_states.json"
        try:
            with open(state_file) as f:
                states = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            states = {}

        states[state] = {
            "created_at": datetime.now(timezone.utc).isoformat(),
            "used": False,
        }

        # Clean old states (>10 min)
        cutoff = datetime.now(timezone.utc).timestamp() - 600
        states = {
            k: v for k, v in states.items()
            if datetime.fromisoformat(v["created_at"]).timestamp() > cutoff
        }

        with open(state_file, "w") as f:
            json.dump(states, f)

    def _validate_state(self, state: str) -> bool:
        """Validate and consume an OAuth state token."""
        state_file = "/mnt/e/genesis-system/GHL/oauth/.oauth_states.json"
        try:
            with open(state_file) as f:
                states = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            return False

        if state not in states or states[state].get("used"):
            return False

        states[state]["used"] = True
        with open(state_file, "w") as f:
            json.dump(states, f)
        return True
