import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple

import networkx as nx
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Distance, VectorParams

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class KnowledgeGraphBuilder:
    """
    Builds and manages a knowledge graph, supporting node and edge manipulation,
    persistence to JSON and Qdrant, and graph traversal queries.
    """

    def __init__(self, graph_name: str, qdrant_host: str = "localhost", qdrant_port: int = 6333):
        """
        Initializes the KnowledgeGraphBuilder.

        Args:
            graph_name: The name of the knowledge graph. Used for Qdrant collection name.
            qdrant_host: The hostname of the Qdrant instance.
            qdrant_port: The port of the Qdrant instance.
        """
        self.graph_name = graph_name
        self.graph = nx.MultiDiGraph()  # Use MultiDiGraph to allow multiple edges between nodes
        self.qdrant_host = qdrant_host
        self.qdrant_port = qdrant_port
        self.qdrant_client = self._connect_to_qdrant()
        self._create_qdrant_collection()
        logging.info(f"KnowledgeGraphBuilder initialized for graph: {self.graph_name}")

    def _connect_to_qdrant(self) -> QdrantClient:
        """
        Connects to the Qdrant client.

        Returns:
            A QdrantClient instance.
        """
        try:
            client = QdrantClient(host=self.qdrant_host, port=self.qdrant_port)
            logging.info(f"Connected to Qdrant at {self.qdrant_host}:{self.qdrant_port}")
            return client
        except Exception as e:
            logging.error(f"Failed to connect to Qdrant: {e}")
            raise

    def _create_qdrant_collection(self) -> None:
        """
        Creates a Qdrant collection for storing graph nodes if it doesn't exist.
        """
        try:
            self.qdrant_client.recreate_collection(
                collection_name=self.graph_name,
                vectors_config=VectorParams(size=1536, distance=Distance.COSINE), # Adjust vector size as needed
            )
            logging.info(f"Qdrant collection '{self.graph_name}' created or already exists.")
        except Exception as e:
            logging.error(f"Failed to create Qdrant collection: {e}")
            raise

    def add_node(self, node_id: str, attributes: Optional[Dict[str, Any]] = None, vector: Optional[List[float]] = None) -> None:
        """
        Adds a node to the knowledge graph.

        Args:
            node_id: The unique identifier for the node.
            attributes: A dictionary of attributes associated with the node.
            vector: The vector representation of the node for Qdrant.
        """
        try:
            if self.graph.has_node(node_id):
                logging.warning(f"Node with id '{node_id}' already exists.  Overwriting attributes.")

            self.graph.add_node(node_id, **(attributes or {}))

            if vector:
                self._upsert_node_vector(node_id, vector, attributes)

            logging.info(f"Node '{node_id}' added to the graph.")

        except Exception as e:
            logging.error(f"Failed to add node '{node_id}': {e}")
            raise

    def _upsert_node_vector(self, node_id: str, vector: List[float], attributes: Optional[Dict[str, Any]] = None) -> None:
        """
        Upserts the vector representation of a node to Qdrant.

        Args:
            node_id: The unique identifier for the node.
            vector: The vector representation of the node.
            attributes: Additional payload to store along with the vector.
        """
        try:
            payload = attributes or {}
            payload["node_id"] = node_id # Add node_id to payload for easy retrieval

            self.qdrant_client.upsert(
                collection_name=self.graph_name,
                points=[
                    models.PointStruct(
                        id=node_id,
                        vector=vector,
                        payload=payload,
                    )
                ],
            )
            logging.info(f"Node vector for '{node_id}' upserted to Qdrant.")
        except Exception as e:
            logging.error(f"Failed to upsert node vector for '{node_id}' to Qdrant: {e}")
            raise

    def add_edge(self, source_node_id: str, target_node_id: str, relation: str, attributes: Optional[Dict[str, Any]] = None) -> None:
        """
        Adds an edge between two nodes in the knowledge graph.

        Args:
            source_node_id: The identifier of the source node.
            target_node_id: The identifier of the target node.
            relation: The type of relationship between the nodes.
            attributes: A dictionary of attributes associated with the edge.
        """
        try:
            if not self.graph.has_node(source_node_id):
                raise ValueError(f"Source node with id '{source_node_id}' does not exist.")
            if not self.graph.has_node(target_node_id):
                raise ValueError(f"Target node with id '{target_node_id}' does not exist.")

            self.graph.add_edge(source_node_id, target_node_id, relation=relation, **(attributes or {}))
            logging.info(f"Edge added between '{source_node_id}' and '{target_node_id}' with relation '{relation}'.")

        except Exception as e:
            logging.error(f"Failed to add edge between '{source_node_id}' and '{target_node_id}': {e}")
            raise

    def query_neighbors(self, node_id: str, relation: Optional[str] = None) -> List[str]:
        """
        Queries the neighbors of a given node, optionally filtering by relation type.

        Args:
            node_id: The identifier of the node to query.
            relation: The type of relation to filter by (optional).

        Returns:
            A list of neighbor node IDs.
        """
        try:
            if not self.graph.has_node(node_id):
                raise ValueError(f"Node with id '{node_id}' does not exist.")

            neighbors = []
            for neighbor in self.graph.neighbors(node_id):
                if relation:
                    # Check if an edge exists with the specified relation
                    if any(data['relation'] == relation for _, _, data in self.graph.edges(node_id, neighbor, data=True)):
                        neighbors.append(neighbor)
                else:
                    neighbors.append(neighbor)
            logging.info(f"Neighbors of '{node_id}' queried.")
            return neighbors
        except Exception as e:
            logging.error(f"Failed to query neighbors of '{node_id}': {e}")
            raise

    def find_path(self, source_node_id: str, target_node_id: str) -> Optional[List[str]]:
        """
        Finds a path between two nodes in the knowledge graph.

        Args:
            source_node_id: The identifier of the source node.
            target_node_id: The identifier of the target node.

        Returns:
            A list of node IDs representing the path, or None if no path exists.
        """
        try:
            if not self.graph.has_node(source_node_id):
                raise ValueError(f"Source node with id '{source_node_id}' does not exist.")
            if not self.graph.has_node(target_node_id):
                raise ValueError(f"Target node with id '{target_node_id}' does not exist.")

            try:
                path = nx.shortest_path(self.graph, source=source_node_id, target=target_node_id)
                logging.info(f"Path found between '{source_node_id}' and '{target_node_id}'.")
                return path
            except nx.NetworkXNoPath:
                logging.info(f"No path found between '{source_node_id}' and '{target_node_id}'.")
                return None
        except Exception as e:
            logging.error(f"Failed to find path between '{source_node_id}' and '{target_node_id}': {e}")
            raise

    def save_to_json(self, file_path: str) -> None:
        """
        Saves the knowledge graph to a JSON file.

        Args:
            file_path: The path to the JSON file.
        """
        try:
            data = nx.node_link_data(self.graph)
            with open(file_path, 'w') as f:
                json.dump(data, f, indent=4)
            logging.info(f"Knowledge graph saved to JSON file: {file_path}")
        except Exception as e:
            logging.error(f"Failed to save knowledge graph to JSON file: {e}")
            raise

    def load_from_json(self, file_path: str) -> None:
        """
        Loads the knowledge graph from a JSON file.

        Args:
            file_path: The path to the JSON file.
        """
        try:
            with open(file_path, 'r') as f:
                data = json.load(f)
            self.graph = nx.node_link_graph(data)
            logging.info(f"Knowledge graph loaded from JSON file: {file_path}")
        except Exception as e:
            logging.error(f"Failed to load knowledge graph from JSON file: {e}")
            raise

    def semantic_search(self, query_vector: List[float], limit: int = 10) -> List[Tuple[str, float]]:
        """
        Performs a semantic search in Qdrant based on a query vector.

        Args:
            query_vector: The vector representation of the query.
            limit: The maximum number of results to return.

        Returns:
            A list of tuples containing the node ID and the similarity score.
        """
        try:
            search_result = self.qdrant_client.search(
                collection_name=self.graph_name,
                query_vector=query_vector,
                limit=limit,
            )
            results = [(hit.payload["node_id"], hit.score) for hit in search_result]
            logging.info(f"Semantic search performed in Qdrant, returning {len(results)} results.")
            return results
        except Exception as e:
            logging.error(f"Failed to perform semantic search in Qdrant: {e}")
            raise

    def traverse_graph(self, start_node: str, depth: int) -> Dict[str, Any]:
        """
        Traverses the graph starting from a given node up to a specified depth.

        Args:
            start_node: The node to start the traversal from.
            depth: The maximum depth of the traversal.

        Returns:
            A dictionary representing the traversed subgraph.
        """
        try:
            if not self.graph.has_node(start_node):
                raise ValueError(f"Node with id '{start_node}' does not exist.")

            visited_nodes = set()
            traversal_result = {}

            def recursive_traverse(node: str, current_depth: int):
                if current_depth > depth or node in visited_nodes:
                    return

                visited_nodes.add(node)
                traversal_result[node] = {"attributes": self.graph.nodes[node], "neighbors": {}}

                for neighbor in self.graph.neighbors(node):
                    traversal_result[node]["neighbors"][neighbor] = []
                    for _, _, data in self.graph.edges(node, neighbor, data=True):
                        traversal_result[node]["neighbors"][neighbor].append(data)

                    recursive_traverse(neighbor, current_depth + 1)

            recursive_traverse(start_node, 0)
            logging.info(f"Graph traversal started from '{start_node}' with depth {depth}.")
            return traversal_result
        except Exception as e:
            logging.error(f"Failed to traverse graph from '{start_node}' with depth {depth}: {e}")
            raise
