import logging
import os
from typing import List, Dict, Tuple

import numpy as np
import psycopg2
import redis
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class ConceptLinker:
    """
    Links related concepts across documents by identifying similar topics,
    shared entities, and related files, calculating concept similarity scores,
    and creating concept graph edges.
    """

    def __init__(self,
                 db_host: str = "postgresql-genesis-u50607.vm.elestio.app",
                 db_port: int = 25432,
                 db_name: str = "genesis",
                 db_user: str = "genesis",
                 db_pass: str = "genesis",
                 redis_host: str = "redis-genesis-u50607.vm.elestio.app",
                 redis_port: int = 26379,
                 qdrant_host: str = "qdrant-b3knu-u50607.vm.elestio.app",
                 qdrant_port: int = 6333,
                 embedding_model_name: str = 'all-mpnet-base-v2',
                 qdrant_collection_name: str = 'concept_embeddings'):
        """
        Initializes the ConceptLinker with database, Redis, Qdrant, and embedding model configurations.
        """
        self.db_host = db_host
        self.db_port = db_port
        self.db_name = db_name
        self.db_user = db_user
        self.db_pass = db_pass
        self.redis_host = redis_host
        self.redis_port = redis_port
        self.qdrant_host = qdrant_host
        self.qdrant_port = qdrant_port
        self.embedding_model_name = embedding_model_name
        self.qdrant_collection_name = qdrant_collection_name

        try:
            self.db_conn = psycopg2.connect(
                host=self.db_host,
                port=self.db_port,
                database=self.db_name,
                user=self.db_user,
                password=self.db_pass
            )
            self.db_cursor = self.db_conn.cursor()
            logging.info("Connected to PostgreSQL database.")
        except psycopg2.Error as e:
            logging.error(f"Failed to connect to PostgreSQL: {e}")
            raise

        try:
            self.redis_client = redis.Redis(host=self.redis_host, port=self.redis_port, db=0)
            self.redis_client.ping()  # Check connection
            logging.info("Connected to Redis.")
        except redis.exceptions.ConnectionError as e:
            logging.error(f"Failed to connect to Redis: {e}")
            raise

        try:
            self.qdrant_client = QdrantClient(host=self.qdrant_host, port=self.qdrant_port)
            self.qdrant_client.get_collection(collection_name=self.qdrant_collection_name) #Check if collection exists
            logging.info("Connected to Qdrant.")
        except Exception as e:
            logging.error(f"Failed to connect to Qdrant or collection does not exist: {e}")
            raise

        try:
            self.embedding_model = SentenceTransformer(self.embedding_model_name)
            logging.info(f"Loaded embedding model: {self.embedding_model_name}")
        except Exception as e:
            logging.error(f"Failed to load embedding model: {e}")
            raise

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if hasattr(self, 'db_cursor') and self.db_cursor:
            self.db_cursor.close()
        if hasattr(self, 'db_conn') and self.db_conn:
            self.db_conn.close()
        logging.info("Closed database connection.")

    def get_document_summaries(self) -> List[Tuple[int, str]]:
        """
        Retrieves document summaries from the database.

        Returns:
            A list of tuples, where each tuple contains the document ID and its summary.
        """
        try:
            self.db_cursor.execute("SELECT id, summary FROM documents WHERE summary IS NOT NULL")
            document_summaries = self.db_cursor.fetchall()
            logging.info(f"Retrieved {len(document_summaries)} document summaries.")
            return document_summaries
        except psycopg2.Error as e:
            logging.error(f"Failed to retrieve document summaries: {e}")
            return []

    def calculate_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
        """
        Calculates the cosine similarity between two embeddings.

        Args:
            embedding1: The first embedding as a numpy array.
            embedding2: The second embedding as a numpy array.

        Returns:
            The cosine similarity score.
        """
        return cosine_similarity(embedding1.reshape(1, -1), embedding2.reshape(1, -1))[0][0]

    def find_similar_concepts(self, document_summaries: List[Tuple[int, str]], similarity_threshold: float = 0.7) -> List[Tuple[int, int, float]]:
        """
        Identifies similar concepts across document summaries based on cosine similarity.

        Args:
            document_summaries: A list of tuples, where each tuple contains the document ID and its summary.
            similarity_threshold: The minimum cosine similarity score for two concepts to be considered similar.

        Returns:
            A list of tuples, where each tuple contains the IDs of two similar documents and their similarity score.
        """
        similar_concepts: List[Tuple[int, int, float]] = []
        for i in range(len(document_summaries)):
            for j in range(i + 1, len(document_summaries)):
                doc_id1, summary1 = document_summaries[i]
                doc_id2, summary2 = document_summaries[j]

                embedding1 = self.embedding_model.encode(summary1, convert_to_tensor=False)
                embedding2 = self.embedding_model.encode(summary2, convert_to_tensor=False)

                similarity_score = self.calculate_similarity(embedding1, embedding2)

                if similarity_score >= similarity_threshold:
                    similar_concepts.append((doc_id1, doc_id2, similarity_score))
                    logging.info(f"Found similar concepts between document {doc_id1} and {doc_id2} with score: {similarity_score}")

        return similar_concepts

    def create_concept_graph_edges(self, similar_concepts: List[Tuple[int, int, float]]):
        """
        Creates edges in the concept graph based on the identified similar concepts.

        Args:
            similar_concepts: A list of tuples, where each tuple contains the IDs of two similar documents and their similarity score.
        """
        try:
            for doc_id1, doc_id2, similarity_score in similar_concepts:
                # Assuming you have a table named 'concept_graph' with columns 'source_id', 'target_id', and 'similarity_score'
                sql = "INSERT INTO concept_graph (source_id, target_id, similarity_score) VALUES (%s, %s, %s)"
                self.db_cursor.execute(sql, (doc_id1, doc_id2, similarity_score))
            self.db_conn.commit()
            logging.info(f"Created {len(similar_concepts)} concept graph edges.")
        except psycopg2.Error as e:
            logging.error(f"Failed to create concept graph edges: {e}")
            self.db_conn.rollback() # Rollback if any insertion fails

    def link_concepts(self):
        """
        Orchestrates the concept linking process.
        """
        document_summaries = self.get_document_summaries()
        if not document_summaries:
            logging.warning("No document summaries found. Skipping concept linking.")
            return

        similar_concepts = self.find_similar_concepts(document_summaries)
        if not similar_concepts:
            logging.info("No similar concepts found.")
            return

        self.create_concept_graph_edges(similar_concepts)

if __name__ == '__main__':
    with ConceptLinker() as linker:
        linker.link_concepts()
