# knowledge_graph_engine.py
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import networkx as nx
from collections import defaultdict
import re

class KnowledgeGraphEngine:
    """
    Advanced Knowledge Graph Engine with deduplication, incremental updates,
    text extraction and path finding.
    """
    def __init__(self, workspace_path: str = "aiva_kg"):
        self.workspace = Path(workspace_path)
        self.kg_dir = self.workspace / "KNOWLEDGE_GRAPH"
        self.entities_path = self.kg_dir / "entities.jsonl"
        self.relationships_path = self.kg_dir / "relationships.jsonl"
        self.graph = nx.Graph()
        self.entity_index = {}  # For deduplication
        os.makedirs(self.kg_dir, exist_ok=True)
        self._load_graph()

    def _load_graph(self):
        """Loads entities and relationships into a NetworkX graph."""
        if not self.entities_path.exists() or not self.relationships_path.exists():
            return

        # Load Entities
        with open(self.entities_path, "r", encoding="utf-8") as f:
            for line in f:
                entity = json.loads(line)
                self.add_entity(entity, load=True) # Use add_entity to populate index

        # Load Relationships
        with open(self.relationships_path, "r", encoding="utf-8") as f:
            for line in f:
                rel = json.loads(line)
                self.add_relationship(rel["from"], rel["to"], rel, load=True)


    def add_entity(self, entity: Dict[str, Any], load: bool = False) -> str:
        """Adds an entity to the graph, handling deduplication."""
        entity_id = entity["id"]
        if entity_id in self.entity_index:
            # Entity already exists, update it (merge attributes)
            existing_entity = self.entity_index[entity_id]
            merged_entity = {**existing_entity, **entity}  # New values overwrite old
            self.graph.nodes[entity_id].update(merged_entity) # Update the graph node
            self.entity_index[entity_id] = merged_entity  # Update index
            print(f"Updated entity: {entity_id}")
        else:
            # New entity
            self.graph.add_node(entity_id, **entity)
            self.entity_index[entity_id] = entity
            print(f"Added entity: {entity_id}")

        if not load:
            self._save_entity(entity)
        return entity_id

    def _save_entity(self, entity: Dict[str, Any]):
        """Appends the entity to the entities.jsonl file."""
        with open(self.entities_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(entity) + "\n")

    def add_relationship(self, from_id: str, to_id: str, relationship: Dict[str, Any], load: bool = False):
        """Adds a relationship to the graph."""
        if from_id not in self.entity_index or to_id not in self.entity_index:
            print(f"Warning: Cannot add relationship. Entity {from_id} or {to_id} does not exist.")
            return

        self.graph.add_edge(from_id, to_id, **relationship)
        print(f"Added relationship: {relationship.get('type', 'related')} from {from_id} to {to_id}")

        if not load:
             self._save_relationship(relationship)

    def _save_relationship(self, relationship: Dict[str, Any]):
        """Appends the relationship to the relationships.jsonl file."""
        with open(self.relationships_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(relationship) + "\n")

    def extract_entities_from_text(self, text: str) -> List[Dict[str, Any]]:
        """
        Extracts entities and relationships from unstructured text.
        This is a simplified example and would ideally use an NLP model.
        """
        # Define regex patterns for different entity types
        person_pattern = r"(Dr\.\s*)?([A-Z][a-z]+)\s+([A-Z][a-z]+)"  # Example: John Doe, Dr. Jane Smith
        concept_pattern = r"\b([A-Z][a-z]+(?: [A-Z][a-z]+)*)\b"  # Simple noun phrase extraction
        skill_pattern = r"(?:skill|expertise) in ([a-zA-Z\s]+)"
        tool_pattern = r"(?:using|with) ([A-Za-z0-9\.]+)"
        revenue_pattern = r"\$\s*[\d,]+(?:\.\d+)?"

        entities = []

        # Extract people
        for match in re.finditer(person_pattern, text):
            title, first_name, last_name = match.groups()
            entity_id = f"PERSON_{first_name}_{last_name}"
            entities.append({
                "id": entity_id,
                "type": "Person",
                "name": f"{first_name} {last_name}",
                "title": title if title else None,
                "source": "text_extraction"
            })

        # Extract concepts
        for match in re.finditer(concept_pattern, text):
            concept_name = match.group(0)
            entity_id = f"CONCEPT_{concept_name.replace(' ', '_')}"
            entities.append({
                "id": entity_id,
                "type": "Concept",
                "name": concept_name,
                "source": "text_extraction"
            })
        # Extract skills
        for match in re.finditer(skill_pattern, text):
            skill_name = match.group(1)
            entity_id = f"SKILL_{skill_name.replace(' ', '_')}"
            entities.append({
                "id": entity_id,
                "type": "Skill",
                "name": skill_name,
                "source": "text_extraction"
            })

        # Extract tools
        for match in re.finditer(tool_pattern, text):
            tool_name = match.group(1)
            entity_id = f"TOOL_{tool_name.replace(' ', '_')}"
            entities.append({
                "id": entity_id,
                "type": "Tool",
                "name": tool_name,
                "source": "text_extraction"
            })

        # Extract revenue
        for match in re.finditer(revenue_pattern, text):
            revenue_amount = match.group(0)
            entity_id = f"REVENUE_{revenue_amount.replace('$', '').replace('.', '_')}"
            entities.append({
                "id": entity_id,
                "type": "Revenue",
                "amount": revenue_amount,
                "source": "text_extraction"
            })

        return entities

    def find_shortest_path(self, start_entity: str, end_entity: str) -> Optional[List[str]]:
        """Finds the shortest path between two entities in the graph."""
        try:
            path = nx.shortest_path(self.graph, source=start_entity, target=end_entity)
            return path
        except nx.NetworkXNoPath:
            print(f"No path found between {start_entity} and {end_entity}")
            return None
        except nx.NetworkXError:
            print(f"Entity {start_entity} or {end_entity} does not exist.")
            return None

    def vector_search(self, query: str, top_k: int = 5) -> List[str]:
        """
        Placeholder for vector embedding based semantic search.
        In a real implementation, this would query a vector database (e.g., Faiss, Qdrant)
        and return the top_k most similar entity IDs.
        """
        # Dummy implementation returning some entities
        all_entity_ids = list(self.entity_index.keys())
        return all_entity_ids[:min(top_k, len(all_entity_ids))]


    def hybrid_search(self, query: str, top_k_vector: int = 5, hop_distance: int = 1) -> List[Dict[str, Any]]:
        """
        Combines vector search with graph traversal.
        1. Perform vector search to find semantically similar entities.
        2. Expand the search to neighbors within a specified hop distance.
        """
        vector_results = self.vector_search(query, top_k=top_k_vector)
        expanded_results = []
        visited = set(vector_results)

        for entity_id in vector_results:
            if entity_id in self.graph:
                entity_data = self.graph.nodes[entity_id]
                expanded_results.append({"id": entity_id, "data": entity_data, "source": "vector_search"})

                # Expand to neighbors
                for neighbor in self.graph.neighbors(entity_id):
                    if neighbor not in visited and hop_distance > 0:
                        neighbor_data = self.graph.nodes[neighbor]
                        expanded_results.append({"id": neighbor, "data": neighbor_data, "source": "graph_expansion", "related_to": entity_id})
                        visited.add(neighbor)

        return expanded_results

    def query_interface(self, query_type: str, **kwargs) -> Any:
        """
        Unified query interface for different types of queries.
        """
        if query_type == "path":
            start_entity = kwargs.get("start_entity")
            end_entity = kwargs.get("end_entity")
            if not start_entity or not end_entity:
                return "Error: start_entity and end_entity are required for path queries."
            return self.find_shortest_path(start_entity, end_entity)
        elif query_type == "hybrid_search":
            query = kwargs.get("query")
            if not query:
                return "Error: query is required for hybrid_search queries."
            top_k_vector = kwargs.get("top_k_vector", 5)
            hop_distance = kwargs.get("hop_distance", 1)
            return self.hybrid_search(query, top_k_vector, hop_distance)
        else:
            return f"Error: Unsupported query type: {query_type}"

    def get_entity(self, entity_id: str) -> Optional[Dict[str, Any]]:
        """Retrieves an entity by its ID."""
        if entity_id in self.entity_index:
            return self.entity_index[entity_id]
        else:
            return None

    def get_relationships(self, entity_id: str) -> List[Tuple[str, str, Dict[str, Any]]]:
        """Retrieves all relationships for a given entity."""
        relationships = []
        for neighbor in self.graph.neighbors(entity_id):
            edge_data = self.graph.get_edge_data(entity_id, neighbor)
            relationships.append((entity_id, neighbor, edge_data))
        return relationships

