import unittest
from aiva.rate_limiting.threshold_manager import ThresholdManager

class TestThresholdManager(unittest.TestCase):

    def test_initialization(self):
        thresholds = [100, 200, 300]
        hysteresis_percentages = [5, 5]
        tm = ThresholdManager(thresholds, hysteresis_percentages)
        self.assertEqual(tm.thresholds, thresholds)
        self.assertEqual(tm.hysteresis_percentages, hysteresis_percentages)
        self.assertEqual(tm.current_threshold_index, 0)
        self.assertEqual(tm.get_current_threshold(), 100)

    def test_transition_to_higher_threshold(self):
        thresholds = [100, 200, 300]
        hysteresis_percentages = [5, 5]
        tm = ThresholdManager(thresholds, hysteresis_percentages)
        tm.check_threshold(211) # 200 + 200 * 0.05 = 210
        self.assertEqual(tm.current_threshold_index, 1)
        self.assertEqual(tm.get_current_threshold(), 200)

    def test_transition_to_lower_threshold(self):
        thresholds = [100, 200, 300]
        hysteresis_percentages = [5, 5]
        tm = ThresholdManager(thresholds, hysteresis_percentages)
        tm.current_threshold_index = 2
        tm.check_threshold(189) # 200 - 200 * 0.05 = 190
        self.assertEqual(tm.current_threshold_index, 1)
        self.assertEqual(tm.get_current_threshold(), 200)

    def test_no_transition(self):
        thresholds = [100, 200, 300]
        hysteresis_percentages = [5, 5]
        tm = ThresholdManager(thresholds, hysteresis_percentages)
        tm.check_threshold(150)
        self.assertEqual(tm.current_threshold_index, 0)
        self.assertEqual(tm.get_current_threshold(), 100)

    def test_invalid_initialization(self):
        thresholds = [100, 200, 300]
        hysteresis_percentages = [5]
        with self.assertRaises(ValueError):
            ThresholdManager(thresholds, hysteresis_percentages)

if __name__ == '__main__':
    unittest.main()