# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import networkx as nx
from collections import defaultdict
import re  # For entity extraction from text

class KnowledgeGraphEngine:
    """
    Advanced Knowledge Graph System with Deduplication, Incremental Updates, and Text Extraction.
    """
    def __init__(self, workspace_path: str = "aiva_kg", embedding_model=None):
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "KNOWLEDGE_GRAPH"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"
        self.graph = nx.MultiDiGraph()  # Use MultiDiGraph to allow multiple relationships between nodes
        self.entity_index: Dict[str, str] = {}  # Map of entity text to entity ID for deduplication
        self.embedding_model = embedding_model  # Placeholder for embedding model
        
        self.entity_id_counter = 0

        # Ensure directories exist
        self.kg_dir.mkdir(parents=True, exist_ok=True)

        self._load_graph()

    def _generate_entity_id(self, entity_type: str) -> str:
        """Generates a unique entity ID."""
        self.entity_id_counter += 1
        return f"{entity_type.upper()}_{self.entity_id_counter:06d}"

    def _load_graph(self):
        """Loads entities and relationships from JSONL files."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            return

        print("Loading graph from disk...")
        
        # Load Entities
        with open(self.entities_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    entity = json.loads(line)
                    entity_id = entity["id"]
                    self.graph.add_node(entity_id, **entity)
                    if "text" in entity:  # Index by text for deduplication
                        self.entity_index[entity["text"]] = entity_id
                    self.entity_id_counter = max(self.entity_id_counter, int(entity_id.split("_")[-1]))
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON: {e}")
                except KeyError as e:
                    print(f"Missing key in entity: {e}")

        # Load Relationships
        with open(self.relationships_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    rel = json.loads(line)
                    self.graph.add_edge(rel["from"], rel["to"], **rel)
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON: {e}")
                except KeyError as e:
                    print(f"Missing key in relationship: {e}")

        print(f"Graph loaded with {self.graph.number_of_nodes()} nodes and {self.graph.number_of_edges()} edges.")

    def _save_graph(self):
        """Saves the graph to JSONL files."""
        print("Saving graph to disk...")
        
        # Save Entities
        with open(self.entities_path, "w", encoding="utf-8") as f:
            for node_id, data in self.graph.nodes(data=True):
                try:
                    f.write(json.dumps(data) + "\n")
                except TypeError as e:
                     print(f"Error encoding entity {node_id}: {e}.  Data: {data}")

        # Save Relationships
        with open(self.relationships_path, "w", encoding="utf-8") as f:
            for from_node, to_node, data in self.graph.edges(data=True):
                 try:
                    f.write(json.dumps(data | {"from": from_node, "to": to_node}) + "\n") # Merge data with from/to
                 except TypeError as e:
                     print(f"Error encoding relationship {from_node} -> {to_node}: {e}. Data: {data}")

        print("Graph saved.")

    def add_entity(self, entity_type: str, properties: Dict[str, Any], text: Optional[str] = None) -> str:
        """Adds a new entity to the graph, handling deduplication."""
        if text and text in self.entity_index:
            print(f"Entity '{text}' already exists, skipping creation.")
            return self.entity_index[text]

        entity_id = self._generate_entity_id(entity_type)
        entity_data = {"id": entity_id, "type": entity_type, **properties}

        if text:
            entity_data["text"] = text  # Store text for deduplication
            self.entity_index[text] = entity_id

        self.graph.add_node(entity_id, **entity_data)
        self._save_graph()
        return entity_id

    def add_relationship(self, from_entity: str, to_entity: str, relationship_type: str, properties: Dict[str, Any] = {}) -> None:
        """Adds a relationship between two entities."""
        if from_entity not in self.graph or to_entity not in self.graph:
            print(f"Warning: One or both entities in relationship do not exist: {from_entity}, {to_entity}")
            return

        self.graph.add_edge(from_entity, to_entity, type=relationship_type, **properties)
        self._save_graph()

    def update_entity(self, entity_id: str, properties: Dict[str, Any]) -> None:
        """Updates an existing entity's properties."""
        if entity_id not in self.graph:
            print(f"Error: Entity with ID '{entity_id}' not found.")
            return

        for key, value in properties.items():
            self.graph.nodes[entity_id][key] = value

        self._save_graph()

    def delete_entity(self, entity_id: str) -> None:
        """Deletes an entity and all its relationships."""
        if entity_id not in self.graph:
            print(f"Error: Entity with ID '{entity_id}' not found.")
            return

        # Remove from entity index if present
        node_data = self.graph.nodes[entity_id]
        if "text" in node_data and node_data["text"] in self.entity_index:
            del self.entity_index[node_data["text"]]

        self.graph.remove_node(entity_id)
        self._save_graph()

    def extract_entities_from_text(self, text: str) -> List[Tuple[str, str, Dict[str, Any]]]:
        """
        Extracts entities and their properties from unstructured text using regex (for demonstration).
        In a real-world scenario, this would use a more sophisticated NLP model.
        Returns a list of (entity_type, entity_text, properties) tuples.
        """
        entities = []

        # Example: Extract "Person" entities (names)
        person_pattern = r"\b([A-Z][a-z]+ [A-Z][a-z]+)\b"
        for match in re.finditer(person_pattern, text):
            name = match.group(1)
            entities.append(("Person", name, {}))  # Basic properties

        # Example: Extract "Skill" entities (keywords)
        skill_pattern = r"\b(AI|Machine Learning|Data Science|Python)\b"
        for match in re.finditer(skill_pattern, text, re.IGNORECASE):
            skill = match.group(1)
            entities.append(("Skill", skill, {}))

        return entities

    def ingest_text(self, text: str, source_id: str) -> None:
        """
        Ingests unstructured text, extracts entities, and creates relationships.
        """
        extracted_entities = self.extract_entities_from_text(text)

        for entity_type, entity_text, properties in extracted_entities:
            entity_id = self.add_entity(entity_type, properties | {"source": source_id}, entity_text)
            # Link the extracted entity to the source document
            self.add_relationship(entity_id, source_id, "mentions", {"context": "extracted from text"})

    def find_shortest_path(self, start_entity: str, end_entity: str) -> Optional[List[str]]:
        """Finds the shortest path between two entities."""
        try:
            path = nx.shortest_path(self.graph, source=start_entity, target=end_entity)
            return path
        except nx.NetworkXNoPath:
            return None
        except nx.NodeNotFound:
            return None

    def find_all_paths(self, start_entity: str, end_entity: str, max_length: int = 4) -> List[List[str]]:
        """Finds all paths between two entities with a maximum length."""
        try:
            all_paths = list(nx.all_simple_paths(self.graph, source=start_entity, target=end_entity, cutoff=max_length))
            return all_paths
        except nx.NodeNotFound:
            return []

    def get_entity_by_id(self, entity_id: str) -> Optional[Dict[str, Any]]:
        """Retrieves an entity by its ID."""
        if entity_id in self.graph:
            return self.graph.nodes[entity_id]
        else:
            return None

    def search_entities_by_type(self, entity_type: str) -> List[Dict[str, Any]]:
        """Searches for entities by their type."""
        entities = []
        for node_id, data in self.graph.nodes(data=True):
            if data.get("type") == entity_type:
                entities.append(data)
        return entities

    def semantic_search(self, query: str, top_k: int = 5) -> List[str]:
        """
        Performs a semantic search using vector embeddings.
        Requires an embedding model to be initialized.
        """
        if not self.embedding_model:
            print("Error: Embedding model not initialized.")
            return []

        # 1. Embed the query
        query_embedding = self.embedding_model.encode(query)

        # 2. Calculate cosine similarity between query embedding and entity embeddings
        results = []
        for entity_id, data in self.graph.nodes(data=True):
            if "embedding" in data:  # Assuming entities have pre-computed embeddings
                entity_embedding = data["embedding"]
                similarity = self._cosine_similarity(query_embedding, entity_embedding)
                results.append((entity_id, similarity))

        # 3. Sort by similarity and return top_k entity IDs
        results.sort(key=lambda x: x[1], reverse=True)
        return [entity_id for entity_id, similarity in results[:top_k]]

    def _cosine_similarity(self, a, b):
        """Calculates the cosine similarity between two vectors."""
        import numpy as np
        a = np.array(a)
        b = np.array(b)
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


