# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import networkx as nx
from collections import defaultdict
import re

class KnowledgeGraphEngine:
    """
    A knowledge graph system that stores entities and relationships, supports graph
    traversal queries, integrates with vector embeddings for semantic search,
    handles entity deduplication, supports incremental updates, and extracts
    entities from unstructured text.
    """

    def __init__(self, workspace_path: str = "aiva_kg"):
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "knowledge_graph"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"
        self.graph = nx.MultiDiGraph()  # Use MultiDiGraph to allow multiple relationships between nodes
        self._ensure_directories()
        self._load_graph()

    def _ensure_directories(self):
        """Ensures the workspace and KG directories exist."""
        self.workspace.mkdir(exist_ok=True)
        self.kg_dir.mkdir(exist_ok=True)

    def _load_graph(self):
        """Loads entities and relationships into a NetworkX graph."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            print("No existing knowledge graph found. Starting with an empty graph.")
            return

        # Load Entities
        try:
            with open(self.entities_path, "r", encoding="utf-8") as f:
                for line in f:
                    try:
                        entity = json.loads(line)
                        self.add_entity(entity, load=True)
                    except json.JSONDecodeError:
                        print(f"Skipping invalid JSON line in entities.jsonl: {line.strip()}")
        except FileNotFoundError:
            print("Entities file not found.")
        except Exception as e:
            print(f"Error loading entities: {e}")


        # Load Relationships
        try:
            with open(self.relationships_path, "r", encoding="utf-8") as f:
                for line in f:
                    try:
                        rel = json.loads(line)
                        self.add_relationship(rel["from"], rel["to"], rel["type"], rel)
                    except json.JSONDecodeError:
                        print(f"Skipping invalid JSON line in relationships.jsonl: {line.strip()}")
        except FileNotFoundError:
            print("Relationships file not found.")
        except Exception as e:
            print(f"Error loading relationships: {e}")


    def add_entity(self, entity_data: Dict[str, Any], load: bool = False):
        """Adds an entity to the knowledge graph, handling deduplication."""
        entity_id = entity_data["id"]
        if self.graph.has_node(entity_id):
            # Update existing entity (deduplication)
            existing_data = self.graph.nodes[entity_id]
            updated_data = {**existing_data, **entity_data}  # Merge dictionaries
            self.graph.nodes[entity_id].update(updated_data)
            if not load:
                self._persist_entity(updated_data)
        else:
            self.graph.add_node(entity_id, **entity_data)
            if not load:
                self._persist_entity(entity_data)

    def _persist_entity(self, entity_data: Dict[str, Any]):
        """Persists an entity to the entities.jsonl file."""
        try:
            with open(self.entities_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(entity_data, ensure_ascii=False) + "\n")
        except Exception as e:
            print(f"Error persisting entity: {e}")

    def add_relationship(self, from_entity: str, to_entity: str, relationship_type: str, relationship_data: Dict[str, Any]):
        """Adds a relationship between two entities."""
        if not self.graph.has_node(from_entity) or not self.graph.has_node(to_entity):
            print(f"Warning: Adding relationship between non-existent entities: {from_entity} and {to_entity}")
            return

        self.graph.add_edge(from_entity, to_entity, type=relationship_type, **relationship_data)
        self._persist_relationship(from_entity, to_entity, relationship_type, relationship_data)

    def _persist_relationship(self, from_entity: str, to_entity: str, relationship_type: str, relationship_data: Dict[str, Any]):
         """Persists a relationship to the relationships.jsonl file."""
         try:
             with open(self.relationships_path, "a", encoding="utf-8") as f:
                 rel_data = {
                     "from": from_entity,
                     "to": to_entity,
                     "type": relationship_type,
                     **relationship_data
                 }
                 f.write(json.dumps(rel_data, ensure_ascii=False) + "\n")
         except Exception as e:
             print(f"Error persisting relationship: {e}")


    def find_shortest_path(self, start_entity: str, end_entity: str) -> Optional[List[str]]:
        """Finds the shortest path between two entities in the graph."""
        try:
            return nx.shortest_path(self.graph, start_entity, end_entity)
        except nx.NetworkXNoPath:
            return None
        except nx.NetworkXError:
            return None

    def find_all_paths(self, start_entity: str, end_entity: str, max_depth: int = 3) -> List[List[str]]:
        """Finds all paths between two entities up to a maximum depth."""
        paths = []
        for path in nx.all_simple_paths(self.graph, start_entity, end_entity, cutoff=max_depth):
            paths.append(path)
        return paths

    def get_neighbors(self, entity_id: str, relationship_type: Optional[str] = None) -> List[str]:
        """
        Retrieves the neighbors of an entity, optionally filtered by relationship type.
        """
        neighbors = []
        for neighbor in self.graph.neighbors(entity_id):
            for edge_data in self.graph.get_edge_data(entity_id, neighbor).values():
                if relationship_type is None or edge_data.get("type") == relationship_type:
                    neighbors.append(neighbor)
                    break  # Avoid duplicates if multiple edges of the same type exist
        return neighbors

    def search_entities(self, query: str, entity_types: Optional[List[str]] = None) -> List[Dict[str, Any]]:
        """Searches for entities whose attributes match the query (simple string matching)."""
        results = []
        for node_id, data in self.graph.nodes(data=True):
            if entity_types and data.get("type") not in entity_types:
                continue  # Skip if entity type doesn't match

            # Simple string matching against all attribute values
            for value in data.values():
                if isinstance(value, str) and query.lower() in value.lower():
                    results.append(data)
                    break  # Only add the entity once

        return results

    def extract_entities_from_text(self, text: str) -> List[Tuple[str, str]]:
        """
        Extracts entities from unstructured text using simple regex patterns.
        This is a placeholder; a real implementation would use NLP techniques.

        Returns:
            List of (entity_name, entity_type) tuples.
        """
        entities = []

        # Example patterns (extend as needed)
        person_pattern = r"\b[A-Z][a-z]+ [A-Z][a-z]+\b"  # Simple name pattern
        concept_pattern = r"\b(artificial intelligence|machine learning|blockchain)\b"
        skill_pattern = r"\b(programming|marketing|sales)\b"

        for match in re.findall(person_pattern, text):
            entities.append((match, "Person"))
        for match in re.findall(concept_pattern, text, re.IGNORECASE):
            entities.append((match, "Concept"))
        for match in re.findall(skill_pattern, text, re.IGNORECASE):
            entities.append((match, "Skill"))

        return entities

    def process_text_and_update_graph(self, text: str, source_id: str):
        """
        Extracts entities and relationships from text and updates the knowledge graph.
        """
        extracted_entities = self.extract_entities_from_text(text)

        # Create entities if they don't exist
        for entity_name, entity_type in extracted_entities:
            entity_id = f"{entity_type.lower()}_{entity_name.replace(' ', '_')}"  # Generate a unique ID
            entity_data = {"id": entity_id, "type": entity_type, "name": entity_name, "source": source_id}
            self.add_entity(entity_data)

        # Create relationships (this is a placeholder - actual relationship extraction is complex)
        # Example:  If the text mentions "Person A uses Tool B", create a "uses" relationship
        for i in range(len(extracted_entities)):
            for j in range(i + 1, len(extracted_entities)):
                entity1_name, entity1_type = extracted_entities[i]
                entity2_name, entity2_type = extracted_entities[j]

                entity1_id = f"{entity1_type.lower()}_{entity1_name.replace(' ', '_')}"
                entity2_id = f"{entity2_type.lower()}_{entity2_name.replace(' ', '_')}"

                # Create a simple "related_to" relationship (improve this with NLP)
                self.add_relationship(entity1_id, entity2_id, "related_to", {"source": source_id, "confidence": 0.5})

    def export_graph_to_json(self, filepath: str):
        """Exports the entire graph (nodes and edges) to a JSON file."""
        graph_data = nx.node_link_data(self.graph)
        try:
            with open(filepath, "w", encoding="utf-8") as f:
                json.dump(graph_data, indent=2, ensure_ascii=False, fp=f)
            print(f"Graph exported to {filepath}")
        except Exception as e:
            print(f"Error exporting graph: {e}")


if __name__ == "__main__":
    # Example Usage
    kg_engine = KnowledgeGraphEngine()

    # Add some entities
    kg_engine.add_entity({"id": "person_alice", "type": "Person", "name": "Alice", "skill": "programming"})
    kg_engine.add_entity({"id": "concept_ai", "type": "Concept", "name": "Artificial Intelligence"})
    kg_engine.add_entity({"id": "tool_python", "type": "Tool", "name": "Python"})

    # Add a relationship
    kg_engine.add_relationship("person_alice", "tool_python", "uses", {"description": "Alice uses Python for development"})
    kg_engine.add_relationship("person_alice", "concept_ai", "studies", {"description": "Alice studies AI"})

    # Find shortest path
    path = kg_engine.find_shortest_path("person_alice", "concept_ai")
    print(f"Shortest path from Alice to AI: {path}")

    # Get neighbors
    neighbors = kg_engine.get_neighbors("person_alice")
    print(f"Neighbors of Alice: {neighbors}")

    # Search entities
    search_results = kg_engine.search_entities("Alice")
    print(f"Search results for 'Alice': {search_results}")

    # Extract entities from text
    text = "Alice and Bob are skilled programmers working with Artificial Intelligence and Python. Bob also knows marketing."
    extracted = kg_engine.extract_entities_from_text(text)
    print(f"Extracted entities from text: {extracted}")

    # Process text and update graph
    kg_engine.process_text_and_update_graph(text, "document_123")

    # Search entities after processing text
    search_results = kg_engine.search_entities("Bob")
    print(f"Search results for 'Bob' after processing text: {search_results}")

    # Export the graph to a JSON file
    kg_engine.export_graph_to_json("knowledge_graph.json")

    print(f"Graph has {kg_engine.graph.number_of_nodes()} nodes and {kg_engine.graph.number_of_edges()} edges.")