import os
import os.path as osp
import logging

logger = logging.getLogger(__name__)


def safe_join(base, filename):
    """
    Safely joins a base directory and a filename to prevent path traversal vulnerabilities.

    Args:
        base (str): The base directory. This should be an absolute path.
        filename (str): The filename to join to the base directory.

    Returns:
        str: The absolute path of the joined filename, or None if the filename attempts
             to traverse outside the base directory.

    Raises:
        TypeError: if base or filename are not strings
        ValueError: if base is not an absolute path
    """
    if not isinstance(base, str) or not isinstance(filename, str):
        raise TypeError("base and filename must be strings")

    if not osp.isabs(base):
        raise ValueError("base must be an absolute path")
        
    try:
        # Normalize both paths
        base = osp.abspath(base)
        filepath = osp.abspath(osp.join(base, filename))

        # Check if the filepath starts with the base directory
        if filepath.startswith(base):
            return filepath
        else:
            logger.warning(f"Path traversal attempt detected. base={base}, filename={filename}, resolved path={filepath}")
            return None  # Path traversal detected
    except Exception as e:
        logger.exception(f"An error occurred while joining paths: {e}")
        return None


def is_safe_path(base, path_to_check):
    """
    Checks if the provided path is within the allowed base directory.

    Args:
        base (str): The allowed base directory.
        path_to_check (str): The path to validate.

    Returns:
        bool: True if the path is within the base directory, False otherwise.

     Raises:
        TypeError: if base or path_to_check are not strings
        ValueError: if base is not an absolute path
    """
    if not isinstance(base, str) or not isinstance(path_to_check, str):
        raise TypeError("base and path_to_check must be strings")

    if not osp.isabs(base):
        raise ValueError("base must be an absolute path")
        
    try:
        base = osp.abspath(base)
        path_to_check = osp.abspath(path_to_check)
        return path_to_check.startswith(base)
    except Exception as e:
        logger.exception(f"An error occurred while checking path safety: {e}")
        return False


if __name__ == '__main__':
    # Example Usage (for testing purposes)
    base_dir = "/tmp/safe_dir"  # Replace with an appropriate safe directory for testing
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)

    print(f"Base directory: {base_dir}")

    safe_file = "my_file.txt"
    unsafe_file = "../../../etc/passwd"

    safe_path = safe_join(base_dir, safe_file)
    unsafe_path = safe_join(base_dir, unsafe_file)

    print(f"Safe file path: {safe_path}")
    print(f"Unsafe file path: {unsafe_path}")

    print(f"Is {safe_path} safe? {is_safe_path(base_dir, safe_path)}")
    print(f"Is {unsafe_path} safe? {is_safe_path(base_dir, unsafe_path)}")

    #Example with an absolute path filename
    absolute_file = "/tmp/another_file.txt"
    absolute_path_result = safe_join(base_dir, absolute_file)

    print(f"Absolute path result: {absolute_path_result}")