#!/usr/bin/env python3
"""
Genesis Validation Gate 1 - Input Validation
==============================================
Validates input format, structure, and content safety.
Foundation layer enforcing P1 (Crypto), P4 (Audit), P8 (Privacy).

PM-042: Validation Gate 1 - Input Enhancement
- Validates input format and content
- Rejects malformed input
- Logs validation results
"""

import re
import json
import hashlib
import logging
from typing import Dict, Any, List, Tuple, Optional
from datetime import datetime
from dataclasses import dataclass, asdict

logger = logging.getLogger("ValidationGate1")
logging.basicConfig(level=logging.INFO)


@dataclass
class ValidationResult:
    """Result of validation check."""
    valid: bool
    score: float
    checks: Dict[str, bool]
    errors: List[str]
    warnings: List[str]
    timestamp: str
    input_hash: str

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


class InputValidationGate:
    """
    Gate 1: Input Validation Layer
    Validates input format, structure, and basic content safety.
    Enforces P1 (Cryptographic integrity), P4 (Audit trail), P8 (Privacy).
    """

    # PII patterns for detection
    PII_PATTERNS = {
        "ssn": r'\b\d{3}-\d{2}-\d{4}\b',
        "credit_card": r'\b(?:\d{4}[-\s]?){3}\d{4}\b',
        "email": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
        "phone_us": r'\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b',
        "phone_intl": r'\+\d{1,3}[-.\s]?\d{6,14}\b',
        "ip_address": r'\b(?:\d{1,3}\.){3}\d{1,3}\b',
        "api_key": r'\b(?:sk|pk|api)[_-]?[a-zA-Z0-9]{20,}\b'
    }

    # Required fields for common input types
    REQUIRED_FIELDS = {
        "task": ["task_id", "description"],
        "lead": ["email", "phone"],
        "contact": ["first_name", "email"],
        "message": ["content"],
        "webhook": ["event_type", "payload"]
    }

    # Maximum input sizes
    MAX_SIZES = {
        "string": 50000,
        "array": 1000,
        "object_depth": 10,
        "total_bytes": 1_000_000
    }

    def __init__(self, memory_stores=None, strict_mode: bool = False):
        """
        Initialize input validation gate.

        Args:
            memory_stores: Optional memory backend reference
            strict_mode: If True, fails on warnings too
        """
        self.memory = memory_stores
        self.strict_mode = strict_mode
        self.validation_log: List[ValidationResult] = []

        logger.info("Validation Gate 1 (Input) initialized")

    def validate(self, input_data: Any, input_type: str = "generic",
                 metadata: Dict[str, Any] = None) -> Tuple[float, Dict[str, bool]]:
        """
        Validate input data.

        Args:
            input_data: The input to validate
            input_type: Type of input (task, lead, contact, message, webhook)
            metadata: Additional validation context

        Returns:
            Tuple of (score, checks_dict)
        """
        metadata = metadata or {}
        errors = []
        warnings = []
        checks = {}

        # Generate input hash for audit trail
        input_str = json.dumps(input_data, default=str) if isinstance(input_data, (dict, list)) else str(input_data)
        input_hash = hashlib.sha256(input_str.encode()).hexdigest()

        # P1: Format validation
        checks["format_valid"] = self._validate_format(input_data, input_type, errors)

        # P1: Structure validation
        checks["structure_valid"] = self._validate_structure(input_data, input_type, errors, warnings)

        # P1: Size validation
        checks["size_valid"] = self._validate_size(input_data, errors)

        # P8: PII detection
        pii_result, pii_types = self._detect_pii(input_str)
        checks["no_pii_exposure"] = not pii_result
        if pii_result:
            warnings.append(f"PII detected: {', '.join(pii_types)}")

        # P4: Audit readiness
        checks["audit_ready"] = "timestamp" in metadata or "request_id" in metadata

        # P1: Tamper check (if hash provided)
        if "input_hash" in metadata:
            checks["tamper_proof"] = input_hash == metadata["input_hash"]
        else:
            checks["tamper_proof"] = True  # No previous hash to compare

        # P1: Signature validation (placeholder for Ed25519)
        checks["signature_valid"] = self._validate_signature(input_data, metadata)

        # P8: Injection prevention
        checks["no_injection"] = self._check_injection(input_str, errors)

        # P4: Required fields
        checks["required_fields_present"] = self._check_required_fields(
            input_data, input_type, errors
        )

        # Calculate score
        score = sum(checks.values()) / len(checks)

        # Create result
        result = ValidationResult(
            valid=score >= 0.8 and len(errors) == 0,
            score=score,
            checks=checks,
            errors=errors,
            warnings=warnings,
            timestamp=datetime.utcnow().isoformat(),
            input_hash=input_hash
        )

        # Log result
        self._log_validation(result, input_type)

        return score, checks

    def _validate_format(self, input_data: Any, input_type: str,
                         errors: List[str]) -> bool:
        """Validate input format matches expected type."""
        if input_type in ["task", "lead", "contact", "webhook"]:
            if not isinstance(input_data, dict):
                errors.append(f"Expected dict for {input_type}, got {type(input_data).__name__}")
                return False
        elif input_type == "message":
            if not isinstance(input_data, (str, dict)):
                errors.append(f"Expected str or dict for message, got {type(input_data).__name__}")
                return False
        return True

    def _validate_structure(self, input_data: Any, input_type: str,
                            errors: List[str], warnings: List[str]) -> bool:
        """Validate input structure and content types."""
        if not isinstance(input_data, dict):
            return True  # Non-dict validated elsewhere

        # Check for null/empty values
        empty_keys = [k for k, v in input_data.items() if v is None or v == ""]
        if empty_keys:
            warnings.append(f"Empty values for: {', '.join(empty_keys)}")

        # Check nested depth
        depth = self._get_object_depth(input_data)
        if depth > self.MAX_SIZES["object_depth"]:
            errors.append(f"Object depth {depth} exceeds max {self.MAX_SIZES['object_depth']}")
            return False

        return True

    def _validate_size(self, input_data: Any, errors: List[str]) -> bool:
        """Validate input size limits."""
        input_str = json.dumps(input_data, default=str) if isinstance(input_data, (dict, list)) else str(input_data)

        if len(input_str) > self.MAX_SIZES["total_bytes"]:
            errors.append(f"Input size {len(input_str)} exceeds max {self.MAX_SIZES['total_bytes']}")
            return False

        if isinstance(input_data, str) and len(input_data) > self.MAX_SIZES["string"]:
            errors.append(f"String length {len(input_data)} exceeds max {self.MAX_SIZES['string']}")
            return False

        if isinstance(input_data, list) and len(input_data) > self.MAX_SIZES["array"]:
            errors.append(f"Array length {len(input_data)} exceeds max {self.MAX_SIZES['array']}")
            return False

        return True

    def _detect_pii(self, text: str) -> Tuple[bool, List[str]]:
        """
        Detect PII patterns in text.

        Returns:
            Tuple of (pii_detected, list_of_pii_types)
        """
        detected = []
        for pii_type, pattern in self.PII_PATTERNS.items():
            if re.search(pattern, text, re.IGNORECASE):
                detected.append(pii_type)
        return len(detected) > 0, detected

    def _validate_signature(self, input_data: Any,
                            metadata: Dict[str, Any]) -> bool:
        """Validate cryptographic signature if present."""
        if "signature" not in metadata:
            return True  # No signature to validate

        # Placeholder for Ed25519 signature verification
        # In production, implement proper signature verification
        signature = metadata.get("signature", "")
        return len(signature) > 0

    def _check_injection(self, text: str, errors: List[str]) -> bool:
        """Check for common injection patterns."""
        injection_patterns = [
            # SQL injection
            r"(?:(?:;|\')(?:--|#|/\*)|\b(?:SELECT|INSERT|UPDATE|DELETE|DROP|UNION)\b)",
            # XSS
            r"<script[^>]*>|javascript:|on\w+\s*=",
            # Command injection
            r"(?:;|\||`|\$\(|\$\{)(?:[^;|`]*(?:rm|cat|wget|curl|bash|sh))",
            # Path traversal
            r"\.\.(?:/|\\)",
        ]

        for pattern in injection_patterns:
            if re.search(pattern, text, re.IGNORECASE):
                errors.append("Potential injection pattern detected")
                return False

        return True

    def _check_required_fields(self, input_data: Any, input_type: str,
                               errors: List[str]) -> bool:
        """Check if required fields are present."""
        if not isinstance(input_data, dict):
            return True

        required = self.REQUIRED_FIELDS.get(input_type, [])
        missing = [f for f in required if f not in input_data or not input_data[f]]

        if missing:
            errors.append(f"Missing required fields: {', '.join(missing)}")
            return False

        return True

    def _get_object_depth(self, obj: Any, current_depth: int = 0) -> int:
        """Calculate maximum nesting depth of object."""
        if isinstance(obj, dict):
            if not obj:
                return current_depth
            return max(self._get_object_depth(v, current_depth + 1) for v in obj.values())
        elif isinstance(obj, list):
            if not obj:
                return current_depth
            return max(self._get_object_depth(item, current_depth + 1) for item in obj)
        return current_depth

    def _log_validation(self, result: ValidationResult, input_type: str) -> None:
        """Log validation result for audit trail."""
        self.validation_log.append(result)

        log_level = logging.WARNING if not result.valid else logging.INFO
        logger.log(
            log_level,
            f"Gate1 Validation [{input_type}]: valid={result.valid}, score={result.score:.2f}, "
            f"errors={len(result.errors)}, warnings={len(result.warnings)}"
        )

    def get_validation_history(self, limit: int = 100) -> List[Dict[str, Any]]:
        """Get recent validation results."""
        return [r.to_dict() for r in self.validation_log[-limit:]]

    def get_metrics(self) -> Dict[str, Any]:
        """Get validation gate metrics."""
        if not self.validation_log:
            return {"total_validations": 0}

        valid_count = sum(1 for r in self.validation_log if r.valid)
        total_count = len(self.validation_log)

        return {
            "total_validations": total_count,
            "valid_count": valid_count,
            "invalid_count": total_count - valid_count,
            "pass_rate": valid_count / total_count if total_count > 0 else 0,
            "avg_score": sum(r.score for r in self.validation_log) / total_count
        }


# Backward compatibility alias
FoundationGate = InputValidationGate


if __name__ == "__main__":
    # Self-test
    gate = InputValidationGate()

    print("\n=== Validation Gate 1 (Input) Test ===")

    # Test valid input
    valid_input = {
        "task_id": "task_123",
        "description": "Process lead data",
        "priority": "high"
    }
    score, checks = gate.validate(valid_input, "task", {"timestamp": datetime.utcnow().isoformat()})
    print(f"Valid input: score={score:.2f}, checks={checks}")

    # Test with PII
    pii_input = {
        "message": "Contact John at john@example.com or 123-456-7890"
    }
    score, checks = gate.validate(pii_input, "message")
    print(f"PII input: score={score:.2f}, checks={checks}")

    # Test malformed input
    malformed = "not a dict"
    score, checks = gate.validate(malformed, "task")
    print(f"Malformed input: score={score:.2f}, checks={checks}")

    # Test with injection attempt
    injection_input = {
        "query": "SELECT * FROM users; DROP TABLE users;--"
    }
    score, checks = gate.validate(injection_input, "message")
    print(f"Injection input: score={score:.2f}, checks={checks}")

    # Metrics
    print(f"\nGate Metrics: {gate.get_metrics()}")
