import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class ContextRotDetector:
    """
    Utility to detect context rot in agent sessions based on iteration count.
    """

    def __init__(self, session_log_path=None, iteration_threshold=30):
        """
        Initializes the ContextRotDetector.

        Args:
            session_log_path (str, optional): Path to the session log file. If None, metric logic is used. Defaults to None.
            iteration_threshold (int, optional): Threshold for iteration count above which rot risk is considered high. Defaults to 30.
        """
        self.session_log_path = session_log_path
        self.iteration_threshold = iteration_threshold
        self.iteration_count = 0

    def analyze_session(self):
        """
        Analyzes the agent session to detect context rot.

        Returns:
            str: "ROT_RISK: HIGH" if iteration count exceeds the threshold,
                 "ROT_RISK: LOW" if it's a fresh session or iteration count is low.
                 "ROT_RISK: ERROR" if there are issues processing the logs
        """

        if self.session_log_path:
            try:
                self.iteration_count = self._extract_iteration_count_from_log()
            except FileNotFoundError:
                logging.error(f"Session log file not found: {self.session_log_path}")
                return "ROT_RISK: ERROR - Log file not found"
            except Exception as e:
                logging.error(f"Error processing session log: {e}")
                return "ROT_RISK: ERROR - Log processing failed"

        if self.iteration_count > self.iteration_threshold:
            logging.warning(f"Iteration count is {self.iteration_count}, exceeding threshold of {self.iteration_threshold}. Context rot risk: HIGH")
            return "ROT_RISK: HIGH"
        else:
            logging.info(f"Iteration count is {self.iteration_count}, below threshold of {self.iteration_threshold}. Context rot risk: LOW")
            return "ROT_RISK: LOW"

    def _extract_iteration_count_from_log(self):
        """
        Extracts the iteration count from the session log file.

        This is a placeholder and needs to be implemented based on the actual log format.
        Example: searching for a line like "Iteration: 42" and extracting the number.

        Returns:
            int: The iteration count.

        Raises:
            FileNotFoundError: If the session log file is not found.
            Exception: If there are issues processing the log.
        """
        try:
            with open(self.session_log_path, 'r') as f:
                for line in f:
                    if "Iteration:" in line:
                        try:
                            # Example: Assuming the line is "Iteration: 42"
                            count = int(line.split(":")[1].strip())
                            return count
                        except ValueError:
                            logging.warning(f"Could not parse iteration count from line: {line}")
                            continue  # Try the next line
        except FileNotFoundError:
            raise FileNotFoundError(f"Session log file not found: {self.session_log_path}")
        except Exception as e:
            raise Exception(f"Error reading session log: {e}")
        
        logging.info("Iteration count not found in log, assuming fresh session.")
        return 0  # Assume fresh session if iteration count not found


if __name__ == '__main__':
    # Example usage with a dummy session log
    # You can replace 'dummy_session.log' with a real log file path.
    detector = ContextRotDetector(session_log_path='dummy_session.log')
    rot_risk = detector.analyze_session()
    print(f"Context Rot Risk: {rot_risk}")

    # Create a dummy log file for testing
    with open('dummy_session.log', 'w') as f:
        f.write("Some log entries...\n")
        f.write("Iteration: 15\n")
        f.write("More log entries...\n")
        f.write("Iteration: 35\n") # Simulating a high iteration count
        f.write("Even more log entries...\n")