import random
import statistics
import time
from enum import Enum
from typing import Dict, Any, Callable, List, Union
import scipy.stats as stats  # For statistical significance testing
import json  # For experiment definition DSL

# Define Experiment Types
class ExperimentType(Enum):
    SKILL_VARIANTS = "skill_variants"
    THRESHOLD_SETTINGS = "threshold_settings"
    PROMPT_TEMPLATES = "prompt_templates"
    RETRIEVAL_STRATEGIES = "retrieval_strategies"

# Data structure for experiment definition
class ExperimentDefinition:
    def __init__(self,
                 name: str,
                 experiment_type: ExperimentType,
                 variants: Dict[str, Any],  # Key: Variant Name, Value: Variant Configuration
                 traffic_split: Dict[str, float], # Key: Variant Name, Value: Percentage of Traffic (0.0 to 1.0)
                 metric_function: Callable[[Any], float], # Function to calculate the metric from a result
                 description: str = "",
                 minimum_runs: int = 30,  # Minimum runs per variant for statistical significance
                 confidence_level: float = 0.95 # Confidence level for winner detection
                 ):
        """
        Defines an A/B test experiment.

        Args:
            name: A unique name for the experiment.
            experiment_type: The type of experiment (e.g., Skill Variants, Threshold Settings).
            variants: A dictionary defining the different variants to test.  Each variant
                      should have a unique name. The value associated with each name is the
                      configuration or setting for that variant.
            traffic_split: A dictionary specifying how traffic should be split between the variants.
                           The keys are the variant names, and the values are the percentage of
                           traffic (as a float between 0.0 and 1.0) that should be directed to that variant.
                           The sum of the traffic split values should equal 1.0.
            metric_function: A function that takes the result of running a variant and returns a numerical
                             metric value.  This metric is used to compare the performance of the variants.
            description: Optional description of the experiment.
            minimum_runs: Minimum number of runs per variant before statistical significance is checked.
            confidence_level: Confidence level (alpha) for statistical significance testing (e.g., 0.95 for 95% confidence).
        """

        if sum(traffic_split.values()) != 1.0:
            raise ValueError("Traffic split percentages must sum to 1.0")

        self.name = name
        self.experiment_type = experiment_type
        self.variants = variants
        self.traffic_split = traffic_split
        self.metric_function = metric_function
        self.description = description
        self.minimum_runs = minimum_runs
        self.confidence_level = confidence_level

    def __repr__(self):
        return f"ExperimentDefinition(name='{self.name}', type={self.experiment_type}, variants={list(self.variants.keys())}, traffic_split={self.traffic_split})"

    @classmethod
    def from_dsl(cls, dsl_string: str, metric_function: Callable[[Any], float]) -> 'ExperimentDefinition':
        """
        Creates an ExperimentDefinition from a DSL (Domain Specific Language) string (JSON).

        Args:
            dsl_string: A JSON string representing the experiment definition.
            metric_function: The metric function to use for the experiment.

        Returns:
            An ExperimentDefinition object.
        """
        try:
            data = json.loads(dsl_string)
            name = data["name"]
            experiment_type = ExperimentType(data["experiment_type"])
            variants = data["variants"]
            traffic_split = data["traffic_split"]
            description = data.get("description", "")  # Optional description
            minimum_runs = data.get("minimum_runs", 30)
            confidence_level = data.get("confidence_level", 0.95)

            return cls(
                name=name,
                experiment_type=experiment_type,
                variants=variants,
                traffic_split=traffic_split,
                metric_function=metric_function,
                description=description,
                minimum_runs=minimum_runs,
                confidence_level=confidence_level
            )
        except (json.JSONDecodeError, KeyError, ValueError) as e:
            raise ValueError(f"Invalid DSL format: {e}")


