import asyncio
import logging
import httpx
from typing import Optional, Dict, Any

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class AIVAConnectionPool:
    """
    Manages a connection pool to the AIVA Ollama endpoint.

    Handles connection creation, maintenance (health checks, retries), and
    limits the number of concurrent connections.  Exposes pool status.
    """

    def __init__(self, endpoint_url: str, max_connections: int = 3, connection_timeout: float = 10.0, health_ping_interval: int = 60):
        """
        Initializes the AIVAConnectionPool.

        Args:
            endpoint_url: The URL of the AIVA Ollama endpoint.
            max_connections: The maximum number of concurrent connections allowed.
            connection_timeout: The timeout (in seconds) for establishing a connection.
            health_ping_interval: The interval (in seconds) for health ping checks.
        """
        self.endpoint_url = endpoint_url
        self.max_connections = max_connections
        self.connection_timeout = connection_timeout
        self.health_ping_interval = health_ping_interval
        self.semaphore = asyncio.Semaphore(max_connections)
        self.client = httpx.AsyncClient(timeout=self.connection_timeout)
        self.healthy = True  # Assume healthy at startup
        self.health_check_task = None  # Initialize health check task
        self.startup_complete = False # Flag to indicate startup is complete
        self.num_requests = 0
        self.num_errors = 0
        self.last_error = None


    async def start_health_checks(self):
         """
         Starts the background health check task.
         """
         self.health_check_task = asyncio.create_task(self._health_check_loop())
         self.startup_complete = True

    async def stop_health_checks(self):
        """
        Stops the background health check task.
        """
        if self.health_check_task:
            self.health_check_task.cancel()
            try:
                await self.health_check_task
            except asyncio.CancelledError:
                pass

    async def _health_check_loop(self):
        """
        Periodically checks the health of the AIVA endpoint.
        """
        while True:
            await asyncio.sleep(self.health_ping_interval)
            await self.health_check()

    async def health_check(self) -> bool:
        """
        Performs a health check on the AIVA endpoint.

        Returns:
            True if the endpoint is healthy, False otherwise.
        """
        try:
            response = await self.client.get(self.endpoint_url)  # Basic GET request
            response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
            self.healthy = True
            logging.info("AIVA endpoint is healthy.")
            return True
        except httpx.HTTPError as e:
            self.healthy = False
            logging.error(f"AIVA endpoint health check failed: {e}")
            self.last_error = str(e)
            return False
        except Exception as e:
            self.healthy = False
            logging.exception(f"Unexpected error during AIVA health check: {e}")
            self.last_error = str(e)
            return False


    async def execute_request(self, method: str, url: str, data: Optional[Dict[str, Any]] = None, max_retries: int = 3) -> Optional[httpx.Response]:
        """
        Executes a request to the AIVA endpoint with retry logic.

        Args:
            method: The HTTP method (e.g., "GET", "POST").
            url: The URL to request.
            data: The request body (optional).
            max_retries: The maximum number of retries.

        Returns:
            The httpx.Response object if the request was successful, otherwise None.
        """
        self.num_requests += 1
        for attempt in range(max_retries):
            async with self.semaphore:
                try:
                    if method.upper() == "GET":
                        response = await self.client.get(url)
                    elif method.upper() == "POST":
                        response = await self.client.post(url, json=data)
                    else:
                        logging.error(f"Unsupported HTTP method: {method}")
                        self.num_errors += 1
                        self.last_error = f"Unsupported HTTP method: {method}"
                        return None

                    response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
                    return response
                except httpx.HTTPStatusError as e:
                    if 400 <= e.response.status_code < 500:
                        logging.error(f"Client error: {e}")
                        self.num_errors += 1
                        self.last_error = str(e)
                        return None  # Don't retry client errors
                    logging.warning(f"Attempt {attempt + 1} failed with status code {e.response.status_code}: {e}")
                    await asyncio.sleep(2 ** attempt)  # Exponential backoff
                except httpx.HTTPError as e:
                    logging.error(f"HTTP error during request: {e}")
                    self.num_errors += 1
                    self.last_error = str(e)
                    await asyncio.sleep(2 ** attempt) # Exponential backoff
                except Exception as e:
                    logging.exception(f"Unexpected error during request: {e}")
                    self.num_errors += 1
                    self.last_error = str(e)
                    await asyncio.sleep(2 ** attempt) # Exponential backoff

        logging.error(f"Request failed after {max_retries} retries.")
        return None

    def pool_status(self) -> Dict[str, Any]:
        """
        Returns the current status of the connection pool.

        Returns:
            A dictionary containing the pool status information.
        """
        return {
            "endpoint_url": self.endpoint_url,
            "max_connections": self.max_connections,
            "available_connections": self.semaphore._value,  # Number of available slots in the semaphore
            "healthy": self.healthy,
            "num_requests": self.num_requests,
            "num_errors": self.num_errors,
            "last_error": self.last_error,
            "startup_complete": self.startup_complete
        }

    async def close(self):
        """
        Closes all connections in the pool.
        """
        await self.stop_health_checks()
        await self.client.aclose()


# Example usage (for testing purposes):
async def main():
    """
    Example usage of the AIVAConnectionPool.
    """
    pool = AIVAConnectionPool(endpoint_url="http://localhost:23405", max_connections=3)
    await pool.start_health_checks()

    # Simulate some requests
    async def make_request(i: int):
        response = await pool.execute_request("GET", "http://localhost:23405/api/generate")  # Replace with a valid AIVA endpoint
        if response:
            logging.info(f"Request {i} successful: {response.status_code}")
        else:
            logging.error(f"Request {i} failed.")
        logging.info(f"Pool Status: {pool.pool_status()}")

    tasks = [make_request(i) for i in range(5)]
    await asyncio.gather(*tasks)

    print("Pool Status:", pool.pool_status())

    await asyncio.sleep(5)
    await pool.close()


if __name__ == "__main__":
    asyncio.run(main())
