# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import networkx as nx
from collections import defaultdict
import re  # For text extraction
from uuid import uuid4  # For generating unique IDs

class KnowledgeGraphEngine:
    """
    Enhanced Knowledge Graph System with deduplication, incremental updates,
    text extraction, and pathfinding.
    """

    def __init__(self, workspace_path: str = "aiva_workspace"):
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "KNOWLEDGE_GRAPH"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"
        self.kg_dir.mkdir(parents=True, exist_ok=True)
        self.graph = nx.MultiDiGraph()  # Directed graph allows for nuanced relationships
        self.entity_ids = set()  # Track existing entity IDs for deduplication
        self._load_graph()

    def _load_graph(self):
        """Loads entities and relationships from JSONL files."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            return

        with open(self.entities_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    entity = json.loads(line)
                    if entity["id"] in self.entity_ids:
                        continue  # Skip duplicate entities
                    self.add_entity_to_graph(entity, persist=False)  # Load without re-persisting
                except json.JSONDecodeError:
                    print(f"Error decoding JSON: {line}")
                except KeyError as e:
                    print(f"Missing key in entity: {e}, line: {line}")

        with open(self.relationships_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    rel = json.loads(line)
                    self.add_relationship_to_graph(rel, persist=False)
                except json.JSONDecodeError:
                    print(f"Error decoding JSON: {line}")
                except KeyError as e:
                    print(f"Missing key in relationship: {e}, line: {line}")

    def add_entity(self, entity_data: Dict[str, Any]) -> str:
        """Adds a new entity to the graph, handling deduplication."""
        entity_id = entity_data.get("id")
        if not entity_id:
            entity_id = str(uuid4())  # Generate a unique ID if one doesn't exist
            entity_data["id"] = entity_id

        if entity_id in self.entity_ids:
            print(f"Entity with ID '{entity_id}' already exists. Skipping.")
            return entity_id  # Return the existing ID

        self.add_entity_to_graph(entity_data)
        return entity_id  # Return the new ID

    def add_entity_to_graph(self, entity_data: Dict[str, Any], persist: bool = True):
        """Adds the entity to the graph and persists it."""
        entity_id = entity_data["id"]
        self.graph.add_node(entity_id, **entity_data)
        self.entity_ids.add(entity_id)

        if persist:
            self._persist_entity(entity_data)

    def add_relationship(self, from_entity: str, to_entity: str, relationship_type: str, metadata: Dict[str, Any] = {}) -> None:
        """Adds a relationship between two entities."""
        rel_data = {"from": from_entity, "to": to_entity, "type": relationship_type, **metadata}
        self.add_relationship_to_graph(rel_data)

    def add_relationship_to_graph(self, rel_data: Dict[str, Any], persist: bool = True) -> None:
        """Adds the relationship to the graph and persists it."""
        from_entity = rel_data["from"]
        to_entity = rel_data["to"]
        rel_type = rel_data["type"]

        if not self.graph.has_node(from_entity) or not self.graph.has_node(to_entity):
            print(f"Warning: Adding relationship between non-existent entities: {from_entity} and {to_entity}")
            return

        self.graph.add_edge(from_entity, to_entity, key=rel_type, **rel_data)
        if persist:
            self._persist_relationship(rel_data)

    def _persist_entity(self, entity: Dict[str, Any]):
        """Appends the entity to the entities JSONL file."""
        with open(self.entities_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(entity) + "\n")

    def _persist_relationship(self, relationship: Dict[str, Any]):
        """Appends the relationship to the relationships JSONL file."""
        with open(self.relationships_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(relationship) + "\n")

    def extract_entities_from_text(self, text: str) -> List[Dict[str, Any]]:
        """
        Extracts entities and relationships from unstructured text using regex.
        This is a placeholder; in reality, this would use an LLM or NLP model.
        """
        entities = []
        # Example: Extract "Person: [name]" and "Skill: [skill]"
        person_matches = re.findall(r"Person:\s*\[(.*?)\]", text)
        skill_matches = re.findall(r"Skill:\s*\[(.*?)\]", text)

        for name in person_matches:
            entities.append({"id": str(uuid4()), "type": "Person", "name": name})
        for skill in skill_matches:
            entities.append({"id": str(uuid4()), "type": "Skill", "name": skill})

        return entities

    def ingest_text(self, text: str) -> None:
        """Ingests text, extracts entities, and adds them to the graph."""
        extracted_entities = self.extract_entities_from_text(text)
        for entity in extracted_entities:
            self.add_entity(entity)

        # Dummy relationship extraction (replace with actual logic)
        if len(extracted_entities) > 1:
            self.add_relationship(extracted_entities[0]["id"], extracted_entities[1]["id"], "related_to")

    def find_shortest_path(self, source: str, target: str) -> Optional[List[str]]:
        """Finds the shortest path between two entities."""
        try:
            return nx.shortest_path(self.graph, source=source, target=target)
        except nx.NetworkXNoPath:
            return None
        except nx.NodeNotFound:
            return None

    def find_all_paths(self, source: str, target: str, cutoff: int = 3) -> List[List[str]]:
        """Finds all paths between two entities within a given length."""
        try:
            all_paths = list(nx.all_simple_paths(self.graph, source=source, target=target, cutoff=cutoff))
            return all_paths
        except nx.NodeNotFound:
            return []

    def get_entity(self, entity_id: str) -> Optional[Dict[str, Any]]:
        """Retrieves an entity by its ID."""
        if entity_id in self.graph.nodes:
            return self.graph.nodes[entity_id]
        return None

    def search_entities(self, query: str) -> List[str]:
        """
        Placeholder for semantic search using vector embeddings.
        In a real implementation, this would query a vector database.
        """
        # Dummy implementation: search entity names for the query string
        results = []
        for node_id, data in self.graph.nodes(data=True):
            if "name" in data and query.lower() in data["name"].lower():
                results.append(node_id)
        return results

    def hybrid_search(self, query: str) -> List[Dict[str, Any]]:
        """Combines semantic search with graph traversal."""
        vector_results = self.search_entities(query)
        hybrid_results = []
        for entity_id in vector_results:
            entity = self.get_entity(entity_id)
            if entity:
                hybrid_results.append({"source": "vector_search", "data": entity})
                # Add neighbors
                for neighbor in self.graph.neighbors(entity_id):
                    neighbor_entity = self.get_entity(neighbor)
                    if neighbor_entity:
                        hybrid_results.append({"source": "graph_neighbor", "data": neighbor_entity, "related_to": entity_id})
        return hybrid_results

    def export_graph(self, filepath: str) -> None:
        """Exports the graph to a file in JSON format."""
        graph_data = nx.node_link_data(self.graph)
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump(graph_data, f, indent=4)

    def import_graph(self, filepath: str) -> None:
        """Imports a graph from a JSON file."""
        try:
            with open(filepath, "r", encoding="utf-8") as f:
                graph_data = json.load(f)
                self.graph = nx.node_link_graph(graph_data)
                self.entity_ids = set(self.graph.nodes())
        except FileNotFoundError:
            print(f"Error: File not found: {filepath}")
        except json.JSONDecodeError:
            print(f"Error: Invalid JSON format in file: {filepath}")

    def clear_graph(self) -> None:
        """Clears the graph and resets the entity ID set."""
        self.graph.clear()
        self.entity_ids = set()
        # Optionally clear the JSONL files as well
        with open(self.entities_path, "w", encoding="utf-8") as f:
            pass  # Clear the file
        with open(self.relationships_path, "w", encoding="utf-8") as f:
            pass  # Clear the file

if __name__ == "__main__":
    # Example Usage
    kg_engine = KnowledgeGraphEngine()

    # Add entities
    entity1_id = kg_engine.add_entity({"type": "Person", "name": "Alice", "id": "alice123"})
    entity2_id = kg_engine.add_entity({"type": "Skill", "name": "Python Programming"})
    entity3_id = kg_engine.add_entity({"type": "Tool", "name": "Jupyter Notebook"})
    entity4_id = kg_engine.add_entity({"type": "Patent", "name": "AI Invention", "patent_number": "US1234567B2"})
    entity5_id = kg_engine.add_entity({"type": "Revenue", "amount": 1000000, "currency": "USD"})

    # Add relationships
    kg_engine.add_relationship(entity1_id, entity2_id, "has_skill")
    kg_engine.add_relationship(entity1_id, entity3_id, "uses")
    kg_engine.add_relationship(entity2_id, entity4_id, "generates")
    kg_engine.add_relationship(entity4_id, entity5_id, "generates")

    # Search for entities
    search_results = kg_engine.search_entities("Python")
    print(f"Search results for 'Python': {search_results}")

    # Find shortest path
    path = kg_engine.find_shortest_path(entity1_id, entity5_id)
    print(f"Shortest path from Alice to Revenue: {path}")

    # Ingest text
    kg_engine.ingest_text("The engineer Bob [BobTheBuilder] is great at using Terraform [TerraformSkill].")

    # Search for entities after ingestion
    search_results_bob = kg_engine.search_entities("Bob")
    print(f"Search results for 'Bob': {search_results_bob}")

    # Export and import graph
    kg_engine.export_graph("knowledge_graph.json")
    kg_engine2 = KnowledgeGraphEngine()
    kg_engine2.import_graph("knowledge_graph.json")
    print(f"Number of nodes in original graph: {len(kg_engine.graph.nodes)}")
    print(f"Number of nodes in imported graph: {len(kg_engine2.graph.nodes)}")

    # Clear the graph
    # kg_engine.clear_graph()  # Uncomment to clear the graph after the example
    # print(f"Number of nodes after clearing graph: {len(kg_engine.graph.nodes)}")