import networkx as nx
import json
from networkx.algorithms import centrality
from neo4j import GraphDatabase

class PatentKnowledgeGraph:
    """
    An advanced knowledge graph for patents, concepts, and algorithms, with Neo4j support.
    """

    def __init__(self, neo4j_uri=None, neo4j_user=None, neo4j_password=None):
        """
        Initializes the knowledge graph using NetworkX or Neo4j.

        Args:
            neo4j_uri (str, optional): The URI for the Neo4j database. Defaults to None (using NetworkX).
            neo4j_user (str, optional): The username for the Neo4j database. Defaults to None.
            neo4j_password (str, optional): The password for the Neo4j database. Defaults to None.
        """
        self.use_neo4j = neo4j_uri is not None and neo4j_user is not None and neo4j_password is not None
        if self.use_neo4j:
            self.driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
            self._create_indexes() # Ensure indexes for performance
        else:
            self.graph = nx.MultiDiGraph()

    def close(self):
        """Closes the Neo4j driver, if it's open."""
        if self.use_neo4j:
            self.driver.close()

    def _execute_query(self, query, parameters=None):
        """Executes a Cypher query against the Neo4j database."""
        with self.driver.session() as session:
            result = session.run(query, parameters)
            return result.data()

    def _create_indexes(self):
        """Creates indexes in Neo4j to speed up common queries."""
        queries = [
            "CREATE INDEX IF NOT EXISTS patent_id FOR (p:Patent) ON (p.patent_id)",
            "CREATE INDEX IF NOT EXISTS concept_name FOR (c:Concept) ON (c.concept_name)",
            "CREATE INDEX IF NOT EXISTS algorithm_name FOR (a:Algorithm) ON (a.algorithm_name)"
        ]
        for query in queries:
            self._execute_query(query)

    def add_patent(self, patent_id, title, claims, abstract):
        """
        Adds a patent node to the graph.

        Args:
            patent_id (str): The unique identifier for the patent.
            title (str): The title of the patent.
            claims (str): The claims of the patent.
            abstract (str): The abstract of the patent.
        """
        if self.use_neo4j:
            query = """
            CREATE (p:Patent {patent_id: $patent_id, title: $title, claims: $claims, abstract: $abstract})
            """
            parameters = {"patent_id": patent_id, "title": title, "claims": claims, "abstract": abstract}
            self._execute_query(query, parameters)
        else:
            self.graph.add_node(patent_id, type="Patent", title=title, claims=claims, abstract=abstract)

    def add_concept(self, concept_name, description=""):
        """
        Adds a concept node to the graph.

        Args:
            concept_name (str): The name of the concept.
            description (str, optional): A description of the concept. Defaults to "".
        """
        if self.use_neo4j:
            query = """
            CREATE (c:Concept {concept_name: $concept_name, description: $description})
            """
            parameters = {"concept_name": concept_name, "description": description}
            self._execute_query(query, parameters)
        else:
            self.graph.add_node(concept_name, type="Concept", description=description)

    def add_algorithm(self, algorithm_name, description=""):
        """
        Adds an algorithm node to the graph.

        Args:
            algorithm_name (str): The name of the algorithm.
            description (str, optional): A description of the algorithm. Defaults to "".
        """
        if self.use_neo4j:
            query = """
            CREATE (a:Algorithm {algorithm_name: $algorithm_name, description: $description})
            """
            parameters = {"algorithm_name": algorithm_name, "description": description}
            self._execute_query(query, parameters)
        else:
            self.graph.add_node(algorithm_name, type="Algorithm", description=description)

    def add_relationship(self, source_node, target_node, relation_type, attributes=None):
        """
        Adds a relationship between two nodes in the graph.

        Args:
            source_node (str): The identifier of the source node.
            target_node (str): The identifier of the target node.
            relation_type (str): The type of relationship between the nodes (e.g., "DEPENDS_ON", "IMPLEMENTS").
            attributes (dict, optional): Additional attributes for the relationship. Defaults to None.
        """
        if self.use_neo4j:
            query = f"""
            MATCH (s), (t)
            WHERE s.patent_id = $source_node OR s.concept_name = $source_node OR s.algorithm_name = $source_node
            AND   t.patent_id = $target_node OR t.concept_name = $target_node OR t.algorithm_name = $target_node
            CREATE (s)-[r:{relation_type} $attributes]->(t)
            """
            parameters = {"source_node": source_node, "target_node": target_node, "attributes": attributes or {}}
            self._execute_query(query, parameters)
        else:
            self.graph.add_edge(source_node, target_node, type=relation_type, **(attributes or {}))

    def find_related_patents(self, patent_id, relation_type=None):
        """
        Finds patents related to a given patent, optionally filtered by relationship type.

        Args:
            patent_id (str): The identifier of the patent to find related patents for.
            relation_type (str, optional): The type of relationship to filter by. Defaults to None.

        Returns:
            list: A list of patent IDs that are related to the given patent.
        """
        if self.use_neo4j:
             query = f"""
            MATCH (p:Patent {{patent_id: $patent_id}})-[r]->(related:Patent)
            {'WHERE type(r) = $relation_type' if relation_type else ''}
            RETURN related.patent_id AS related_patent_id
            UNION
            MATCH (p:Patent {{patent_id: $patent_id}})<-[r]-(related:Patent)
            {'WHERE type(r) = $relation_type' if relation_type else ''}
            RETURN related.patent_id AS related_patent_id
            """
             parameters = {"patent_id": patent_id, "relation_type": relation_type}
             results = self._execute_query(query, parameters)
             return [row['related_patent_id'] for row in results]

        else:
            related_patents = []
            for source, target, data in self.graph.edges(data=True):
                if source == patent_id and self.graph.nodes[source]['type'] == "Patent":
                    if self.graph.nodes[target]['type'] == "Patent":
                        if relation_type is None or data.get('type') == relation_type:
                            related_patents.append(target)
                elif target == patent_id and self.graph.nodes[target]['type'] == "Patent":
                    if self.graph.nodes[source]['type'] == "Patent":
                        if relation_type is None or data.get('type') == relation_type:
                            related_patents.append(source)
            return related_patents


    def trace_concept_lineage(self, concept_name):
        """
        Traces the lineage of a concept by finding all algorithms that implement it and patents that use those algorithms.

        Args:
            concept_name (str): The name of the concept to trace.

        Returns:
            dict: A dictionary containing lists of algorithms and patents related to the concept.
        """
        if self.use_neo4j:
            query = """
            MATCH (c:Concept {concept_name: $concept_name})-[r:IMPLEMENTS]->(a:Algorithm)-[r2]->(p:Patent)
            RETURN a.algorithm_name AS algorithm, p.patent_id AS patent
            """
            parameters = {"concept_name": concept_name}
            results = self._execute_query(query, parameters)
            algorithms = list(set([row['algorithm'] for row in results]))
            patents = list(set([row['patent'] for row in results]))
            return {"algorithms": algorithms, "patents": patents}
        else:
            algorithms = []
            patents = []
            for source, target, data in self.graph.edges(data=True):
                if source == concept_name and self.graph.nodes[source]['type'] == "Concept" and data.get('type') == "IMPLEMENTS":
                    algorithms.append(target)
                    for s, t, d in self.graph.edges(data=True):
                        if s == target and self.graph.nodes[s]['type'] == "Algorithm":
                            if self.graph.nodes[t]['type'] == "Patent":
                                patents.append(t)
            return {"algorithms": algorithms, "patents": patents}

    def find_all_validations_for_type(self, validation_type):
        """
        Finds all patents that have a validation of a specific type.

        Args:
            validation_type (str): The type of validation to search for.

        Returns:
            list: A list of patent IDs that have the specified validation type.
        """
        if self.use_neo4j:
            query = """
            MATCH (p:Patent)-[r:VALIDATES {validation_type: $validation_type}]->()
            RETURN p.patent_id AS patent_id
            """
            parameters = {"validation_type": validation_type}
            results = self._execute_query(query, parameters)
            return [row['patent_id'] for row in results]
        else:
            validated_patents = []
            for source, target, data in self.graph.edges(data=True):
                if data.get('type') == "VALIDATES" and data.get('validation_type') == validation_type:
                    validated_patents.append(source)
            return validated_patents

    def calculate_page_rank(self):
        """Calculates PageRank for patents."""
        if self.use_neo4j:
            #PageRank calculation in Neo4j is best done using the Graph Data Science library.
            #This requires installing the GDS library and using its procedures.
            print("PageRank calculation requires the Neo4j Graph Data Science library.  Please install and configure it for optimal performance.")
            return None  # Or raise an exception if GDS is required.
        else:
            pagerank = nx.pagerank(self.graph)
            return pagerank

    def calculate_degree_centrality(self):
        """Calculates degree centrality for nodes in the graph."""
        if self.use_neo4j:
            #Degree Centrality calculation in Neo4j is best done using the Graph Data Science library.
            #This requires installing the GDS library and using its procedures.
            print("Degree Centrality calculation requires the Neo4j Graph Data Science library.  Please install and configure it for optimal performance.")
            return None  # Or raise an exception if GDS is required.
        else:
            degree_centrality = centrality.degree_centrality(self.graph)
            return degree_centrality

    def export_to_d3js_json(self, filename="graph.json"):
        """
        Exports the graph data to a JSON format suitable for D3.js visualization.

        Args:
            filename (str, optional): The name of the file to export to. Defaults to "graph.json".
        """
        if self.use_neo4j:
            # Fetch all nodes and relationships from Neo4j
            query = """
            MATCH (n)
            RETURN n
            """
            nodes_data = self._execute_query(query)

            query = """
            MATCH (s)-[r]->(t)
            RETURN s, r, t
            """
            links_data = self._execute_query(query)

            nodes = []
            links = []

            # Process nodes
            for node_data in nodes_data:
                node = node_data['n']
                node_id = node.get('patent_id') or node.get('concept_name') or node.get('algorithm_name')
                node_type = "Patent" if node.get('patent_id') else "Concept" if node.get('concept_name') else "Algorithm"
                nodes.append({"id": node_id, "group": self._get_node_group({"type": node_type}), **node})

            # Process links
            for link_data in links_data:
                source_node = link_data['s']
                target_node = link_data['t']
                source_id = source_node.get('patent_id') or source_node.get('concept_name') or source_node.get('algorithm_name')
                target_id = target_node.get('patent_id') or target_node.get('concept_name') or target_node.get('algorithm_name')
                relation_type = type(link_data['r']).__name__
                links.append({"source": source_id, "target": target_id, "type": relation_type})

            graph_data = {"nodes": nodes, "links": links}
        else:
            nodes = []
            links = []
            for node_id, node_data in self.graph.nodes(data=True):
                nodes.append({"id": node_id, "group": self._get_node_group(node_data), **node_data})

            for source, target, data in self.graph.edges(data=True):
                links.append({"source": source, "target": target, "type": data.get('type')})

            graph_data = {"nodes": nodes, "links": links}


        with open(filename, 'w') as f:
            json.dump(graph_data, f, indent=4)
        print(f"Graph data exported to {filename}")


    def _get_node_group(self, node_data):
        """
        Assigns a group number to each node type for D3.js visualization.
        """
        node_type = node_data.get('type')
        if node_type == "Patent":
            return 1
        elif node_type == "Concept":
            return 2
        elif node_type == "Algorithm":
            return 3
        else:
            return 0  # Default group

    def print_graph_stats(self):
        """Prints basic statistics about the graph."""
        if self.use_neo4j:
            query = "MATCH (n) RETURN count(n) AS nodeCount"
            node_count = self._execute_query(query)[0]['nodeCount']
            query = "MATCH ()-[r]->() RETURN count(r) AS edgeCount"
            edge_count = self._execute_query(query)[0]['edgeCount']

            print(f"Number of nodes: {node_count}")
            print(f"Number of edges: {edge_count}")
        else:
            num_nodes = self.graph.number_of_nodes()
            num_edges = self.graph.number_of_edges()
            print(f"Number of nodes: {num_nodes}")
            print(f"Number of edges: {num_edges}")

    def find_cross_patent_relationships(self, source_country, target_country):
        """
        Finds relationships between patents from different countries.
        """
        if self.use_neo4j:
            query = """
            MATCH (p1:Patent)-[r]->(p2:Patent)
            WHERE p1.patent_id STARTS WITH $source_country AND p2.patent_id STARTS WITH $target_country
            RETURN p1.patent_id AS source_patent, type(r) AS relationship_type, p2.patent_id AS target_patent
            """
            parameters = {"source_country": source_country, "target_country": target_country}
            results = self._execute_query(query, parameters)
            return results
        else:
            print("Cross-patent relationships are only supported with Neo4j.")
            return None

    def semantic_relationship_extraction(self, text):
        """
        Placeholder for semantic relationship extraction from text using NLP techniques.
        This is a simplified example and would require integration with an NLP library.
        """
        # In a real implementation, this would use NLP to identify entities and relationships.
        # For example, using spaCy or transformers.
        print("Semantic Relationship Extraction is a placeholder. Requires NLP integration.")
        return []

