# enhanced_consensus_validation.py
import asyncio
import concurrent.futures
import random
import time
from typing import List, Dict, Callable, Optional, Tuple
from collections import defaultdict
import logging
import json
import os
import aiohttp
from prometheus_client import start_http_server, Gauge, Histogram

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Prometheus metrics
MODEL_PREDICTION_DURATION = Histogram('model_prediction_duration_seconds', 'Time spent predicting by each model', ['model_name'])
CONSENSUS_DURATION = Histogram('consensus_duration_seconds', 'Time spent in consensus algorithms', ['algorithm'])
DISAGREEMENT_DURATION = Histogram('disagreement_duration_seconds', 'Time spent in disagreement resolution')
CONSENSUS_CONFIDENCE = Gauge('consensus_confidence', 'Confidence score of the consensus')
MODEL_INVOCATION_ERRORS = Gauge('model_invocation_errors_total', 'Number of errors during model invocation', ['model_name'])

class AIModel:
    """A simulated AI model with failure handling."""
    def __init__(self, name: str, prediction_function: Callable[[str], Dict[str, float]],
                 failure_rate: float = 0.0, timeout: float = 1.0):
        self.name = name
        self.predict = prediction_function
        self.failure_rate = failure_rate
        self.timeout = timeout
        self.circuit_breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=10)

    async def predict_with_retry(self, data: str) -> Optional[Dict[str, float]]:
        """Invokes the model with retry logic and circuit breaker."""
        if self.circuit_breaker.is_open():
            logger.warning(f"Model {self.name} circuit breaker is open.")
            return None

        for attempt in range(3):  # Retry up to 3 times
            try:
                with MODEL_PREDICTION_DURATION.labels(model_name=self.name).time():
                    result = await asyncio.wait_for(self.async_predict(data), timeout=self.timeout)
                self.circuit_breaker.on_success()
                return result
            except asyncio.TimeoutError:
                logger.warning(f"Model {self.name} timed out on attempt {attempt + 1}.")
                self.circuit_breaker.on_failure()
                MODEL_INVOCATION_ERRORS.labels(model_name=self.name).inc()
            except Exception as e:
                logger.error(f"Model {self.name} failed with error: {e}")
                self.circuit_breaker.on_failure()
                MODEL_INVOCATION_ERRORS.labels(model_name=self.name).inc()
                await asyncio.sleep(0.1)  # Backoff

        logger.error(f"Model {self.name} failed after multiple retries.")
        return None

    async def async_predict(self, data: str) -> Dict[str, float]:
        """Simulates an asynchronous prediction with optional failure."""
        await asyncio.sleep(random.uniform(0.01, 0.1))  # Simulate processing time

        if random.random() < self.failure_rate:
            raise ValueError(f"Simulated failure for model {self.name}")

        return self.predict(data)

    def __repr__(self):
        return f"AIModel(name='{self.name}')"


class CircuitBreaker:
    """Implements a circuit breaker pattern."""
    def __init__(self, failure_threshold: int, recovery_timeout: int):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.failure_count = 0
        self.last_failure_time: Optional[float] = None
        self.state = "CLOSED"  # CLOSED, OPEN, HALF_OPEN

    def on_success(self):
        """Resets the circuit breaker on success."""
        self.failure_count = 0
        self.state = "CLOSED"

    def on_failure(self):
        """Increments failure count and opens the circuit breaker if necessary."""
        self.failure_count += 1
        self.last_failure_time = time.time()
        if self.failure_count >= self.failure_threshold:
            self.state = "OPEN"
            logger.warning("Circuit breaker opened.")

    def is_open(self):
        """Checks if the circuit breaker is open."""
        if self.state == "OPEN":
            if time.time() - (self.last_failure_time or 0) > self.recovery_timeout:
                self.state = "HALF_OPEN" #attempt to recover
                logger.info("Circuit breaker half-open.")
                return False # Allow one attempt
            return True
        return False


async def parallel_model_invocation(models: List[AIModel], data: str) -> List[Optional[Dict[str, float]]]:
    """Invokes models in parallel with timeout and error handling."""
    tasks = [model.predict_with_retry(data) for model in models]
    results = await asyncio.gather(*tasks)
    return results


