"""
AIVA Triple Gate Validator - PM-025 (Enhanced)

Ensures all AIVA outputs pass 3 validation gates:
- Gate 1: Syntax/format validation
- Gate 2: Quality/coherence check
- Gate 3: Safety/permission check

Logs all gate results for audit trail.
"""

import os
import re
import json
import logging
import ast
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, asdict
from enum import Enum
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class GateResult(Enum):
    """Result of a gate validation."""
    PASS = "pass"
    FAIL = "fail"
    WARN = "warn"


@dataclass
class GateCheck:
    """Individual check within a gate."""
    name: str
    passed: bool
    message: str
    score: float = 1.0


@dataclass
class GateReport:
    """Report from a single gate."""
    gate_name: str
    gate_number: int
    result: GateResult
    score: float
    checks: List[GateCheck]
    timestamp: str

    def to_dict(self) -> Dict:
        data = asdict(self)
        data["result"] = self.result.value
        data["checks"] = [asdict(c) for c in self.checks]
        return data


@dataclass
class ValidationReport:
    """Complete validation report across all gates."""
    output_id: str
    passed: bool
    overall_score: float
    gate1_syntax: GateReport
    gate2_quality: GateReport
    gate3_safety: GateReport
    timestamp: str
    metadata: Dict

    def to_dict(self) -> Dict:
        return {
            "output_id": self.output_id,
            "passed": self.passed,
            "overall_score": self.overall_score,
            "gates": {
                "gate1_syntax": self.gate1_syntax.to_dict(),
                "gate2_quality": self.gate2_quality.to_dict(),
                "gate3_safety": self.gate3_safety.to_dict()
            },
            "timestamp": self.timestamp,
            "metadata": self.metadata
        }


class Gate1SyntaxValidator:
    """
    Gate 1: Syntax/Format Validation

    Validates that output has correct syntax and format.
    """

    def validate(self, output: Any, output_type: str = "text") -> GateReport:
        """
        Validate output syntax and format.

        Args:
            output: The output to validate
            output_type: Expected type (text, python, json, markdown)

        Returns:
            GateReport with validation results
        """
        checks: List[GateCheck] = []
        timestamp = datetime.utcnow().isoformat()

        # Check 1: Non-empty output
        if output is None or (isinstance(output, str) and len(output.strip()) == 0):
            checks.append(GateCheck(
                name="non_empty",
                passed=False,
                message="Output is empty or None",
                score=0.0
            ))
        else:
            checks.append(GateCheck(
                name="non_empty",
                passed=True,
                message="Output is non-empty",
                score=1.0
            ))

        # Check 2: Type-specific validation
        if output_type == "python" and isinstance(output, str):
            checks.append(self._validate_python_syntax(output))
        elif output_type == "json" and isinstance(output, str):
            checks.append(self._validate_json_syntax(output))
        elif output_type == "markdown" and isinstance(output, str):
            checks.append(self._validate_markdown_format(output))
        else:
            checks.append(GateCheck(
                name="format_check",
                passed=True,
                message=f"Format validation skipped for type: {output_type}",
                score=1.0
            ))

        # Check 3: Encoding validation
        if isinstance(output, str):
            checks.append(self._validate_encoding(output))

        # Calculate overall score
        total_score = sum(c.score for c in checks) / len(checks) if checks else 0.0
        all_passed = all(c.passed for c in checks)

        return GateReport(
            gate_name="Syntax/Format Validation",
            gate_number=1,
            result=GateResult.PASS if all_passed else GateResult.FAIL,
            score=total_score,
            checks=checks,
            timestamp=timestamp
        )

    def _validate_python_syntax(self, code: str) -> GateCheck:
        """Validate Python code syntax."""
        try:
            ast.parse(code)
            return GateCheck(
                name="python_syntax",
                passed=True,
                message="Python syntax is valid",
                score=1.0
            )
        except SyntaxError as e:
            return GateCheck(
                name="python_syntax",
                passed=False,
                message=f"Python syntax error: {e}",
                score=0.0
            )

    def _validate_json_syntax(self, content: str) -> GateCheck:
        """Validate JSON syntax."""
        try:
            json.loads(content)
            return GateCheck(
                name="json_syntax",
                passed=True,
                message="JSON syntax is valid",
                score=1.0
            )
        except json.JSONDecodeError as e:
            return GateCheck(
                name="json_syntax",
                passed=False,
                message=f"JSON syntax error: {e}",
                score=0.0
            )

    def _validate_markdown_format(self, content: str) -> GateCheck:
        """Validate basic markdown structure."""
        # Check for unbalanced code blocks
        code_blocks = content.count("```")
        if code_blocks % 2 != 0:
            return GateCheck(
                name="markdown_format",
                passed=False,
                message="Unbalanced code blocks in markdown",
                score=0.5
            )
        return GateCheck(
            name="markdown_format",
            passed=True,
            message="Markdown format appears valid",
            score=1.0
        )

    def _validate_encoding(self, content: str) -> GateCheck:
        """Validate text encoding."""
        try:
            content.encode('utf-8')
            return GateCheck(
                name="encoding",
                passed=True,
                message="UTF-8 encoding valid",
                score=1.0
            )
        except UnicodeEncodeError:
            return GateCheck(
                name="encoding",
                passed=False,
                message="Invalid UTF-8 characters detected",
                score=0.0
            )


