import random
import logging
from sklearn.cluster import AgglomerativeClustering
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import pairwise_distances
import numpy as np

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class MultiModelConsensusValidationAdvanced:
    """
    This skill enables AIVA to:
    1. Query multiple AI models for the same question
    2. Aggregate responses with weighted voting based on model confidence
    3. Implement Byzantine fault tolerance
    4. Perform semantic similarity clustering of responses
    5. Analyze disagreements between models
    6. Aggregate confidence scores
    7. Track model performance
    8. Dynamically select models based on past performance
    """

    def __init__(self, models, model_weights=None, similarity_threshold=0.8, byzantine_tolerance=0.0, performance_history_decay=0.95):
        """
        Initializes the MultiModelConsensusValidation skill.

        Args:
            models (dict): A dictionary of AI models to use.  Keys are model names (strings),
                           and values are callable functions that take a query (string) as input
                           and return a response (string).  Example:
                           {'model_1': my_model_query_function, 'model_2': another_model_query_function}
            model_weights (dict, optional): A dictionary of initial model weights.  Keys are model names (strings),
                                            and values are initial weights (floats). If None, all models start with equal weight. Defaults to None.
            similarity_threshold (float, optional): Threshold for semantic similarity clustering. Defaults to 0.8.
            byzantine_tolerance (float, optional):  Fraction of models that can be Byzantine (malicious/faulty).  Used to adjust confidence calculation. Defaults to 0.0.
            performance_history_decay (float, optional): Decay factor for model performance history.  Values closer to 1 mean more weight is given to older performance. Defaults to 0.95.

        """
        if not isinstance(models, dict):
            raise TypeError("Models must be a dictionary of {model_name: model_query_function}.")
        if not models:
            raise ValueError("At least one model must be provided.")
        for model_name, model_function in models.items():
            if not callable(model_function):
                raise ValueError(f"The value for model '{model_name}' is not a callable function.")
        self.models = models
        self.model_names = list(models.keys())  # store the model names for iteration

        if model_weights is None:
            self.model_weights = {name: 1.0 for name in self.model_names} #Initialize with equal weights
        else:
            if not isinstance(model_weights, dict):
                raise TypeError("model_weights must be a dictionary of {model_name: weight}.")
            for model_name in self.model_names:
                if model_name not in model_weights:
                    raise ValueError(f"Model '{model_name}' missing from model_weights.")
            self.model_weights = model_weights

        self.similarity_threshold = similarity_threshold
        self.byzantine_tolerance = byzantine_tolerance
        self.performance_history_decay = performance_history_decay
        self.model_performance = {name: [] for name in self.model_names}  # Track model performance over time

        logging.info(f"MultiModelConsensusValidation initialized with models: {self.model_names}")
        logging.info(f"Initial model weights: {self.model_weights}")


    def validate(self, query, models_to_use=None):
        """
        Validates a query by querying multiple AI models, aggregating responses with weighted voting,
        detecting disagreements, and calculating consensus confidence.

        Args:
            query (str): The query to validate.
            models_to_use (list, optional): A list of model names (strings) to use for validation.
                                             If None, all models in the 'models' dictionary are used.
                                             Defaults to None.

        Returns:
            tuple: A tuple containing:
                - consensus_response (str): The consensus response with the highest number of weighted votes.
                - confidence (float): The consensus confidence, representing the weighted percentage of models
                                      that agreed on the consensus response.  Ranges from 0.0 to 1.0.
                - disagreements (dict):  A dictionary where keys are model names and values are the models' responses.
                                        This shows which models disagreed and what their responses were.
        """

        if not isinstance(query, str):
            raise TypeError("Query must be a string.")
        if not query:
            raise ValueError("Query cannot be empty.")

        if models_to_use is None:
            models_to_use = self.model_names
        elif not isinstance(models_to_use, list):
            raise TypeError("models_to_use must be a list of model names.")
        elif not all(isinstance(model_name, str) for model_name in models_to_use):
            raise TypeError("All elements in models_to_use must be strings (model names).")

        # Validate that all models_to_use are actually in the available models
        for model_name in models_to_use:
            if model_name not in self.models:
                raise ValueError(f"Model '{model_name}' is not available.  Available models are: {self.model_names}")

        responses = {}
        for model_name in models_to_use:
            try:
                response = self.models[model_name](query)  # Query the model
                if not isinstance(response, str):
                    logging.warning(f"Model '{model_name}' returned a non-string response. Converting to string.")
                    response = str(response)  # Attempt to convert to string
                responses[model_name] = response
            except Exception as e:
                logging.error(f"Error querying model '{model_name}': {e}")
                responses[model_name] = "ERROR"  # Indicate an error occurred.  Better error handling might be needed

        # Semantic Similarity Clustering
        clusters = self.cluster_responses(list(responses.values()))
        clustered_responses = {}
        i = 0
        for model_name in responses:
            clustered_responses[model_name] = clusters[i]
            i += 1

        # Aggregate responses and count weighted votes
        vote_counts = {}
        for model_name, response in clustered_responses.items():
            weight = self.model_weights[model_name]
            if response in vote_counts:
                vote_counts[response] += weight
            else:
                vote_counts[response] = weight

        # Determine the consensus response (weighted)
        if not vote_counts:
            logging.warning("No responses received from any models.")
            return "NO RESPONSE", 0.0, {}

        consensus_response = max(vote_counts, key=vote_counts.get)  # Get the response with the most votes

        # Calculate consensus confidence, accounting for byzantine fault tolerance
        total_weight = sum(self.model_weights[name] for name in models_to_use)
        consensus_weight = vote_counts[consensus_response]
        # Byzantine fault tolerance:  Reduce confidence based on potential faulty models
        byzantine_adjusted_total_weight = total_weight * (1.0 - self.byzantine_tolerance)
        confidence = float(consensus_weight) / byzantine_adjusted_total_weight if byzantine_adjusted_total_weight > 0 else 0.0


        # Identify disagreements
        disagreements = {}
        for model_name, response in clustered_responses.items():
            if response != consensus_response:
                disagreements[model_name] = response

        logging.info(f"Query: {query}, Consensus: {consensus_response}, Confidence: {confidence}, Disagreements: {disagreements}")

        # Track model performance (Simplified: 1 if agrees with consensus, 0 otherwise)
        for model_name in models_to_use:
            if clustered_responses[model_name] == consensus_response and clustered_responses[model_name] != "ERROR": # Don't reward for agreeing with "ERROR"
                self.model_performance[model_name].append(1)
            else:
                self.model_performance[model_name].append(0)

        # Update Model Weights (Exponential Moving Average)
        self.update_model_weights()

        return consensus_response, confidence, disagreements

    def cluster_responses(self, responses, n_clusters=None):
        """Clusters responses based on semantic similarity using TF-IDF and Agglomerative Clustering."""
        vectorizer = TfidfVectorizer()
        tfidf_matrix = vectorizer.fit_transform(responses)

        # Calculate cosine similarity distance matrix
        distance_matrix = 1 - pairwise_distances(tfidf_matrix, metric="cosine")

        # Determine the number of clusters automatically if n_clusters is None
        if n_clusters is None:
            # Find the number of clusters such that the maximum distance between any two points within a cluster is less than the threshold.
            n_clusters = 1
            for i in range(2, len(responses) + 1):  # Try different numbers of clusters
                clustering = AgglomerativeClustering(n_clusters=i, linkage="complete", affinity="precomputed")
                clustering.fit(1 - distance_matrix) # Pass the distance matrix
                labels = clustering.labels_
                max_distance_within_clusters = 0.0
                for cluster_id in range(i):
                    cluster_indices = np.where(labels == cluster_id)[0]
                    if len(cluster_indices) > 1:
                         # Calculate the maximum distance between any two points in the cluster
                        cluster_distances = distance_matrix[cluster_indices][:, cluster_indices]
                        max_distance_within_clusters = max(max_distance_within_clusters, np.max(1-cluster_distances))

                if max_distance_within_clusters > self.similarity_threshold:
                    n_clusters = i - 1
                    break
            else:
                n_clusters = len(responses)  # Each response is in its own cluster if the loop completes without breaking

        clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage="complete", affinity="precomputed")
        clustering.fit(1 - distance_matrix) # Pass the distance matrix
        labels = clustering.labels_

        # For each response, we'll return the *most representative* response from the cluster it belongs to
        clustered_responses = []
        for i in range(len(responses)):
            cluster_id = labels[i]
            cluster_indices = np.where(labels == cluster_id)[0]
            cluster_responses = [responses[j] for j in cluster_indices]
            # Simple: return the first response in the cluster
            clustered_responses.append(cluster_responses[0])

        return clustered_responses

    def update_model_weights(self):
        """Updates model weights based on recent performance using exponential moving average."""
        for model_name in self.model_names:
            if self.model_performance[model_name]:
                # Calculate the average performance
                average_performance = sum(self.model_performance[model_name]) / len(self.model_performance[model_name])
                # Update the weight using exponential moving average
                self.model_weights[model_name] = (self.performance_history_decay * self.model_weights[model_name]) + ((1 - self.performance_history_decay) * average_performance)
                # Clear the performance history
                self.model_performance[model_name] = []
            else:
                logging.warning(f"No performance data for model {model_name}.  Keeping weight unchanged.")
        logging.info(f"Updated model weights: {self.model_weights}")

    def analyze_disagreements(self, query, disagreements):
        """Analyzes disagreements between models (Placeholder - more sophisticated analysis can be added)."""
        logging.info(f"Analyzing disagreements for query: {query}")
        for model_name, response in disagreements.items():
            logging.info(f"Model {model_name} disagreed with response: {response}")
        # In a more sophisticated implementation, you might:
        # - Categorize the types of disagreements (e.g., factual, stylistic, etc.)
        # - Identify patterns in disagreements (e.g., model X consistently disagrees on topic Y)
        # - Use the disagreement analysis to improve the training of individual models or the consensus mechanism itself.

    def dynamically_select_models(self, query, num_models=None):
        """Dynamically selects models based on past performance (Placeholder - more sophisticated selection can be added)."""

        # Simple: Select the top N models based on current weights
        sorted_models = sorted(self.model_weights.items(), key=lambda item: item[1], reverse=True) # sort by weight descending
        if num_models is None:
            num_models = len(self.models) # default to all models
        selected_models = [model_name for model_name, weight in sorted_models[:num_models]]

        logging.info(f"Dynamically selected models for query: {query} - {selected_models}")
        return selected_models

