# /mnt/e/genesis-system/core/rwl/attacks/path_traversal_tests.py
import os
import pytest

# Define the allowed base directory
ALLOWED_BASE_DIR = '/mnt/e/genesis-system/core/rwl/output'

# Create the directory if it doesn't exist
os.makedirs(ALLOWED_BASE_DIR, exist_ok=True)


def sanitize_path(base_dir, filename):
    """Sanitizes the filename to prevent path traversal attacks."""
    filepath = os.path.join(base_dir, filename)
    # Normalize the path to resolve any '..' components
    normalized_path = os.path.normpath(filepath)

    # Check if the normalized path is still within the allowed base directory
    if not normalized_path.startswith(base_dir):
        raise ValueError("Path traversal detected!")

    return normalized_path


def write_to_file(base_dir, filename, content):
    """Writes content to a file after sanitizing the filename."""
    try:
        filepath = sanitize_path(base_dir, filename)
        with open(filepath, 'w') as f:
            f.write(content)
        return True
    except ValueError as e:
        print(f"Error: {e}")
        return False


# Test Cases

def test_valid_file_path():
    filename = 'safe_file.txt'
    content = 'This is safe content.'
    assert write_to_file(ALLOWED_BASE_DIR, filename, content) == True
    filepath = os.path.join(ALLOWED_BASE_DIR, filename)
    assert os.path.exists(filepath)
    with open(filepath, 'r') as f:
        assert f.read() == content
    os.remove(filepath)


def test_path_traversal_up_one_level():
    filename = '../unsafe_file.txt'
    content = 'This is unsafe content.'
    assert write_to_file(ALLOWED_BASE_DIR, filename, content) == False
    filepath = os.path.join(ALLOWED_BASE_DIR, filename)
    # File should not be created
    assert not os.path.exists(os.path.join(ALLOWED_BASE_DIR, '../unsafe_file.txt'))


def test_path_traversal_multiple_levels_up():
    filename = '../../unsafe_file.txt'
    content = 'This is unsafe content.'
    assert write_to_file(ALLOWED_BASE_DIR, filename, content) == False
    filepath = os.path.join(ALLOWED_BASE_DIR, filename)
    assert not os.path.exists(os.path.join(ALLOWED_BASE_DIR, '../../unsafe_file.txt'))


def test_path_traversal_absolute_path():
    filename = '/etc/passwd'
    content = 'This is unsafe content.'
    assert write_to_file(ALLOWED_BASE_DIR, filename, content) == False
    assert not os.path.exists(os.path.join(ALLOWED_BASE_DIR, filename))


def test_path_traversal_with_normalization():
    filename = 'safe_dir/../unsafe_file.txt'
    content = 'This is unsafe content.'
    assert write_to_file(ALLOWED_BASE_DIR, filename, content) == False
    assert not os.path.exists(os.path.join(ALLOWED_BASE_DIR, 'safe_dir/../unsafe_file.txt'))


def test_empty_filename():
    filename = ''
    content = 'This is content.'
    assert write_to_file(ALLOWED_BASE_DIR, filename, content) == True # Empty filename creates a file in base dir.
    filepath = os.path.join(ALLOWED_BASE_DIR, filename)
    assert os.path.exists(filepath)
    with open(filepath, 'w') as f: #overwrite to empty file.
        f.write(content)
    os.remove(filepath)


if __name__ == "__main__":
    # Run the tests
    test_valid_file_path()
    test_path_traversal_up_one_level()
    test_path_traversal_multiple_levels_up()
    test_path_traversal_absolute_path()
    test_path_traversal_with_normalization()
    test_empty_filename()
    print("All tests completed.")
