#!/usr/bin/env python3
"""
GENESIS ADVERSARIAL VERIFIER
============================
Implements Generator vs Verifier pattern from Adversarial VDD.
Code only passes when the verifier cannot break it.

Pattern:
    1. Generator creates code/solution
    2. Verifier attempts to find holes/edge cases
    3. If verifier finds issues -> Generator must fix
    4. Loop until verifier cannot find issues
    5. Only then is the solution accepted

This creates adversarial pressure that produces robust code.

Usage:
    verifier = AdversarialVerifier()
    result = verifier.verify_implementation(code, tests, description)
"""

import ast
import json
import re
import subprocess
import sys
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Dict, List, Any, Optional, Callable, Tuple


class VulnerabilityType(Enum):
    """Categories of vulnerabilities the verifier looks for."""
    EDGE_CASE = "edge_case"
    INPUT_VALIDATION = "input_validation"
    ERROR_HANDLING = "error_handling"
    TYPE_SAFETY = "type_safety"
    RESOURCE_LEAK = "resource_leak"
    RACE_CONDITION = "race_condition"
    SECURITY = "security"
    LOGIC_ERROR = "logic_error"
    BOUNDARY = "boundary"
    NULL_REFERENCE = "null_reference"


@dataclass
class Vulnerability:
    """Represents a found vulnerability."""
    type: VulnerabilityType
    description: str
    location: str  # file:line or function name
    severity: str  # critical, high, medium, low
    suggested_fix: Optional[str] = None
    test_case: Optional[str] = None  # Test that exposes the vulnerability


@dataclass
class VerificationResult:
    """Result of adversarial verification."""
    passed: bool
    vulnerabilities: List[Vulnerability] = field(default_factory=list)
    iterations: int = 0
    generator_improvements: List[str] = field(default_factory=list)
    final_confidence: float = 0.0
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())

    def to_dict(self) -> Dict:
        return {
            "passed": self.passed,
            "vulnerabilities": [
                {
                    "type": v.type.value,
                    "description": v.description,
                    "location": v.location,
                    "severity": v.severity,
                    "suggested_fix": v.suggested_fix,
                    "test_case": v.test_case
                }
                for v in self.vulnerabilities
            ],
            "iterations": self.iterations,
            "improvements": self.generator_improvements,
            "confidence": self.final_confidence,
            "timestamp": self.timestamp
        }


class CodeAnalyzer:
    """Static code analysis for vulnerability detection."""

    def __init__(self):
        self.dangerous_patterns = [
            (r'eval\s*\(', VulnerabilityType.SECURITY, "Use of eval() - code injection risk"),
            (r'exec\s*\(', VulnerabilityType.SECURITY, "Use of exec() - code injection risk"),
            (r'subprocess\.call.*shell=True', VulnerabilityType.SECURITY, "Shell=True in subprocess - injection risk"),
            (r'os\.system\s*\(', VulnerabilityType.SECURITY, "Use of os.system() - command injection risk"),
            (r'pickle\.load', VulnerabilityType.SECURITY, "Pickle deserialization - arbitrary code execution risk"),
            (r'input\s*\(.*\)', VulnerabilityType.INPUT_VALIDATION, "Unvalidated input() call"),
            (r'except\s*:', VulnerabilityType.ERROR_HANDLING, "Bare except clause - catches all errors"),
            (r'# TODO', VulnerabilityType.LOGIC_ERROR, "TODO comment found - incomplete implementation"),
            (r'pass\s*$', VulnerabilityType.LOGIC_ERROR, "Empty pass statement - possible incomplete implementation"),
            (r'\.get\([^,)]+\)[^.]*([\+\-\*/]|==|!=)', VulnerabilityType.NULL_REFERENCE, "dict.get() used without default, then operated on"),
        ]

    def analyze_code(self, code: str, filename: str = "code.py") -> List[Vulnerability]:
        """Perform static analysis on code."""
        vulnerabilities = []

        # Pattern-based detection
        for pattern, vuln_type, description in self.dangerous_patterns:
            for i, line in enumerate(code.split('\n'), 1):
                if re.search(pattern, line):
                    vulnerabilities.append(Vulnerability(
                        type=vuln_type,
                        description=description,
                        location=f"{filename}:{i}",
                        severity="high" if vuln_type == VulnerabilityType.SECURITY else "medium"
                    ))

        # AST-based analysis
        try:
            tree = ast.parse(code)
            vulnerabilities.extend(self._analyze_ast(tree, filename))
        except SyntaxError as e:
            vulnerabilities.append(Vulnerability(
                type=VulnerabilityType.LOGIC_ERROR,
                description=f"Syntax error: {e}",
                location=f"{filename}:{e.lineno}",
                severity="critical"
            ))

        return vulnerabilities

    def _analyze_ast(self, tree: ast.AST, filename: str) -> List[Vulnerability]:
        """AST-based vulnerability detection."""
        vulnerabilities = []

        for node in ast.walk(tree):
            # Check for functions without docstrings
            if isinstance(node, ast.FunctionDef):
                if not ast.get_docstring(node):
                    vulnerabilities.append(Vulnerability(
                        type=VulnerabilityType.LOGIC_ERROR,
                        description=f"Function '{node.name}' lacks docstring",
                        location=f"{filename}:{node.lineno}",
                        severity="low"
                    ))

                # Check for missing type hints on public functions
                if not node.name.startswith('_'):
                    if not node.returns:
                        vulnerabilities.append(Vulnerability(
                            type=VulnerabilityType.TYPE_SAFETY,
                            description=f"Function '{node.name}' lacks return type hint",
                            location=f"{filename}:{node.lineno}",
                            severity="low"
                        ))

            # Check for empty exception handlers
            if isinstance(node, ast.ExceptHandler):
                if len(node.body) == 1 and isinstance(node.body[0], ast.Pass):
                    vulnerabilities.append(Vulnerability(
                        type=VulnerabilityType.ERROR_HANDLING,
                        description="Exception silently ignored",
                        location=f"{filename}:{node.lineno}",
                        severity="medium"
                    ))

            # Check for potential division by zero
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Div):
                if isinstance(node.right, ast.Num) and node.right.n == 0:
                    vulnerabilities.append(Vulnerability(
                        type=VulnerabilityType.LOGIC_ERROR,
                        description="Division by zero",
                        location=f"{filename}:{node.lineno}",
                        severity="critical"
                    ))

        return vulnerabilities


