# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional
import networkx as nx
from collections import defaultdict
import re

class KnowledgeGraphEngine:
    """
    Knowledge Graph Engine for AIVA, supporting entity deduplication,
    incremental updates, and text extraction.
    """
    def __init__(self, workspace_path: str = "aiva_kb", embedding_model=None):
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "KNOWLEDGE_GRAPH"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"
        self.kg_dir.mkdir(parents=True, exist_ok=True)
        self.graph = nx.Graph()
        self.embedding_model = embedding_model # Placeholder for embedding model
        self._load_graph()

    def _load_graph(self):
        """Loads entities and relationships from disk."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            print("No existing knowledge graph found. Starting with an empty graph.")
            return

        try:
            with open(self.entities_path, "r", encoding="utf-8") as f:
                for line in f:
                    entity = json.loads(line)
                    self.graph.add_node(entity["id"], **entity)

            with open(self.relationships_path, "r", encoding="utf-8") as f:
                for line in f:
                    rel = json.loads(line)
                    self.graph.add_edge(rel["from"], rel["to"], **rel)
            print(f"Knowledge graph loaded successfully. Nodes: {self.graph.number_of_nodes()}, Edges: {self.graph.number_of_edges()}")

        except Exception as e:
            print(f"Error loading knowledge graph: {e}. Starting with an empty graph.")
            self.graph = nx.Graph()

    def _save_graph(self):
        """Saves entities and relationships to disk."""
        print("Saving Knowledge Graph...")
        entities = []
        relationships = []

        for node_id, data in self.graph.nodes(data=True):
            entities.append(data)

        for from_node, to_node, data in self.graph.edges(data=True):
            relationships.append(data)

        try:
            with open(self.entities_path, "w", encoding="utf-8") as f:
                for entity in self.graph.nodes(data=True):
                   json.dump(entity[1], f)
                   f.write('\n')

            with open(self.relationships_path, "w", encoding="utf-8") as f:
                for edge in self.graph.edges(data=True):
                    rel = edge[2]
                    rel["from"] = edge[0]
                    rel["to"] = edge[1]
                    json.dump(rel, f)
                    f.write('\n')
            print("Knowledge graph saved successfully.")
        except Exception as e:
            print(f"Error saving knowledge graph: {e}")

    def add_entity(self, entity_id: str, entity_type: str, **kwargs):
        """Adds a new entity to the graph, handling deduplication."""
        if entity_id in self.graph:
            print(f"Entity with ID '{entity_id}' already exists. Updating attributes.")
            for key, value in kwargs.items():
                self.graph.nodes[entity_id][key] = value
        else:
            entity_data = {"id": entity_id, "type": entity_type, **kwargs}
            self.graph.add_node(entity_id, **entity_data)
            print(f"Entity '{entity_id}' added to the graph.")
        self._save_graph()

    def add_relationship(self, from_entity: str, to_entity: str, relationship_type: str, **kwargs):
        """Adds a relationship between two entities."""
        if from_entity not in self.graph or to_entity not in self.graph:
            print(f"One or both entities do not exist in the graph. Creating missing entities.")
            if from_entity not in self.graph:
                self.add_entity(from_entity, "unknown")
            if to_entity not in self.graph:
                self.add_entity(to_entity, "unknown")

        if self.graph.has_edge(from_entity, to_entity):
            print(f"Relationship between '{from_entity}' and '{to_entity}' already exists. Updating attributes.")
            for key, value in kwargs.items():
                self.graph.edges[from_entity, to_entity][key] = value
        else:
            relationship_data = {"type": relationship_type, **kwargs}
            self.graph.add_edge(from_entity, to_entity, **relationship_data)
            print(f"Relationship '{relationship_type}' added between '{from_entity}' and '{to_entity}'.")
        self._save_graph()

    def extract_entities_from_text(self, text: str) -> List[Dict]:
        """
        Extracts entities from unstructured text using regex (extend with NLP models).
        Example entity extraction logic (can be replaced with a more sophisticated NLP model)
        """
        entities = []
        # Regex for patents (e.g., US1234567B2)
        patent_matches = re.findall(r"\b[A-Z]{2}\d{7,}[\w\d]+\b", text)
        for match in patent_matches:
            entities.append({"id": match, "type": "Patent"})
        # Regex for monetary values
        money_matches = re.findall(r"\$\d+(?:\,\d{3})*(?:\.\d{2})?", text)
        for match in money_matches:
            entities.append({"id": match, "type": "Revenue"})

        # Regex for tool names (Capitalized words)
        tool_matches = re.findall(r"[A-Z][a-z]+", text)
        for match in tool_matches:
            entities.append({"id": match, "type": "Tool"})

        return entities

    def process_text_and_update_graph(self, text: str, source_id: str):
        """
        Extracts entities from text, adds them to the graph, and creates relationships
        to the source document.
        """
        extracted_entities = self.extract_entities_from_text(text)
        if not extracted_entities:
            print("No entities found in the text.")
            return

        # Add the source document as an entity
        self.add_entity(source_id, "SourceDocument", text=text)

        for entity in extracted_entities:
            entity_id = entity["id"]
            entity_type = entity["type"]
            self.add_entity(entity_id, entity_type)
            self.add_relationship(source_id, entity_id, "mentions")
        self._save_graph()

    def semantic_search(self, query: str, top_k: int = 5) -> List[str]:
        """
        Performs a semantic search using vector embeddings (placeholder).
        Replace with actual embedding and similarity search implementation.
        """
        if not self.embedding_model:
            print("No embedding model provided. Returning empty list.")
            return []

        # In a real implementation, this would:
        # 1. Embed the query using self.embedding_model
        # 2. Search a vector database for similar entity embeddings
        # 3. Return the IDs of the top_k most similar entities.
        print(f"Performing semantic search for query: '{query}'. This is a placeholder.")
        # Placeholder: Return the first few entity IDs
        return list(self.graph.nodes)[:top_k]

    def find_shortest_path(self, start_entity: str, end_entity: str) -> Optional[List[str]]:
        """Finds the shortest path between two entities in the graph."""
        try:
            path = nx.shortest_path(self.graph, source=start_entity, target=end_entity)
            return path
        except nx.NetworkXNoPath:
            print(f"No path found between '{start_entity}' and '{end_entity}'.")
            return None
        except nx.NetworkXError as e:
            print(f"Error finding shortest path: {e}")
            return None

    def find_neighbors(self, entity_id: str, relationship_type: str = None) -> List[str]:
        """Finds neighbors of an entity, optionally filtered by relationship type."""
        neighbors = []
        for neighbor in self.graph.neighbors(entity_id):
            if relationship_type:
                if self.graph.has_edge(entity_id, neighbor) and self.graph.edges[entity_id, neighbor]["type"] == relationship_type:
                    neighbors.append(neighbor)
            else:
                neighbors.append(neighbor)
        return neighbors

    def query_graph(self, query_type: str, **kwargs) -> Any:
         """
         A generic query interface for the knowledge graph.
         Supports various query types based on the 'query_type' argument.
         """
         if query_type == "shortest_path":
             start_entity = kwargs.get("start_entity")
             end_entity = kwargs.get("end_entity")
             if not start_entity or not end_entity:
                 return "Error: start_entity and end_entity are required for shortest_path query."
             return self.find_shortest_path(start_entity, end_entity)
         elif query_type == "neighbors":
             entity_id = kwargs.get("entity_id")
             relationship_type = kwargs.get("relationship_type")
             if not entity_id:
                 return "Error: entity_id is required for neighbors query."
             return self.find_neighbors(entity_id, relationship_type)
         elif query_type == "semantic_search":
              query = kwargs.get("query")
              top_k = kwargs.get("top_k", 5)
              if not query:
                  return "Error: query is required for semantic_search query."
              return self.semantic_search(query, top_k)
         else:
             return f"Error: Unknown query type '{query_type}'."

    def export_graph_to_json(self, filepath: str):
        """Exports the entire graph data to a JSON file."""
        graph_data = nx.node_link_data(self.graph)
        try:
            with open(filepath, "w", encoding="utf-8") as f:
                json.dump(graph_data, f, indent=4)
            print(f"Graph exported to JSON file: {filepath}")
        except Exception as e:
            print(f"Error exporting graph to JSON: {e}")

if __name__ == "__main__":
    # Example Usage
    engine = KnowledgeGraphEngine(workspace_path="aiva_kb")

    # 1. Add entities
    engine.add_entity("person1", "Person", name="Alice", age=30)
    engine.add_entity("concept1", "Concept", name="AI", description="Artificial Intelligence")
    engine.add_entity("skill1", "Skill", name="Python")

    # 2. Add relationships
    engine.add_relationship("person1", "skill1", "uses")
    engine.add_relationship("person1", "concept1", "works_with")
    engine.add_relationship("skill1", "concept1", "implements")

    # 3. Extract entities from text
    text_data = "Alice uses Python to implement AI. She also mentioned US1234567B2 and earned $1000 today."
    engine.process_text_and_update_graph(text_data, "document1")

    # 4. Perform semantic search (placeholder)
    search_results = engine.query_graph("semantic_search", query="artificial intelligence")
    print(f"Semantic search results: {search_results}")

    # 5. Find shortest path
    path = engine.query_graph("shortest_path", start_entity="person1", end_entity="concept1")
    print(f"Shortest path between person1 and concept1: {path}")

    # 6. Find neighbors
    neighbors = engine.query_graph("neighbors", entity_id="person1")
    print(f"Neighbors of person1: {neighbors}")

    neighbors_uses = engine.query_graph("neighbors", entity_id="person1", relationship_type="uses")
    print(f"Neighbors of person1 with 'uses' relationship: {neighbors_uses}")

    # 7. Export the graph
    engine.export_graph_to_json("aiva_knowledge_graph.json")

    # Load the graph again to verify
    engine2 = KnowledgeGraphEngine(workspace_path="aiva_kb")
    print(f"Graph loaded again with {engine2.graph.number_of_nodes()} nodes and {engine2.graph.number_of_edges()} edges.")