import asyncio
import logging
from typing import Dict, Any, Optional, List

import httpx
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import CollectionStatus

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class QdrantHealthMonitor:
    """
    Monitors the health and performance of a Qdrant vector database.
    """

    def __init__(self, host: str = "localhost", port: int = 6333, https: bool = False, timeout: int = 10):
        """
        Initializes the QdrantHealthMonitor.

        Args:
            host: The hostname of the Qdrant instance.
            port: The port number of the Qdrant instance.
            https: Whether to use HTTPS for the connection.
            timeout: Timeout in seconds for Qdrant client operations.
        """
        self.host = host
        self.port = port
        self.https = https
        self.timeout = timeout
        self.client = QdrantClient(host=self.host, port=self.port, https=self.https, timeout=self.timeout)
        self.http_client = httpx.AsyncClient(base_url=f"http{'s' if https else ''}://{host}:{port}", timeout=timeout)

    async def check_connection(self) -> bool:
        """
        Checks the connection to the Qdrant instance.

        Returns:
            True if the connection is successful, False otherwise.
        """
        try:
            await self.http_client.get("/healthz")
            return True
        except httpx.ConnectError as e:
            logger.error(f"Connection error: {e}")
            return False
        except Exception as e:
            logger.exception(f"Unexpected error during connection check: {e}")
            return False

    async def get_collection_count(self) -> Optional[int]:
        """
        Retrieves the number of collections in the Qdrant instance.

        Returns:
            The number of collections, or None if an error occurs.
        """
        try:
            collections = await self.client.get_collections()
            return len(collections.collections)
        except Exception as e:
            logger.exception(f"Error getting collection count: {e}")
            return None

    async def get_vector_count(self) -> Optional[int]:
        """
        Retrieves the total number of vectors across all collections.

        Returns:
            The total number of vectors, or None if an error occurs.
        """
        try:
            vector_count = 0
            collections = await self.client.get_collections()
            for collection in collections.collections:
                collection_info = await self.client.get_collection(collection_name=collection.name)
                vector_count += collection_info.vectors_count
            return vector_count
        except Exception as e:
            logger.exception(f"Error getting vector count: {e}")
            return None

    async def get_search_latency(self, collection_name: str = "test_collection", vector: List[float] = None, limit: int = 1) -> Optional[float]:
        """
        Measures the search latency for a given collection.

        Args:
            collection_name: The name of the collection to test.
            vector: The search vector.  If None, a dummy vector is created.
            limit: The number of results to return.

        Returns:
            The search latency in seconds, or None if an error occurs.
        """
        try:
            if vector is None:
                # Use a dummy vector if none is provided
                collection_info = await self.client.get_collection(collection_name=collection_name)
                vector_size = collection_info.config.params.vectors.size
                vector = [0.0] * vector_size  # Create a dummy vector
            import time
            start_time = time.time()
            await self.client.search(collection_name=collection_name, query_vector=vector, limit=limit)
            end_time = time.time()
            return end_time - start_time
        except Exception as e:
            logger.exception(f"Error getting search latency: {e}")
            return None

    async def get_memory_usage(self) -> Optional[Dict[str, Any]]:
        """
        Retrieves memory usage statistics from the Qdrant instance.

        Returns:
            A dictionary containing memory usage statistics, or None if an error occurs.
        """
        try:
            response = await self.http_client.get("/cluster/stats")
            response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
            return response.json()
        except httpx.HTTPStatusError as e:
            logger.error(f"HTTP error getting memory usage: {e}")
            return None
        except Exception as e:
            logger.exception(f"Error getting memory usage: {e}")
            return None

    async def validate_collection_schemas(self) -> Dict[str, bool]:
        """
        Validates the schemas of all collections.

        Returns:
            A dictionary where keys are collection names and values are booleans
            indicating whether the schema is valid.
        """
        validation_results: Dict[str, bool] = {}
        try:
            collections = await self.client.get_collections()
            for collection in collections.collections:
                try:
                    # Attempt to get collection info.  If the schema is invalid, this will raise an exception.
                    await self.client.get_collection(collection_name=collection.name)
                    validation_results[collection.name] = True  # Schema is valid
                except Exception as e:
                    logger.error(f"Schema validation failed for collection {collection.name}: {e}")
                    validation_results[collection.name] = False  # Schema is invalid
        except Exception as e:
            logger.exception(f"Error validating collection schemas: {e}")
            # If we can't get the collections, then all validations are considered failed.
            return {}
        return validation_results

    async def get_health_status(self) -> Dict[str, Any]:
        """
        Gathers all health metrics and returns a comprehensive health status report.

        Returns:
            A dictionary containing the health status report.
        """
        connection_status = await self.check_connection()
        collection_count = await self.get_collection_count()
        vector_count = await self.get_vector_count()
        search_latency = await self.get_search_latency()  # Uses default test collection
        memory_usage = await self.get_memory_usage()
        schema_validations = await self.validate_collection_schemas()

        health_status: Dict[str, Any] = {
            "connection_status": connection_status,
            "collection_count": collection_count,
            "vector_count": vector_count,
            "search_latency": search_latency,
            "memory_usage": memory_usage,
            "schema_validations": schema_validations,
        }

        return health_status

    async def close(self):
        """
        Closes the Qdrant client.
        """
        await self.http_client.aclose()


async def main():
    """
    Example usage of the QdrantHealthMonitor.
    """
    monitor = QdrantHealthMonitor(host="qdrant-b3knu-u50607.vm.elestio.app", port=6333)  # Replace with your Qdrant instance details

    try:
        health_status = await monitor.get_health_status()
        print("Qdrant Health Status:")
        print(health_status)
    except Exception as e:
        logger.error(f"Error getting health status: {e}")
    finally:
        await monitor.close()


if __name__ == "__main__":
    asyncio.run(main())
