import unittest
import time
from src.aiva.scaler import Scaler
from src.aiva import metrics
from prometheus_client import REGISTRY, CollectorRegistry, generate_latest
import threading

class TestScaler(unittest.TestCase):

    def setUp(self):
        self.registry = CollectorRegistry(auto_describe=True)
        self.scaler = Scaler(downscale_duration=2)  # Short duration for testing
        # Ensure the timer is not running before each test
        self.scaler.stop_downscale_timer()

    def test_start_stop_downscale_timer(self):
        self.scaler.check_resource_utilization(0.4) # Below threshold
        self.assertTrue(self.scaler.is_active)
        self.assertIsNotNone(self.scaler.timer)
        self.assertTrue(self.scaler.timer.is_alive())
        self.assertEqual(metrics.aiva_downscale_timer_active._value.get(), 1)

        self.scaler.check_resource_utilization(0.6) # Above threshold
        self.assertFalse(self.scaler.is_active)
        if self.scaler.timer: #timer may have already finished in the test
            self.assertFalse(self.scaler.timer.is_alive())
        self.assertEqual(metrics.aiva_downscale_timer_active._value.get(), 0)

    def test_downscale_execution(self):
        # Mock the downscale method to avoid actual downscaling
        downscale_executed = threading.Event()
        def mock_downscale():
            self.scaler.is_active = False
            metrics.aiva_downscale_timer_active.set(0)
            downscale_executed.set()

        self.scaler.downscale = mock_downscale

        self.scaler.check_resource_utilization(0.4)  # Start the timer
        downscale_executed.wait(5) # wait for downscale to execute

        self.assertTrue(downscale_executed.is_set())
        self.assertFalse(self.scaler.is_active)
        self.assertEqual(metrics.aiva_downscale_timer_active._value.get(), 0)

    def test_get_remaining_time(self):
        self.scaler.check_resource_utilization(0.4)  # Start the timer
        time.sleep(1)  # Wait for some time to pass
        remaining_time = self.scaler.get_remaining_time()
        self.assertGreater(remaining_time, 0)
        self.assertLess(remaining_time, 2)
        self.scaler.stop_downscale_timer()
        self.assertEqual(self.scaler.get_remaining_time(), 0)

    def test_update_metrics(self):
        self.scaler.check_resource_utilization(0.4) # Start the timer
        time.sleep(1)
        self.scaler.update_metrics()
        metrics_string = metrics.generate_metrics_string(self.registry).decode('utf-8')
        # Check if aiva_downscale_timer_remaining metric is present and has a value > 0
        self.assertIn('aiva_downscale_timer_remaining', metrics_string)
        remaining_time = self.scaler.get_remaining_time()
        self.assertGreaterEqual(remaining_time, 0)

if __name__ == '__main__':
    unittest.main()