class Gate2QualityValidator:
    """
    Gate 2: Quality/Coherence Check

    Validates that output meets quality standards.
    """

    def __init__(self, min_length: int = 10, max_length: int = 100000):
        self.min_length = min_length
        self.max_length = max_length

    def validate(self, output: Any, context: Optional[Dict] = None) -> GateReport:
        """
        Validate output quality and coherence.

        Args:
            output: The output to validate
            context: Context about what was requested

        Returns:
            GateReport with validation results
        """
        checks: List[GateCheck] = []
        timestamp = datetime.utcnow().isoformat()
        context = context or {}

        if isinstance(output, str):
            # Check 1: Length validation
            checks.append(self._check_length(output))

            # Check 2: Repetition detection
            checks.append(self._check_repetition(output))

            # Check 3: Coherence check
            checks.append(self._check_coherence(output))

            # Check 4: Relevance check (if context provided)
            if context.get("keywords"):
                checks.append(self._check_relevance(output, context["keywords"]))
        else:
            checks.append(GateCheck(
                name="type_check",
                passed=True,
                message="Non-string output accepted",
                score=1.0
            ))

        # Calculate overall score
        total_score = sum(c.score for c in checks) / len(checks) if checks else 0.0
        passed = total_score >= 0.7  # 70% threshold for quality

        return GateReport(
            gate_name="Quality/Coherence Check",
            gate_number=2,
            result=GateResult.PASS if passed else GateResult.FAIL,
            score=total_score,
            checks=checks,
            timestamp=timestamp
        )

    def _check_length(self, content: str) -> GateCheck:
        """Check content length is within bounds."""
        length = len(content)
        if length < self.min_length:
            return GateCheck(
                name="length",
                passed=False,
                message=f"Content too short: {length} < {self.min_length}",
                score=0.3
            )
        elif length > self.max_length:
            return GateCheck(
                name="length",
                passed=False,
                message=f"Content too long: {length} > {self.max_length}",
                score=0.5
            )
        return GateCheck(
            name="length",
            passed=True,
            message=f"Content length acceptable: {length}",
            score=1.0
        )

    def _check_repetition(self, content: str) -> GateCheck:
        """Detect excessive repetition."""
        words = content.lower().split()
        if len(words) < 10:
            return GateCheck(
                name="repetition",
                passed=True,
                message="Content too short for repetition check",
                score=1.0
            )

        # Check for repeated phrases
        unique_words = set(words)
        repetition_ratio = len(unique_words) / len(words)

        if repetition_ratio < 0.3:
            return GateCheck(
                name="repetition",
                passed=False,
                message=f"Excessive repetition detected: {repetition_ratio:.1%} unique words",
                score=0.2
            )
        elif repetition_ratio < 0.5:
            return GateCheck(
                name="repetition",
                passed=True,
                message=f"Some repetition: {repetition_ratio:.1%} unique words",
                score=0.7
            )
        return GateCheck(
            name="repetition",
            passed=True,
            message=f"Good variety: {repetition_ratio:.1%} unique words",
            score=1.0
        )

    def _check_coherence(self, content: str) -> GateCheck:
        """Basic coherence check."""
        # Check for sentence structure
        sentences = re.split(r'[.!?]+', content)
        valid_sentences = [s.strip() for s in sentences if len(s.strip()) > 5]

        if len(valid_sentences) == 0:
            return GateCheck(
                name="coherence",
                passed=False,
                message="No valid sentences detected",
                score=0.3
            )

        return GateCheck(
            name="coherence",
            passed=True,
            message=f"Found {len(valid_sentences)} valid sentences",
            score=1.0
        )

    def _check_relevance(self, content: str, keywords: List[str]) -> GateCheck:
        """Check if content contains expected keywords."""
        content_lower = content.lower()
        found = sum(1 for kw in keywords if kw.lower() in content_lower)
        relevance = found / len(keywords) if keywords else 1.0

        if relevance < 0.3:
            return GateCheck(
                name="relevance",
                passed=False,
                message=f"Low relevance: {found}/{len(keywords)} keywords found",
                score=relevance
            )
        return GateCheck(
            name="relevance",
            passed=True,
            message=f"Good relevance: {found}/{len(keywords)} keywords found",
            score=min(1.0, relevance + 0.3)
        )


