"""
Genesis V2 MCP Validation Module
=================================
Pydantic models and validators for MCP tool inputs.
H-11, H-19, H-20: Input validation, sanitization, URL validation.
"""

import re
from typing import Optional, Any
from urllib.parse import urlparse
from collections import defaultdict
from time import time
from threading import Lock

# Allowed URL schemes for browser navigation
ALLOWED_URL_SCHEMES = frozenset({'http', 'https', 'file'})

# URL pattern for validation
URL_PATTERN = r'^https?://(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|localhost|\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})(?::\d+)?(?:/?|[/?]\S+)$'

# Dangerous patterns to reject in CSS selectors
DANGEROUS_SELECTOR_PATTERNS = (
    'javascript:',
    'data:',
    '<script',
    '</script',
    'onerror',
    'onload',
    'onclick',
    r'eval\(',
)


def validate_url(url: str):
    """
    Validate URL for browser navigation.

    Args:
        url: URL to validate

    Returns:
        Tuple of (is_valid, error_message)
    """
    if not url or not isinstance(url, str):
        return (False, 'URL must be a non-empty string')

    url = url.strip()

    if len(url) > 2048:
        return (False, 'URL too long (max 2048 chars)')

    try:
        parsed = urlparse(url)
        if parsed.scheme not in ALLOWED_URL_SCHEMES:
            return (False, f'Invalid URL scheme: {parsed.scheme}. Allowed: {ALLOWED_URL_SCHEMES}')
        if not parsed.netloc:
            return (False, 'URL must have a valid host')
        return (True, None)
    except Exception as e:
        return (False, f'URL parsing failed: {str(e)}')


def sanitize_selector(selector: str):
    """
    Sanitize CSS selector for browser click operations.

    Args:
        selector: CSS selector to sanitize

    Returns:
        Tuple of (sanitized_selector, warning_message)
    """
    if not selector or not isinstance(selector, str):
        return ('', 'Selector must be a non-empty string')

    selector = selector.strip()

    if len(selector) > 500:
        return ('', 'Selector too long (max 500 chars)')

    selector_lower = selector.lower()
    for pattern in DANGEROUS_SELECTOR_PATTERNS:
        if re.search(pattern, selector_lower, re.IGNORECASE):
            return ('', f'Dangerous pattern detected in selector: {pattern}')

    return (selector, None)


def validate_kg_search_input(query: str, limit: int = 10):
    """Validate kg_search tool inputs."""
    if not query or not isinstance(query, str):
        return (False, 'Query must be a non-empty string')

    if len(query) > 1000:
        return (False, 'Query too long (max 1000 chars)')

    if not isinstance(limit, int) or limit < 1 or limit > 100:
        return (False, 'Limit must be an integer between 1 and 100')

    return (True, None)


def validate_entity_json(entity_json: str):
    """
    Validate entity JSON for kg_ingest.

    Returns:
        Tuple of (is_valid, error_message, parsed_dict)
    """
    import json

    if not entity_json or not isinstance(entity_json, str):
        return (False, 'Entity JSON must be a non-empty string', None)

    if len(entity_json) > 50000:
        return (False, 'Entity JSON too large (max 50KB)', None)

    try:
        data = json.loads(entity_json)
    except json.JSONDecodeError as e:
        return (False, f'Invalid JSON: {e}', None)

    if not isinstance(data, dict):
        return (False, 'Entity must be a JSON object', None)

    if 'id' not in data:
        return (False, "Entity must have 'id' field", None)

    if 'type' not in data:
        return (False, "Entity must have 'type' field", None)

    return (True, None, data)


class RateLimiter:
    """Simple in-memory rate limiter for MCP tools."""

    def __init__(self, max_calls: int = 100, window_seconds: int = 60):
        self.max_calls = max_calls
        self.window_seconds = window_seconds
        self.calls = defaultdict(list)
        self._lock = Lock()

    def is_allowed(self, key: str = 'default'):
        """
        Check if a call is allowed under rate limit.

        Returns:
            Tuple of (is_allowed, remaining_calls)
        """
        now = time()
        window_start = now - self.window_seconds

        with self._lock:
            # Prune old calls outside the window
            self.calls[key] = [t for t in self.calls[key] if t > window_start]

            if len(self.calls[key]) >= self.max_calls:
                return (False, 0)

            self.calls[key].append(now)
            remaining = self.max_calls - len(self.calls[key])
            return (True, remaining)


# Global rate limiter instance
_rate_limiter = RateLimiter()


def check_rate_limit(tool_name: str):
    """Check rate limit for a tool call."""
    allowed, remaining = _rate_limiter.is_allowed(tool_name)
    if not allowed:
        return (False, f'Rate limit exceeded for {tool_name}. Try again later.')
    return (True, None)
