import pytest
import asyncio
from unittest.mock import patch, AsyncMock
from litellm import health_monitor

@pytest.mark.asyncio
async def test_check_model_health_healthy():
    mock_response = AsyncMock()
    mock_response.choices = [AsyncMock()]
    mock_response.choices[0].message.content = "I am healthy."

    with patch('litellm.acompletion', new_callable=AsyncMock) as mock_acompletion:
        mock_acompletion.return_value = mock_response
        result = await health_monitor.check_model_health("test_model")
        assert result['model'] == "test_model"
        assert result['status'] == "healthy"
        assert result['response'] == "I am healthy."

@pytest.mark.asyncio
async def test_check_model_health_unhealthy():
    with patch('litellm.acompletion', new_callable=AsyncMock) as mock_acompletion:
        mock_acompletion.side_effect = Exception("Model unavailable")
        result = await health_monitor.check_model_health("test_model")
        assert result['model'] == "test_model"
        assert result['status'] == "unhealthy"
        assert "Model unavailable" in result['error']

@pytest.mark.asyncio
async def test_check_model_health_no_response():
    mock_response = AsyncMock()
    mock_response.choices = []

    with patch('litellm.acompletion', new_callable=AsyncMock) as mock_acompletion:
        mock_acompletion.return_value = mock_response
        result = await health_monitor.check_model_health("test_model")
        assert result['model'] == "test_model"
        assert result['status'] == "unhealthy"
        assert result['error'] == "No response content"


@pytest.mark.asyncio
async def test_get_all_model_health():
    model_list = ["model1", "model2"]
    healthy_result = {"model": "model1", "status": "healthy", "response": "I am healthy."}  # Example healthy result
    unhealthy_result = {"model": "model2", "status": "unhealthy", "error": "Model failed"}  # Example unhealthy result

    with patch('litellm.health_monitor.check_model_health', side_effect=[asyncio.Future(), asyncio.Future()]) as mock_check_model_health:
        # Create two futures and set their results
        future1 = mock_check_model_health.side_effect[0]
        future2 = mock_check_model_health.side_effect[1]

        # Set results for the futures
        future1.set_result(healthy_result)
        future2.set_result(unhealthy_result)

        results = await health_monitor.get_all_model_health(model_list)
        assert len(results) == 2
        assert results[0] == healthy_result
        assert results[1] == unhealthy_result