import redis
import os
import time
import logging
from typing import Optional, List

# Configure logging
logging.basicConfig(filename='/mnt/e/genesis-system/redis_optimization_log.txt', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')


def connect_redis(host: str, port: int, db: int, password: Optional[str] = None) -> redis.Redis:
    """
    Connects to a Redis server.

    Args:
        host: The Redis host.
        port: The Redis port.
        db: The Redis database number.
        password: The Redis password (optional).

    Returns:
        A Redis client object.

    Raises:
        redis.exceptions.ConnectionError: If the connection fails.
    """
    try:
        r = redis.Redis(host=host, port=port, db=db, password=password, decode_responses=True)
        r.ping()  # Check the connection
        logging.info(f"Successfully connected to Redis at {host}:{port}/{db}")
        return r
    except redis.exceptions.ConnectionError as e:
        logging.error(f"Failed to connect to Redis: {e}")
        raise


def analyze_memory_usage(r: redis.Redis) -> dict:
    """
    Analyzes Redis memory usage.

    Args:
        r: The Redis client object.

    Returns:
        A dictionary containing memory usage statistics.
    """
    try:
        info = r.info('memory')
        logging.info(f"Redis memory usage: {info}")
        return info
    except Exception as e:
        logging.error(f"Failed to analyze memory usage: {e}")
        return {}


def identify_stale_keys(r: redis.Redis, ttl_threshold: int) -> List[str]:
    """
    Identifies stale keys based on a TTL threshold.

    Args:
        r: The Redis client object.
        ttl_threshold: The TTL threshold in seconds.

    Returns:
        A list of stale keys.
    """
    stale_keys = []
    try:
        for key in r.scan_iter():
            ttl = r.ttl(key)
            if ttl is not None and ttl >= 0 and ttl < ttl_threshold:
                stale_keys.append(key)
                logging.debug(f"Key {key} has TTL {ttl} and is considered stale.")
            elif ttl is None:
                logging.debug(f"Key {key} has no expiry set.")

    except Exception as e:
        logging.error(f"Failed to identify stale keys: {e}")
    return stale_keys


def remove_stale_keys(r: redis.Redis, keys: List[str]) -> None:
    """
    Removes stale keys from Redis.

    Args:
        r: The Redis client object.
        keys: A list of keys to remove.
    """
    try:
        if keys:
            deleted_count = r.delete(*keys)
            logging.info(f"Deleted {deleted_count} stale keys.")
        else:
            logging.info("No stale keys found.")
    except Exception as e:
        logging.error(f"Failed to remove stale keys: {e}")


def main():
    """
    Main function to orchestrate Redis memory optimization.
    """
    redis_host = os.environ.get("REDIS_HOST", "localhost")
    redis_port = int(os.environ.get("REDIS_PORT", 6379))
    redis_db = int(os.environ.get("REDIS_DB", 0))
    redis_password = os.environ.get("REDIS_PASSWORD")
    ttl_threshold = int(os.environ.get("TTL_THRESHOLD", 3600))  # 1 hour default

    try:
        r = connect_redis(redis_host, redis_port, redis_db, redis_password)
        analyze_memory_usage(r)
        stale_keys = identify_stale_keys(r, ttl_threshold)
        remove_stale_keys(r, stale_keys)
    except Exception as e:
        logging.error(f"An error occurred: {e}")


if __name__ == "__main__":
    main()