# genesis_patent_sdk/__init__.py
"""Genesis Patent Validation SDK"""

__version__ = "0.1.0"

from .client import PatentValidatorClient
from .models import ValidationRequest, ValidationResponse, AuditEvent, PatentUsageMetrics
from .decorators import validate_output, require_consensus, audit_tracked, rate_limited
from .utils import batch_process, stream_validate

# genesis_patent_sdk/client.py
import asyncio
import functools
import logging
import time
from typing import Any, List, Dict, Callable

import aiohttp
import requests
from pydantic import BaseModel, ValidationError

from .models import ValidationRequest, ValidationResponse, AuditEvent
from .utils import ResultCache

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class PatentValidatorClient:
    """
    Client for interacting with the Genesis Patent Validation System.
    Provides synchronous and asynchronous interfaces.
    """

    def __init__(self, api_url: str, api_key: str, max_retries: int = 3, timeout: int = 10,
                 max_connections: int = 10, enable_cache: bool = True):
        """
        Initializes the PatentValidatorClient.

        Args:
            api_url: The base URL of the Genesis Patent Validation API.
            api_key: The API key for authentication.
            max_retries: The maximum number of retries for API requests.
            timeout: The timeout for API requests in seconds.
            max_connections: Maximum number of concurrent connections.
            enable_cache: Whether to enable result caching.
        """
        self.api_url = api_url
        self.api_key = api_key
        self.max_retries = max_retries
        self.timeout = timeout
        self.session = requests.Session()
        self.session.headers.update({"X-API-Key": self.api_key})
        self.adapter = requests.adapters.HTTPAdapter(pool_connections=max_connections, pool_maxsize=max_connections)
        self.session.mount("http://", self.adapter)
        self.session.mount("https://", self.adapter)
        self.async_session = None  # Initialized on first async call
        self.cache = ResultCache() if enable_cache else None
        self.rate_limit_last_call = 0
        self.rate_limit_period = 1  # Default period of 1 second
        self.rate_limit_max_calls = 10

    async def _get_async_session(self):
        """
        Initializes the aiohttp session if it doesn't exist.
        """
        if self.async_session is None:
            self.async_session = aiohttp.ClientSession(
                headers={"X-API-Key": self.api_key},
                connector=aiohttp.TCPConnector(limit=10) # Limit connection pool
            )
        return self.async_session

    def _handle_request(self, method: str, endpoint: str, data: dict = None) -> dict:
        """
        Handles synchronous API requests with retry logic and timeout.
        """
        url = f"{self.api_url}/{endpoint}"
        for attempt in range(self.max_retries + 1):
            try:
                response = self.session.request(method, url, json=data, timeout=self.timeout)
                response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
                return response.json()
            except requests.exceptions.RequestException as e:
                logger.warning(f"Attempt {attempt + 1} failed: {e}")
                if attempt == self.max_retries:
                    raise  # Re-raise the exception if retries are exhausted
                time.sleep(2 ** attempt)  # Exponential backoff

    async def _handle_request_async(self, method: str, endpoint: str, data: dict = None) -> dict:
        """
        Handles asynchronous API requests with retry logic and timeout.
        """
        url = f"{self.api_url}/{endpoint}"
        async_session = await self._get_async_session()
        for attempt in range(self.max_retries + 1):
            try:
                async with async_session.request(method, url, json=data, timeout=self.timeout) as response:
                    response.raise_for_status()
                    return await response.json()
            except aiohttp.ClientError as e:
                logger.warning(f"Async attempt {attempt + 1} failed: {e}")
                if attempt == self.max_retries:
                    raise  # Re-raise the exception if retries are exhausted
                await asyncio.sleep(2 ** attempt)  # Exponential backoff

    def validate(self, data: ValidationRequest) -> ValidationResponse:
        """
        Performs a full patent validation pipeline.
        """
        cache_key = f"validate:{data.json()}"
        if self.cache and self.cache.get(cache_key):
            return self.cache.get(cache_key)

        response_data = self._handle_request("POST", "validate", data.dict())
        response = ValidationResponse(**response_data)

        if self.cache:
            self.cache.set(cache_key, response)
        return response

    async def validate_async(self, data: ValidationRequest) -> ValidationResponse:
         """
        Performs a full patent validation pipeline asynchronously.
        """
        cache_key = f"validate:{data.json()}"
        if self.cache and self.cache.get(cache_key):
            return self.cache.get(cache_key)

        response_data = await self._handle_request_async("POST", "validate", data.dict())
        response = ValidationResponse(**response_data)

        if self.cache:
            self.cache.set(cache_key, response)
        return response

    def validate_crypto(self, data: dict, key: str) -> dict:
        """
        Performs cryptographic validation.
        """
        return self._handle_request("POST", "validate/crypto", {"data": data, "key": key})

    async def validate_crypto_async(self, data: dict, key: str) -> dict:
        """
        Performs cryptographic validation asynchronously.
        """
        return await self._handle_request_async("POST", "validate/crypto", {"data": data, "key": key})

    def validate_consensus(self, data: dict, models: List[str]) -> dict:
        """
        Performs multi-model consensus validation.
        """
        return self._handle_request("POST", "validate/consensus", {"data": data, "models": models})

    async def validate_consensus_async(self, data: dict, models: List[str]) -> dict:
        """
        Performs multi-model consensus validation asynchronously.
        """
        return await self._handle_request_async("POST", "validate/consensus", {"data": data, "models": models})

    def detect_hallucinations(self, text: str) -> dict:
        """
        Detects hallucinations in the given text.
        """
        return self._handle_request("POST", "detect/hallucinations", {"text": text})

    async def detect_hallucinations_async(self, text: str) -> dict:
        """
        Detects hallucinations in the given text asynchronously.
        """
        return await self._handle_request_async("POST", "detect/hallucinations", {"text": text})

    def assess_risk(self, data: dict) -> dict:
        """
        Assesses the risk associated with the given data.
        """
        return self._handle_request("POST", "assess/risk", data)

    async def assess_risk_async(self, data: dict) -> dict:
        """
        Assesses the risk associated with the given data asynchronously.
        """
        return await self._handle_request_async("POST", "assess/risk", data)

    def close(self):
        """
        Closes the session.
        """
        self.session.close()
        if self.async_session:
            asyncio.run(self.async_session.close())