# Consensus Algorithms

@CONSENSUS_DURATION.labels(algorithm='majority_vote').time()
def majority_vote(model_outputs: List[Dict[str, float]]) -> Dict[str, float]:
    """Implements a majority vote consensus algorithm."""
    valid_outputs = [output for output in model_outputs if output]  # Filter out None values
    if not valid_outputs:
        return {}

    category_counts = defaultdict(int)
    for output in valid_outputs:
        for category in output:
            category_counts[category] += 1

    max_count = 0
    winning_categories = []
    for category, count in category_counts.items():
        if count > max_count:
            max_count = count
            winning_categories = [category]
        elif count == max_count:
            winning_categories.append(category)

    if len(winning_categories) > 1 or max_count <= len(valid_outputs) // 2:
        return {}

    winning_category = winning_categories[0]
    total_confidence = sum(output.get(winning_category, 0) for output in valid_outputs)
    average_confidence = total_confidence / sum(1 for output in valid_outputs if winning_category in output) if any(winning_category in output for output in valid_outputs) else 0

    return {winning_category: average_confidence}

@CONSENSUS_DURATION.labels(algorithm='weighted_average').time()
def weighted_average(model_outputs: List[Dict[str, float]]) -> Dict[str, float]:
    """Implements a weighted average consensus algorithm."""
    valid_outputs = [output for output in model_outputs if output]
    if not valid_outputs:
        return {}

    category_sums = defaultdict(float)
    total_weights = defaultdict(float)

    for output in valid_outputs:
        for category, confidence in output.items():
            category_sums[category] += confidence * confidence
            total_weights[category] += confidence

    consensus = {}
    for category, sum_value in category_sums.items():
        if total_weights[category] > 0:
            consensus[category] = sum_value / total_weights[category]

    return consensus


def pbft_consensus(model_outputs: List[Dict[str, float]], num_faulty: int) -> Dict[str, float]:
    """Simulates a simplified PBFT consensus (not a full implementation)."""
    valid_outputs = [output for output in model_outputs if output]
    if not valid_outputs:
        return {}

    # Simplified: Remove 'faulty' outputs (lowest confidence categories)
    all_categories = set()
    for output in valid_outputs:
        all_categories.update(output.keys())

    category_confidences = defaultdict(list)
    for category in all_categories:
        for output in valid_outputs:
            category_confidences[category].append(output.get(category, 0.0))

    # Remove lowest confidence categories from num_faulty models
    for _ in range(min(num_faulty, len(valid_outputs))):
        lowest_confidence_category = None
        lowest_confidence = float('inf')
        for category, confidences in category_confidences.items():
            avg_confidence = sum(confidences) / len(confidences) if confidences else 0
            if avg_confidence < lowest_confidence:
                lowest_confidence = avg_confidence
                lowest_confidence_category = category
        if lowest_confidence_category:
            for output in valid_outputs:
                if lowest_confidence_category in output:
                    del output[lowest_confidence_category] #remove this category from the model output
            del category_confidences[lowest_confidence_category]


    # Apply weighted average consensus on remaining outputs
    return weighted_average(valid_outputs)

def raft_consensus_adaptation(model_outputs: List[Dict[str, float]]) -> Dict[str, float]:
    """Adapts Raft concepts for model consensus."""
    valid_outputs = [output for output in model_outputs if output]
    if not valid_outputs:
        return {}

    # Elect a 'leader' (model with highest average confidence)
    model_confidences = []
    for output in valid_outputs:
        total_confidence = sum(output.values())
        avg_confidence = total_confidence / len(output) if output else 0
        model_confidences.append(avg_confidence)

    if not model_confidences:
        return {}

    leader_index = model_confidences.index(max(model_confidences))
    leader_output = valid_outputs[leader_index]

    # Replicate leader's decision (weighted by other models' agreement)
    consensus = {}
    for category, leader_confidence in leader_output.items():
        total_agreement = 0.0
        for output in valid_outputs:
            if category in output:
                total_agreement += output[category]
        average_agreement = total_agreement / len(valid_outputs) if valid_outputs else 0
        consensus[category] = (leader_confidence + average_agreement) / 2

    return consensus

