import matplotlib.pyplot as plt
import io
import base64

class GraphGenerator:

    def __init__(self):
        pass

    def generate_line_graph(self, data, title, x_label, y_label):
        """Generates a line graph from the given data.

        Args:
            data (dict): A dictionary where keys are labels and values are lists of data points.
                           Example: {"Metric 1": [1, 2, 3], "Metric 2": [4, 5, 6]}
            title (str): The title of the graph.
            x_label (str): The label for the x-axis.
            y_label (str): The label for the y-axis.

        Returns:
            str: A base64 encoded PNG image of the graph.
        """
        plt.figure(figsize=(10, 6))  # Adjust figure size as needed

        for label, values in data.items():
            plt.plot(values, label=label)

        plt.title(title)
        plt.xlabel(x_label)
        plt.ylabel(y_label)
        plt.legend()
        plt.grid(True)

        # Save the graph to a buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        plt.close()

        # Encode the buffer to base64
        data = base64.b64encode(buf.getbuffer()).decode('ascii')

        return f'data:image/png;base64,{data}'

    def generate_bar_graph(self, data, title, x_label, y_label):
        """Generates a bar graph from the given data.

        Args:
            data (dict): A dictionary where keys are labels and values are single data points.
                           Example: {"Metric 1": 10, "Metric 2": 20}
            title (str): The title of the graph.
            x_label (str): The label for the x-axis.
            y_label (str): The label for the y-axis.

        Returns:
            str: A base64 encoded PNG image of the graph.
        """
        labels = list(data.keys())
        values = list(data.values())

        plt.figure(figsize=(10, 6))  # Adjust figure size as needed
        plt.bar(labels, values)

        plt.title(title)
        plt.xlabel(x_label)
        plt.ylabel(y_label)
        plt.grid(True)

        # Save the graph to a buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        plt.close()

        # Encode the buffer to base64
        data = base64.b64encode(buf.getbuffer()).decode('ascii')

        return f'data:image/png;base64,{data}'

# Example Usage (This would be in a separate reporting module)
# if __name__ == '__main__':
#     graph_generator = GraphGenerator()
#
#     # Line graph example
#     line_data = {"CPU Usage": [10, 20, 30, 40, 50], "Memory Usage": [5, 15, 25, 35, 45]}
#     line_graph_image = graph_generator.generate_line_graph(line_data, "System Resource Usage", "Time", "Percentage")
#     print(f'<img src="{line_graph_image}" alt="Line Graph">')
#
#     # Bar graph example
#     bar_data = {"Success": 90, "Failure": 10}
#     bar_graph_image = graph_generator.generate_bar_graph(bar_data, "Test Results", "Result", "Count")
#     print(f'<img src="{bar_graph_image}" alt="Bar Graph">')