class EdgeCaseGenerator:
    """Generates adversarial test cases to find edge cases."""

    def __init__(self):
        self.edge_cases = {
            "string": ["", " ", "\n", "\t", "a" * 10000, None, "!@#$%^&*()", "\x00", "🔥"],
            "integer": [0, -1, 1, -2147483648, 2147483647, None],
            "float": [0.0, -0.0, float('inf'), float('-inf'), float('nan'), None],
            "list": [[], [None], [[]], [1] * 10000, None],
            "dict": [{}, {"": ""}, {None: None}, None],
            "bool": [True, False, None, 0, 1, "", []],
        }

    def generate_test_cases(self, function_signature: str) -> List[Dict]:
        """Generate adversarial test cases for a function."""
        test_cases = []

        # Parse signature to get parameter types
        params = self._parse_params(function_signature)

        for param_name, param_type in params.items():
            edges = self.edge_cases.get(param_type, self.edge_cases["string"])
            for edge in edges:
                test_cases.append({
                    "param": param_name,
                    "value": edge,
                    "type": param_type,
                    "description": f"Edge case: {param_name}={repr(edge)}"
                })

        return test_cases

    def _parse_params(self, signature: str) -> Dict[str, str]:
        """Parse function signature to extract parameters and types."""
        params = {}
        # Simple regex to extract param:type pairs
        matches = re.findall(r'(\w+)\s*:\s*(\w+)', signature)
        for name, typ in matches:
            params[name] = typ.lower()
        return params


