import unittest
from unittest.mock import patch
import time

from rwl.metrics.collector import MetricsCollector
from rwl.metrics.aggregator import MetricsAggregator

class TestMetricsCollector(unittest.TestCase):

    def setUp(self):
        self.collector = MetricsCollector()
        self.aggregator = MetricsAggregator()

    def test_collect_success(self):
        def my_function(a, b):
            return a + b

        result = self.collector.collect(my_function, 1, 2)
        self.assertEqual(result, 3)

        metrics = self.collector.get_metrics()
        self.assertIn('my_function', metrics)
        self.assertEqual(len(metrics['my_function']), 1)
        self.assertTrue(metrics['my_function'][0]['success'])
        self.assertEqual(metrics['my_function'][0]['retry_count'], 0)
        self.assertIsInstance(metrics['my_function'][0]['execution_time'], float)
        self.assertEqual(metrics['my_function'][0]['result'], '3')
        self.assertIsNone(metrics['my_function'][0]['exception'])

    def test_collect_failure(self):
        def my_function():
            raise ValueError("Something went wrong")

        with self.assertRaises(ValueError):
            self.collector.collect(my_function)

        metrics = self.collector.get_metrics()
        self.assertIn('my_function', metrics)
        self.assertEqual(len(metrics['my_function']), 1)
        self.assertFalse(metrics['my_function'][0]['success'])
        self.assertEqual(metrics['my_function'][0]['retry_count'], 0)
        self.assertIsInstance(metrics['my_function'][0]['execution_time'], float)
        self.assertIn('Something went wrong', str(metrics['my_function'][0]['exception']))

    def test_update_retry_count(self):
        def my_function():
            return "Hello"

        self.collector.collect(my_function)
        self.collector.update_retry_count('my_function', 3)

        metrics = self.collector.get_metrics()
        self.assertEqual(metrics['my_function'][0]['retry_count'], 3)

    def test_add_quality_score(self):
        def my_function():
            return "World"

        self.collector.collect(my_function)
        self.collector.add_quality_score('my_function', 0.95)

        metrics = self.collector.get_metrics()
        self.assertEqual(metrics['my_function'][0]['quality_score'], 0.95)

    def test_aggregate_metrics(self):
        def func1(): return 1
        def func2(): raise ValueError

        self.collector.collect(func1)
        self.collector.collect(func2)
        self.collector.update_retry_count('func2', 2)
        self.collector.add_quality_score('func1', 0.8)

        aggregated_metrics = self.aggregator.aggregate(self.collector.get_metrics())

        self.assertIn('func1', aggregated_metrics)
        self.assertIn('func2', aggregated_metrics)

        self.assertEqual(aggregated_metrics['func1']['num_calls'], 1)
        self.assertEqual(aggregated_metrics['func1']['success_count'], 1)
        self.assertEqual(aggregated_metrics['func1']['failure_count'], 0)
        self.assertEqual(aggregated_metrics['func1']['total_retries'], 0)
        self.assertEqual(aggregated_metrics['func1']['average_quality_score'], 0.8)

        self.assertEqual(aggregated_metrics['func2']['num_calls'], 1)
        self.assertEqual(aggregated_metrics['func2']['success_count'], 0)
        self.assertEqual(aggregated_metrics['func2']['failure_count'], 1)
        self.assertEqual(aggregated_metrics['func2']['total_retries'], 2)
        self.assertIsNone(aggregated_metrics['func2'].get('average_quality_score'))

if __name__ == '__main__':
    unittest.main()