import unittest
import time
import threading
import random

from src.scaling_policies.latency_policy import LatencyPolicy
from src.simulators.request_simulator import RequestSimulator

class TestLatencyScaling(unittest.TestCase):

    def setUp(self):
        self.initial_instances = 2
        self.latency_threshold = 0.5  # 500ms
        self.latency_low_threshold = 0.2  # 200ms
        self.scale_up_delay = 5  # seconds
        self.scale_down_delay = 5 # seconds

        self.current_instances = self.initial_instances
        self.scaling_events = []
        self.lock = threading.Lock()

        self.policy = LatencyPolicy(
            initial_instances=self.initial_instances,
            latency_threshold=self.latency_threshold,
            latency_low_threshold=self.latency_low_threshold,
            scale_up_delay=self.scale_up_delay,
            scale_down_delay=self.scale_down_delay,
            callback=self.scaling_callback
        )

        self.simulator = RequestSimulator(self.policy.record_latency)
        self.simulator.start()

    def tearDown(self):
        self.simulator.stop()
        self.simulator.join()

    def scaling_callback(self, new_instances, event_type):
        with self.lock:
            self.current_instances = new_instances
            self.scaling_events.append({
                'time': time.time(),
                'instances': new_instances,
                'type': event_type,
                'latency': self.policy.average_latency
            })

    def test_scale_up(self):
        # Simulate high latency
        for _ in range(20): # Ensure enough requests to trigger scaling
            self.simulator.submit_request(random.uniform(0.6, 0.8))
            time.sleep(0.1)

        time.sleep(self.scale_up_delay + 1) # Wait for scaling to occur

        with self.lock:
            self.assertGreater(self.current_instances, self.initial_instances)
            self.assertTrue(any(event['type'] == 'scale_up' for event in self.scaling_events))

    def test_scale_down(self):
        # First, scale up
        for _ in range(20):
            self.simulator.submit_request(random.uniform(0.6, 0.8))
            time.sleep(0.1)

        time.sleep(self.scale_up_delay + 1)

        # Then, simulate low latency
        for _ in range(20):
            self.simulator.submit_request(random.uniform(0.1, 0.15))
            time.sleep(0.1)

        time.sleep(self.scale_down_delay + 1)

        with self.lock:
            self.assertLess(self.current_instances, self.initial_instances + 1)
            self.assertTrue(any(event['type'] == 'scale_down' for event in self.scaling_events))

    def test_scaling_events_recorded(self):
        # Simulate high latency
        for _ in range(20):
            self.simulator.submit_request(random.uniform(0.6, 0.8))
            time.sleep(0.1)

        time.sleep(self.scale_up_delay + 1)

        with self.lock:
            self.assertTrue(len(self.scaling_events) > 0)
            for event in self.scaling_events:
                self.assertIsNotNone(event['time'])
                self.assertIsNotNone(event['instances'])
                self.assertIsNotNone(event['type'])
                self.assertIsNotNone(event['latency'])

if __name__ == '__main__':
    unittest.main()
