import unittest
from src.aiva import metrics
from prometheus_client import REGISTRY, CollectorRegistry, generate_latest
import re

class TestMetrics(unittest.TestCase):

    def setUp(self):
        # Create a fresh registry for each test to avoid conflicts.
        self.registry = CollectorRegistry(auto_describe=True)

    def test_downscale_timer_duration(self):
        metrics.aiva_downscale_timer_duration.set(300)
        metrics_string = metrics.generate_metrics_string(self.registry).decode('utf-8')
        self.assertIn('aiva_downscale_timer_duration 300.0', metrics_string)

    def test_downscale_timer_remaining(self):
        metrics.aiva_downscale_timer_remaining.set(150)
        metrics_string = metrics.generate_metrics_string(self.registry).decode('utf-8')
        self.assertIn('aiva_downscale_timer_remaining 150.0', metrics_string)

    def test_downscale_timer_active(self):
        metrics.aiva_downscale_timer_active.set(1)
        metrics_string = metrics.generate_metrics_string(self.registry).decode('utf-8')
        self.assertIn('aiva_downscale_timer_active 1.0', metrics_string)

        metrics.aiva_downscale_timer_active.set(0)
        metrics_string = metrics.generate_metrics_string(self.registry).decode('utf-8')
        self.assertIn('aiva_downscale_timer_active 0.0', metrics_string)

    def test_generate_metrics_string(self):
        metrics.aiva_downscale_timer_duration.set(600)
        metrics.aiva_downscale_timer_remaining.set(300)
        metrics.aiva_downscale_timer_active.set(1)

        metrics_string = metrics.generate_metrics_string(self.registry).decode('utf-8')

        self.assertIn('aiva_downscale_timer_duration 600.0', metrics_string)
        self.assertIn('aiva_downscale_timer_remaining 300.0', metrics_string)
        self.assertIn('aiva_downscale_timer_active 1.0', metrics_string)

if __name__ == '__main__':
    unittest.main()
