import os
import time
import psutil
import logging
import argparse
from typing import Optional, Tuple

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class MemoryDetector:
    """
    Detects memory leaks in a given process.
    """

    def __init__(self, pid: int, interval: int = 60, threshold: float = 0.1) -> None:
        """
        Initializes the MemoryDetector.

        Args:
            pid: The process ID to monitor.
            interval: The sampling interval in seconds.
            threshold: The percentage increase in memory usage that triggers an alert.
        """
        self.pid = pid
        self.interval = interval
        self.threshold = threshold
        self.process = self._get_process()
        self.initial_memory = self._get_memory_usage()
        self.last_memory = self.initial_memory
        self.start_time = time.time()

    def _get_process(self) -> psutil.Process:
        """
        Gets the process object from the PID.

        Returns:
            The psutil.Process object.

        Raises:
            ValueError: If the process does not exist.
        """
        try:
            process = psutil.Process(self.pid)
            return process
        except psutil.NoSuchProcess:
            logging.error(f"Process with PID {self.pid} not found.")
            raise ValueError(f"Process with PID {self.pid} not found.")
        except Exception as e:
            logging.error(f"Error getting process: {e}")
            raise

    def _get_memory_usage(self) -> float:
        """
        Gets the current memory usage of the process in MB.

        Returns:
            The memory usage in MB.
        """
        try:
            memory_info = self.process.memory_info()
            memory_usage_mb = memory_info.rss / (1024 * 1024)  # Convert bytes to MB
            return memory_usage_mb
        except psutil.NoSuchProcess:
            logging.error(f"Process with PID {self.pid} no longer exists.")
            raise ValueError(f"Process with PID {self.pid} no longer exists.")
        except Exception as e:
            logging.error(f"Error getting memory usage: {e}")
            return 0.0

    def detect_leak(self) -> Optional[Tuple[float, float]]:
        """
        Detects memory leaks by sampling memory usage at regular intervals.

        Returns:
            A tuple containing the current memory usage and the percentage increase
            since the initial measurement, or None if no leak is detected.
        """
        try:
            current_memory = self._get_memory_usage()
            increase = (current_memory - self.initial_memory) / self.initial_memory
            percentage_increase = increase * 100

            logging.info(f"Current memory usage: {current_memory:.2f} MB, Increase: {percentage_increase:.2f}%")

            if percentage_increase > self.threshold:
                logging.warning(f"Memory growth exceeds threshold of {self.threshold*100:.2f}%! Current increase: {percentage_increase:.2f}%")
                return current_memory, percentage_increase

            self.last_memory = current_memory
            return None

        except ValueError as e:
            logging.error(f"Error during leak detection: {e}")
            return None

    def run(self) -> None:
        """
        Runs the memory leak detector indefinitely.
        """
        logging.info(f"Starting memory leak detection for PID {self.pid}...")
        try:
            while True:
                leak_detected = self.detect_leak()
                if leak_detected:
                    current_memory, percentage_increase = leak_detected
                    logging.critical(f"POSSIBLE MEMORY LEAK DETECTED! Current memory: {current_memory:.2f} MB, Increase: {percentage_increase:.2f}%")

                time.sleep(self.interval)
        except KeyboardInterrupt:
            logging.info("Memory leak detection stopped.")
        except Exception as e:
            logging.error(f"An unexpected error occurred: {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Detect memory leaks in a process.")
    parser.add_argument("pid", type=int, help="The process ID to monitor.")
    parser.add_argument("--interval", type=int, default=60, help="The sampling interval in seconds (default: 60).")
    parser.add_argument("--threshold", type=float, default=0.1, help="The percentage increase threshold (default: 0.1).")

    args = parser.parse_args()

    try:
        detector = MemoryDetector(args.pid, args.interval, args.threshold)
        detector.run()
    except ValueError as e:
        logging.error(e)
    except Exception as e:
        logging.error(f"An error occurred: {e}")