# genesis_patent_sdk/models.py
from typing import Optional, List

from pydantic import BaseModel, validator

class ValidationRequest(BaseModel):
    """
    Data model for validation requests.
    """
    patent_text: str
    prior_art_search_query: str
    jurisdiction: str = "US"
    model_version: str = "v1"

    @validator("patent_text")
    def patent_text_must_not_be_empty(cls, v):
        if not v:
            raise ValueError("Patent text cannot be empty")
        return v

class ValidationResponse(BaseModel):
    """
    Data model for validation responses.
    """
    is_valid: bool
    confidence_score: float
    explanation: str
    related_patents: Optional[List[str]] = None

class AuditEvent(BaseModel):
    """
    Data model for audit events.
    """
    event_type: str
    timestamp: float
    user_id: str
    data: dict

class PatentUsageMetrics(BaseModel):
    """
    Data model for tracking patent usage metrics.
    """
    patent_id: str
    num_validations: int
    total_validation_time: float

# genesis_patent_sdk/decorators.py
import functools
import logging
import time
from typing import Callable, Any, List

from pydantic import BaseModel, ValidationError

from .models import AuditEvent
from .client import PatentValidatorClient

logger = logging.getLogger(__name__)

def validate_output(model: BaseModel):
    """
    Decorator to automatically validate the output of a function against a Pydantic model.
    """
    def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            result = func(*args, **kwargs)
            try:
                return model(**result) # Assuming result is a dict-like object
            except ValidationError as e:
                logger.error(f"Output validation failed for {func.__name__}: {e}")
                raise
        return wrapper
    return decorator

def require_consensus(num_models: int):
    """
    Decorator to require agreement from a specified number of models.
    """
    def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            models = kwargs.get("models")  # Assuming models are passed as a kwarg
            if models is None or len(models) < num_models:
                raise ValueError(f"Need at least {num_models} models for consensus.")

            results = [func(*args, **kwargs, models=[model]) for model in models]
            # Simple consensus check: all results must be the same (e.g., all True or all False)
            if len(set(r['is_valid'] for r in results)) != 1:
                raise ValueError("Consensus not reached among models.")

            return results[0] # Return the result from the first model (they all agree)
        return wrapper
    return decorator

def audit_tracked(event_type: str):
    """
    Decorator to log function calls to an audit trail.
    """
    def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # Assuming the first argument is 'self' (the client) and has access to user information
            try:
                client: PatentValidatorClient = args[0]
                user_id = "unknown" # Replace with actual user ID retrieval logic if available
                audit_event = AuditEvent(
                    event_type=event_type,
                    timestamp=time.time(),
                    user_id=user_id,
                    data={"args": args[1:], "kwargs": kwargs}
                )
                logger.info(f"Audit event: {audit_event}")
                # In a real system, you'd persist this audit event to a database or log file.
            except Exception as e:
                logger.warning(f"Failed to create audit log: {e}")
            return func(*args, **kwargs)
        return wrapper
    return decorator

