import pytest
import os
import inspect
import importlib.util
import logging
from typing import List, Callable, Any

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def generate_test_function_name(module_name: str, class_name: str, method_name: str) -> str:
    """
    Generates a test function name based on the module, class, and method names.

    Args:
        module_name: The name of the module.
        class_name: The name of the class (can be empty if it's a module-level function).
        method_name: The name of the method.

    Returns:
        A string representing the test function name.
    """
    module_name = module_name.replace('.', '_')  # Replace dots with underscores
    if class_name:
        return f"test_{module_name}_{class_name}_{method_name}"
    else:
        return f"test_{module_name}_{method_name}"


def is_public_method(method: Callable) -> bool:
    """
    Checks if a method is a public method (not private or protected).

    Args:
        method: The method to check.

    Returns:
        True if the method is public, False otherwise.
    """
    return not method.__name__.startswith('_')


def get_public_methods(module: Any, class_name: str = None) -> List[Callable]:
    """
    Gets all public methods of a module or a class within a module.

    Args:
        module: The module or class to inspect.
        class_name: The name of the class to inspect (if applicable).

    Returns:
        A list of public methods.
    """
    methods = []
    if class_name:
        try:
            cls = getattr(module, class_name)
            for name, member in inspect.getmembers(cls, predicate=inspect.isfunction):
                if is_public_method(member):
                    methods.append(member)
        except AttributeError:
            logging.warning(f"Class {class_name} not found in module {module.__name__}")
    else:
        for name, member in inspect.getmembers(module, predicate=inspect.isfunction):
            if is_public_method(member):
                methods.append(member)
    return methods


def generate_test_code(module_path: str) -> str:
    """
    Generates pytest test code for a given Python module.

    Args:
        module_path: The path to the Python module.

    Returns:
        A string containing the generated test code.
    """
    try:
        # Load the module dynamically
        spec = importlib.util.spec_from_file_location("module.name", module_path)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)

        module_name = module.__name__
        test_code = f"""
import pytest
from {module_name} import *  # Import everything from the module
"""

        # Iterate through classes in the module
        for name, obj in inspect.getmembers(module, inspect.isclass):
            public_methods = get_public_methods(module, name)
            for method in public_methods:
                test_name = generate_test_function_name(module_name, name, method.__name__)
                test_code += f"""

def {test_name}():
    \"\"\"
    Test case for {name}.{method.__name__}.
    \"\"\"
    # Happy path
    try:
        # Add assertions here for the happy path
        assert True  # Replace with actual assertions
    except Exception as e:
        pytest.fail(f"Happy path failed: {{e}}")

    # Edge cases
    try:
        # Add assertions here for edge cases
        assert True  # Replace with actual assertions
    except Exception as e:
        pytest.fail(f"Edge case failed: {{e}}")

    # Error cases
    try:
        # Add assertions here for error cases (e.g., check for exceptions)
        with pytest.raises(Exception):  # Replace Exception with the expected exception
            pass # Replace with code that should raise an exception
    except Exception as e:
        pytest.fail(f"Error case failed: {{e}}")
"""

        # Iterate through module-level functions
        public_methods = get_public_methods(module)
        for method in public_methods:
            test_name = generate_test_function_name(module_name, "", method.__name__)
            test_code += f"""

def {test_name}():
    \"\"\"
    Test case for {method.__name__}.
    \"\"\"
    # Happy path
    try:
        # Add assertions here for the happy path
        assert True  # Replace with actual assertions
    except Exception as e:
        pytest.fail(f"Happy path failed: {{e}}")

    # Edge cases
    try:
        # Add assertions here for edge cases
        assert True  # Replace with actual assertions
    except Exception as e:
        pytest.fail(f"Edge case failed: {{e}}")

    # Error cases
    try:
        # Add assertions here for error cases (e.g., check for exceptions)
        with pytest.raises(Exception):  # Replace Exception with the expected exception
            pass # Replace with code that should raise an exception
    except Exception as e:
        pytest.fail(f"Error case failed: {{e}}")
"""

        return test_code

    except Exception as e:
        logging.error(f"Error generating test code for {module_path}: {e}")
        return ""


def generate_all_tests(base_path: str):
    """
    Generates and saves test files for all Python modules under the given base path.

    Args:
        base_path: The base directory to search for Python modules.
    """
    for root, _, files in os.walk(base_path):
        for file in files:
            if file.endswith(".py") and file != "__init__.py":
                module_path = os.path.join(root, file)
                test_code = generate_test_code(module_path)
                if test_code:
                    test_file_path = os.path.join(base_path, "tests", f"test_{file}")
                    test_file_path = test_file_path.replace(".py", "_test.py") # Renaming to avoid confusion

                    # Ensure the 'tests' directory exists
                    tests_dir = os.path.join(base_path, "tests")
                    if not os.path.exists(tests_dir):
                        os.makedirs(tests_dir)

                    try:
                        with open(test_file_path, "w") as f:
                            f.write(test_code)
                        logging.info(f"Test file generated: {test_file_path}")
                    except Exception as e:
                        logging.error(f"Error writing to file {test_file_path}: {e}")

if __name__ == '__main__':
    # Example usage: Replace with the actual base path of your project
    project_base_path = "/mnt/e/genesis-system"
    generate_all_tests(project_base_path)
    print("Test generation complete. Check the 'tests' directory.")