class AdversarialVerifier:
    """
    Main verifier class implementing the Generator vs Verifier pattern.

    The verifier acts as an adversary trying to break the code.
    It uses multiple strategies:
    1. Static code analysis
    2. Edge case generation
    3. Boundary testing
    4. Error injection
    5. Resource stress testing
    """

    def __init__(self, max_iterations: int = 5, confidence_threshold: float = 0.9):
        self.max_iterations = max_iterations
        self.confidence_threshold = confidence_threshold
        self.analyzer = CodeAnalyzer()
        self.edge_generator = EdgeCaseGenerator()
        self.verification_history: List[VerificationResult] = []

    def verify_implementation(
        self,
        code: str,
        tests: Optional[str] = None,
        description: str = "",
        file_path: Optional[str] = None
    ) -> VerificationResult:
        """
        Verify an implementation adversarially.

        Args:
            code: The code to verify
            tests: Optional test code
            description: What the code should do
            file_path: Path if verifying a file

        Returns:
            VerificationResult with pass/fail and found vulnerabilities
        """
        result = VerificationResult(passed=False, iterations=0)
        filename = file_path or "code.py"

        for iteration in range(self.max_iterations):
            result.iterations = iteration + 1

            # Phase 1: Static Analysis
            vulnerabilities = self.analyzer.analyze_code(code, filename)

            # Phase 2: Edge Case Testing
            edge_vulns = self._test_edge_cases(code)
            vulnerabilities.extend(edge_vulns)

            # Phase 3: Test Verification (if tests provided)
            if tests:
                test_vulns = self._verify_tests(code, tests)
                vulnerabilities.extend(test_vulns)

            # Filter out low severity for pass/fail decision
            critical_vulns = [v for v in vulnerabilities if v.severity in ["critical", "high"]]

            if not critical_vulns:
                # Calculate confidence based on analysis depth
                result.passed = True
                result.final_confidence = self._calculate_confidence(
                    code, tests, len(vulnerabilities)
                )
                result.vulnerabilities = vulnerabilities  # Include all, even low severity
                break
            else:
                result.vulnerabilities = vulnerabilities
                # In a real implementation, we'd call the Generator to fix
                # For now, we just report what was found

        self.verification_history.append(result)
        return result

    def _test_edge_cases(self, code: str) -> List[Vulnerability]:
        """Test code with edge case inputs."""
        vulnerabilities = []

        # Find function definitions
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef):
                    # Get signature
                    args = []
                    for arg in node.args.args:
                        if arg.annotation:
                            args.append(f"{arg.arg}: {ast.unparse(arg.annotation)}")
                        else:
                            args.append(f"{arg.arg}: str")

                    signature = f"{node.name}({', '.join(args)})"
                    test_cases = self.edge_generator.generate_test_cases(signature)

                    # We can't actually run the tests without more context
                    # but we can note that edge cases should be tested
                    if test_cases and not self._has_input_validation(node):
                        vulnerabilities.append(Vulnerability(
                            type=VulnerabilityType.INPUT_VALIDATION,
                            description=f"Function '{node.name}' may need input validation for edge cases",
                            location=f"function:{node.name}",
                            severity="medium",
                            test_case=str(test_cases[:3])  # First 3 test cases
                        ))
        except:
            pass

        return vulnerabilities

    def _has_input_validation(self, func_node: ast.FunctionDef) -> bool:
        """Check if function has input validation."""
        for node in ast.walk(func_node):
            # Look for type checks or assertions
            if isinstance(node, ast.Call):
                if isinstance(node.func, ast.Name):
                    if node.func.id in ['isinstance', 'assert', 'validate']:
                        return True
            if isinstance(node, ast.Assert):
                return True
            if isinstance(node, ast.If):
                # Check if it's a type check
                test = ast.unparse(node.test) if hasattr(ast, 'unparse') else ""
                if 'isinstance' in test or 'type(' in test:
                    return True
        return False

    def _verify_tests(self, code: str, tests: str) -> List[Vulnerability]:
        """Verify that tests actually test the code properly."""
        vulnerabilities = []

        # Check test coverage indicators
        try:
            code_tree = ast.parse(code)
            test_tree = ast.parse(tests)

            # Get all function names in code
            code_funcs = set()
            for node in ast.walk(code_tree):
                if isinstance(node, ast.FunctionDef):
                    if not node.name.startswith('_'):
                        code_funcs.add(node.name)

            # Get all function names called in tests
            test_calls = set()
            for node in ast.walk(test_tree):
                if isinstance(node, ast.Call):
                    if isinstance(node.func, ast.Name):
                        test_calls.add(node.func.id)
                    elif isinstance(node.func, ast.Attribute):
                        test_calls.add(node.func.attr)

            # Check for untested functions
            untested = code_funcs - test_calls
            for func in untested:
                vulnerabilities.append(Vulnerability(
                    type=VulnerabilityType.LOGIC_ERROR,
                    description=f"Function '{func}' appears to be untested",
                    location=f"function:{func}",
                    severity="medium"
                ))
        except:
            pass

        return vulnerabilities

    def _calculate_confidence(
        self,
        code: str,
        tests: Optional[str],
        vuln_count: int
    ) -> float:
        """Calculate confidence score for the verification."""
        confidence = 0.5  # Base confidence

        # More code analyzed = more confidence
        lines = len(code.split('\n'))
        if lines > 50:
            confidence += 0.1
        if lines > 100:
            confidence += 0.1

        # Tests provided = more confidence
        if tests:
            confidence += 0.2
            test_lines = len(tests.split('\n'))
            if test_lines > lines * 0.5:  # Good test coverage
                confidence += 0.1

        # Fewer vulnerabilities = more confidence
        if vuln_count == 0:
            confidence += 0.2
        elif vuln_count < 3:
            confidence += 0.1

        return min(confidence, 1.0)

    def verify_file(self, file_path: Path) -> VerificationResult:
        """Verify a Python file."""
        if not file_path.exists():
            return VerificationResult(
                passed=False,
                vulnerabilities=[Vulnerability(
                    type=VulnerabilityType.LOGIC_ERROR,
                    description=f"File not found: {file_path}",
                    location=str(file_path),
                    severity="critical"
                )]
            )

        code = file_path.read_text()

        # Look for test file
        test_path = file_path.parent / f"test_{file_path.name}"
        tests = test_path.read_text() if test_path.exists() else None

        return self.verify_implementation(
            code=code,
            tests=tests,
            file_path=str(file_path)
        )

    def generate_report(self) -> str:
        """Generate a markdown report of verification history."""
        lines = [
            "# Adversarial Verification Report",
            f"\nGenerated: {datetime.now().isoformat()}",
            f"\nTotal verifications: {len(self.verification_history)}",
            ""
        ]

        passed = sum(1 for r in self.verification_history if r.passed)
        lines.append(f"**Passed:** {passed}/{len(self.verification_history)}")
        lines.append("")

        for i, result in enumerate(self.verification_history, 1):
            status = "✅ PASSED" if result.passed else "❌ FAILED"
            lines.append(f"## Verification #{i} - {status}")
            lines.append(f"- Iterations: {result.iterations}")
            lines.append(f"- Confidence: {result.final_confidence:.1%}")
            lines.append(f"- Vulnerabilities: {len(result.vulnerabilities)}")

            if result.vulnerabilities:
                lines.append("\n### Found Issues:")
                for v in result.vulnerabilities:
                    lines.append(f"- [{v.severity.upper()}] {v.type.value}: {v.description}")
                    lines.append(f"  - Location: {v.location}")

            lines.append("")

        return "\n".join(lines)


