import unittest
import random
import statistics

from aiva.utils.jitter import generate_jitter

class TestJitterGeneration(unittest.TestCase):

    def test_jitter_within_range(self):
        min_val = -0.5
        max_val = 0.5
        for _ in range(1000):
            jitter_value = generate_jitter(min_val, max_val)
            self.assertTrue(min_val <= jitter_value <= max_val)

    def test_different_ranges(self):
        test_ranges = [
            (-1.0, 1.0),
            (0.0, 1.0),
            (-5.0, -2.0),
            (10.0, 15.0),
        ]

        for min_val, max_val in test_ranges:
            for _ in range(1000):
                jitter_value = generate_jitter(min_val, max_val)
                self.assertTrue(min_val <= jitter_value <= max_val)

    def test_random_distribution(self):
        min_val = 0.0
        max_val = 1.0
        num_samples = 1000
        jitter_values = [generate_jitter(min_val, max_val) for _ in range(num_samples)]

        # Check if the mean is approximately in the middle of the range
        mean_val = statistics.mean(jitter_values)
        self.assertTrue(min_val <= mean_val <= max_val)

        # Check if the standard deviation is reasonable for a uniform distribution
        # The expected standard deviation for a uniform distribution is (b-a)/sqrt(12)
        expected_std = (max_val - min_val) / (12**0.5)
        std_val = statistics.stdev(jitter_values)
        # Allow for some tolerance in the standard deviation
        self.assertTrue(expected_std * 0.5 <= std_val <= expected_std * 1.5)


if __name__ == '__main__':
    unittest.main()