# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional
import networkx as nx
from collections import defaultdict
import re

class KnowledgeGraphEngine:
    """
    Knowledge Graph Engine: Stores entities with relationships, supports graph traversal,
    integrates with vector embeddings, handles deduplication, supports incremental updates,
    and extracts entities from unstructured text.
    """
    def __init__(self, workspace_path: str = "e:/genesis-system", embedding_model=None):
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "KNOWLEDGE_GRAPH"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"
        self.graph = nx.Graph()
        self.embedding_model = embedding_model  # Optional: Vector embedding model
        self._load_graph()

    def _load_graph(self):
        """Loads entities and relationships from JSONL files into the graph."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            return

        entity_ids = set()
        with open(self.entities_path, "r", encoding="utf-8") as f:
            for line in f:
                entity = json.loads(line)
                entity_id = entity["id"]
                if entity_id not in entity_ids:  # Deduplication during load
                    self.graph.add_node(entity_id, **entity)
                    entity_ids.add(entity_id)

        with open(self.relationships_path, "r", encoding="utf-8") as f:
            for line in f:
                rel = json.loads(line)
                self.graph.add_edge(rel["from"], rel["to"], **rel)

    def add_entity(self, entity_data: Dict[str, Any], deduplicate: bool = True):
        """Adds a new entity to the graph. Handles deduplication."""
        entity_id = entity_data["id"]

        if deduplicate and entity_id in self.graph:
            # Optionally, update existing entity data instead of rejecting
            # self.graph.nodes[entity_id].update(entity_data)
            print(f"Entity with ID '{entity_id}' already exists. Skipping.")
            return False

        self.graph.add_node(entity_id, **entity_data)
        self._append_to_jsonl(self.entities_path, entity_data)
        return True

    def add_relationship(self, from_entity: str, to_entity: str, relationship_data: Dict[str, Any]):
        """Adds a new relationship between entities in the graph."""
        if from_entity not in self.graph or to_entity not in self.graph:
            print("Error: One or both entities do not exist in the graph.")
            return False

        self.graph.add_edge(from_entity, to_entity, **relationship_data)
        self._append_to_jsonl(self.relationships_path, {
            "from": from_entity,
            "to": to_entity,
            **relationship_data
        })
        return True

    def _append_to_jsonl(self, file_path: Path, data: Dict[str, Any]):
        """Appends a JSON object to a JSONL file."""
        with open(file_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(data, ensure_ascii=False) + "\n")

    def find_shortest_path(self, start_entity: str, end_entity: str) -> Optional[List[str]]:
        """Finds the shortest path between two entities in the graph."""
        try:
            return nx.shortest_path(self.graph, source=start_entity, target=end_entity)
        except nx.NetworkXNoPath:
            return None
        except nx.NetworkXError:
            return None

    def graph_traverse(self, start_entity: str, relationship_type: str, max_depth: int = 2) -> List[Dict]:
      """
      Traverses the graph starting from a given entity, filtering by relationship type, up to a maximum depth.
      """
      results = []
      visited = {start_entity}
      queue = [(start_entity, 0)]

      while queue:
          node_id, depth = queue.pop(0)
          if depth >= max_depth:
              continue

          for neighbor in self.graph.neighbors(node_id):
              edge_data = self.graph.get_edge_data(node_id, neighbor)
              if edge_data and edge_data.get("type") == relationship_type:
                  if neighbor not in visited:
                      visited.add(neighbor)
                      results.append({
                          "from": node_id,
                          "to": neighbor,
                          "relationship": relationship_type,
                          "depth": depth + 1,
                          "entity_data": self.graph.nodes[neighbor]
                      })
                      queue.append((neighbor, depth + 1))

      return results


    def semantic_search(self, query: str, top_k: int = 5) -> List[str]:
        """
        Performs a semantic search using vector embeddings.
        Requires an embedding model to be initialized.

        Returns a list of entity IDs that are semantically similar to the query.
        """
        if not self.embedding_model:
            print("Error: Embedding model not initialized.")
            return []

        # 1. Embed the query
        query_embedding = self.embedding_model.encode(query)

        # 2. Embed all entity descriptions (or a representative text field)
        entity_embeddings = {}
        for node_id, node_data in self.graph.nodes(data=True):
            # Assuming each entity has a 'description' field. Adapt as needed.
            description = node_data.get("title", node_data.get("name", ""))  # Prioritize title, then name
            if description:
                entity_embeddings[node_id] = self.embedding_model.encode(description)
            else:
                entity_embeddings[node_id] = None  # Or skip this entity

        # 3. Calculate cosine similarity between query and entity embeddings
        similarities = {}
        for entity_id, entity_embedding in entity_embeddings.items():
            if entity_embedding is not None:
                similarity = self._cosine_similarity(query_embedding, entity_embedding)
                similarities[entity_id] = similarity
            else:
                similarities[entity_id] = -1  # Penalize entities without descriptions

        # 4. Sort entities by similarity and return top_k results
        sorted_entities = sorted(similarities.items(), key=lambda item: item[1], reverse=True)
        top_results = [entity_id for entity_id, similarity in sorted_entities[:top_k]]
        return top_results

    def _cosine_similarity(self, vec1, vec2):
        """Calculates the cosine similarity between two vectors."""
        import numpy as np
        dot_product = np.dot(vec1, vec2)
        magnitude_vec1 = np.linalg.norm(vec1)
        magnitude_vec2 = np.linalg.norm(vec2)
        if magnitude_vec1 == 0 or magnitude_vec2 == 0:
            return 0
        return dot_product / (magnitude_vec1 * magnitude_vec2)

    def extract_entities_from_text(self, text: str) -> List[Dict[str, Any]]:
        """
        Extracts entities from unstructured text using a rule-based approach.
        This is a simplified example and can be enhanced with NLP techniques.
        """
        # Define entity patterns (can be improved with NER models)
        entity_patterns = {
            "Person": r"\b[A-Z][a-z]+\s[A-Z][a-z]+\b",  # Simple name pattern
            "Tool": r"\b(Gemini|ChatGPT|SORA|GoHighLevel|Telnyx|Instantly\.ai|Apify|Pomelli|Veo|n8n|PostgreSQL|Redis|Qdrant|ElevenLabs|Nano Banana Pro|Skywork AI|NotebookLM|Baidu Ernie 5\.0|Opal|DeepCode|Sim AI|AgentKit|Perplexity AI|HeyGen AI)\b",
            "Revenue": r"\$\d+(?:\,\d{3})*(?:\.\d{2})?", # Matches dollar amounts
            "Skill": r"\b(SEO|Automation|Coding|Web Dev|Marketing|Sales|AI|Data Analysis)\b",
            "Concept": r"\b(AI Agents|Funnel Builder|Cash Flow|Affiliate Marketing|Digital Asset|Brand Story|Market Analysis|Database Reactivation|Lead Generation)\b",
            "Patent": r"\bP[0-9]+\b"
        }

        extracted_entities = []
        for entity_type, pattern in entity_patterns.items():
            matches = re.findall(pattern, text)
            for match in matches:
                entity_id = f"{entity_type}_{match.replace(' ', '_')}"  # Unique ID
                entity_data = {
                    "id": entity_id,
                    "type": entity_type,
                    "name": match,
                    "source": "text_extraction"
                }
                extracted_entities.append(entity_data)

        return extracted_entities

    def process_text_and_ingest(self, text: str):
        """
        Extracts entities from text and adds them to the knowledge graph.
        """
        extracted_entities = self.extract_entities_from_text(text)
        for entity in extracted_entities:
            self.add_entity(entity)
        return extracted_entities

if __name__ == "__main__":
    # Example Usage
    kg_engine = KnowledgeGraphEngine()

    # 1. Adding Entities
    entity1 = {"id": "person_1", "type": "Person", "name": "Alice", "age": 30}
    entity2 = {"id": "concept_1", "type": "Concept", "name": "Machine Learning"}
    kg_engine.add_entity(entity1)
    kg_engine.add_entity(entity2)

    # 2. Adding Relationships
    kg_engine.add_relationship("person_1", "concept_1", {"type": "studies"})

    # 3. Graph Traversal
    path = kg_engine.find_shortest_path("person_1", "concept_1")
    print(f"Shortest path between person_1 and concept_1: {path}")

    # Example of graph_traverse
    traversal_results = kg_engine.graph_traverse(start_entity="person_1", relationship_type="studies", max_depth=2)
    print(f"Graph traversal results from person_1 with 'studies' relationship: {traversal_results}")

    # 4. Text Extraction and Ingestion
    text = "Alice is learning about Machine Learning and using Python. Bob uses Gemini."
    extracted = kg_engine.process_text_and_ingest(text)
    print(f"Extracted entities from text: {extracted}")

    #Example of semantic search (requires an embedding model)
    # from sentence_transformers import SentenceTransformer
    # embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
    # kg_engine_with_embeddings = KnowledgeGraphEngine(embedding_model=embedding_model)

    # kg_engine_with_embeddings.add_entity({"id": "entity_x", "type": "Concept", "title": "Advanced AI Techniques"})
    # kg_engine_with_embeddings.add_entity({"id": "entity_y", "type": "Concept", "title": "Basic Machine Learning"})

    # search_results = kg_engine_with_embeddings.semantic_search("Learning AI", top_k=2)
    # print(f"Semantic Search Results: {search_results}")

    # Print the graph summary
    print(f"Graph has {kg_engine.graph.number_of_nodes()} nodes and {kg_engine.graph.number_of_edges()} edges.")