import os
import time
import logging
import requests
import psycopg2
import redis
from qdrant_client import QdrantClient
from typing import Tuple

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class SmokeTest:
    """
    A suite of smoke tests to verify the basic health of the Genesis system after deployment.
    """

    def __init__(self):
        """
        Initializes the SmokeTest suite with connection details for various services.
        """
        self.postgres_host = os.getenv("POSTGRES_HOST", "postgresql-genesis-u50607.vm.elestio.app")
        self.postgres_port = int(os.getenv("POSTGRES_PORT", "25432"))
        self.postgres_user = os.getenv("POSTGRES_USER", "genesis")
        self.postgres_password = os.getenv("POSTGRES_PASSWORD", "genesis")
        self.postgres_db = os.getenv("POSTGRES_DB", "genesis")

        self.redis_host = os.getenv("REDIS_HOST", "redis-genesis-u50607.vm.elestio.app")
        self.redis_port = int(os.getenv("REDIS_PORT", "26379"))

        self.qdrant_host = os.getenv("QDRANT_HOST", "qdrant-b3knu-u50607.vm.elestio.app")
        self.qdrant_port = int(os.getenv("QDRANT_PORT", "6333"))

        self.ollama_host = os.getenv("OLLAMA_HOST", "localhost")
        self.ollama_port = int(os.getenv("OLLAMA_PORT", "23405"))


    def test_all_endpoints_respond(self) -> bool:
        """
        Tests if all specified endpoints respond with a 200 OK status code.
        This is a placeholder; replace with actual endpoint checks.

        Returns:
            bool: True if all endpoints respond, False otherwise.
        """
        endpoints = [
            "http://localhost:8000/health",  # Example endpoint
            "http://localhost:8000/api/v1/status"  # Example endpoint
        ]
        try:
            for endpoint in endpoints:
                response = requests.get(endpoint, timeout=5)
                response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
                logging.info(f"Endpoint {endpoint} responded with status code: {response.status_code}")
            return True
        except requests.exceptions.RequestException as e:
            logging.error(f"Endpoint test failed: {e}")
            return False

    def test_database_connection(self) -> bool:
        """
        Tests the connection to the PostgreSQL database.

        Returns:
            bool: True if the connection is successful, False otherwise.
        """
        try:
            conn = psycopg2.connect(
                host=self.postgres_host,
                port=self.postgres_port,
                user=self.postgres_user,
                password=self.postgres_password,
                database=self.postgres_db,
                connect_timeout=5
            )
            conn.close()
            logging.info("Database connection successful.")
            return True
        except psycopg2.Error as e:
            logging.error(f"Database connection failed: {e}")
            return False

    def test_redis_connection(self) -> bool:
        """
        Tests the connection to the Redis server.

        Returns:
            bool: True if the connection is successful, False otherwise.
        """
        try:
            r = redis.Redis(host=self.redis_host, port=self.redis_port, socket_connect_timeout=5)
            r.ping()
            logging.info("Redis connection successful.")
            return True
        except redis.exceptions.ConnectionError as e:
            logging.error(f"Redis connection failed: {e}")
            return False

    def test_qdrant_connection(self) -> bool:
        """
        Tests the connection to the Qdrant vector database.

        Returns:
            bool: True if the connection is successful, False otherwise.
        """
        try:
            client = QdrantClient(host=self.qdrant_host, port=self.qdrant_port, timeout=5)
            client.get_telemetry_data()
            logging.info("Qdrant connection successful.")
            return True
        except Exception as e:
            logging.error(f"Qdrant connection failed: {e}")
            return False

    def test_ollama_connection(self) -> bool:
        """
        Tests the connection to the Ollama server.

        Returns:
            bool: True if the connection is successful, False otherwise.
        """
        try:
            response = requests.get(f"http://{self.ollama_host}:{self.ollama_port}/", timeout=5)
            response.raise_for_status()
            logging.info("Ollama connection successful.")
            return True
        except requests.exceptions.RequestException as e:
            logging.error(f"Ollama connection failed: {e}")
            return False

    def run_all_tests(self) -> Tuple[bool, float]:
        """
        Runs all smoke tests and returns a go/no-go status and the total execution time.

        Returns:
            Tuple[bool, float]: A tuple containing the overall status (True for go, False for no-go)
                                 and the total execution time in seconds.
        """
        start_time = time.time()
        all_tests_passed = True

        all_tests_passed &= self.test_all_endpoints_respond()
        all_tests_passed &= self.test_database_connection()
        all_tests_passed &= self.test_redis_connection()
        all_tests_passed &= self.test_qdrant_connection()
        all_tests_passed &= self.test_ollama_connection()

        end_time = time.time()
        execution_time = end_time - start_time

        if all_tests_passed:
            logging.info("All smoke tests passed.")
        else:
            logging.error("One or more smoke tests failed.")

        return all_tests_passed, execution_time

if __name__ == "__main__":
    smoke_test = SmokeTest()
    status, duration = smoke_test.run_all_tests()

    print(f"Smoke Test Status: {'Go' if status else 'No-Go'}")
    print(f"Execution Time: {duration:.2f} seconds")

    if duration > 60:
        print("WARNING: Smoke tests took longer than 60 seconds.")