if __name__ == "__main__":
    # Example Usage
    kg_engine = KnowledgeGraphEngine()

    # 1. Add Entities
    person_id = kg_engine.add_entity("Person", {"name": "Alice", "title": "Engineer"}, "Alice")
    concept_id = kg_engine.add_entity("Concept", {"name": "Knowledge Graph"}, "Knowledge Graph")
    skill_id = kg_engine.add_entity("Skill", {"name": "Python"}, "Python")
    tool_id = kg_engine.add_entity("Tool", {"name": "NetworkX"}, "NetworkX")
    patent_id = kg_engine.add_entity("Patent", {"number": "US1234567B2"}, "US1234567B2")
    revenue_id = kg_engine.add_entity("Revenue", {"amount": 1000000, "currency": "USD"}, "Million Dollar Revenue")

    # 2. Add Relationships
    kg_engine.add_relationship(person_id, concept_id, "knows_about")
    kg_engine.add_relationship(person_id, skill_id, "uses")
    kg_engine.add_relationship(person_id, tool_id, "uses")
    kg_engine.add_relationship(concept_id, skill_id, "requires")
    kg_engine.add_relationship(tool_id, skill_id, "implements")
    kg_engine.add_relationship(patent_id, revenue_id, "generates")

    # 3. Update Entity
    kg_engine.update_entity(person_id, {"title": "Senior Engineer"})

    # 4. Query: Find shortest path
    path = kg_engine.find_shortest_path(person_id, revenue_id)
    print(f"Shortest path from Alice to Revenue: {path}")

    # 5. Ingest Text
    text = "Bob is a Data Scientist who uses Python and Machine Learning. He generated a patent."
    kg_engine.ingest_text(text, "document123")  # Source ID for the document

    # 6. Search Entities by Type
    people = kg_engine.search_entities_by_type("Person")
    print(f"People in the graph: {people}")

    # 7. Example of deduplication: Adding Alice again will not create a new entity.
    kg_engine.add_entity("Person", {"name": "Alice"}, "Alice")

    # 8. Example of finding all paths
    all_paths = kg_engine.find_all_paths(skill_id, revenue_id)
    print(f"All paths between Skill and Revenue: {all_paths}")

    # 9. Delete Entity (example)
    # kg_engine.delete_entity(person_id)
    # print(f"Graph after deleting Alice: {list(kg_engine.graph.nodes)}")

    # 10. Retrieve an entity by ID
    alice = kg_engine.get_entity_by_id(person_id)
    print(f"Alice's details: {alice}")

    # Example using the existing graph data:
    rag = KnowledgeGraphEngine()
    first_node = list(rag.graph.nodes())[0]

    print(f"Graph loaded with {rag.graph.number_of_nodes()} nodes and {rag.graph.number_of_edges()} edges.")
    
    if rag.graph.number_of_nodes() > 0:
        print(f"Shortest path from {first_node} to REVENUE_TECH_STACK:")
        path = rag.find_shortest_path(first_node, "REVENUE_TECH_STACK")
        print(path)