import unittest
from unittest.mock import patch
import matplotlib.pyplot as plt
import io
import base64
from .graph_generator import GraphGenerator  # Relative import

class TestGraphGenerator(unittest.TestCase):

    def setUp(self):
        self.graph_generator = GraphGenerator()

    def test_generate_line_graph(self):
        data = {"Metric 1": [1, 2, 3], "Metric 2": [4, 5, 6]}
        title = "Test Line Graph"
        x_label = "X-Axis"
        y_label = "Y-Axis"

        image_data = self.graph_generator.generate_line_graph(data, title, x_label, y_label)

        self.assertTrue(image_data.startswith("data:image/png;base64,"))

        #Attempt to decode the base64 data, if it fails, the image is likely corrupted
        try:
            base64.b64decode(image_data.split(',')[1])
        except Exception as e:
            self.fail(f"Base64 decoding failed: {e}")

    def test_generate_bar_graph(self):
        data = {"Metric 1": 10, "Metric 2": 20}
        title = "Test Bar Graph"
        x_label = "X-Axis"
        y_label = "Y-Axis"

        image_data = self.graph_generator.generate_bar_graph(data, title, x_label, y_label)

        self.assertTrue(image_data.startswith("data:image/png;base64,"))

        #Attempt to decode the base64 data, if it fails, the image is likely corrupted
        try:
            base64.b64decode(image_data.split(',')[1])
        except Exception as e:
            self.fail(f"Base64 decoding failed: {e}")

    #Mocking test to avoid actual plotting
    @patch('matplotlib.pyplot.savefig')
    def test_generate_line_graph_mocked(self, mock_savefig):
        data = {"Metric 1": [1, 2, 3], "Metric 2": [4, 5, 6]}
        title = "Test Line Graph"
        x_label = "X-Axis"
        y_label = "Y-Axis"

        self.graph_generator.generate_line_graph(data, title, x_label, y_label)

        mock_savefig.assert_called_once()

    @patch('matplotlib.pyplot.savefig')
    def test_generate_bar_graph_mocked(self, mock_savefig):
        data = {"Metric 1": 10, "Metric 2": 20}
        title = "Test Bar Graph"
        x_label = "X-Axis"
        y_label = "Y-Axis"

        self.graph_generator.generate_bar_graph(data, title, x_label, y_label)

        mock_savefig.assert_called_once()

if __name__ == '__main__':
    unittest.main()