# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional
import networkx as nx
from collections import defaultdict
import re
import spacy

class KnowledgeGraphEngine:
    """
    Advanced Knowledge Graph System with deduplication, incremental updates,
    text extraction, and query interface.
    """
    def __init__(self, workspace_path: str = "aiva_kg", nlp_model: str = "en_core_web_sm"):
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "KNOWLEDGE_GRAPH"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"
        self.graph = nx.Graph()
        self.entity_ids = set()  # Track existing entity IDs for deduplication
        self.nlp = spacy.load(nlp_model)
        self._load_graph()
        self.entity_types = ["Person", "Concept", "Skill", "Tool", "Patent", "Revenue", "strategy_node", "technology_enabler", "protocol", "pricing_model", "service_architecture", "market_segment", "premortem_risk", "infrastructure", "value_proposition", "retention_protocol", "decision_point", "brand_strategy", "sales_framework", "success_criteria", "milestone", "localization_rule", "quality_control", "market_intelligence", "pricing_intelligence", "strategic_axiom", "decision_record", "research_request", "pricing_intelligence", "strategic_intelligence", "pricing_recommendation", "performance_data", "market_data", "strategic_plan", "competitive_intelligence", "mandatory_protocol", "critical_learning", "strategic_insight"]
        self.relationship_types = ["uses", "implements", "depends_on", "generates", "related_to"]

        # Create directories if they don't exist
        self.kg_dir.mkdir(parents=True, exist_ok=True)

    def _load_graph(self):
        """Loads entities and relationships into a NetworkX graph."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            return

        # Load Entities
        with open(self.entities_path, "r", encoding="utf-8") as f:
            for line in f:
                entity = json.loads(line)
                if entity["id"] not in self.entity_ids: #Deduplication during load
                    self.graph.add_node(entity["id"], **entity)
                    self.entity_ids.add(entity["id"])

        # Load Relationships
        with open(self.relationships_path, "r", encoding="utf-8") as f:
            for line in f:
                rel = json.loads(line)
                self.graph.add_edge(rel["from"], rel["to"], **rel)

    def add_entity(self, entity_id: str, entity_type: str, **kwargs):
        """Adds a new entity to the graph, ensuring deduplication."""
        if entity_id in self.entity_ids:
            print(f"Entity with ID '{entity_id}' already exists. Skipping.")
            return False

        entity = {"id": entity_id, "type": entity_type, **kwargs}
        self.graph.add_node(entity_id, **entity)
        self.entity_ids.add(entity_id)

        # Persist to disk
        with open(self.entities_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(entity) + "\n")

        return True

    def add_relationship(self, from_entity: str, to_entity: str, relationship_type: str, **kwargs):
        """Adds a relationship between two entities."""
        if from_entity not in self.graph or to_entity not in self.graph:
            print(f"One or both entities do not exist in the graph.")
            return False

        relationship = {"from": from_entity, "to": to_entity, "type": relationship_type, **kwargs}
        self.graph.add_edge(from_entity, to_entity, **relationship)

        # Persist to disk
        with open(self.relationships_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(relationship) + "\n")

        return True
    
    def extract_entities_from_text(self, text: str) -> List[Dict[str, Any]]:
        """Extracts entities and relationships from unstructured text using spaCy."""
        doc = self.nlp(text)
        entities = []
        for ent in doc.ents:
            entity = {
                "text": ent.text,
                "label": ent.label_,
                "start_char": ent.start_char,
                "end_char": ent.end_char
            }
            entities.append(entity)
        return entities

    def infer_relationships(self, text: str, source_entity_id: str) -> None:
        """Infers relationships from text, linking to a source entity."""
        extracted_entities = self.extract_entities_from_text(text)
        for entity in extracted_entities:
            # Simplify: Create a new entity for each extracted entity
            new_entity_id = f"EXTRACTED_{hash(entity['text']) % 10000}" #Simple hash for ID
            if self.add_entity(new_entity_id, entity["label"], text=entity["text"], source="text_extraction"):
                self.add_relationship(source_entity_id, new_entity_id, "mentions", context=text)

    def find_shortest_path(self, start_entity: str, end_entity: str) -> Optional[List[str]]:
        """Finds the shortest path between two entities in the graph."""
        try:
            path = nx.shortest_path(self.graph, source=start_entity, target=end_entity)
            return path
        except nx.NetworkXNoPath:
            return None
        except nx.NodeNotFound:
            return None

    def find_all_paths(self, start_entity: str, end_entity: str, cutoff: int = 3) -> List[List[str]]:
        """Finds all paths between two entities with a specified maximum length."""
        try:
            all_paths = list(nx.all_simple_paths(self.graph, source=start_entity, target=end_entity, cutoff=cutoff))
            return all_paths
        except nx.NodeNotFound:
            return []

    def hybrid_retrieve(self, query: str, vector_results: List[str]) -> List[Dict]:
        """
        Combines vector search results with structural graph neighbours.
        This provides BOTH semantic similarity AND relational context.
        """
        hybrid_results = []
        for entity_id in vector_results:
            if entity_id in self.graph:
                # Add the entity itself
                hybrid_results.append({
                    "id": entity_id,
                    "data": self.graph.nodes[entity_id],
                    "source": "vector"
                })
                # Add immediate neighbors (1-hop) as context
                for neighbor in self.graph.neighbors(entity_id):
                    hybrid_results.append({
                        "id": neighbor,
                        "data": self.graph.nodes[neighbor],
                        "source": "graph_expansion",
                        "related_to": entity_id
                    })
        return hybrid_results

    def get_entity(self, entity_id: str) -> Optional[Dict]:
        """Retrieves an entity by its ID."""
        if entity_id in self.graph:
            return self.graph.nodes[entity_id]
        else:
            return None

    def get_neighbors(self, entity_id: str) -> List[str]:
        """Retrieves the neighbors of an entity."""
        if entity_id in self.graph:
            return list(self.graph.neighbors(entity_id))
        else:
            return []

    def query_graph(self, query_type: str, **kwargs) -> Any:
        """
        Flexible query interface for the knowledge graph.

        Supported query types:
            - shortest_path: Find the shortest path between two entities.
            - all_paths: Find all paths between two entities within a cutoff.
            - entity: Retrieve an entity by ID.
            - neighbors: Retrieve neighbors of an entity.
        """
        if query_type == "shortest_path":
            return self.find_shortest_path(kwargs["start_entity"], kwargs["end_entity"])
        elif query_type == "all_paths":
            return self.find_all_paths(kwargs["start_entity"], kwargs["end_entity"], kwargs.get("cutoff", 3))
        elif query_type == "entity":
            return self.get_entity(kwargs["entity_id"])
        elif query_type == "neighbors":
            return self.get_neighbors(kwargs["entity_id"])
        else:
            raise ValueError(f"Unsupported query type: {query_type}")

    def save_graph(self):
        """Saves the graph data to disk."""
        entities = []
        for node in self.graph.nodes(data=True):
            entities.append(node[1])  # Node data is the second element

        relationships = []
        for edge in self.graph.edges(data=True):
            relationships.append({
                "from": edge[0],
                "to": edge[1],
                **edge[2]  # Edge data
            })

        with open(self.entities_path, 'w', encoding='utf-8') as f:
            for entity in entities:
                json.dump(entity, f, ensure_ascii=False)
                f.write('\n')

        with open(self.relationships_path, 'w', encoding='utf-8') as f:
            for relationship in relationships:
                json.dump(relationship, f, ensure_ascii=False)
                f.write('\n')
        print("Graph saved to disk.")


if __name__ == "__main__":
    # Example Usage
    kg_engine = KnowledgeGraphEngine()

    # 1. Add Entities
    kg_engine.add_entity("person1", "Person", name="Alice", skill="Python")
    kg_engine.add_entity("concept1", "Concept", name="Machine Learning")
    kg_engine.add_entity("tool1", "Tool", name="TensorFlow")
    kg_engine.add_entity("patent1", "Patent", title="AI Invention")
    kg_engine.add_entity("revenue1", "Revenue", amount=1000000)

    # 2. Add Relationships
    kg_engine.add_relationship("person1", "concept1", "studies")
    kg_engine.add_relationship("person1", "tool1", "uses")
    kg_engine.add_relationship("concept1", "tool1", "implements")
    kg_engine.add_relationship("tool1", "patent1", "relates_to")
    kg_engine.add_relationship("concept1", "revenue1", "generates")

    # 3. Deduplication Test
    kg_engine.add_entity("person1", "Person", name="Alice", skill="Java")  # Should be skipped

    # 4. Text Extraction and Relationship Inference
    text_example = "Alice is a skilled Python programmer who uses TensorFlow for machine learning."
    kg_engine.infer_relationships(text_example, "person1")

    # 5. Query Interface Examples
    shortest_path = kg_engine.query_graph("shortest_path", start_entity="person1", end_entity="patent1")
    print(f"Shortest path between person1 and patent1: {shortest_path}")

    all_paths = kg_engine.query_graph("all_paths", start_entity="person1", end_entity="patent1", cutoff=4)
    print(f"All paths between person1 and patent1 (cutoff=4): {all_paths}")

    entity_info = kg_engine.query_graph("entity", entity_id="person1")
    print(f"Information about person1: {entity_info}")

    neighbors = kg_engine.query_graph("neighbors", entity_id="concept1")
    print(f"Neighbors of concept1: {neighbors}")

    # 6. Hybrid Retrieval Example
    vector_results = ["concept1", "tool1"]  # Example vector search results
    hybrid_results = kg_engine.hybrid_retrieve("machine learning tools", vector_results)
    print(f"Hybrid retrieval results: {json.dumps(hybrid_results, indent=2)}")

    # 7. Save the graph
    kg_engine.save_graph()

    # 8. Load graph from existing entities.jsonl and relationships.jsonl
    print("\nLoading graph from existing files...")
    kg_engine_loaded = KnowledgeGraphEngine() # Loads the existing graph
    print(f"Graph loaded with {kg_engine_loaded.graph.number_of_nodes()} nodes and {kg_engine_loaded.graph.number_of_edges()} edges.")

    # Demonstrate querying the loaded graph
    entity_info_loaded = kg_engine_loaded.query_graph("entity", entity_id="person1")
    print(f"Information about person1 (from loaded graph): {entity_info_loaded}")