# Example usage
if __name__ == '__main__':
    # Configure Neo4j connection details (replace with your actual credentials)
    NEO4J_URI = "bolt://localhost:7687"  # Replace with your Neo4j URI
    NEO4J_USER = "neo4j"  # Replace with your Neo4j username
    NEO4J_PASSWORD = "your_neo4j_password"  # Replace with your Neo4j password

    # Choose whether to use Neo4j or NetworkX
    use_neo4j = True  # Set to False to use NetworkX

    if use_neo4j:
        kg = PatentKnowledgeGraph(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    else:
        kg = PatentKnowledgeGraph()

    try:
        # Add patents
        kg.add_patent("US1234567B1", "Method for Image Enhancement", "A method comprising...", "This patent describes...")
        kg.add_patent("US7654321B2", "Improved Image Filtering Technique", "An improved technique...", "This patent improves upon...")
        kg.add_patent("DE9876543C3", "German Patent for Noise Reduction", "A method for reducing noise...", "This patent focuses on...")

        # Add concepts
        kg.add_concept("Image Enhancement", "Techniques to improve image quality")
        kg.add_concept("Image Filtering", "Techniques to remove noise from images")

        # Add algorithms
        kg.add_algorithm("Bilateral Filtering", "A non-linear filtering technique")
        kg.add_algorithm("Unsharp Masking", "A technique to sharpen images")

        # Add relationships
        kg.add_relationship("US7654321B2", "US1234567B1", "DEPENDS_ON")
        kg.add_relationship("Image Enhancement", "Unsharp Masking", "IMPLEMENTS")
        kg.add_relationship("Image Filtering", "Bilateral Filtering", "IMPLEMENTS")
        kg.add_relationship("US1234567B1", "Bilateral Filtering", "IMPLEMENTS")
        kg.add_relationship("US7654321B2", "Unsharp Masking", "IMPLEMENTS")
        kg.add_relationship("US1234567B1", "Performance", "VALIDATES", attributes={"validation_type": "Performance"})

        # Query the graph
        related_patents = kg.find_related_patents("US1234567B1")
        print(f"Patents related to US1234567B1: {related_patents}")

        concept_lineage = kg.trace_concept_lineage("Image Enhancement")
        print(f"Lineage of Image Enhancement: {concept_lineage}")

        performance_validations = kg.find_all_validations_for_type("Performance")
        print(f"Patents with Performance validations: {performance_validations}")

        # Graph algorithms (only for NetworkX, requires Neo4j GDS for Neo4j)
        if not use_neo4j:
            pagerank = kg.calculate_page_rank()
            if pagerank:
                print(f"PageRank: {pagerank}")

            degree_centrality = kg.calculate_degree_centrality()
            if degree_centrality:
                print(f"Degree Centrality: {degree_centrality}")

        # Export to D3.js JSON
        kg.export_to_d3js_json()

        # Cross-patent relationships (Neo4j only)
        if use_neo4j:
            cross_patent_relationships = kg.find_cross_patent_relationships("US", "DE")
            print(f"Cross-patent relationships between US and DE: {cross_patent_relationships}")

        # Semantic relationship extraction (placeholder)
        semantic_relationships = kg.semantic_relationship_extraction("The bilateral filter enhances images by reducing noise.")
        print(f"Semantic relationships: {semantic_relationships}")

        kg.print_graph_stats()


    finally:
        if use_neo4j:
            kg.close() # Close Neo4j driver
