# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import networkx as nx
from collections import defaultdict
import re

class KnowledgeGraphEngine:
    """
    Knowledge Graph Engine for AIVA's knowledge graph system.
    """
    def __init__(self, workspace_path: str = "aiva_kg", vector_engine=None):
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "KNOWLEDGE_GRAPH"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"
        self.graph = nx.Graph()
        self.vector_engine = vector_engine  # Optional: Integration with vector embeddings
        os.makedirs(self.kg_dir, exist_ok=True)  # Ensure directory exists
        self._load_graph()

    def _load_graph(self):
        """Loads entities and relationships from disk."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            print("KG files not found, creating new KG")
            return

        # Load Entities
        with open(self.entities_path, "r", encoding="utf-8") as f:
            for line in f:
                entity = json.loads(line)
                self.graph.add_node(entity["id"], **entity)

        # Load Relationships
        with open(self.relationships_path, "r", encoding="utf-8") as f:
            for line in f:
                rel = json.loads(line)
                self.graph.add_edge(rel["from"], rel["to"], **rel)

    def _save_graph(self):
        """Saves the graph to disk."""
        # Save Entities
        with open(self.entities_path, "w", encoding="utf-8") as f:
            for node_id, data in self.graph.nodes(data=True):
                json.dump(data, f, ensure_ascii=False)
                f.write("\n")

        # Save Relationships
        with open(self.relationships_path, "w", encoding="utf-8") as f:
            for from_node, to_node, data in self.graph.edges(data=True):
                rel = {
                    "from": from_node,
                    "to": to_node,
                    **data
                }
                json.dump(rel, f, ensure_ascii=False)
                f.write("\n")

    def add_entity(self, entity_id: str, entity_type: str, **attributes):
        """Adds a new entity to the graph or updates an existing one."""
        if entity_id in self.graph:
            # Update existing entity
            for key, value in attributes.items():
                self.graph.nodes[entity_id][key] = value
        else:
            # Add new entity
            self.graph.add_node(entity_id, id=entity_id, type=entity_type, **attributes)
        self._save_graph()

    def add_relationship(self, from_entity: str, to_entity: str, relationship_type: str, **attributes):
        """Adds a relationship between two entities."""
        if from_entity not in self.graph or to_entity not in self.graph:
            raise ValueError("One or both entities do not exist in the graph.")

        if self.graph.has_edge(from_entity, to_entity):
          # Update existing edge
          for key, value in attributes.items():
            self.graph[from_entity][to_entity][key] = value
        else:
          # Add new edge
          self.graph.add_edge(from_entity, to_entity, type=relationship_type, **attributes)
        self._save_graph()

    def remove_entity(self, entity_id: str):
        """Removes an entity from the graph."""
        if entity_id not in self.graph:
            return  # Entity doesn't exist

        self.graph.remove_node(entity_id)
        self._save_graph()

    def remove_relationship(self, from_entity: str, to_entity: str):
        """Removes a relationship between two entities."""
        if not self.graph.has_edge(from_entity, to_entity):
            return  # Relationship doesn't exist

        self.graph.remove_edge(from_entity, to_entity)
        self._save_graph()

    def find_shortest_path(self, start_entity: str, end_entity: str) -> Optional[List[str]]:
        """Finds the shortest path between two entities."""
        try:
            path = nx.shortest_path(self.graph, source=start_entity, target=end_entity)
            return path
        except nx.NetworkXNoPath:
            return None  # No path exists
        except nx.NetworkXError:
            return None  # One or both entities don't exist

    def find_neighbors(self, entity_id: str, relationship_type: Optional[str] = None) -> List[str]:
        """Finds the neighbors of a given entity, optionally filtered by relationship type."""
        neighbors = []
        for neighbor in self.graph.neighbors(entity_id):
            if relationship_type is None or self.graph.get_edge_data(entity_id, neighbor).get("type") == relationship_type:
                neighbors.append(neighbor)
        return neighbors

    def search_entities(self, query: str, entity_types: Optional[List[str]] = None, limit: int = 10) -> List[str]:
        """Searches for entities based on a query using vector embeddings (if available)."""
        if self.vector_engine is None:
            # Basic text search (replace with vector search if vector_engine is available)
            results = []
            for node_id, data in self.graph.nodes(data=True):
                if entity_types and data.get("type") not in entity_types:
                    continue
                if query.lower() in node_id.lower() or query.lower() in str(data.values()).lower():
                    results.append(node_id)
                if len(results) >= limit:
                    break
            return results
        else:
            # Use vector engine to search for similar entities
            return self.vector_engine.search(query, entity_types=entity_types, limit=limit)

    def extract_entities_from_text(self, text: str) -> List[Tuple[str, str, Dict]]:
      """
      Extracts entities and their attributes from unstructured text using regex patterns.
      Returns a list of tuples: (entity_id, entity_type, attributes).
      """

      entities = []

      # Example: Extract "Patent" entities with IDs like "US1234567B2"
      patent_pattern = r"(Patent)\s+(US\d+[A-Z]\d*)"
      for match in re.finditer(patent_pattern, text, re.IGNORECASE):
          entity_type = match.group(1)
          entity_id = match.group(2)
          entities.append((entity_id, entity_type, {}))

      # Example: Extract "Person" entities with names starting with a capital letter
      person_pattern = r"(Person)\s+([A-Z][a-z]+ [A-Z][a-z]+)"
      for match in re.finditer(person_pattern, text, re.IGNORECASE):
          entity_type = match.group(1)
          entity_id = match.group(2)  # Use the name as the ID
          entities.append((entity_id, entity_type, {}))

      # Example: Extract "Revenue" entities with dollar amounts
      revenue_pattern = r"(Revenue)\s+\$(\d+)"
      for match in re.finditer(revenue_pattern, text, re.IGNORECASE):
          entity_type = match.group(1)
          entity_id = f"REVENUE_{match.group(2)}"  # Create a unique ID
          attributes = {"amount": int(match.group(2))}
          entities.append((entity_id, entity_type, attributes))

      return entities

    def process_text_and_update_graph(self, text: str):
      """
      Extracts entities from text and adds them to the graph, also handling relationships.
      """
      extracted_entities = self.extract_entities_from_text(text)

      for entity_id, entity_type, attributes in extracted_entities:
          self.add_entity(entity_id, entity_type, **attributes)
          print(f"Added/Updated entity: {entity_id} ({entity_type})")

      # Example: Hypothetical relationship extraction (modify as needed)
      # This could involve looking for keywords like "uses", "depends on", etc.
      # and then linking the extracted entities accordingly.
      # For now, let's assume the first two extracted entities are related:
      if len(extracted_entities) >= 2:
          entity1_id, _, _ = extracted_entities[0]
          entity2_id, _, _ = extracted_entities[1]
          self.add_relationship(entity1_id, entity2_id, "related_to", source="text_extraction")
          print(f"Added relationship: {entity1_id} -> {entity2_id} (related_to)")


    def get_entity_details(self, entity_id: str) -> Optional[Dict]:
        """Retrieves details for a specific entity."""
        if entity_id in self.graph:
            return self.graph.nodes[entity_id]
        else:
            return None

    def export_graph(self, filepath: str):
        """Exports the entire graph to a JSON file."""
        graph_data = {
            "nodes": list(self.graph.nodes(data=True)),
            "edges": list(self.graph.edges(data=True))
        }
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump(graph_data, f, indent=2, ensure_ascii=False)

    def import_graph(self, filepath: str):
        """Imports a graph from a JSON file."""
        try:
            with open(filepath, "r", encoding="utf-8") as f:
                graph_data = json.load(f)

            self.graph.clear() # Clear existing graph

            # Add nodes
            for node_id, node_data in graph_data["nodes"]:
                self.graph.add_node(node_id, **node_data)

            # Add edges
            for from_node, to_node, edge_data in graph_data["edges"]:
                self.graph.add_edge(from_node, to_node, **edge_data)
            self._save_graph()
            print(f"Graph imported successfully from {filepath}")

        except FileNotFoundError:
            print(f"Error: File not found at {filepath}")
        except json.JSONDecodeError:
            print(f"Error: Invalid JSON format in {filepath}")
        except Exception as e:
            print(f"An error occurred during import: {e}")

    def get_graph_stats(self) -> Dict:
      """Returns basic statistics about the graph."""
      return {
          "num_nodes": self.graph.number_of_nodes(),
          "num_edges": self.graph.number_of_edges(),
          "node_types": defaultdict(int, [(data.get("type"), 1) for _, data in self.graph.nodes(data=True)]),
          "edge_types": defaultdict(int, [(data.get("type"), 1) for _, _, data in self.graph.edges(data=True)])
      }


# Example Usage (requires aiva_kg directory to exist):
if __name__ == "__main__":
    # Initialize the Knowledge Graph Engine
    kg_engine = KnowledgeGraphEngine()

    # Add some entities
    kg_engine.add_entity("person1", "Person", name="Alice", skill="AI Development")
    kg_engine.add_entity("concept1", "Concept", name="Reinforcement Learning")
    kg_engine.add_entity("tool1", "Tool", name="TensorFlow")

    # Add some relationships
    kg_engine.add_relationship("person1", "concept1", "works_on")
    kg_engine.add_relationship("person1", "tool1", "uses")
    kg_engine.add_relationship("concept1", "tool1", "implemented_in")

    # Search for entities
    search_results = kg_engine.search_entities("AI", entity_types=["Person", "Concept"])
    print(f"Search results for 'AI': {search_results}")

    # Find shortest path
    path = kg_engine.find_shortest_path("person1", "tool1")
    print(f"Shortest path between person1 and tool1: {path}")

    # Extract entities from text
    text_example = "Patent US1234567B2 describes a new AI algorithm. Person John Doe is the inventor. Revenue $1000 was generated."
    kg_engine.process_text_and_update_graph(text_example)

    # Get graph statistics
    stats = kg_engine.get_graph_stats()
    print(f"Graph Statistics: {stats}")

    # Export and Import graph
    kg_engine.export_graph("aiva_kg/KNOWLEDGE_GRAPH/graph_export.json")
    kg_engine2 = KnowledgeGraphEngine() # new instance
    kg_engine2.import_graph("aiva_kg/KNOWLEDGE_GRAPH/graph_export.json")
    stats2 = kg_engine2.get_graph_stats()
    print(f"Imported Graph Statistics: {stats2}")

    # Remove an entity
    kg_engine.remove_entity("concept1")
    print(f"Graph stats after removing concept1: {kg_engine.get_graph_stats()}")