class Gate3SafetyValidator:
    """
    Gate 3: Safety/Permission Check

    Validates that output does not contain dangerous content
    and action is permitted.
    """

    # Patterns that indicate potentially dangerous content
    DANGEROUS_PATTERNS = [
        r'rm\s+-rf\s+/',                    # Recursive delete root
        r'DROP\s+DATABASE',                  # Database deletion
        r'DELETE\s+FROM\s+\w+\s*;',         # Mass deletion without WHERE
        r'api[_-]?key\s*[=:]\s*["\']?\w{20,}', # API key exposure
        r'password\s*[=:]\s*["\'][^"\']+["\']', # Password in plaintext
        r'eval\s*\(',                        # Dangerous eval
        r'exec\s*\(',                        # Dangerous exec
        r'__import__',                       # Dynamic import
    ]

    BLOCKED_ACTIONS = [
        "delete_production",
        "modify_credentials",
        "disable_logging",
        "bypass_validation",
        "force_push_main"
    ]

    def __init__(self, permission_manager=None):
        """
        Initialize with optional permission manager.

        Args:
            permission_manager: PermissionManager instance for checking permissions
        """
        self.permission_manager = permission_manager
        self.compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.DANGEROUS_PATTERNS]

    def validate(
        self,
        output: Any,
        action: Optional[str] = None,
        context: Optional[Dict] = None
    ) -> GateReport:
        """
        Validate output safety and permission compliance.

        Args:
            output: The output to validate
            action: The action this output is for (for permission check)
            context: Additional context

        Returns:
            GateReport with validation results
        """
        checks: List[GateCheck] = []
        timestamp = datetime.utcnow().isoformat()
        context = context or {}

        if isinstance(output, str):
            # Check 1: Dangerous patterns
            checks.append(self._check_dangerous_patterns(output))

            # Check 2: Credential exposure
            checks.append(self._check_credential_exposure(output))

        # Check 3: Action permission
        if action:
            checks.append(self._check_action_permission(action, context))
        else:
            checks.append(GateCheck(
                name="action_permission",
                passed=True,
                message="No action specified for permission check",
                score=1.0
            ))

        # Check 4: Blocked actions
        if action:
            checks.append(self._check_blocked_actions(action))

        # Calculate overall score
        # Safety gate is strict - any failure is serious
        all_passed = all(c.passed for c in checks)
        total_score = sum(c.score for c in checks) / len(checks) if checks else 0.0

        return GateReport(
            gate_name="Safety/Permission Check",
            gate_number=3,
            result=GateResult.PASS if all_passed else GateResult.FAIL,
            score=total_score,
            checks=checks,
            timestamp=timestamp
        )

    def _check_dangerous_patterns(self, content: str) -> GateCheck:
        """Check for dangerous patterns in content."""
        for pattern in self.compiled_patterns:
            if pattern.search(content):
                return GateCheck(
                    name="dangerous_patterns",
                    passed=False,
                    message=f"Dangerous pattern detected: {pattern.pattern}",
                    score=0.0
                )
        return GateCheck(
            name="dangerous_patterns",
            passed=True,
            message="No dangerous patterns detected",
            score=1.0
        )

    def _check_credential_exposure(self, content: str) -> GateCheck:
        """Check for potential credential exposure."""
        # Look for common secret patterns
        secret_patterns = [
            r'sk-[a-zA-Z0-9]{20,}',  # OpenAI keys
            r'AKIA[A-Z0-9]{16}',      # AWS access keys
            r'ghp_[a-zA-Z0-9]{36}',   # GitHub tokens
        ]

        for pattern in secret_patterns:
            if re.search(pattern, content):
                return GateCheck(
                    name="credential_exposure",
                    passed=False,
                    message="Potential credential exposure detected",
                    score=0.0
                )

        return GateCheck(
            name="credential_exposure",
            passed=True,
            message="No credential exposure detected",
            score=1.0
        )

    def _check_action_permission(self, action: str, context: Dict) -> GateCheck:
        """Check if action is permitted."""
        if self.permission_manager:
            from .permission_manager import PermissionResult
            result, message = self.permission_manager.check_permission(action, context)
            passed = result in [PermissionResult.APPROVED, PermissionResult.NOTIFY]
            return GateCheck(
                name="action_permission",
                passed=passed,
                message=message,
                score=1.0 if passed else 0.0
            )

        # Without permission manager, allow by default but log
        return GateCheck(
            name="action_permission",
            passed=True,
            message="Permission check skipped (no permission manager)",
            score=0.8
        )

    def _check_blocked_actions(self, action: str) -> GateCheck:
        """Check if action is in blocked list."""
        if action in self.BLOCKED_ACTIONS:
            return GateCheck(
                name="blocked_action",
                passed=False,
                message=f"Action '{action}' is explicitly blocked",
                score=0.0
            )
        return GateCheck(
            name="blocked_action",
            passed=True,
            message=f"Action '{action}' not in blocked list",
            score=1.0
        )