def rate_limited(period: int = 1, max_calls: int = 10):
    """
    Decorator to implement client-side rate limiting.
    """
    def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            client: PatentValidatorClient = args[0]
            now = time.time()
            if now - client.rate_limit_last_call < period:
                if client.rate_limit_max_calls <= 0:
                    time.sleep(period - (now - client.rate_limit_last_call))
                else:
                    client.rate_limit_max_calls -= 1
            else:
                client.rate_limit_max_calls = max_calls
            client.rate_limit_last_call = now
            return func(*args, **kwargs)
        return wrapper
    return decorator

# genesis_patent_sdk/utils.py
import asyncio
import logging
from typing import List, Any, Callable

logger = logging.getLogger(__name__)

class ResultCache:
    """
    Simple in-memory result cache.  Consider using a more robust caching solution
    (e.g., Redis) for production environments.
    """
    def __init__(self, max_size: int = 100):
        self.cache = {}
        self.max_size = max_size
        self.lru_queue = []  # List of keys, least recently used at the beginning

    def get(self, key: str) -> Any:
        if key in self.cache:
            # Move to the end of the LRU queue
            self.lru_queue.remove(key)
            self.lru_queue.append(key)
            return self.cache[key]
        return None

    def set(self, key: str, value: Any):
        if key in self.cache:
            self.lru_queue.remove(key) # Update LRU
        elif len(self.cache) >= self.max_size:
            # Evict least recently used
            evicted_key = self.lru_queue.pop(0)
            del self.cache[evicted_key]

        self.cache[key] = value
        self.lru_queue.append(key)

def batch_process(data: List[Any], func: Callable[[Any], Any], batch_size: int = 10) -> List[Any]:
    """
    Processes data in batches using the provided function.
    """
    results = []
    for i in range(0, len(data), batch_size):
        batch = data[i:i + batch_size]
        batch_results = [func(item) for item in batch]
        results.extend(batch_results)
    return results

async def batch_process_async(data: List[Any], func: Callable[[Any], Any], batch_size: int = 10) -> List[Any]:
    """
    Processes data in batches asynchronously using the provided function.
    """
    results = []
    for i in range(0, len(data), batch_size):
        batch = data[i:i + batch_size]
        tasks = [func(item) for item in batch]
        batch_results = await asyncio.gather(*tasks)
        results.extend(batch_results)
    return results

def stream_validate(data: List[str], client, chunk_size: int = 5) -> None:  # Simplified example
    """
    Streams patent text chunks for validation.  Prints results to console.
    """
    for i in range(0, len(data), chunk_size):
        chunk = data[i:i + chunk_size]
        try:
            request = ValidationRequest(patent_text=" ".join(chunk), prior_art_search_query="example query")
            response = client.validate(request)
            print(f"Chunk {i//chunk_size + 1}: Valid = {response.is_valid}, Confidence = {response.confidence_score}")
        except Exception as e:
            logger.error(f"Error validating chunk {i//chunk_size + 1}: {e}")

# Example usage (not part of the package, but shows how to use it)
if __name__ == "__main__":
    # Replace with your actual API URL and key
    api_url = "http://localhost:8000"
    api_key = "your_api_key"

    client = PatentValidatorClient(api_url, api_key)

    # Example usage of validation
    try:
        request_data = ValidationRequest(
            patent_text="This is a sample patent text.",
            prior_art_search_query="Relevant prior art query"
        )

        response = client.validate(request_data)
        print(f"Validation Result: {response}")

        # Example usage of decorators (assuming you have a method called 'complex_operation' on the client)
        @audit_tracked("complex_operation")
        @rate_limited(period=1, max_calls=5)
        @validate_output(ValidationResponse)
        def complex_operation(self, data: ValidationRequest) -> dict:
            """Simulates a complex operation that needs validation, auditing, and rate limiting."""
            # Simulate API call
            result = self._handle_request("POST", "validate", data.dict())
            return result

        client.complex_operation = complex_operation.__get__(client) # Bind the method

        try:
            complex_result = client.complex_operation(request_data)
            print(f"Complex Operation Result: {complex_result}")

        except Exception as e:
            print(f"Error during complex operation: {e}")

    except Exception as e:
        print(f"Error: {e}")

    finally:
        client.close()
