# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional
import networkx as nx
from collections import defaultdict
import spacy

class KnowledgeGraphEngine:
    """
    A knowledge graph system that:
        1. Stores entities with relationships
        2. Supports graph traversal queries
        3. Integrates with vector embeddings for semantic search (placeholder)
        4. Handles entity deduplication
        5. Supports incremental updates
        6. Extracts entities from unstructured text
    """
    def __init__(self, workspace_path: str = "aiva_kg", nlp_model: str = "en_core_web_sm"):  # Use smaller model by default
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "KNOWLEDGE_GRAPH"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"

        self.kg_dir.mkdir(parents=True, exist_ok=True)

        self.graph = nx.Graph()
        self._load_graph()
        self.nlp = spacy.load(nlp_model) # Load the spacy model

        self.entity_types = ["Person", "Concept", "Skill", "Tool", "Patent", "Revenue"]
        self.relationship_types = ["uses", "implements", "depends_on", "generates"]

    def _load_graph(self):
        """Loads entities and relationships into a NetworkX graph."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            return

        # Load Entities
        with open(self.entities_path, "r", encoding="utf-8") as f:
            for line in f:
                entity = json.loads(line)
                self.graph.add_node(entity["id"], **entity)

        # Load Relationships
        with open(self.relationships_path, "r", encoding="utf-8") as f:
            for line in f:
                rel = json.loads(line)
                self.graph.add_edge(rel["from"], rel["to"], **rel)

    def _save_graph(self):
        """Saves the graph data to JSONL files."""
        entities = [data for _, data in self.graph.nodes(data=True)]
        relationships = []
        for u, v, data in self.graph.edges(data=True):
            relationships.append({
                "from": u,
                "to": v,
                **data
            })

        with open(self.entities_path, "w", encoding="utf-8") as f:
            for entity in entities:
                f.write(json.dumps(entity) + "\n")

        with open(self.relationships_path, "w", encoding="utf-8") as f:
            for rel in relationships:
                f.write(json.dumps(rel) + "\n")

    def add_entity(self, entity_id: str, entity_type: str, **kwargs):
        """Adds a new entity to the graph, handling deduplication."""
        if entity_id in self.graph:
            print(f"Entity with ID '{entity_id}' already exists. Skipping.")
            return

        if entity_type not in self.entity_types:
            raise ValueError(f"Invalid entity type: {entity_type}. Must be one of {self.entity_types}")

        entity_data = {"id": entity_id, "type": entity_type, **kwargs}
        self.graph.add_node(entity_id, **entity_data)
        self._save_graph()

    def add_relationship(self, from_entity: str, to_entity: str, relationship_type: str, **kwargs):
        """Adds a relationship between two entities."""
        if from_entity not in self.graph or to_entity not in self.graph:
            raise ValueError("Both entities must exist in the graph.")

        if relationship_type not in self.relationship_types:
            raise ValueError(f"Invalid relationship type: {relationship_type}. Must be one of {self.relationship_types}")

        self.graph.add_edge(from_entity, to_entity, type=relationship_type, **kwargs)
        self._save_graph()

    def find_path(self, start_entity: str, end_entity: str) -> Optional[List[str]]:
        """Finds a path between two entities in the graph."""
        try:
            path = nx.shortest_path(self.graph, source=start_entity, target=end_entity)
            return path
        except nx.NetworkXNoPath:
            return None

    def extract_entities_from_text(self, text: str) -> List[Dict[str, str]]:
        """
        Extracts entities from unstructured text using spaCy.
        This is a basic implementation and can be improved with custom NER models.
        """
        doc = self.nlp(text)
        entities = []
        for ent in doc.ents:
            entities.append({"text": ent.text, "label": ent.label_})
        return entities

    def integrate_vector_embeddings(self, entity_id: str, embedding: List[float]):
        """
        Placeholder for integrating vector embeddings for semantic search.
        In a real implementation, this would store the embedding in a vector database
        and use it for similarity search.
        """
        if entity_id not in self.graph:
            raise ValueError(f"Entity with ID '{entity_id}' does not exist.")

        # In a real implementation, store the embedding in a vector database.
        # For this example, we'll just store it in the entity data.
        self.graph.nodes[entity_id]["embedding"] = embedding
        self._save_graph()

    def semantic_search(self, query: str, top_k: int = 5) -> List[str]:
         """
         Placeholder for semantic search using vector embeddings.
         In a real implementation, this would query a vector database to find the
         most similar entities to the query.
         """
         # This is a placeholder and returns dummy results.
         # Replace with actual vector search logic.
         print(f"Performing semantic search for query: {query}")
         return list(self.graph.nodes)[:top_k]

    def get_entity(self, entity_id: str) -> Optional[Dict[str, Any]]:
        """Retrieves an entity from the graph by its ID."""
        if entity_id in self.graph:
            return self.graph.nodes[entity_id]
        else:
            return None

    def update_entity(self, entity_id: str, **kwargs):
        """Updates an existing entity with new data."""
        if entity_id not in self.graph:
            raise ValueError(f"Entity with ID '{entity_id}' does not exist.")

        for key, value in kwargs.items():
            self.graph.nodes[entity_id][key] = value
        self._save_graph()

    def delete_entity(self, entity_id: str):
        """Deletes an entity from the graph."""
        if entity_id not in self.graph:
            raise ValueError(f"Entity with ID '{entity_id}' does not exist.")

        self.graph.remove_node(entity_id)
        self._save_graph()

    def delete_relationship(self, from_entity: str, to_entity: str):
        """Deletes a relationship between two entities."""
        if from_entity not in self.graph or to_entity not in self.graph:
            raise ValueError("Both entities must exist in the graph.")

        if not self.graph.has_edge(from_entity, to_entity):
            raise ValueError(f"No relationship exists between '{from_entity}' and '{to_entity}'.")

        self.graph.remove_edge(from_entity, to_entity)
        self._save_graph()

    def get_all_entities_of_type(self, entity_type: str) -> List[Dict[str, Any]]:
        """Retrieves all entities of a specific type."""
        entities = []
        for _, data in self.graph.nodes(data=True):
            if data.get("type") == entity_type:
                entities.append(data)
        return entities

    def get_neighbors(self, entity_id: str) -> List[str]:
        """Retrieves the neighbors of a given entity."""
        if entity_id not in self.graph:
            raise ValueError(f"Entity with ID '{entity_id}' does not exist.")
        return list(self.graph.neighbors(entity_id))

    def get_relationships(self, entity_id: str) -> List[Dict[str, Any]]:
         """Retrieves all relationships associated with a given entity."""
         if entity_id not in self.graph:
             raise ValueError(f"Entity with ID '{entity_id}' does not exist.")

         relationships = []
         for neighbor in self.graph.neighbors(entity_id):
             edge_data = self.graph.get_edge_data(entity_id, neighbor)
             relationships.append({
                 "target": neighbor,
                 "type": edge_data.get("type", "related"),
                 **edge_data
             })
         return relationships

    def export_graph_to_json(self, filepath: str):
        """Exports the entire graph data to a single JSON file."""
        graph_data = {
            "nodes": [dict(data) for _, data in self.graph.nodes(data=True)],
            "edges": [{"source": u, "target": v, **data} for u, v, data in self.graph.edges(data=True)]
        }

        with open(filepath, "w", encoding="utf-8") as f:
            json.dump(graph_data, f, indent=4)

    def import_graph_from_json(self, filepath: str):
        """Imports graph data from a JSON file, overwriting the existing graph."""
        try:
            with open(filepath, "r", encoding="utf-8") as f:
                graph_data = json.load(f)

            self.graph = nx.Graph()  # Create a new graph
            for node_data in graph_data["nodes"]:
                self.graph.add_node(node_data["id"], **node_data)

            for edge_data in graph_data["edges"]:
                self.graph.add_edge(edge_data["source"], edge_data["target"], **edge_data)

            self._save_graph()  # Update JSONL files
            print(f"Graph imported successfully from {filepath}")

        except FileNotFoundError:
            print(f"Error: File not found at {filepath}")
        except json.JSONDecodeError:
            print(f"Error: Invalid JSON format in {filepath}")
        except KeyError as e:
            print(f"Error: Missing key in JSON data: {e}")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")

if __name__ == "__main__":
    # Example Usage
    kg_engine = KnowledgeGraphEngine()

    # Add entities
    kg_engine.add_entity("person1", "Person", name="Alice", occupation="Engineer")
    kg_engine.add_entity("concept1", "Concept", name="AI", description="Artificial Intelligence")
    kg_engine.add_entity("tool1", "Tool", name="Python", purpose="Programming")
    kg_engine.add_entity("patent1", "Patent", name="US1234567", description="AI invention")
    kg_engine.add_entity("revenue1", "Revenue", amount=1000000, currency="USD")

    # Add relationships
    kg_engine.add_relationship("person1", "tool1", "uses")
    kg_engine.add_relationship("person1", "concept1", "implements")
    kg_engine.add_relationship("tool1", "concept1", "depends_on")
    kg_engine.add_relationship("concept1", "revenue1", "generates")

    # Find a path
    path = kg_engine.find_path("person1", "revenue1")
    print(f"Path from person1 to revenue1: {path}")

    # Extract entities from text
    text = "Alice uses Python to implement AI and generate revenue."
    extracted_entities = kg_engine.extract_entities_from_text(text)
    print(f"Extracted entities from text: {extracted_entities}")

    # Integrate vector embeddings (placeholder)
    kg_engine.integrate_vector_embeddings("concept1", [0.1, 0.2, 0.3])

    # Semantic search (placeholder)
    search_results = kg_engine.semantic_search("artificial intelligence")
    print(f"Semantic search results: {search_results}")

    # Retrieve an entity
    entity = kg_engine.get_entity("person1")
    print(f"Retrieved entity: {entity}")

    # Update an entity
    kg_engine.update_entity("person1", occupation="Software Engineer")
    updated_entity = kg_engine.get_entity("person1")
    print(f"Updated entity: {updated_entity}")

    # Get all entities of a specific type
    persons = kg_engine.get_all_entities_of_type("Person")
    print(f"All persons: {persons}")

    # Get neighbors of an entity
    neighbors = kg_engine.get_neighbors("person1")
    print(f"Neighbors of person1: {neighbors}")

    # Get relationships of an entity
    relationships = kg_engine.get_relationships("person1")
    print(f"Relationships of person1: {relationships}")

    # Export the graph to a JSON file
    kg_engine.export_graph_to_json("knowledge_graph.json")

    # Delete relationship
    kg_engine.delete_relationship("person1", "tool1")

    # Delete an entity
    kg_engine.delete_entity("revenue1")

    # Import the graph from a JSON file
    kg_engine.import_graph_from_json("knowledge_graph.json")

    print("Example completed.")