import numpy as np

class ConfidenceCalibrator:
    """
    A core component for calibrating AIVA's confidence scores against actual accuracy.
    Ensures that AIVA's self-assessment of certainty aligns with reality,
    a crucial step for robust decision-making and evolution to Genesis Prime Mother.

    This calibrator quantifies how well AIVA's predicted probabilities (confidences)
    match the true correctness rates of her predictions.
    """

    def __init__(self, num_bins: int = 10, confidence_threshold: float = 0.05):
        """
        Initializes the ConfidenceCalibrator.

        Args:
            num_bins (int): The number of bins to use for calculating Expected Calibration Error (ECE).
                            More bins provide finer granularity but require more data.
            confidence_threshold (float): The allowed deviation (e.g., 0.05 for 5%) for
                                          confidence to be considered "calibrated".
                                          Used for ECE and overall over/underconfidence detection.
        """
        if num_bins <= 0:
            raise ValueError("num_bins must be a positive integer.")
        if not (0 <= confidence_threshold < 1):
            raise ValueError("confidence_threshold must be between 0 and 1.")

        self.num_bins = num_bins
        self.confidence_threshold = confidence_threshold
        # Create bins for confidence scores, including 0 and 1
        self.bins = np.linspace(0, 1, num_bins + 1)
        self.calibration_report = {}

    def analyze(self, confidences: list, outcomes: list) -> dict:
        """
        Analyzes a set of confidence scores and their corresponding binary outcomes.

        Args:
            confidences (list): A list of confidence scores (floats between 0 and 1).
            outcomes (list): A list of binary outcomes (0 or 1), where 1 means the
                             prediction was correct, and 0 means it was incorrect.

        Returns:
            dict: A comprehensive calibration report detailing ECE, over/underconfidence,
                  and per-bin analysis.
        """
        if not confidences or not outcomes:
            self.calibration_report = {
                "status": "No data provided for analysis. Cannot calibrate.",
                "is_calibrated_within_threshold": False,
                "overall_ece": None,
                "is_overconfident": False,
                "is_underconfident": False,
                "weighted_confidence_accuracy_difference": None,
                "bin_details": []
            }
            return self.calibration_report

        if len(confidences) != len(outcomes):
            raise ValueError("The length of confidences and outcomes lists must be equal.")

        confidences_arr = np.array(confidences, dtype=np.float32)
        outcomes_arr = np.array(outcomes, dtype=np.int8)

        if not ((0 <= confidences_arr).all() and (confidences_arr <= 1).all()):
            raise ValueError("Confidence scores must be between 0 and 1.")
        if not (np.isin(outcomes_arr, [0, 1]).all()):
            raise ValueError("Outcomes must be binary (0 or 1).")

        bin_details = []
        total_samples = len(confidences_arr)
        ece_sum = 0.0
        weighted_confidence_accuracy_diff_sum = 0.0 # Positive for overconfidence, negative for underconfidence

        for i in range(self.num_bins):
            bin_lower = self.bins[i]
            bin_upper = self.bins[i+1]

            # Find samples that fall into this confidence bin. Use <= for upper bound to include 1.0 in last bin.
            # For intermediate bins, we typically use [lower, upper) to avoid double counting at boundaries.
            # However, for the last bin, we must include 1.0. np.digitize handles this robustly.
            # A simpler approach for illustrative purposes with linspace is to include upper for all,
            # then handle potential overlaps if necessary, but with floats it's less common.
            # For ECE, usually 'np.digitize' is used or explicit binning as below.
            if i == self.num_bins - 1: # Last bin includes 1.0
                bin_indices = np.where((confidences_arr >= bin_lower) & (confidences_arr <= bin_upper))
            else: # Other bins are [lower, upper)
                bin_indices = np.where((confidences_arr >= bin_lower) & (confidences_arr < bin_upper))

            bin_confidences = confidences_arr[bin_indices]
            bin_outcomes = outcomes_arr[bin_indices]

            num_samples_in_bin = len(bin_confidences)

            if num_samples_in_bin > 0:
                avg_confidence = np.mean(bin_confidences)
                avg_accuracy = np.mean(bin_outcomes) # Mean of binary outcomes is accuracy

                # Difference: Confidence - Accuracy (Positive means overconfident, negative means underconfident)
                miscalibration_diff = avg_confidence - avg_accuracy

                # ECE contribution for this bin
                ece_sum += (num_samples_in_bin / total_samples) * abs(miscalibration_diff)
                
                # Weighted sum for overall over/underconfidence detection
                weighted_confidence_accuracy_diff_sum += (num_samples_in_bin / total_samples) * miscalibration_diff

                bin_details.append({
                    "bin_range": f"[{bin_lower:.2f}, {bin_upper:.2f}]",
                    "num_samples": num_samples_in_bin,
                    "avg_confidence": float(avg_confidence),
                    "avg_accuracy": float(avg_accuracy),
                    "miscalibration_difference": float(miscalibration_diff),
                    "is_overconfident_in_bin": miscalibration_diff > self.confidence_threshold,
                    "is_underconfident_in_bin": miscalibration_diff < -self.confidence_threshold,
                })
            else:
                bin_details.append({
                    "bin_range": f"[{bin_lower:.2f}, {bin_upper:.2f}]",
                    "num_samples": 0,
                    "avg_confidence": None,
                    "avg_accuracy": None,
                    "miscalibration_difference": None,
                    "is_overconfident_in_bin": False,
                    "is_underconfident_in_bin": False,
                })

        overall_ece = ece_sum
        is_calibrated_within_threshold = overall_ece <= self.confidence_threshold

        # Overall over/underconfidence detection based on weighted average difference
        is_overconfident = weighted_confidence_accuracy_diff_sum > self.confidence_threshold
        is_underconfident = weighted_confidence_accuracy_diff_sum < -self.confidence_threshold

        self.calibration_report = {
            "status": "Analysis complete.",
            "overall_ece": float(overall_ece),
            "is_calibrated_within_threshold": is_calibrated_within_threshold,
            "confidence_threshold_used": self.confidence_threshold,
            "is_overconfident": is_overconfident,
            "is_underconfident": is_underconfident,
            "weighted_confidence_accuracy_difference": float(weighted_confidence_accuracy_diff_sum),
            "num_bins_used": self.num_bins,
            "total_samples_analyzed": total_samples,
            "bin_details": bin_details
        }
        return self.calibration_report

    def get_calibration_report(self) -> dict:
        """
        Returns the last generated calibration report.

        Returns:
            dict: The calibration report.
        """
        return self.calibration_report
