import unittest
import time
import os
import json
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

try:
    from qwen_unified.qwen import Qwen  # Assuming Qwen class is in qwen.py
except ImportError as e:
    logging.error(f"Failed to import Qwen class: {e}. Ensure the qwen module is correctly installed and accessible.")
    Qwen = None # Define Qwen as None to avoid NameError if import fails

class ContextStressTest(unittest.TestCase):

    def setUp(self):
        # Initialize Qwen with API key and model name
        api_key = os.environ.get("QWEN_API_KEY")
        model_name = "Qwen-72B-Chat"  # Or any other suitable model for testing
        if not api_key:
            logging.warning("QWEN_API_KEY environment variable not set. Tests may fail.")
            self.qwen = None
        elif Qwen is None:
            self.qwen = None
        else:
            try:
                self.qwen = Qwen(api_key=api_key, model_name=model_name)
            except Exception as e:
                logging.error(f"Failed to initialize Qwen: {e}")
                self.qwen = None
        self.report = {}

    def generate_long_context(self, size_kb):
        """Generates a context of a specified size in KB."""
        size_bytes = size_kb * 1024
        long_string = "This is a test sentence. " * (size_bytes // len("This is a test sentence. "))
        return long_string

    def test_context_size_limits(self):
        """Tests Qwen's context size handling with increasing context sizes."""
        if not self.qwen:
            self.skipTest("Qwen initialization failed. Skipping tests.")

        context_sizes = [2, 4, 8, 16, 32, 64, 128]  # Context sizes in KB
        results = {}

        for size_kb in context_sizes:
            try:
                context = self.generate_long_context(size_kb)
                prompt = f"Summarize the following text. {context}"
                start_time = time.time()
                response = self.qwen.generate_response(prompt)
                end_time = time.time()
                duration = end_time - start_time

                if response:
                    results[size_kb] = {"status": "success", "duration": duration, "response": response}
                    logging.info(f"Context size {size_kb}KB: Success. Response length: {len(response)}")
                else:
                     results[size_kb] = {"status": "failure", "duration": duration, "error": "No response received."}
                     logging.error(f"Context size {size_kb}KB: Failure - No response received.")


            except Exception as e:
                results[size_kb] = {"status": "failure", "duration": time.time() - start_time, "error": str(e)}
                logging.error(f"Context size {size_kb}KB: Failure - {e}")

            self.report = results # Store the results in the test report
            time.sleep(2) # add sleep to avoid rate limits

        self.generate_report_file()

        # Assert that at least one test succeeded and one test failed, indicating limits were reached
        success_count = sum(1 for result in results.values() if result["status"] == "success")
        failure_count = sum(1 for result in results.values() if result["status"] == "failure")

        self.assertGreater(success_count, 0, "At least one test should succeed.")
        self.assertGreater(failure_count, 0, "At least one test should fail, indicating context limits were reached.")


    def generate_report_file(self):
        """Generates a report file summarizing the test results."""
        report_path = "/mnt/e/genesis-system/AIVA/qwen-unified/tests/context_stress_report.md"
        try:
            with open(report_path, "w") as f:
                f.write("# Context Stress Test Report\n\n")
                f.write("This report summarizes the results of the context stress test for the Qwen model.\n\n")
                f.write("## Test Results\n\n")
                for size_kb, result in self.report.items():
                    f.write(f"### Context Size: {size_kb}KB\n")
                    f.write(f"- Status: {result['status']}\n")
                    f.write(f"- Duration: {result['duration']:.2f} seconds\n")
                    if "error" in result:
                        f.write(f"- Error: {result['error']}\n")
                    if "response" in result:
                        f.write(f"- First 100 chars of Response: {result['response'][:100]}...\n")
                    f.write("\n")
            logging.info(f"Report file generated at: {report_path}")
        except Exception as e:
            logging.error(f"Failed to generate report file: {e}")

if __name__ == "__main__":
    unittest.main()