# patent_5_consensus_validation.py
import random
from typing import List, Dict, Callable

class AIModel:
    """
    A simulated AI model.  Each model has a name and a function that 
    generates a prediction based on input data.  The prediction is also 
    associated with a confidence score.
    """
    def __init__(self, name: str, prediction_function: Callable[[str], Dict[str, float]]):
        """
        Initializes an AIModel instance.

        Args:
            name: The name of the AI model.
            prediction_function: A function that takes input data (string) and 
                                 returns a dictionary with predictions (e.g., 
                                 {'category_a': 0.8, 'category_b': 0.2}).  
                                 The values represent confidence scores.
        """
        self.name = name
        self.predict = prediction_function

    def __repr__(self):
        return f"AIModel(name='{self.name}')"


def majority_vote(model_outputs: List[Dict[str, float]]) -> Dict[str, float]:
    """
    Implements a majority vote consensus algorithm.

    Args:
        model_outputs: A list of dictionaries, where each dictionary represents
                       the output of an AI model.

    Returns:
        A dictionary representing the consensus prediction, with confidence scores.
        Returns an empty dictionary if no consensus can be reached (tie).
    """
    category_counts = {}
    for output in model_outputs:
        for category, confidence in output.items():
            category_counts.setdefault(category, 0)
            category_counts[category] += 1

    max_count = 0
    winning_categories = []
    for category, count in category_counts.items():
        if count > max_count:
            max_count = count
            winning_categories = [category]
        elif count == max_count:
            winning_categories.append(category)

    if len(winning_categories) > 1 or max_count <= len(model_outputs) // 2:
        # No clear majority or a tie
        return {}  # Indicate no consensus

    winning_category = winning_categories[0]
    
    # Calculate average confidence for the winning category
    total_confidence = 0
    count = 0
    for output in model_outputs:
        if winning_category in output:
            total_confidence += output[winning_category]
            count += 1
    
    average_confidence = total_confidence / count if count > 0 else 0
    
    return {winning_category: average_confidence}


def weighted_average(model_outputs: List[Dict[str, float]]) -> Dict[str, float]:
    """
    Implements a weighted average consensus algorithm.

    Args:
        model_outputs: A list of dictionaries, where each dictionary represents
                       the output of an AI model.  The confidence scores are used
                       as weights.

    Returns:
        A dictionary representing the consensus prediction, with confidence scores.
    """
    category_sums = {}
    total_weights = {}

    for output in model_outputs:
        for category, confidence in output.items():
            category_sums.setdefault(category, 0)
            total_weights.setdefault(category, 0)
            category_sums[category] += confidence * confidence  # Weight by confidence squared
            total_weights[category] += confidence

    consensus = {}
    for category, sum_value in category_sums.items():
        if total_weights[category] > 0:
            consensus[category] = sum_value / total_weights[category]  # Calculate weighted average

    return consensus


def detect_disagreements(model_outputs: List[Dict[str, float]], threshold: float = 0.2) -> List[str]:
    """
    Detects disagreements between AI models based on a confidence threshold.

    Args:
        model_outputs: A list of dictionaries, where each dictionary represents
                       the output of an AI model.
        threshold: The confidence threshold for disagreement detection.

    Returns:
        A list of categories where significant disagreement exists.
    """
    categories = set()
    for output in model_outputs:
        categories.update(output.keys())

    disagreements = []
    for category in categories:
        confidences = []
        for output in model_outputs:
            confidences.append(output.get(category, 0.0))  # Use 0.0 if category is absent

        max_confidence = max(confidences)
        min_confidence = min(confidences)

        if max_confidence - min_confidence > threshold:
            disagreements.append(category)

    return disagreements


def consensus_confidence(consensus_output: Dict[str, float], model_outputs: List[Dict[str, float]]) -> float:
    """
    Calculates a consensus confidence score based on model agreement.

    Args:
        consensus_output: The output of the consensus algorithm.
        model_outputs: A list of dictionaries, where each dictionary represents
                       the output of an AI model.

    Returns:
        A consensus confidence score (between 0 and 1).
    """
    if not consensus_output:
        return 0.0  # No consensus

    category = list(consensus_output.keys())[0]
    num_agreeing_models = 0
    total_confidence = 0.0

    for output in model_outputs:
        if category in output:
            num_agreeing_models += 1
            total_confidence += output[category]

    if num_agreeing_models == 0:
      return 0.0

    average_confidence_in_agreement = total_confidence / num_agreeing_models
    agreement_ratio = num_agreeing_models / len(model_outputs)

    # Combine agreement ratio and average confidence
    return (agreement_ratio + average_confidence_in_agreement) / 2


def create_dummy_model(name: str, bias: Dict[str, float] = None, noise_level: float = 0.1) -> AIModel:
    """
    Creates a dummy AI model for testing purposes.

    Args:
        name: The name of the model.
        bias: A dictionary representing the model's bias towards certain categories.
              If None, the model will be unbiased.
        noise_level: The level of noise to add to the predictions.

    Returns:
        An AIModel instance.
    """
    if bias is None:
        bias = {}

    def prediction_function(data: str) -> Dict[str, float]:
        """
        Dummy prediction function.  Simulates a model that makes predictions
        based on a bias and adds some random noise.
        """
        predictions = {}
        categories = ["category_a", "category_b", "category_c"]  # Example categories

        for category in categories:
            base_confidence = bias.get(category, 0.2)  # Default confidence if no bias
            noise = random.uniform(-noise_level, noise_level)
            confidence = max(0.0, min(1.0, base_confidence + noise))  # Ensure confidence is between 0 and 1
            predictions[category] = confidence
        return predictions

    return AIModel(name, prediction_function)


def main():
    """
    Main function to demonstrate the multi-model consensus validation process.
    """
    # 1. Create multiple AI models
    model1 = create_dummy_model("Model_A", bias={"category_a": 0.7, "category_b": 0.3})
    model2 = create_dummy_model("Model_B", bias={"category_a": 0.6, "category_c": 0.4})
    model3 = create_dummy_model("Model_C", bias={"category_b": 0.8})

    models = [model1, model2, model3]

    # 2. Orchestrate model calls with input data
    input_data = "Example data for classification."
    model_outputs = [model.predict(input_data) for model in models]

    print("Model Outputs:")
    for i, output in enumerate(model_outputs):
        print(f"{models[i].name}: {output}")

    # 3. Implement voting/consensus algorithms
    majority_vote_result = majority_vote(model_outputs)
    weighted_average_result = weighted_average(model_outputs)

    print("\nMajority Vote Consensus:", majority_vote_result)
    print("Weighted Average Consensus:", weighted_average_result)

    # 4. Detect and flag disagreements
    disagreements = detect_disagreements(model_outputs)
    print("\nDetected Disagreements:", disagreements)

    # 5. Provide consensus confidence scores
    majority_confidence = consensus_confidence(majority_vote_result, model_outputs)
    weighted_confidence = consensus_confidence(weighted_average_result, model_outputs)

    print("\nMajority Vote Confidence:", majority_confidence)
    print("Weighted Average Confidence:", weighted_confidence)


if __name__ == "__main__":
    main()