class VerifierLoop:
    """
    Runs the full Generator-Verifier adversarial loop.

    This coordinates between:
    - Generator (creates/fixes code)
    - Verifier (tries to break it)

    Loop continues until verifier cannot find issues.
    """

    def __init__(
        self,
        generator_callback: Optional[Callable[[str, List[Vulnerability]], str]] = None,
        max_rounds: int = 10
    ):
        self.verifier = AdversarialVerifier()
        self.generator = generator_callback
        self.max_rounds = max_rounds
        self.round_history: List[Dict] = []

    def run(self, initial_code: str, spec: str) -> Tuple[str, VerificationResult]:
        """
        Run the adversarial loop.

        Args:
            initial_code: Starting code from generator
            spec: What the code should do

        Returns:
            (final_code, final_result)
        """
        current_code = initial_code

        for round_num in range(self.max_rounds):
            # Verify current code
            result = self.verifier.verify_implementation(
                code=current_code,
                description=spec
            )

            self.round_history.append({
                "round": round_num + 1,
                "passed": result.passed,
                "vulnerabilities": len(result.vulnerabilities),
                "critical": len([v for v in result.vulnerabilities if v.severity == "critical"])
            })

            if result.passed:
                return current_code, result

            # If we have a generator, ask it to fix
            if self.generator:
                current_code = self.generator(current_code, result.vulnerabilities)
            else:
                # Without generator, just return what we have
                return current_code, result

        # Max rounds reached
        return current_code, result


def main():
    """Test the adversarial verifier."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Adversarial Verifier")
    parser.add_argument("file", nargs="?", help="Python file to verify")
    parser.add_argument("--code", help="Inline code to verify")
    parser.add_argument("--report", action="store_true", help="Generate report")
    args = parser.parse_args()

    verifier = AdversarialVerifier()

    if args.file:
        result = verifier.verify_file(Path(args.file))
    elif args.code:
        result = verifier.verify_implementation(args.code)
    else:
        # Demo with sample code
        sample_code = '''
def divide(a, b):
    return a / b

def greet(name):
    print("Hello " + name)
'''
        print("Demo: Verifying sample code...\n")
        result = verifier.verify_implementation(sample_code, file_path="sample.py")

    # Print results
    print("=" * 50)
    print("ADVERSARIAL VERIFICATION RESULT")
    print("=" * 50)
    print(f"Status: {'PASSED ✅' if result.passed else 'FAILED ❌'}")
    print(f"Iterations: {result.iterations}")
    print(f"Confidence: {result.final_confidence:.1%}")
    print(f"Issues Found: {len(result.vulnerabilities)}")

    if result.vulnerabilities:
        print("\nVulnerabilities:")
        for v in result.vulnerabilities:
            print(f"  [{v.severity.upper()}] {v.type.value}")
            print(f"    {v.description}")
            print(f"    Location: {v.location}")

    if args.report:
        print("\n" + verifier.generate_report())


if __name__ == "__main__":
    main()
