from fastapi import FastAPI, Depends
from pydantic import BaseModel
import redis
import psycopg2
import qdrant_client
import google.generativeai as genai
import time
from typing import Dict

app = FastAPI()

class HealthStatus(BaseModel):
    status: str
    components: Dict[str, str]
    latencies: Dict[str, float]

def check_redis_connection(redis_host: str, redis_port: int) -> str:
    """Checks Redis connectivity."""
    try:
        start_time = time.time()
        r = redis.Redis(host=redis_host, port=redis_port, socket_connect_timeout=5)
        r.ping()
        latency = time.time() - start_time
        return "healthy", latency
    except redis.exceptions.ConnectionError:
        return "unhealthy", None
    except Exception as e:
        print(f"Redis check failed: {e}")
        return "degraded", None

def check_postgres_connection(postgres_host: str, postgres_port: int, postgres_db: str, postgres_user: str, postgres_password: str) -> str:
    """Checks PostgreSQL connectivity."""
    try:
        start_time = time.time()
        conn = psycopg2.connect(
            host=postgres_host,
            port=postgres_port,
            database=postgres_db,
            user=postgres_user,
            password=postgres_password,
            connect_timeout=5  # Connection timeout in seconds
        )
        cur = conn.cursor()
        cur.execute("SELECT 1")
        result = cur.fetchone()
        latency = time.time() - start_time
        conn.close()
        if result == (1,):
            return "healthy", latency
        else:
            return "degraded", None
    except psycopg2.Error as e:
        print(f"PostgreSQL check failed: {e}")
        return "unhealthy", None
    except Exception as e:
        print(f"PostgreSQL check failed: {e}")
        return "degraded", None

def check_qdrant_connection(qdrant_host: str, qdrant_port: int) -> str:
    """Checks Qdrant connectivity."""
    try:
        start_time = time.time()
        client = qdrant_client.QdrantClient(host=qdrant_host, port=qdrant_port, timeout=5)
        client.get_telemetry_data()  # An operation to test the connection.  Doesn't require auth.
        latency = time.time() - start_time
        return "healthy", latency
    except Exception as e:
        print(f"Qdrant check failed: {e}")
        return "unhealthy", None

def check_gemini_api_availability(gemini_api_key: str) -> str:
    """Checks Gemini API availability."""
    try:
        start_time = time.time()
        genai.configure(api_key=gemini_api_key)
        model = genai.GenerativeModel('gemini-1.5-pro-latest')
        response = model.generate_content("What is the capital of France?", stream=False)
        latency = time.time() - start_time

        if response and "Paris" in response.text:  # Simple check on the response
            return "healthy", latency
        else:
            return "degraded", None
    except Exception as e:
        print(f"Gemini API check failed: {e}")
        return "unhealthy", None

@app.get("/health", response_model=HealthStatus)
async def health_check(
    redis_host: str = "localhost",
    redis_port: int = 6379,
    postgres_host: str = "localhost",
    postgres_port: int = 5432,
    postgres_db: str = "mydatabase",
    postgres_user: str = "myuser",
    postgres_password: str = "mypassword",
    qdrant_host: str = "localhost",
    qdrant_port: int = 6333,
    gemini_api_key: str = "YOUR_GEMINI_API_KEY" #Replace with an env variable or config file
):
    """Performs a health check on all Genesis components."""

    redis_status, redis_latency = check_redis_connection(redis_host, redis_port)
    postgres_status, postgres_latency = check_postgres_connection(postgres_host, postgres_port, postgres_db, postgres_user, postgres_password)
    qdrant_status, qdrant_latency = check_qdrant_connection(qdrant_host, qdrant_port)
    gemini_status, gemini_latency = check_gemini_api_availability(gemini_api_key)

    components = {
        "redis": redis_status,
        "postgres": postgres_status,
        "qdrant": qdrant_status,
        "gemini_api": gemini_status,
    }

    latencies = {
        "redis": redis_latency,
        "postgres": postgres_latency,
        "qdrant": qdrant_latency,
        "gemini_api": gemini_latency,
    }

    overall_status = "healthy"
    if any(status == "unhealthy" for status in components.values()):
        overall_status = "unhealthy"
    elif any(status == "degraded" for status in components.values()):
        overall_status = "degraded"

    return HealthStatus(status=overall_status, components=components, latencies=latencies)