class TripleGateValidator:
    """
    Main validator orchestrating all three gates.

    Usage:
        validator = TripleGateValidator()
        report = validator.validate(output, output_type="python", action="write_file")
        if report.passed:
            # Proceed with output
        else:
            # Handle validation failure
    """

    def __init__(self, log_dir: str = "logs", permission_manager=None):
        """
        Initialize the triple gate validator.

        Args:
            log_dir: Directory for validation logs
            permission_manager: Optional PermissionManager instance
        """
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.log_file = self.log_dir / "triple_gate_validation.jsonl"

        self.gate1 = Gate1SyntaxValidator()
        self.gate2 = Gate2QualityValidator()
        self.gate3 = Gate3SafetyValidator(permission_manager)

        logger.info("TripleGateValidator initialized")

    def validate(
        self,
        output: Any,
        output_type: str = "text",
        action: Optional[str] = None,
        context: Optional[Dict] = None,
        output_id: Optional[str] = None
    ) -> ValidationReport:
        """
        Run complete triple-gate validation.

        Args:
            output: The output to validate
            output_type: Type of output (text, python, json, markdown)
            action: Action this output is for (for permission check)
            context: Additional context (keywords, task info)
            output_id: Unique ID for this output

        Returns:
            ValidationReport with results from all gates
        """
        timestamp = datetime.utcnow().isoformat()
        output_id = output_id or f"output_{hash(str(output))}"
        context = context or {}

        # Gate 1: Syntax/Format
        gate1_report = self.gate1.validate(output, output_type)
        logger.debug(f"Gate 1 result: {gate1_report.result.value}")

        # Gate 2: Quality/Coherence
        gate2_report = self.gate2.validate(output, context)
        logger.debug(f"Gate 2 result: {gate2_report.result.value}")

        # Gate 3: Safety/Permission
        gate3_report = self.gate3.validate(output, action, context)
        logger.debug(f"Gate 3 result: {gate3_report.result.value}")

        # Calculate overall
        overall_score = (gate1_report.score + gate2_report.score + gate3_report.score) / 3

        # All gates must pass for overall pass
        # Safety gate failure is critical
        passed = (
            gate1_report.result == GateResult.PASS and
            gate2_report.result == GateResult.PASS and
            gate3_report.result == GateResult.PASS
        )

        report = ValidationReport(
            output_id=output_id,
            passed=passed,
            overall_score=overall_score,
            gate1_syntax=gate1_report,
            gate2_quality=gate2_report,
            gate3_safety=gate3_report,
            timestamp=timestamp,
            metadata={
                "output_type": output_type,
                "action": action,
                "context_keys": list(context.keys())
            }
        )

        # Log the validation
        self._log_validation(report)

        logger.info(
            f"Validation complete: passed={passed}, "
            f"score={overall_score:.2f}, id={output_id}"
        )

        return report

    def _log_validation(self, report: ValidationReport) -> None:
        """Log validation result to file."""
        try:
            with open(self.log_file, "a") as f:
                f.write(json.dumps(report.to_dict()) + "\n")
        except Exception as e:
            logger.error(f"Failed to log validation: {e}")

    def get_improvement_suggestions(self, report: ValidationReport) -> List[str]:
        """
        Generate improvement suggestions based on failed checks.

        Args:
            report: Validation report to analyze

        Returns:
            List of improvement suggestions
        """
        suggestions = []

        for gate_report in [report.gate1_syntax, report.gate2_quality, report.gate3_safety]:
            for check in gate_report.checks:
                if not check.passed:
                    suggestions.append(f"[Gate {gate_report.gate_number}] {check.name}: {check.message}")

        return suggestions


