"""
Comprehensive secret scanner for codebase auditing.
Scans files and git history for potential hardcoded secrets.
"""

import os
import re
import sys
import subprocess
from pathlib import Path
from typing import List, Tuple, Set
from dataclasses import dataclass
from enum import Enum


class Severity(Enum):
    CRITICAL = "CRITICAL"
    HIGH = "HIGH"
    MEDIUM = "MEDIUM"
    LOW = "LOW"


@dataclass
class Finding:
    file_path: str
    line_number: int
    line_content: str
    pattern_name: str
    severity: Severity
    match_text: str


class SecretAuditor:
    """
    Scans codebase for hardcoded secrets using regex patterns and entropy analysis.
    """
    
    # Patterns for common secrets
    PATTERNS = {
        'aws_access_key_id': (
            r'AKIA[0-9A-Z]{16}',
            Severity.CRITICAL
        ),
        'aws_secret_key': (
            r'[0-9a-zA-Z/+]{40}',
            Severity.CRITICAL
        ),
        'generic_api_key': (
            r'(?i)(api[_-]?key|apikey)\s*[:=]\s*["\']?[a-z0-9]{16,}["\']?',
            Severity.HIGH
        ),
        'private_key': (
            r'-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----',
            Severity.CRITICAL
        ),
        'password_assignment': (
            r'(?i)(password|passwd|pwd)\s*[:=]\s*["\'][^"\']{8,}["\']',
            Severity.HIGH
        ),
        'auth_token': (
            r'(?i)(auth[_-]?token|access[_-]?token)\s*[:=]\s*["\']?[a-z0-9]{20,}["\']?',
            Severity.HIGH
        ),
        'openai_key': (
            r'sk-[a-zA-Z0-9]{20,}',
            Severity.CRITICAL
        ),
        'database_url': (
            r'(postgres|mysql|mongodb)://[^:]+:[^@]+@',
            Severity.CRITICAL
        ),
        'jwt_secret': (
            r'(?i)(jwt[_-]?secret|secret[_-]?key)\s*[:=]\s*["\'][^"\']{10,}["\']',
            Severity.HIGH
        ),
        'high_entropy_string': (
            r'[a-zA-Z0-9+/]{40,}',
            Severity.MEDIUM
        )
    }
    
    # Files and directories to exclude
    EXCLUDED_PATHS = {
        '.git', '__pycache__', 'node_modules', '.venv', 'venv',
        'env', '.env', '.pytest_cache', '.mypy_cache', '*.pyc',
        'audit_secrets.py', '.env.example', 'README.md'
    }
    
    def __init__(self, root_path: str = "."):
        self.root_path = Path(root_path).resolve()
        self.findings: List[Finding] = []
        
    def should_scan_file(self, file_path: Path) -> bool:
        """Determine if file should be scanned based on exclusion rules."""
        path_str = str(file_path)
        
        # Check excluded directories
        for excluded in self.EXCLUDED_PATHS:
            if excluded in path_str:
                return False
        
        # Only scan text files
        try:
            with open(file_path, 'rb') as f:
                chunk = f.read(1024)
                if b'\0' in chunk:  # Binary file
                    return False
        except Exception:
            return False
            
        return True
    
    def scan_file(self, file_path: Path) -> List[Finding]:
        """Scan a single file for secrets."""
        findings = []
        
        try:
            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                lines = f.readlines()
                
            for line_num, line in enumerate(lines, 1):
                for pattern_name, (pattern, severity) in self.PATTERNS.items():
                    matches = re.finditer(pattern, line)
                    for match in matches:
                        # Avoid false positives in comments (basic check)
                        stripped = line.strip()
                        if stripped.startswith('#') or stripped.startswith('//'):
                            continue
                            
                        findings.append(Finding(
                            file_path=str(file_path.relative_to(self.root_path)),
                            line_number=line_num,
                            line_content=line.strip(),
                            pattern_name=pattern_name,
                            severity=severity,
                            match_text=match.group()[:20] + "..." if len(match.group()) > 20 else match.group()
                        ))
        except Exception as e:
            print(f"Error scanning {file_path}: {e}")
            
        return findings
    
    def scan_directory(self) -> List[Finding]:
        """Recursively scan directory for secrets."""
        all_findings = []
        
        for root, dirs, files in os.walk(self.root_path):
            # Modify dirs in-place to exclude directories
            dirs[:] = [d for d in dirs if d not in self.EXCLUDED_PATHS]
            
            for file in files:
                file_path = Path(root) / file
                if self.should_scan_file(file_path):
                    findings = self.scan_file(file_path)
                    all_findings.extend(findings)
                    
        return all_findings
    
    def scan_git_history(self) -> List[Finding]:
        """Scan git history for secrets that may have been committed."""
        findings = []
        
        try:
            # Get all commits
            result = subprocess.run(
                ['git', 'log', '--all', '--pretty=format:%H'],
                capture_output=True,
                text=True,
                cwd=self.root_path
            )
            
            if result.returncode != 0:
                print("Warning: Could not access git history")
                return findings
                
            commits = result.stdout.strip().split('\n')[:100]  # Limit to last 100 commits
            
            for commit in commits:
                # Get diff for each commit
                diff_result = subprocess.run(
                    ['git', 'show', commit, '--no-patch'],
                    capture_output=True,
                    text=True,
                    cwd=self.root_path
                )
                
                content = diff_result.stdout
                for pattern_name, (pattern, severity) in self.PATTERNS.items():
                    matches = re.finditer(pattern, content)
                    for match in matches:
                        findings.append(Finding(
                            file_path=f"git_history:{commit[:8]}",
                            line_number=0,
                            line_content="[Commit message or diff content]",
                            pattern_name=pattern_name,
                            severity=severity,
                            match_text=match.group()[:20] + "..."
                        ))
                        
        except Exception as e:
            print(f"Error scanning git history: {e}")
            
        return findings
    
    def generate_report(self) -> str:
        """Generate formatted report of findings."""
        if not self.findings:
            return "✅ No secrets detected in codebase!"
            
        report = ["🔒 SECRET AUDIT REPORT", "=" * 50, ""]
        
        # Group by severity
        by_severity = {}
        for finding in self.findings:
            by_severity.setdefault(finding.severity, []).append(finding)
            
        for severity in [Severity.CRITICAL, Severity.HIGH, Severity.MEDIUM, Severity.LOW]:
            if severity in by_severity:
                report.append(f"\n{severity.value} ({len(by_severity[severity])} findings):")
                report.append("-" * 40)
                
                for finding in by_severity[severity]:
                    report.append(f"  File: {finding.file_path}:{finding.line_number}")
                    report.append(f"  Pattern: {finding.pattern_name}")
                    report.append(f"  Match: {finding.match_text}")
                    report.append(f"  Line: {finding.line_content[:80]}")
                    report.append("")
                    
        return "\n".join(report)
    
    def run_audit(self, include_git_history: bool = True) -> int:
        """
        Run complete audit and return exit code.
        
        Returns:
            int: 0 if no secrets found, 1 otherwise
        """
        print("🔍 Starting secret audit...")
        print(f"Scanning directory: {self.root_path}")
        
        # Scan current files
        self.findings = self.scan_directory()
        
        # Scan git history if requested
        if include_git_history:
            print("Scanning git history...")
            history_findings = self.scan_git_history()
            self.findings.extend(history_findings)
        
        # Generate report
        report = self.generate_report()
        print(report)
        
        # Return appropriate exit code
        critical_count = sum(1 for f in self.findings if f.severity == Severity.CRITICAL)
        high_count = sum(1 for f in self.findings if f.severity == Severity.HIGH)
        
        if critical_count > 0:
            print(f"\n❌ CRITICAL: {critical_count} critical secrets detected!")
            return 1
        elif high_count > 0:
            print(f"\n⚠️  WARNING: {high_count} high-risk patterns detected")
            return 1
        else:
            print("\n✅ Audit passed")
            return 0


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Audit codebase for hardcoded secrets")
    parser.add_argument("--path", default=".", help="Root path to scan")
    parser.add_argument("--no-history", action="store_true", help="Skip git history scan")
    
    args = parser.parse_args()
    
    auditor = SecretAuditor(args.path)
    sys.exit(auditor.run_audit(include_git_history=not args.no_history))