def hierarchical_consensus(model_outputs: List[Dict[str, float]]) -> Dict[str, float]:
    """Implements a hierarchical consensus approach."""
    valid_outputs = [output for output in model_outputs if output]
    if not valid_outputs:
        return {}

    # Level 1: Majority Vote to identify common categories
    majority_result = majority_vote(valid_outputs)

    if not majority_result:
        # Level 2: If no clear majority, use weighted average
        weighted_result = weighted_average(valid_outputs)
        return weighted_result
    else:
        return majority_result


# Disagreement Resolution

@DISAGREEMENT_DURATION.time()
def detect_disagreements(model_outputs: List[Dict[str, float]], threshold: float = 0.2) -> List[str]:
    """Detects disagreements between AI models based on a confidence threshold."""
    valid_outputs = [output for output in model_outputs if output]
    if not valid_outputs:
        return []

    categories = set()
    for output in valid_outputs:
        categories.update(output.keys())

    disagreements = []
    for category in categories:
        confidences = [output.get(category, 0.0) for output in valid_outputs]
        max_confidence = max(confidences)
        min_confidence = min(confidences)

        if max_confidence - min_confidence > threshold:
            disagreements.append(category)

    return disagreements

def semantic_similarity_analysis(model_outputs: List[Dict[str, float]]) -> Dict[str, float]:
    """Placeholder for semantic similarity analysis."""
    # In a real implementation, this would involve:
    # 1. Encoding the categories into vector representations (e.g., using Word2Vec, SentenceTransformers).
    # 2. Calculating cosine similarity between the category vectors.
    # 3. Grouping categories with high similarity.
    # 4. Adjusting confidence scores based on similarity clusters.
    print("Semantic similarity analysis stub.")
    return {}

def conflict_detection_and_categorization(model_outputs: List[Dict[str, float]]) -> Dict[str, List[str]]:
    """Detects and categorizes conflicts between model outputs."""
    valid_outputs = [output for output in model_outputs if output]
    if not valid_outputs:
        return {}

    conflicts = defaultdict(list)
    categories = set()
    for output in valid_outputs:
        categories.update(output.keys())

    for category in categories:
        confidences = [output.get(category, 0.0) for output in valid_outputs]
        max_confidence = max(confidences)
        min_confidence = min(confidences)

        if max_confidence - min_confidence > 0.5: # Example threshold
            conflicts["high_variance"].append(category)
        elif any(c > 0.8 for c in confidences) and any(c < 0.2 for c in confidences):
            conflicts["strong_opposing_views"].append(category)

    return conflicts

def human_escalation_triggers(conflicts: Dict[str, List[str]], threshold: int = 2) -> bool:
    """Triggers human escalation based on conflict severity."""
    total_conflicts = sum(len(categories) for categories in conflicts.values())
    if total_conflicts > threshold:
        logger.warning("Human escalation triggered due to high conflict level.")
        return True
    return False

def explanation_generation_for_disagreements(model_outputs: List[Dict[str, float]], disagreements: List[str]) -> Dict[str, str]:
    """Generates explanations for disagreements."""
    explanations = {}
    for disagreement in disagreements:
        model_confidences = {model.name: output.get(disagreement, 0.0) for model, output in zip(models, model_outputs) if output}
        explanations[disagreement] = f"Disagreement detected. Model confidences: {model_confidences}"
    return explanations

# Consensus Confidence
def consensus_confidence(consensus_output: Dict[str, float], model_outputs: List[Dict[str, float]]) -> float:
    """Calculates a consensus confidence score based on model agreement."""
    valid_outputs = [output for output in model_outputs if output]
    if not consensus_output:
        return 0.0

    category = list(consensus_output.keys())[0]
    num_agreeing_models = 0
    total_confidence = 0.0

    for output in valid_outputs:
        if category in output:
            num_agreeing_models += 1
            total_confidence += output[category]

    if num_agreeing_models == 0:
      return 0.0

    average_confidence_in_agreement = total_confidence / num_agreeing_models
    agreement_ratio = num_agreeing_models / len(valid_outputs)

    # Combine agreement ratio and average confidence
    confidence = (agreement_ratio + average_confidence_in_agreement) / 2
    CONSENSUS_CONFIDENCE.set(confidence) #set prometheus metric
    return confidence


