import time
import logging
import argparse
import json
from typing import Dict, Any, List
import requests

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class AivaBenchmark:
    """
    A class for benchmarking AIVA performance.
    """

    def __init__(self, endpoint: str, baseline_results: Dict[str, Any] = None) -> None:
        """
        Initializes the AivaBenchmark class.

        Args:
            endpoint: The AIVA endpoint to benchmark.
            baseline_results: A dictionary containing baseline benchmark results for comparison.
        """
        self.endpoint = endpoint
        self.baseline_results = baseline_results or {}
        logging.info(f"AIVA Benchmark initialized with endpoint: {self.endpoint}")

    def run_test(self, prompt: str, test_name: str) -> Dict[str, Any]:
        """
        Runs a single benchmark test.

        Args:
            prompt: The prompt to send to the AIVA endpoint.
            test_name: The name of the test.

        Returns:
            A dictionary containing the benchmark results.
        """
        start_time = time.time()
        first_token_time = None
        tokens_received = 0
        response_content = ""

        try:
            response = requests.post(self.endpoint, json={"prompt": prompt}, stream=True)
            response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)

            for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
                if not first_token_time:
                    first_token_time = time.time()
                response_content += chunk
                tokens_received += len(chunk.split())  # Approximate tokens by splitting on spaces


            end_time = time.time()

            total_latency = end_time - start_time
            time_to_first_token = first_token_time - start_time if first_token_time else total_latency #if streaming fails
            tokens_per_second = tokens_received / total_latency if total_latency > 0 else 0

            results = {
                "test_name": test_name,
                "prompt": prompt,
                "total_latency": total_latency,
                "time_to_first_token": time_to_first_token,
                "tokens_per_second": tokens_per_second,
                "tokens_received": tokens_received,
                "response_content": response_content
            }

            logging.info(f"Test '{test_name}' completed successfully.")
            return results

        except requests.exceptions.RequestException as e:
            logging.error(f"Error during request: {e}")
            return {
                "test_name": test_name,
                "error": str(e)
            }
        except Exception as e:
            logging.error(f"An unexpected error occurred: {e}")
            return {
                "test_name": test_name,
                "error": str(e)
            }

    def compare_with_baseline(self, results: Dict[str, Any]) -> Dict[str, Any]:
        """
        Compares the benchmark results with the baseline results.

        Args:
            results: The benchmark results to compare.

        Returns:
            A dictionary containing the comparison results.
        """
        test_name = results["test_name"]
        if test_name in self.baseline_results:
            baseline = self.baseline_results[test_name]
            comparison = {}
            for metric in ["total_latency", "time_to_first_token", "tokens_per_second"]:
                if metric in results and metric in baseline:
                    comparison[metric] = {
                        "current": results[metric],
                        "baseline": baseline[metric],
                        "difference": results[metric] - baseline[metric],
                        "percentage_change": (results[metric] - baseline[metric]) / baseline[metric] * 100 if baseline[metric] != 0 else 0
                    }
            return comparison
        else:
            logging.warning(f"No baseline found for test: {test_name}")
            return {}

    def run_benchmarks(self, tests: List[Dict[str, str]]) -> Dict[str, Any]:
        """
        Runs a series of benchmark tests.

        Args:
            tests: A list of dictionaries, where each dictionary contains the prompt and test name.

        Returns:
            A dictionary containing all benchmark results and comparisons.
        """
        all_results = {}
        for test in tests:
            test_name = test["test_name"]
            prompt = test["prompt"]
            logging.info(f"Running test: {test_name}")
            results = self.run_test(prompt, test_name)
            all_results[test_name] = results
            comparison = self.compare_with_baseline(results)
            if comparison:
                all_results[test_name]["comparison"] = comparison
        return all_results


def main():
    """
    Main function to run the AIVA benchmarks.
    """
    parser = argparse.ArgumentParser(description="AIVA Benchmark Suite")
    parser.add_argument("--endpoint", type=str, default="http://localhost:23405/api/generate", help="AIVA endpoint URL")
    parser.add_argument("--baseline", type=str, help="Path to baseline results JSON file")
    args = parser.parse_args()

    try:
        with open(args.baseline, "r") as f:
            baseline_results = json.load(f)
            logging.info(f"Loaded baseline results from {args.baseline}")
    except (FileNotFoundError, TypeError):
        baseline_results = None
        logging.warning("No baseline file provided or file not found. Running without baseline comparison.")
    except json.JSONDecodeError as e:
        logging.error(f"Error decoding JSON from baseline file: {e}")
        baseline_results = None


    aiva_benchmark = AivaBenchmark(args.endpoint, baseline_results)

    tests = [
        {"test_name": "simple_prompt", "prompt": "What is the capital of France?"},
        {"test_name": "complex_prompt", "prompt": "Explain the theory of relativity in simple terms."},
        {"test_name": "long_context", "prompt": "Write a long story about a cat who goes on an adventure. The cat's name is Mittens. Mittens lives in a small town. The town is called Willow Creek. Willow Creek is located in the mountains. The mountains are called the Blue Ridge Mountains. Mittens loves to explore. One day, Mittens decides to leave Willow Creek and go on an adventure. Mittens meets many interesting characters along the way. Mittens learns many valuable lessons. Mittens eventually returns to Willow Creek, a changed cat."},
    ]

    all_results = aiva_benchmark.run_benchmarks(tests)

    print(json.dumps(all_results, indent=4))


if __name__ == "__main__":
    main()