# Central class for running experiments
class ExperimentRunner:
    def __init__(self):
        self.experiment_data: Dict[str, List[float]] = {} # Store metric data for each variant
        self.experiment_runs: Dict[str, int] = {} # Keep track of how many times each variant has been run
        self.rollout_complete: Dict[str, bool] = {} # Track if rollout is complete for each experiment
        self.dashboard_data: Dict[str, Any] = {} # Store data for the dashboard

    def run_experiment(self, experiment_definition: ExperimentDefinition, execution_function: Callable[[Any], Any], auto_rollout: bool = False) -> None:
        """
        Runs the A/B test experiment.

        Args:
            experiment_definition: An ExperimentDefinition object that defines the experiment.
            execution_function: A function that takes a variant's configuration as input and
                                returns the result of running that variant.
            auto_rollout: Whether to automatically roll out the winning variant after the experiment.
        """

        print(f"Starting experiment: {experiment_definition.name}")
        self.rollout_complete[experiment_definition.name] = False

        # Initialize data structures for storing results
        for variant_name in experiment_definition.variants:
            self.experiment_data[variant_name] = []
            self.experiment_runs[variant_name] = 0

        # Run the experiment until minimum runs are met for all variants
        while any(self.experiment_runs[variant] < experiment_definition.minimum_runs for variant in experiment_definition.variants):
            chosen_variant = self._choose_variant(experiment_definition.traffic_split)
            print(f"Running variant: {chosen_variant}")

            # Execute the variant and collect the metric
            try:
                variant_config = experiment_definition.variants[chosen_variant]
                execution_result = execution_function(variant_config)
                metric_value = experiment_definition.metric_function(execution_result)

                # Store the metric data
                self.experiment_data[chosen_variant].append(metric_value)
                self.experiment_runs[chosen_variant] += 1

                print(f"Variant {chosen_variant} - Metric: {metric_value}")

            except Exception as e:
                print(f"Error running variant {chosen_variant}: {e}")

        print(f"Experiment {experiment_definition.name} complete.  Minimum runs achieved.")

        # Analyze results and potentially auto-rollout
        results = self.analyze_results()
        print("Experiment Results:")
        print(results)
        self.update_dashboard_data(experiment_definition, results)  # Update for dashboard

        winning_variant = self.select_winner(results, experiment_definition.confidence_level)

        if auto_rollout and winning_variant:
            print("Auto-rollout enabled.")
            self.rollout_variant(experiment_definition, winning_variant, self._default_rollout_function) # Use a default rollout function if none provided.
            self.rollout_complete[experiment_definition.name] = True
        else:
            print("Auto-rollout disabled or no winner detected.")


    def _choose_variant(self, traffic_split: Dict[str, float]) -> str:
        """
        Chooses a variant based on the traffic split.
        """
        rand = random.random() # Generate a random number between 0 and 1
        cumulative_probability = 0.0
        for variant, percentage in traffic_split.items():
            cumulative_probability += percentage
            if rand < cumulative_probability:
                return variant
        return list(traffic_split.keys())[-1] # Fallback

    def analyze_results(self) -> Dict[str, Dict[str, Union[float, int]]]:
        """
        Analyzes the results of the experiment and calculates summary statistics.
        """
        results = {}
        for variant, data in self.experiment_data.items():
            if not data:
                results[variant] = {"mean": None, "std_dev": None, "runs": 0}
            else:
                results[variant] = {
                    "mean": statistics.mean(data),
                    "std_dev": statistics.stdev(data) if len(data) > 1 else 0.0,
                    "runs": self.experiment_runs[variant]
                }
        return results

    def select_winner(self, results: Dict[str, Dict[str, Union[float, int]]], confidence_level: float = 0.95) -> Union[str, None]:
        """
        Selects the winning variant based on a t-test for statistical significance.

        Args:
            results: The results of the analysis (output from analyze_results).
            confidence_level: The desired confidence level (alpha) for the t-test.

        Returns:
            The name of the winning variant, or None if no winner can be determined.
        """
        control_variant = None
        best_variant = None
        best_mean = float('-inf')

        # Identify the control variant (if any) and the variant with the best mean
        for variant, data in results.items():
            if "control" in variant.lower():  # Simple check for "control" in the name
                control_variant = variant
            if data["mean"] is not None and data["mean"] > best_mean:
                best_mean = data["mean"]
                best_variant = variant

        if not control_variant:
            print("No control variant found.  Cannot perform statistical significance test.")
            return best_variant  # Return the best variant based on mean alone

        # Perform t-test if a control variant exists
        control_data = self.experiment_data[control_variant]
        best_data = self.experiment_data[best_variant]

        if len(control_data) < 2 or len(best_data) < 2:
            print("Not enough data to perform a t-test.  Need at least 2 runs per variant.")
            return None

        # Perform independent samples t-test
        t_statistic, p_value = stats.ttest_ind(best_data, control_data)

        alpha = 1 - confidence_level

        print(f"T-test: t={t_statistic}, p={p_value}, alpha={alpha}")

        if p_value < alpha:
            print(f"Winner selected: {best_variant} is statistically significantly better than the control.")
            return best_variant
        else:
            print("No statistically significant winner found.")
            return None


    def rollout_variant(self, experiment_definition: ExperimentDefinition, winning_variant: str, rollout_function: Callable[[Any], None]) -> None:
        """
        Rolls out the winning variant.
        """
        print(f"Rolling out variant: {winning_variant} for experiment: {experiment_definition.name}")
        try:
            winning_config = experiment_definition.variants[winning_variant]
            rollout_function(winning_config)
            print(f"Rollout of {winning_variant} successful.")
        except Exception as e:
            print(f"Error rolling out variant {winning_variant}: {e}")

    def _default_rollout_function(self, config: Dict[str, Any]) -> None:
        """
        A default rollout function that simply prints the configuration.  This should be
        overridden with a real rollout implementation.
        """
        print("Default Rollout Function:")
        print(f"Rolling out configuration: {config}")

    def is_rollout_complete(self, experiment_name: str) -> bool:
        """
        Checks if the rollout is complete for a given experiment.
        """
        return self.rollout_complete.get(experiment_name, False)

    def update_dashboard_data(self, experiment_definition: ExperimentDefinition, results: Dict[str, Dict[str, Union[float, int]]]) -> None:
        """
        Updates the dashboard data with the latest experiment results.
        """
        self.dashboard_data[experiment_definition.name] = {
            "experiment_name": experiment_definition.name,
            "experiment_type": experiment_definition.experiment_type.value,
            "variants": list(experiment_definition.variants.keys()),
            "results": results,
            "rollout_complete": self.is_rollout_complete(experiment_definition.name),
            "timestamp": time.time()  # Add a timestamp for when the data was updated
        }

    def get_dashboard_data(self) -> Dict[str, Any]:
        """
        Returns the data for the dashboard.
        """
        return self.dashboard_data