# Dummy Model Creation
def create_dummy_model(name: str, bias: Dict[str, float] = None, noise_level: float = 0.1, failure_rate: float = 0.0) -> AIModel:
    """Creates a dummy AI model for testing purposes."""
    if bias is None:
        bias = {}

    def prediction_function(data: str) -> Dict[str, float]:
        """Dummy prediction function."""
        predictions = {}
        categories = ["category_a", "category_b", "category_c"]

        for category in categories:
            base_confidence = bias.get(category, 0.2)
            noise = random.uniform(-noise_level, noise_level)
            confidence = max(0.0, min(1.0, base_confidence + noise))
            predictions[category] = confidence
        return predictions

    return AIModel(name, prediction_function, failure_rate=failure_rate)


async def main():
    """Main function to demonstrate the multi-model consensus validation process."""
    global models
    # 1. Create multiple AI models
    model1 = create_dummy_model("Model_A", bias={"category_a": 0.7, "category_b": 0.3}, failure_rate=0.1)
    model2 = create_dummy_model("Model_B", bias={"category_a": 0.6, "category_c": 0.4}, failure_rate=0.05)
    model3 = create_dummy_model("Model_C", bias={"category_b": 0.8}, failure_rate=0.0)
    model4 = create_dummy_model("Model_D", bias={"category_a": 0.5, "category_b": 0.5}, failure_rate=0.2)

    models = [model1, model2, model3, model4]

    # 2. Orchestrate model calls with input data
    input_data = "Example data for classification."
    model_outputs = await parallel_model_invocation(models, input_data)

    print("Model Outputs:")
    for i, output in enumerate(model_outputs):
        print(f"{models[i].name}: {output}")

    # 3. Implement voting/consensus algorithms
    majority_vote_result = majority_vote(model_outputs)
    weighted_average_result = weighted_average(model_outputs)
    pbft_result = pbft_consensus(model_outputs, num_faulty=1)
    raft_result = raft_consensus_adaptation(model_outputs)
    hierarchical_result = hierarchical_consensus(model_outputs)


    print("\nMajority Vote Consensus:", majority_vote_result)
    print("Weighted Average Consensus:", weighted_average_result)
    print("PBFT Consensus:", pbft_result)
    print("Raft Consensus Adaptation:", raft_result)
    print("Hierarchical Consensus:", hierarchical_result)


    # 4. Detect and flag disagreements
    disagreements = detect_disagreements(model_outputs)
    print("\nDetected Disagreements:", disagreements)

    # 5. Provide consensus confidence scores
    majority_confidence = consensus_confidence(majority_vote_result, model_outputs)
    weighted_confidence = consensus_confidence(weighted_average_result, model_outputs)
    pbft_confidence = consensus_confidence(pbft_result, model_outputs)
    raft_confidence = consensus_confidence(raft_result, model_outputs)
    hierarchical_confidence = consensus_confidence(hierarchical_result, model_outputs)

    print("\nMajority Vote Confidence:", majority_confidence)
    print("Weighted Average Confidence:", weighted_confidence)
    print("PBFT Confidence:", pbft_confidence)
    print("Raft Confidence:", raft_confidence)
    print("Hierarchical Confidence:", hierarchical_confidence)

    # 6. Disagreement Resolution
    conflicts = conflict_detection_and_categorization(model_outputs)
    print("\nConflicts:", conflicts)

    escalation_needed = human_escalation_triggers(conflicts)
    print("\nHuman Escalation Needed:", escalation_needed)

    explanations = explanation_generation_for_disagreements(model_outputs, disagreements)
    print("\nExplanations for Disagreements:", explanations)

if __name__ == "__main__":
    # Start Prometheus HTTP server
    start_http_server(8000)  # Expose metrics on port 8000
    asyncio.run(main())
