# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import networkx as nx
from collections import defaultdict
import re

class KnowledgeGraphEngine:
    """
    Knowledge Graph Engine for AIVA, supporting entity deduplication,
    incremental updates, and extraction from unstructured text.
    """

    def __init__(self, workspace_path: str = "aiva_kg"):
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "KNOWLEDGE_GRAPH"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"
        self.graph = nx.MultiDiGraph()  # Use MultiDiGraph for multiple relationship types
        self._ensure_directories()
        self._load_graph()
        self.entity_id_counter = self._get_max_entity_id() + 1 if self.graph.number_of_nodes() > 0 else 1


    def _ensure_directories(self):
        """Ensures the workspace and KG directories exist."""
        self.workspace.mkdir(parents=True, exist_ok=True)
        self.kg_dir.mkdir(parents=True, exist_ok=True)

    def _get_max_entity_id(self) -> int:
        """
        Finds the maximum entity ID currently in the graph.

        Returns:
            int: The largest entity ID, or 0 if the graph is empty.
        """
        max_id = 0
        for node in self.graph.nodes():
            try:
                node_id = int(node.split("_")[-1])  # Assuming ID format: TYPE_ID
                max_id = max(max_id, node_id)
            except ValueError:
                # Handle cases where the node ID doesn't conform to the expected format
                pass
        return max_id

    def _load_graph(self):
        """Loads entities and relationships from disk."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            print("Entity or relationship files not found. Starting with an empty graph.")
            return

        try:
            with open(self.entities_path, "r", encoding="utf-8") as f:
                for line in f:
                    entity = json.loads(line)
                    self.add_entity_to_graph(entity, persist=False)  # Don't persist during load
            print(f"Loaded {self.graph.number_of_nodes()} entities.")
        except Exception as e:
            print(f"Error loading entities: {e}")

        try:
            with open(self.relationships_path, "r", encoding="utf-8") as f:
                for line in f:
                    rel = json.loads(line)
                    self.add_relationship_to_graph(rel, persist=False)  # Don't persist during load
            print(f"Loaded {self.graph.number_of_edges()} relationships.")
        except Exception as e:
            print(f"Error loading relationships: {e}")


    def add_entity(self, entity_type: str, properties: Dict[str, Any]) -> str:
        """Adds a new entity to the graph, handling deduplication."""
        entity_id = self._generate_entity_id(entity_type)
        entity = {"id": entity_id, "type": entity_type, **properties}
        self.add_entity_to_graph(entity)
        return entity_id

    def _generate_entity_id(self, entity_type: str) -> str:
        """Generates a unique entity ID."""
        entity_id = f"{entity_type}_{self.entity_id_counter}"
        self.entity_id_counter += 1
        return entity_id


    def add_entity_to_graph(self, entity: Dict[str, Any], persist: bool = True):
        """Adds an entity to the graph, handling deduplication."""
        existing_node = None
        for node_id, node_data in self.graph.nodes(data=True):
            if node_data.get("type") == entity["type"] and \
               all(entity.get(k) == node_data.get(k) for k in entity if k != "id" and k in node_data):  # compare properties
                existing_node = node_id
                break

        if existing_node:
            print(f"Entity already exists: {existing_node}.  Updating with new properties.")
            self.graph.nodes[existing_node].update(entity)  # Update existing node
        else:
            self.graph.add_node(entity["id"], **entity)
            print(f"Added new entity: {entity['id']}")

        if persist:
            self._persist_entity(entity)

    def add_relationship(self, from_entity: str, to_entity: str, relationship_type: str, properties: Dict[str, Any] = {}):
        """Adds a relationship between two entities."""
        relationship = {"from": from_entity, "to": to_entity, "type": relationship_type, **properties}
        self.add_relationship_to_graph(relationship)


    def add_relationship_to_graph(self, relationship: Dict[str, Any], persist: bool = True):
        """Adds a relationship to the graph."""
        self.graph.add_edge(relationship["from"], relationship["to"], **relationship)
        print(f"Added relationship: {relationship['from']} --[{relationship['type']}]--> {relationship['to']}")
        if persist:
            self._persist_relationship(relationship)


    def _persist_entity(self, entity: Dict[str, Any]):
        """Persists an entity to the entities.jsonl file."""
        with open(self.entities_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(entity) + "\n")

    def _persist_relationship(self, relationship: Dict[str, Any]):
        """Persists a relationship to the relationships.jsonl file."""
        with open(self.relationships_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(relationship) + "\n")

    def find_paths(self, start_entity: str, end_entity: str, max_depth: int = 3) -> List[List[str]]:
        """
        Finds all paths between two entities up to a specified depth.
        """
        try:
            paths = list(nx.all_simple_paths(self.graph, source=start_entity, target=end_entity, cutoff=max_depth))
            return paths
        except nx.NodeNotFound as e:
            print(f"Node not found: {e}")
            return []


    def extract_entities_from_text(self, text: str) -> List[Tuple[str, str, Dict]]:
        """
        Extracts entities and relationships from unstructured text using regex (placeholder).
        This is a simplified example and would ideally use a proper NLP library.
        Returns a list of (entity_type, entity_name, properties) tuples.
        """
        entities = []

        # Example: Extract "Tool" entities (replace with more sophisticated logic)
        tool_matches = re.findall(r"\b(Gemini|ChatGPT|SORA 2|GoHighLevel|Telnyx|Instantly\.ai|Apify|Pomelli|Veo|n8n)\b", text)
        for tool in tool_matches:
            entities.append(("Tool", tool, {"source": "text_extraction"}))

        # Example: Extract "Skill" entities
        skill_matches = re.findall(r"\b(SEO|Automation|Marketing|Sales|Coding)\b", text)
        for skill in skill_matches:
            entities.append(("Skill", skill, {"source": "text_extraction"}))

        return entities


    def integrate_vector_results(self, query: str, vector_results: List[str]) -> List[Dict]:
        """
        Integrates vector search results with the knowledge graph.
        For each entity ID from the vector search, it retrieves the entity data
        and its immediate neighbors in the graph.
        """
        integrated_results = []
        for entity_id in vector_results:
            if entity_id in self.graph:
                entity_data = self.graph.nodes[entity_id]
                integrated_results.append({"source": "vector_search", "entity": entity_data})

                # Add neighbors
                for neighbor in self.graph.neighbors(entity_id):
                    neighbor_data = self.graph.nodes[neighbor]
                    integrated_results.append({"source": "graph_neighbor", "entity": neighbor_data, "related_to": entity_id})
            else:
                print(f"Entity ID from vector search not found in graph: {entity_id}")

        return integrated_results

    def query_entity(self, entity_id: str) -> Optional[Dict]:
        """Retrieves an entity by its ID."""
        if entity_id in self.graph:
            return self.graph.nodes[entity_id]
        else:
            return None

    def get_neighbors(self, entity_id: str) -> List[Dict]:
        """Retrieves the neighbors of an entity."""
        neighbors = []
        for neighbor in self.graph.neighbors(entity_id):
            neighbors.append(self.graph.nodes[neighbor])
        return neighbors

    def save_graph(self):
        """Saves the entire graph to disk."""
        all_entities = []
        all_relationships = []

        for node, data in self.graph.nodes(data=True):
            all_entities.append(data)

        for from_node, to_node, data in self.graph.edges(data=True):
            all_relationships.append(data)

        # Persist entities and relationships
        with open(self.entities_path, "w", encoding="utf-8") as f:
            for entity in all_entities:
                f.write(json.dumps(entity) + "\n")

        with open(self.relationships_path, "w", encoding="utf-8") as f:
            for relationship in all_relationships:
                f.write(json.dumps(relationship) + "\n")

        print("Graph saved to disk.")


if __name__ == "__main__":
    # Example Usage
    engine = KnowledgeGraphEngine()

    # Add entities
    entity1_id = engine.add_entity("Person", {"name": "Alice", "role": "Engineer"})
    entity2_id = engine.add_entity("Concept", {"name": "AI Validation", "description": "Validating AI outputs"})
    entity3_id = engine.add_entity("Tool", {"name": "Gemini", "version": "3.0"})
    entity4_id = engine.add_entity("Revenue", {"name": "Database Reactivation", "price": 2997})

    # Add relationships
    engine.add_relationship(entity1_id, entity2_id, "works_on")
    engine.add_relationship(entity1_id, entity3_id, "uses")
    engine.add_relationship(entity3_id, entity2_id, "implements")
    engine.add_relationship(entity4_id, entity2_id, "generates")

    # Find paths
    paths = engine.find_paths(entity1_id, entity4_id)
    print(f"Paths between Alice and Database Reactivation: {paths}")

    # Extract entities from text
    text = "Alice uses Gemini for AI Validation. This generates revenue from Database Reactivation."
    extracted_entities = engine.extract_entities_from_text(text)
    print(f"Extracted entities: {extracted_entities}")

    # Integrate vector results (example)
    vector_results = [entity2_id, entity3_id]
    integrated_results = engine.integrate_vector_results("AI", vector_results)
    print(f"Integrated vector results: {integrated_results}")

    # Query entity
    entity = engine.query_entity(entity1_id)
    print(f"Entity Alice: {entity}")

    # Get neighbors
    neighbors = engine.get_neighbors(entity1_id)
    print(f"Neighbors of Alice: {neighbors}")

    #Save Graph
    engine.save_graph()

    #Demonstrate that adding the same entity again does not create a duplicate
    entity1_id = engine.add_entity("Person", {"name": "Alice", "role": "Engineer"})
    #Demonstrate that the entity_id counter increments correctly even after saving and loading
    entity5_id = engine.add_entity("Skill", {"name": "Testing"})
    print(f"New skill ID is: {entity5_id}")