# context_optimizer.py

import os
import subprocess
import json

class ContextOptimizer:
    def __init__(self, model_name, data_path, evaluation_script):
        self.model_name = model_name
        self.data_path = data_path
        self.evaluation_script = evaluation_script

    def optimize_context_window(self, context_window_range, learning_rate, batch_size):
        """Optimizes the context window size for a given model.

        Args:
            context_window_range: A list of context window sizes to test.
            learning_rate: The learning rate to use for training.
            batch_size: The batch size to use for training.

        Returns:
            A dictionary containing the results of the optimization.
        """
        results = {}
        best_context_window = None
        best_score = float('-inf')

        for context_window in context_window_range:
            print(f"\n\nTesting context window size: {context_window}\n\n")
            # 1. Modify the model configuration.
            self.modify_model_config(context_window)

            # 2. Train the model.
            train_success = self.train_model(learning_rate, batch_size)

            if not train_success:
                print(f"Training failed for context window size: {context_window}\n")
                results[context_window] = {'train_success': False, 'score': float('-inf')}
                continue

            # 3. Evaluate the model.
            score = self.evaluate_model()

            results[context_window] = {'train_success': True, 'score': score}

            # 4. Update the best context window.
            if score > best_score:
                best_score = score
                best_context_window = context_window

        print("\nOptimization complete.\n")
        print(f"Best context window size: {best_context_window} with score: {best_score}\n")

        return {
            'results': results,
            'best_context_window': best_context_window,
            'best_score': best_score
        }

    def modify_model_config(self, context_window):
        """Modifies the model configuration file with the given context window size.

        This function assumes that there's a separate script (modelfile_generator.py)
        that handles model configuration modification.
        """
        try:
            command = ["python", "/mnt/e/genesis-system/AIVA/qwen-unified/modelfile_generator.py",
                       "--model_name", self.model_name,
                       "--context_window", str(context_window)]
            subprocess.run(command, check=True, capture_output=True, text=True)
            print(f"Successfully modified model config for context window: {context_window}")
        except subprocess.CalledProcessError as e:
            print(f"Error modifying model config: {e.stderr}")
            raise
        except FileNotFoundError:
            print("Error: modelfile_generator.py not found.")
            raise

    def train_model(self, learning_rate, batch_size):
        """Trains the model with the given learning rate and batch size.

        This function would ideally execute the training process.  For now, it's stubbed.
        Replace this with the actual training command.
        """
        print(f"Starting training with learning rate: {learning_rate} and batch size: {batch_size}\n")
        # Simulate training success for now.
        # TODO: Replace with actual training code.
        try:
            # Replace this with your actual training command
            command = ["python", "/mnt/e/genesis-system/AIVA/qwen-unified/train.py",
                       "--model_name", self.model_name,
                       "--learning_rate", str(learning_rate),
                       "--batch_size", str(batch_size),
                       "--data_path", self.data_path]
            result = subprocess.run(command, check=True, capture_output=True, text=True)
            print(f"Training output: {result.stdout}\n")
            return True
        except subprocess.CalledProcessError as e:
            print(f"Training failed: {e.stderr}\n")
            return False
        except FileNotFoundError:
            print("Error: train.py not found.  Ensure training script exists.")
            return False

    def evaluate_model(self):
        """Evaluates the model using the specified evaluation script.

        This function executes the evaluation script and returns the score.
        """
        try:
            command = ["python", self.evaluation_script, "--model_name", self.model_name]
            result = subprocess.run(command, check=True, capture_output=True, text=True)
            print(f"Evaluation output: {result.stdout}\n")
            # Parse the score from the output.
            # Assumes the evaluation script prints the score in JSON format.
            try:
                score_data = json.loads(result.stdout)
                score = score_data['score']
                print(f"Evaluation score: {score}\n")
                return score
            except json.JSONDecodeError:
                print(f"Error decoding evaluation output: {result.stdout}\n")
                return float('-inf') # Indicate failure
        except subprocess.CalledProcessError as e:
            print(f"Evaluation failed: {e.stderr}\n")
            return float('-inf') # Indicate failure
        except FileNotFoundError:
            print(f"Error: Evaluation script not found: {self.evaluation_script}\n")
            return float('-inf')


if __name__ == '__main__':
    # Example usage
    model_name = "Qwen-Unified-AIVA"
    data_path = "/mnt/e/genesis-system/data/training_data.txt" # Replace with actual path
    evaluation_script = "/mnt/e/genesis-system/AIVA/qwen-unified/evaluate.py" # Replace with actual path
    context_window_range = [2048, 4096, 8192]
    learning_rate = 0.0001
    batch_size = 32

    optimizer = ContextOptimizer(model_name, data_path, evaluation_script)
    optimization_results = optimizer.optimize_context_window(context_window_range, learning_rate, batch_size)

    print("\nFinal Optimization Results:\n")
    print(json.dumps(optimization_results, indent=4))
