import numpy as np
import pandas as pd
import random

class SelfImprovingValidationSystem:
    """
    Implementation of Patent 9: Self-Improving Validation System.

    This system automatically adjusts validation thresholds based on outcomes,
    uses reinforcement learning for continuous improvement, and tracks
    performance metrics.
    """

    def __init__(self, initial_thresholds, metrics=['accuracy', 'precision', 'recall', 'f1_score'],
                 learning_rate=0.1, discount_factor=0.9, exploration_rate=0.1):
        """
        Initializes the validation system.

        Args:
            initial_thresholds (dict): A dictionary of initial thresholds for different validation criteria.
                                        Example: {'score_a': 0.7, 'score_b': 0.8}
            metrics (list): List of performance metrics to track.
            learning_rate (float): Learning rate for the reinforcement learning algorithm.
            discount_factor (float): Discount factor for the reinforcement learning algorithm.
            exploration_rate (float): Exploration rate for the reinforcement learning algorithm (epsilon-greedy).
        """
        self.thresholds = initial_thresholds
        self.metrics = metrics
        self.history = []  # Store historical performance and thresholds
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.exploration_rate = exploration_rate
        self.q_table = {}  # Q-table for reinforcement learning
        self.state = self._get_state()

    def _get_state(self):
        """
        Defines the state of the system based on current thresholds.
        This could be more sophisticated, using recent performance, etc.

        Returns:
            tuple: A tuple representing the current state of the system based on thresholds.
        """
        return tuple(sorted(self.thresholds.items())) # Sort to ensure consistent state representation


    def validate(self, data):
        """
        Validates a data point against the current thresholds.

        Args:
            data (dict): A dictionary containing the data to validate.  Must have keys corresponding to the thresholds.
                         Example: {'score_a': 0.75, 'score_b': 0.85}

        Returns:
            bool: True if the data point passes validation, False otherwise.
        """
        for criterion, threshold in self.thresholds.items():
            if data[criterion] < threshold:
                return False
        return True

    def evaluate(self, data, ground_truth):
        """
        Evaluates the performance of the validation system on a set of data
        and calculates performance metrics.

        Args:
            data (list): A list of data points (dictionaries) to validate.
            ground_truth (list): A list of ground truth labels (True/False)
                                 corresponding to each data point.

        Returns:
            dict: A dictionary of performance metrics.
        """
        predictions = [self.validate(d) for d in data]
        tp = sum(1 for p, gt in zip(predictions, ground_truth) if p and gt)
        tn = sum(1 for p, gt in zip(predictions, ground_truth) if not p and not gt)
        fp = sum(1 for p, gt in zip(predictions, ground_truth) if p and not gt)
        fn = sum(1 for p, gt in zip(predictions, ground_truth) if not p and gt)

        results = {}
        if 'accuracy' in self.metrics:
            results['accuracy'] = (tp + tn) / len(data) if len(data) > 0 else 0
        if 'precision' in self.metrics:
            results['precision'] = tp / (tp + fp) if (tp + fp) > 0 else 0
        if 'recall' in self.metrics:
            results['recall'] = tp / (tp + fn) if (tp + fn) > 0 else 0
        if 'f1_score' in self.metrics:
            precision = results.get('precision', tp / (tp + fp) if (tp + fp) > 0 else 0)
            recall = results.get('recall', tp / (tp + fn) if (tp + fn) > 0 else 0)
            results['f1_score'] = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        return results

    def adjust_thresholds(self, reward):
        """
        Adjusts the thresholds based on the reward received from the environment
        using a reinforcement learning approach (Q-learning).

        Args:
            reward (float): The reward received from the environment.
        """
        state = self.state
        action = self._choose_action()
        new_state = self._apply_action(action)

        # Update Q-table
        if (state, action) not in self.q_table:
            self.q_table[(state, action)] = 0.0

        best_next_action = self._choose_best_action(new_state)
        if (new_state, best_next_action) not in self.q_table:
            self.q_table[(new_state, best_next_action)] = 0.0

        old_value = self.q_table.get((state, action), 0.0)
        next_max = self.q_table.get((new_state, best_next_action), 0.0)

        new_value = (1 - self.learning_rate) * old_value + self.learning_rate * (reward + self.discount_factor * next_max)
        self.q_table[(state, action)] = new_value

        self.state = new_state  # Update the current state


    def _choose_action(self):
        """
        Chooses an action using an epsilon-greedy approach.

        Returns:
            tuple: A tuple representing the action to take.  The tuple contains the criterion to adjust and the adjustment amount.
        """
        if random.random() < self.exploration_rate:
            # Explore: Choose a random action
            criterion = random.choice(list(self.thresholds.keys()))
            adjustment = random.choice([-0.05, 0.05])  # Adjust threshold by +/- 0.05
            return (criterion, adjustment)
        else:
            # Exploit: Choose the best action based on the Q-table
            return self._choose_best_action(self.state)


    def _choose_best_action(self, state):
        """
        Chooses the best action based on the Q-table for a given state.

        Args:
            state (tuple): The current state of the system.

        Returns:
            tuple: The best action to take.
        """
        best_action = None
        best_value = float('-inf')
        for criterion in self.thresholds.keys():
            for adjustment in [-0.05, 0.05]:
                action = (criterion, adjustment)
                value = self.q_table.get((state, action), 0.0)  # Default Q-value is 0.0
                if value > best_value:
                    best_value = value
                    best_action = action
        if best_action is None:
            #If no actions have been tried, return a random action
            criterion = random.choice(list(self.thresholds.keys()))
            adjustment = random.choice([-0.05, 0.05])
            return (criterion, adjustment)

        return best_action



    def _apply_action(self, action):
        """
        Applies an action to the system by adjusting a threshold.

        Args:
            action (tuple): The action to apply.  Contains the criterion and the adjustment amount.

        Returns:
            tuple: The new state of the system after applying the action.
        """
        criterion, adjustment = action
        self.thresholds[criterion] += adjustment
        self.thresholds[criterion] = max(0.0, min(1.0, self.thresholds[criterion]))  # Keep thresholds within [0, 1]
        return self._get_state()



    def train(self, data, ground_truth, epochs=100):
        """
        Trains the validation system using reinforcement learning.

        Args:
            data (list): A list of data points (dictionaries) to validate.
            ground_truth (list): A list of ground truth labels (True/False)
                                 corresponding to each data point.
            epochs (int): The number of training epochs.
        """
        for epoch in range(epochs):
            # Evaluate performance
            results = self.evaluate(data, ground_truth)

            # Define a reward function (example: F1-score)
            reward = results.get('f1_score', 0.0)

            # Store history
            self.history.append({'epoch': epoch, 'thresholds': self.thresholds.copy(), 'metrics': results})

            # Adjust thresholds based on reward
            self.adjust_thresholds(reward)

            print(f"Epoch {epoch + 1}/{epochs}, F1-score: {reward:.4f}, Thresholds: {self.thresholds}")


    def get_improvement_analytics(self):
        """
        Provides improvement analytics based on the training history.

        Returns:
            pandas.DataFrame: A DataFrame containing the training history.
        """
        return pd.DataFrame(self.history)


# Example Usage:
if __name__ == '__main__':
    # Sample data
    data = [
        {'score_a': 0.75, 'score_b': 0.85},
        {'score_a': 0.60, 'score_b': 0.70},
        {'score_a': 0.90, 'score_b': 0.95},
        {'score_a': 0.55, 'score_b': 0.65},
        {'score_a': 0.80, 'score_b': 0.90},
    ]
    ground_truth = [True, False, True, False, True]

    # Initialize the validation system
    initial_thresholds = {'score_a': 0.7, 'score_b': 0.8}
    system = SelfImprovingValidationSystem(initial_thresholds)

    # Train the system
    system.train(data, ground_truth, epochs=50)

    # Get improvement analytics
    analytics = system.get_improvement_analytics()
    print("\nImprovement Analytics:")
    print(analytics)

    #Example of using the system after training
    new_data = {'score_a': 0.82, 'score_b': 0.88}
    is_valid = system.validate(new_data)
    print(f"\nNew data validation result: {is_valid}")