# Singleton instance
_validator: Optional[TripleGateValidator] = None


def get_triple_gate_validator() -> TripleGateValidator:
    """Get or create singleton validator."""
    global _validator
    if _validator is None:
        _validator = TripleGateValidator()
    return _validator


# Backwards compatibility - create global validator instance
validator = get_triple_gate_validator()


if __name__ == "__main__":
    # Example usage
    validator = TripleGateValidator()

    # Test valid Python code
    python_code = '''
def hello_world():
    """Print hello world."""
    print("Hello, World!")
    return True
'''
    report = validator.validate(python_code, output_type="python", action="write_file")
    print(f"\nPython code validation: passed={report.passed}, score={report.overall_score:.2f}")

    # Test invalid JSON
    bad_json = '{"key": "value",}'  # Trailing comma
    report = validator.validate(bad_json, output_type="json")
    print(f"Invalid JSON validation: passed={report.passed}, score={report.overall_score:.2f}")

    # Test dangerous content
    dangerous = "rm -rf / --no-preserve-root"
    report = validator.validate(dangerous, action="execute_command")
    print(f"Dangerous content validation: passed={report.passed}, score={report.overall_score:.2f}")

    # Get suggestions
    suggestions = validator.get_improvement_suggestions(report)
    if suggestions:
        print("\nImprovement suggestions:")
        for s in suggestions:
            print(f"  - {s}")