# Example Usage (Requires defining some mock models)
if __name__ == '__main__':

    # Mock AI models (replace with actual model implementations)
    def mock_model_1(query):
        if "weather" in query.lower():
            return "Sunny"
        else:
            return "Unknown"

    def mock_model_2(query):
        if "weather" in query.lower():
            return "Mostly Sunny"
        else:
            return "Cloudy"

    def mock_model_3(query):
        if "weather" in query.lower():
            return "Rainy"
        else:
            return "Unknown"

    def mock_model_4(query):
        if "weather" in query.lower():
            return "Sunny with a chance of clouds"
        else:
            return "Partly Cloudy"

    # Create an instance of the MultiModelConsensusValidation skill
    models = {'model_1': mock_model_1, 'model_2': mock_model_2, 'model_3': mock_model_3, 'model_4': mock_model_4}
    initial_weights = {'model_1': 0.8, 'model_2': 0.9, 'model_3': 0.5, 'model_4': 0.7} #Example of different initial weights
    consensus_validator = MultiModelConsensusValidationAdvanced(models, model_weights=initial_weights, byzantine_tolerance=0.1)

    # Example query
    query = "What is the weather like today?"

    # Validate the query
    consensus_response, confidence, disagreements = consensus_validator.validate(query)

    print(f"Query: {query}")
    print(f"Consensus Response: {consensus_response}")
    print(f"Confidence: {confidence}")
    print(f"Disagreements: {disagreements}")
    print(f"Model Weights: {consensus_validator.model_weights}")

    # Example using specific models
    consensus_response, confidence, disagreements = consensus_validator.validate(query, models_to_use=['model_1', 'model_2'])
    print("\nUsing only model_1 and model_2:")
    print(f"Query: {query}")
    print(f"Consensus Response: {consensus_response}")
    print(f"Confidence: {confidence}")
    print(f"Disagreements: {disagreements}")
    print(f"Model Weights: {consensus_validator.model_weights}")

    # Example of handling errors within the models
    def mock_model_5(query):
        raise ValueError("This model always fails")

    models_with_error = {'model_1': mock_model_1, 'model_5': mock_model_5}
    consensus_validator_error = MultiModelConsensusValidationAdvanced(models_with_error)
    consensus_response, confidence, disagreements = consensus_validator_error.validate(query)

    print("\nTesting with error model:")
    print(f"Query: {query}")
    print(f"Consensus Response: {consensus_response}")
    print(f"Confidence: {confidence}")
    print(f"Disagreements: {disagreements}")
    print(f"Model Weights: {consensus_validator_error.model_weights}")

    # Example of dynamic model selection
    selected_models = consensus_validator.dynamically_select_models(query, num_models=2)
    print("\nDynamically selected models:", selected_models)

    # Run a few more queries to update model weights
    consensus_validator.validate("What is the capital of France?")
    consensus_validator.validate("Tell me a joke.")
    print(f"\nUpdated Model Weights after more queries: {consensus_validator.model_weights}")

    # Analyze Disagreements
    consensus_response, confidence, disagreements = consensus_validator.validate("Is the Earth flat?")
    consensus_validator.analyze_disagreements("Is the Earth flat?", disagreements)