# Example Usage (Illustrative)

if __name__ == '__main__':

    # 1. Define an experiment using the DSL
    experiment_dsl = """
    {
        "name": "Summarization Skill A/B Test v2",
        "experiment_type": "skill_variants",
        "variants": {
            "variant_A": {"skill_name": "summarization_v1", "temperature": 0.7},
            "variant_B": {"skill_name": "summarization_v2", "temperature": 0.9},
            "control": {"skill_name": "summarization_baseline", "temperature": 0.8}
        },
        "traffic_split": {
            "variant_A": 0.3,
            "variant_B": 0.3,
            "control": 0.4
        },
        "description": "Testing different summarization skills and temperature settings.",
        "minimum_runs": 20,
        "confidence_level": 0.95
    }
    """

    def summarization_quality(result: str) -> float:
        """Example metric function:  Placeholder for a more sophisticated quality metric."""
        if "excellent" in result.lower():
            return 1.0
        elif "good" in result.lower():
            return 0.7
        elif "ok" in result.lower():
            return 0.5
        else:
            return 0.2


    experiment_definition = ExperimentDefinition.from_dsl(experiment_dsl, summarization_quality)

    # 2. Define an execution function (this simulates calling a skill)
    def execute_summarization_skill(config: Dict[str, Any]) -> str:
        """Simulates calling a summarization skill.  Replace with actual skill execution."""
        skill_name = config["skill_name"]
        temperature = config["temperature"]
        print(f"Executing skill: {skill_name} with temperature: {temperature}")
        # Simulate different summarization qualities based on the skill name
        if skill_name == "summarization_v1":
            time.sleep(0.1) # Simulate some processing time
            return "Good summarization."
        elif skill_name == "summarization_v2":
            time.sleep(0.2)
            return "Excellent summarization!"
        else:
            time.sleep(0.05)
            return "OK summarization."


    # 3. Run the experiment with auto-rollout
    experiment_runner = ExperimentRunner()
    experiment_runner.run_experiment(experiment_definition, execute_summarization_skill, auto_rollout=True)

    # 4. Get the dashboard data
    dashboard_data = experiment_runner.get_dashboard_data()
    print("Dashboard Data:")
    print(dashboard_data)  # Print the dashboard data to the console

    # Example: Check if rollout is complete
    if experiment_runner.is_rollout_complete(experiment_definition.name):
        print(f"Rollout for experiment {experiment_definition.name} is complete.")
    else:
        print(f"Rollout for experiment {experiment_definition.name} is not complete.")
