import re
import os
import logging
import psycopg2
import redis
from typing import List, Tuple

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class SecurityScanner:
    """
    A security scanner to identify potential vulnerabilities in the codebase.
    """

    def __init__(self, codebase_path: str):
        """
        Initializes the SecurityScanner with the path to the codebase.

        Args:
            codebase_path (str): The path to the root directory of the codebase.
        """
        self.codebase_path = codebase_path
        self.vulnerabilities: List[Tuple[str, str, str]] = [] # (file_path, vulnerability_type, description)

    def scan(self) -> List[Tuple[str, str, str]]:
        """
        Scans the codebase for security vulnerabilities.

        Returns:
            List[Tuple[str, str, str]]: A list of tuples, where each tuple represents a vulnerability
                                        and contains the file path, vulnerability type, and description.
        """
        logging.info(f"Starting security scan of codebase: {self.codebase_path}")
        self.scan_for_hardcoded_secrets()
        self.scan_for_sql_injection()
        self.scan_for_xss_patterns()
        self.validate_input_sanitization()
        self.validate_auth_checks()
        logging.info("Security scan completed.")
        return self.vulnerabilities

    def _read_file(self, file_path: str) -> str:
        """
        Reads the content of a file.

        Args:
            file_path (str): The path to the file.

        Returns:
            str: The content of the file.  Returns an empty string if the file can't be read.
        """
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                return f.read()
        except Exception as e:
            logging.error(f"Error reading file {file_path}: {e}")
            return ""

    def scan_for_hardcoded_secrets(self):
        """
        Scans for hardcoded secrets like passwords, API keys, etc.
        """
        logging.info("Scanning for hardcoded secrets...")
        patterns = [
            r"password\s*=\s*['\"](.*?)['\"]",
            r"api_key\s*=\s*['\"](.*?)['\"]",
            r"secret_key\s*=\s*['\"](.*?)['\"]",
            r"access_token\s*=\s*['\"](.*?)['\"]",
            r"DATABASE_URL\s*=\s*['\"](.*?)['\"]",
            r"REDIS_URL\s*=\s*['\"](.*?)['\"]",
        ]

        for root, _, files in os.walk(self.codebase_path):
            for file in files:
                if file.endswith(('.py', '.js', '.env', '.txt')):  # Scan relevant file types
                    file_path = os.path.join(root, file)
                    content = self._read_file(file_path)

                    for pattern in patterns:
                        matches = re.findall(pattern, content, re.IGNORECASE)
                        for match in matches:
                            self.vulnerabilities.append(
                                (file_path, "Hardcoded Secret", f"Potential hardcoded secret found: {match}")
                            )
                            logging.warning(f"Hardcoded secret found in {file_path}: {match}")

    def scan_for_sql_injection(self):
        """
        Scans for potential SQL injection vulnerabilities.
        """
        logging.info("Scanning for SQL injection vulnerabilities...")
        patterns = [
            r"cursor\.execute\(.*?\+.*?\)",  # Basic string concatenation
            r"cursor\.execute\(f['\"](.*?){\s*.*?['\"]", # f-string usage
            r"subprocess\.run\(.*?sql.*?\)", # subprocess execution with SQL
        ]

        for root, _, files in os.walk(self.codebase_path):
            for file in files:
                if file.endswith('.py'):
                    file_path = os.path.join(root, file)
                    content = self._read_file(file_path)

                    for pattern in patterns:
                        matches = re.findall(pattern, content)
                        for match in matches:
                            self.vulnerabilities.append(
                                (file_path, "SQL Injection", f"Potential SQL injection vulnerability: {match}")
                            )
                            logging.warning(f"SQL injection vulnerability found in {file_path}: {match}")

    def scan_for_xss_patterns(self):
        """
        Scans for potential Cross-Site Scripting (XSS) vulnerabilities.
        """
        logging.info("Scanning for XSS vulnerabilities...")
        patterns = [
            r"render_template_string\(.*?\+.*?\)", # Jinja2 template string concatenation
            r"unescape\(.*?\)", # Unescape functions.
            r"innerHTML\s*=", # Directly setting innerHTML
        ]

        for root, _, files in os.walk(self.codebase_path):
            for file in files:
                if file.endswith(('.py', '.js', '.html')):
                    file_path = os.path.join(root, file)
                    content = self._read_file(file_path)

                    for pattern in patterns:
                        matches = re.findall(pattern, content)
                        for match in matches:
                            self.vulnerabilities.append(
                                (file_path, "XSS Vulnerability", f"Potential XSS vulnerability: {match}")
                            )
                            logging.warning(f"XSS vulnerability found in {file_path}: {match}")

    def validate_input_sanitization(self):
        """
        Validates the presence of input sanitization techniques.  This is a basic placeholder; 
        more sophisticated analysis would be needed in a real-world scenario.
        """
        logging.info("Validating input sanitization...")
        sanitization_functions = ["escape", "sanitize", "strip", "filter"]
        for root, _, files in os.walk(self.codebase_path):
            for file in files:
                if file.endswith('.py'):
                    file_path = os.path.join(root, file)
                    content = self._read_file(file_path)
                    # Look for common sanitization functions being called
                    found = False
                    for func in sanitization_functions:
                        if func in content:
                            found = True
                            break
                    if not found:
                        self.vulnerabilities.append(
                            (file_path, "Input Sanitization", "Missing input sanitization for user-provided input.")
                        )
                        logging.warning(f"Missing input sanitization in {file_path}")

    def validate_auth_checks(self):
         """
         Validates the presence of authentication checks. This is a basic placeholder;
         more sophisticated analysis would be needed in a real-world scenario.
         """
         logging.info("Validating authentication checks...")
         auth_keywords = ["login_required", "is_authenticated", "has_permission"]
         for root, _, files in os.walk(self.codebase_path):
             for file in files:
                 if file.endswith('.py'):
                     file_path = os.path.join(root, file)
                     content = self._read_file(file_path)

                     # Look for common authentication keywords being used
                     found = False
                     for keyword in auth_keywords:
                         if keyword in content:
                             found = True
                             break
                     if not found:
                         self.vulnerabilities.append(
                             (file_path, "Authentication Checks", "Missing authentication checks for protected resources.")
                         )
                         logging.warning(f"Missing authentication checks in {file_path}")


if __name__ == '__main__':
    # Example Usage: Replace with the actual path to your codebase
    codebase_path = "/mnt/e/genesis-system"  # Or wherever your code is
    scanner = SecurityScanner(codebase_path)
    vulnerabilities = scanner.scan()

    if vulnerabilities:
        print("Vulnerabilities found:")
        for file_path, vulnerability_type, description in vulnerabilities:
            print(f"- File: {file_path}")
            print(f"  Type: {vulnerability_type}")
            print(f"  Description: {description}")
    else:
        print("No vulnerabilities found.")
