import redis
import json
import threading
import time
import logging
import atexit

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class SessionPersister:
    """
    Persists session state to Redis for RWL system.  Saves state every 60 seconds.
    Restores state on restart. Includes: active_prds, completed_stories, pending_stories, total_cost.
    Enables graceful shutdown/resume.
    """

    def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0, persist_interval=60):
        """
        Initializes the SessionPersister.

        Args:
            redis_host (str): The Redis host.
            redis_port (int): The Redis port.
            redis_db (int): The Redis database number.
            persist_interval (int): The interval in seconds to persist the session state.
        """
        self.redis_host = redis_host
        self.redis_port = redis_port
        self.redis_db = redis_db
        self.persist_interval = persist_interval
        self.redis_client = redis.Redis(host=self.redis_host, port=self.redis_port, db=self.redis_db)
        self.session_key = 'genesis:session_state'
        self.active_prds = []
        self.completed_stories = []
        self.pending_stories = []
        self.total_cost = 0.0
        self.persistence_thread = None
        self.stop_event = threading.Event()

    def start(self):
        """
        Starts the session persistence background thread.
        """
        self._load_state()
        self.persistence_thread = threading.Thread(target=self._persist_state_loop, daemon=True)
        self.persistence_thread.start()
        logging.info("Session persistence started.")
        atexit.register(self.stop) # Ensure stop is called on exit

    def stop(self):
        """
        Stops the session persistence background thread.
        """
        self.stop_event.set()
        if self.persistence_thread and self.persistence_thread.is_alive():
            self.persistence_thread.join()
        self._save_state() # Save one last time before exiting
        logging.info("Session persistence stopped.")

    def _persist_state_loop(self):
        """
        The main loop for the session persistence thread.
        """
        while not self.stop_event.is_set():
            try:
                self._save_state()
                time.sleep(self.persist_interval)
            except Exception as e:
                logging.error(f"Error during persistence: {e}")

    def _save_state(self):
        """
        Saves the current session state to Redis.
        """
        state = {
            'active_prds': self.active_prds,
            'completed_stories': self.completed_stories,
            'pending_stories': self.pending_stories,
            'total_cost': self.total_cost
        }
        try:
            self.redis_client.set(self.session_key, json.dumps(state))
            logging.debug("Session state saved to Redis.")
        except redis.exceptions.RedisError as e:
            logging.error(f"Error saving state to Redis: {e}")

    def _load_state(self):
        """
        Loads the session state from Redis.
        """
        try:
            state_json = self.redis_client.get(self.session_key)
            if state_json:
                state = json.loads(state_json.decode('utf-8'))
                self.active_prds = state.get('active_prds', [])
                self.completed_stories = state.get('completed_stories', [])
                self.pending_stories = state.get('pending_stories', [])
                self.total_cost = state.get('total_cost', 0.0)
                logging.info("Session state loaded from Redis.")
            else:
                logging.info("No existing session state found in Redis.")
        except redis.exceptions.RedisError as e:
            logging.error(f"Error loading state from Redis: {e}")
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from Redis: {e}")

    def update_active_prds(self, active_prds):
        """
        Updates the active PRDs.

        Args:
            active_prds (list): A list of active PRDs.
        """
        self.active_prds = active_prds

    def add_completed_story(self, story):
        """
        Adds a completed story.

        Args:
            story (str): The completed story.
        """
        self.completed_stories.append(story)

    def add_pending_story(self, story):
        """
        Adds a pending story.

        Args:
            story (str): The pending story.
        """
        self.pending_stories.append(story)

    def update_total_cost(self, total_cost):
        """
        Updates the total cost.

        Args:
            total_cost (float): The total cost.
        """
        self.total_cost = total_cost

    def get_state(self):
        """
        Returns the current session state.

        Returns:
            dict: The current session state.
        """
        return {
            'active_prds': self.active_prds,
            'completed_stories': self.completed_stories,
            'pending_stories': self.pending_stories,
            'total_cost': self.total_cost
        }


if __name__ == '__main__':
    # Example usage
    persister = SessionPersister()
    persister.start()

    # Simulate some activity
    persister.update_active_prds(['PRD-001', 'PRD-002'])
    persister.add_completed_story('US-001')
    persister.add_pending_story('US-002')
    persister.update_total_cost(123.45)

    time.sleep(5)  # Let the persister save the state a few times

    # Simulate a restart
    persister.stop()
    del persister

    persister = SessionPersister()
    persister.start()

    print("Loaded state after restart:", persister.get_state())

    time.sleep(3)
    persister.stop()