import unittest
import unittest.mock
import os
import json
from litellm import usage_callback

class TestUsageCallback(unittest.TestCase):

    @unittest.mock.patch('requests.post')
    def test_log_usage_success(self, mock_post):
        # Mock the requests.post method to return a successful response
        mock_post.return_value.status_code = 200
        mock_post.return_value.json.return_value = {"message": "Usage logged"}

        # Set the GENESIS_COST_TRACKER_URL environment variable
        os.environ["GENESIS_COST_TRACKER_URL"] = "http://test.com/usage"

        # Call the log_usage function
        usage_callback.log_usage(model_name="test_model", prompt_tokens=10, completion_tokens=5, cost=0.001)

        # Assert that requests.post was called with the correct arguments
        mock_post.assert_called_once_with("http://test.com/usage", json={
            "model_name": "test_model",
            "prompt_tokens": 10,
            "completion_tokens": 5,
            "cost": 0.001
        })

    @unittest.mock.patch('requests.post')
    def test_log_usage_failure(self, mock_post):
        # Mock the requests.post method to raise an exception
        mock_post.side_effect = Exception("Request failed")

        # Set the GENESIS_COST_TRACKER_URL environment variable
        os.environ["GENESIS_COST_TRACKER_URL"] = "http://test.com/usage"

        # Call the log_usage function
        usage_callback.log_usage(model_name="test_model", prompt_tokens=10, completion_tokens=5, cost=0.001)

        # Assert that requests.post was called
        mock_post.assert_called_once()

    def test_log_usage_no_url(self):
        # Unset the GENESIS_COST_TRACKER_URL environment variable
        if "GENESIS_COST_TRACKER_URL" in os.environ:
            del os.environ["GENESIS_COST_TRACKER_URL"]

        # Call the log_usage function
        usage_callback.log_usage(model_name="test_model", prompt_tokens=10, completion_tokens=5, cost=0.001)

        # Assert that the function doesn't raise any errors
        self.assertTrue(True) # If it reaches here, test passed

if __name__ == '__main__':
    unittest.main()