if __name__ == "__main__":
    # Example Usage
    kg_engine = KnowledgeGraphEngine()

    # 1. Add Entities
    entity1 = {"id": "PERSON_1", "type": "Person", "name": "Alice", "skill": "coding"}
    entity2 = {"id": "TOOL_1", "type": "Tool", "name": "Python"}
    entity3 = {"id": "CONCEPT_1", "type": "Concept", "name": "Machine Learning"}
    kg_engine.add_entity(entity1)
    kg_engine.add_entity(entity2)
    kg_engine.add_entity(entity3)

    # 2. Add Relationships
    kg_engine.add_relationship("PERSON_1", "TOOL_1", {"type": "uses"})
    kg_engine.add_relationship("PERSON_1", "CONCEPT_1", {"type": "implements"})

    # 3. Extract entities from text
    text_data = "Dr. John Smith has expertise in Data Science using Python and R. He earns $100,000 per year."
    extracted_entities = kg_engine.extract_entities_from_text(text_data)
    for entity in extracted_entities:
        kg_engine.add_entity(entity)

    # 4. Query for shortest path
    path = kg_engine.query_interface("path", start_entity="PERSON_1", end_entity="CONCEPT_1")
    print(f"Shortest path between PERSON_1 and CONCEPT_1: {path}")

    # 5. Hybrid Search
    hybrid_results = kg_engine.query_interface("hybrid_search", query="AI Tools")
    print(f"Hybrid search results for 'AI Tools': {json.dumps(hybrid_results, indent=2)}")

    # 6. Retrieve entity information
    entity_info = kg_engine.get_entity("PERSON_1")
    print(f"Entity information for PERSON_1: {entity_info}")

    # 7. Retrieve relationships for an entity
    entity_relationships = kg_engine.get_relationships("PERSON_1")
    print(f"Relationships for PERSON_1: {entity_relationships}")

    print(f"Graph has {kg_engine.graph.number_of_nodes()} nodes and {kg_engine.graph.number_of_edges()} edges.")