import hashlib
import hmac
import json
import uuid
from datetime import datetime
import os  # For random key generation
import logging
import time

# Import necessary libraries for advanced crypto
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding, ed25519
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from cryptography.exceptions import InvalidSignature

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class CryptographicValidationAdvanced:
    """
    Advanced cryptographic validation skill with multiple signature algorithms,
    key rotation, Merkle tree support, HSM integration patterns, and audit logging.
    """

    def __init__(self, default_key=None, default_algorithm='hmac-sha256', hsm=None):
        """
        Initializes the CryptographicValidationAdvanced skill.

        Args:
            default_key (str, optional): A default key to use for signing. If None, a random key is generated.
            default_algorithm (str, optional): The default signature algorithm.  Options: 'hmac-sha256', 'ed25519', 'rsa'.
            hsm (object, optional): An object representing a Hardware Security Module (HSM).  If None, keys are stored in memory.
        """
        self.default_algorithm = default_algorithm.lower()  # Ensure lowercase for consistency
        self.proof_chain = []  # Initialize an empty proof chain
        self.key_store = {} # Dictionary to store keys with versions
        self.current_key_version = 1  # Start with key version 1
        self.hsm = hsm # HSM integration

        if default_key is None:
            self.rotate_key()  # Generate an initial random key and store it
        else:
            self.add_key(default_key)  # Add the provided key and store it

        self.audit_log = [] # Initialize an audit log

    def _log_audit(self, event_type, description, data=None):
        """Logs an audit event."""
        timestamp = datetime.utcnow().isoformat()
        log_entry = {
            'timestamp': timestamp,
            'event_type': event_type,
            'description': description,
            'data': data
        }
        self.audit_log.append(log_entry)
        logging.info(f"Audit Log: {log_entry}")

    def generate_signature(self, data, key_version=None, algorithm=None):
        """Generates a signature for the given data using the specified algorithm."""

        if key_version is None:
            key_version = self.current_key_version

        if algorithm is None:
            algorithm = self.default_algorithm

        key = self.get_key(key_version)

        if not key:
            raise ValueError(f"Key version {key_version} not found.")

        algorithm = algorithm.lower() # Ensure lowercase for consistency

        if algorithm == 'hmac-sha256':
            return self._generate_hmac_signature(data, key)
        elif algorithm == 'ed25519':
            return self._generate_ed25519_signature(data, key)
        elif algorithm == 'rsa':
            return self._generate_rsa_signature(data, key)
        else:
            raise ValueError(f"Unsupported signature algorithm: {algorithm}")

    def _generate_hmac_signature(self, data, key):
        """Generates an HMAC-SHA256 signature."""
        key_bytes = key.encode('utf-8')
        data_bytes = data.encode('utf-8')
        hmac_obj = hmac.new(key_bytes, data_bytes, hashlib.sha256)
        signature = hmac_obj.hexdigest()
        return signature

    def _generate_ed25519_signature(self, data, key):
        """Generates an Ed25519 signature."""
        try:
            private_key = serialization.load_pem_private_key(
                key.encode('utf-8'),
                password=None,
                backend=default_backend()
            )
            signature = private_key.sign(data.encode('utf-8'))
            return signature.hex()  # Return as hex string
        except Exception as e:
            logging.error(f"Error generating Ed25519 signature: {e}")
            raise

    def _generate_rsa_signature(self, data, key):
        """Generates an RSA signature."""
        try:
            private_key = serialization.load_pem_private_key(
                key.encode('utf-8'),
                password=None,
                backend=default_backend()
            )
            signature = private_key.sign(
                data.encode('utf-8'),
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA256()
            )
            return signature.hex() # Return as hex string
        except Exception as e:
            logging.error(f"Error generating RSA signature: {e}")
            raise

    def verify_signature(self, data, signature, key_version=None, algorithm=None):
        """Verifies the signature of the given data."""

        if key_version is None:
            key_version = self.current_key_version

        if algorithm is None:
            algorithm = self.default_algorithm

        key = self.get_key(key_version)

        if not key:
            raise ValueError(f"Key version {key_version} not found.")

        algorithm = algorithm.lower()

        if algorithm == 'hmac-sha256':
            return self._verify_hmac_signature(data, signature, key)
        elif algorithm == 'ed25519':
            return self._verify_ed25519_signature(data, signature, key)
        elif algorithm == 'rsa':
            return self._verify_rsa_signature(data, signature, key)
        else:
            raise ValueError(f"Unsupported signature algorithm: {algorithm}")

    def _verify_hmac_signature(self, data, signature, key):
        """Verifies an HMAC-SHA256 signature."""
        expected_signature = self._generate_hmac_signature(data, key)
        return hmac.compare_digest(signature, expected_signature)

    def _verify_ed25519_signature(self, data, signature, key):
        """Verifies an Ed25519 signature."""
        try:
            public_key = serialization.load_pem_public_key(
                key.encode('utf-8'),
                backend=default_backend()
            )
            public_key.verify(bytes.fromhex(signature), data.encode('utf-8'))
            return True
        except InvalidSignature:
            return False
        except Exception as e:
            logging.error(f"Error verifying Ed25519 signature: {e}")
            return False

    def _verify_rsa_signature(self, data, signature, key):
        """Verifies an RSA signature."""
        try:
            public_key = serialization.load_pem_public_key(
                key.encode('utf-8'),
                backend=default_backend()
            )
            public_key.verify(
                bytes.fromhex(signature),
                data.encode('utf-8'),
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA256()
            )
            return True
        except InvalidSignature:
            return False
        except Exception as e:
            logging.error(f"Error verifying RSA signature: {e}")
            return False


    def add_key(self, key):
        """Adds a key to the key store with the next available version."""
        self.current_key_version += 1
        self.key_store[self.current_key_version] = key
        self._log_audit("key_add", f"Key added with version {self.current_key_version}")
        return self.current_key_version

    def rotate_key(self, algorithm=None):
        """Rotates the key by generating a new key and incrementing the version."""
        if algorithm is None:
            algorithm = self.default_algorithm
        algorithm = algorithm.lower()

        if algorithm == 'hmac-sha256':
            new_key = str(uuid.uuid4())  # Generate a random key
        elif algorithm == 'ed25519':
            private_key = ed25519.Ed25519PrivateKey.generate()
            new_key = private_key.private_bytes(
                encoding=serialization.Encoding.PEM,
                format=serialization.PrivateFormat.PKCS8,
                encryption_algorithm=serialization.NoEncryption()
            ).decode('utf-8') # Store as PEM string
        elif algorithm == 'rsa':
            private_key = rsa.generate_private_key(
                public_exponent=65537,
                key_size=2048,
                backend=default_backend()
            )
            new_key = private_key.private_bytes(
                encoding=serialization.Encoding.PEM,
                format=serialization.PrivateFormat.PKCS8,
                encryption_algorithm=serialization.NoEncryption()
            ).decode('utf-8') # Store as PEM string
        else:
            raise ValueError(f"Unsupported algorithm for key rotation: {algorithm}")

        self.current_key_version += 1
        self.key_store[self.current_key_version] = new_key
        self._log_audit("key_rotation", f"Key rotated to version {self.current_key_version}", {"algorithm": algorithm})
        return self.current_key_version

    def get_key(self, version):
        """Retrieves a key from the key store by version."""
        return self.key_store.get(version)

    def create_proof_chain(self, data, key_version=None, algorithm=None, previous_hash=None):
        """Creates a cryptographic proof chain entry."""
        if key_version is None:
            key_version = self.current_key_version

        if algorithm is None:
            algorithm = self.default_algorithm

        timestamp = datetime.utcnow().isoformat()
        signature = self.generate_signature(data, key_version, algorithm)

        data_to_hash = json.dumps({
            'data': data,
            'timestamp': timestamp,
            'signature': signature,
            'key_version': key_version,
            'algorithm': algorithm,
            'previous_hash': previous_hash
        }, sort_keys=True).encode('utf-8')

        current_hash = hashlib.sha256(data_to_hash).hexdigest()

        proof_entry = {
            'data': data,
            'timestamp': timestamp,
            'signature': signature,
            'hash': current_hash,
            'key_version': key_version,
            'algorithm': algorithm,
            'previous_hash': previous_hash
        }

        self.proof_chain.append(proof_entry)
        self._log_audit("proof_chain_create", "Proof chain entry created", proof_entry)
        return proof_entry

    def sign_output(self, output_text, key_version=None, algorithm=None):
        """Signs the AI output and creates a proof chain entry."""

        if not self.proof_chain:
            previous_hash = None
        else:
            previous_hash = self.proof_chain[-1]['hash']

        proof_entry = self.create_proof_chain(output_text, key_version, algorithm, previous_hash)

        signed_output = {
            'output_text': output_text,
            'proof_entry': proof_entry
        }
        self._log_audit("output_sign", "AI output signed", signed_output)
        return signed_output

    def verify_output_integrity(self, signed_output):
        """Verifies the integrity of the signed output."""
        output_text = signed_output['output_text']
        proof_entry = signed_output['proof_entry']
        key_version = proof_entry['key_version']
        algorithm = proof_entry['algorithm']

        # Verify the signature
        is_signature_valid = self.verify_signature(output_text, proof_entry['signature'], key_version, algorithm)
        if not is_signature_valid:
            logging.warning("Signature verification failed.")
            return False

        # Recompute the hash and compare
        data_to_hash = json.dumps({
            'data': proof_entry['data'],
            'timestamp': proof_entry['timestamp'],
            'signature': proof_entry['signature'],
            'key_version': proof_entry['key_version'],
            'algorithm': proof_entry['algorithm'],
            'previous_hash': proof_entry['previous_hash']
        }, sort_keys=True).encode('utf-8')

        recomputed_hash = hashlib.sha256(data_to_hash).hexdigest()

        if proof_entry['hash'] != recomputed_hash:
            logging.warning("Hash verification failed.")
            return False

        self._log_audit("output_verify", "Output integrity verified", signed_output)
        return True

    def detect_tampering(self, proof_chain):
        """Detects tampering in the proof chain by verifying the hash chain."""
        if not proof_chain:
            logging.warning("Empty proof chain.")
            return False

        previous_hash = None
        for entry in proof_chain:
            # Recompute the hash and compare
            data_to_hash = json.dumps({
                'data': entry['data'],
                'timestamp': entry['timestamp'],
                'signature': entry['signature'],
                'key_version': entry['key_version'],
                'algorithm': entry['algorithm'],
                'previous_hash': previous_hash
            }, sort_keys=True).encode('utf-8')

            recomputed_hash = hashlib.sha256(data_to_hash).hexdigest()

            if entry['hash'] != recomputed_hash:
                logging.error(f"Tampering detected in entry with timestamp: {entry['timestamp']}")
                self._log_audit("tampering_detect", f"Tampering detected in entry with timestamp: {entry['timestamp']}", entry)
                return True

            previous_hash = entry['hash']

        self._log_audit("tampering_detect", "No tampering detected")
        return False

    def get_confidence(self, algorithm=None):
        """Returns a confidence score based on the cryptographic strength."""
        if algorithm is None:
            algorithm = self.default_algorithm
        algorithm = algorithm.lower()

        if algorithm == 'hmac-sha256':
            return 0.95
        elif algorithm == 'ed25519':
            return 0.98
        elif algorithm == 'rsa':
            return 0.97  # Adjust based on key size
        else:
            return 0.0  # Unknown algorithm

    def get_proof_chain(self):
        """Returns the current proof chain."""
        return self.proof_chain

    def get_audit_log(self):
        """Returns the audit log."""
        return self.audit_log

    def generate_merkle_root(self, data_list):
        """Generates a Merkle root hash from a list of data."""
        hashes = [hashlib.sha256(d.encode('utf-8')).hexdigest() for d in data_list]
        return self._merkle_root(hashes)

    def _merkle_root(self, hashes):
        """Recursively calculates the Merkle root."""
        if not hashes:
            return None  # Or raise an exception

        if len(hashes) == 1:
            return hashes[0]

        if len(hashes) % 2 != 0:
            hashes.append(hashes[-1])  # Duplicate last hash if odd number

        new_hashes = []
        for i in range(0, len(hashes), 2):
            combined = hashes[i] + hashes[i + 1]
            new_hashes.append(hashlib.sha256(combined.encode('utf-8')).hexdigest())

        return self._merkle_root(new_hashes)

    def verify_merkle_proof(self, data, proof, merkle_root):
        """Verifies a Merkle proof for a given data element."""
        leaf_hash = hashlib.sha256(data.encode('utf-8')).hexdigest()
        computed_hash = leaf_hash

        for step in proof:
            if 'left' in step:
                computed_hash = hashlib.sha256((step['left'] + computed_hash).encode('utf-8')).hexdigest()
            elif 'right' in step:
                computed_hash = hashlib.sha256((computed_hash + step['right']).encode('utf-8')).hexdigest()
            else:
                raise ValueError("Invalid Merkle proof step")

        return computed_hash == merkle_root

    def generate_merkle_proof(self, data_list, target_data):
        """Generates a Merkle proof for a specific data element in a list."""
        hashes = [hashlib.sha256(d.encode('utf-8')).hexdigest() for d in data_list]
        target_hash = hashlib.sha256(target_data.encode('utf-8')).hexdigest()

        if target_hash not in hashes:
            return None  # Data not in the list

        proof = []
        def _build_proof(hashes, target_hash, proof):
            if len(hashes) == 1:
                return

            if len(hashes) % 2 != 0:
                hashes.append(hashes[-1])

            new_hashes = []
            for i in range(0, len(hashes), 2):
                left = hashes[i]
                right = hashes[i+1]
                combined = left + right
                new_hash = hashlib.sha256(combined.encode('utf-8')).hexdigest()
                new_hashes.append(new_hash)

                if left == target_hash:
                    proof.append({'right': right})
                    _build_proof(new_hashes, new_hash, proof)
                    return
                elif right == target_hash:
                    proof.append({'left': left})
                    _build_proof(new_hashes, new_hash, proof)
                    return

        _build_proof(hashes, target_hash, proof)
        return proof



    def zero_knowledge_proof_example(self, secret, knowledge):
        """
        Illustrates the concept of zero-knowledge proofs (simplified example).
        This is a conceptual example and not a secure implementation.
        """
        # In a real ZKP, the prover would demonstrate knowledge of the secret
        # without revealing the secret itself.

        # This is a placeholder.  A real ZKP would involve complex mathematical
        # protocols.

        # Example:  Proving you know the solution to a puzzle without revealing the solution.

        print("Zero-Knowledge Proof Concept:")
        print(f"  Secret:  [Secret Hidden]") # Hides the secret
        print(f"  Knowledge:  [Proof of Knowledge - Not Revealed]") # Hides the proof

        # The core idea is to convince a verifier that you possess knowledge
        # without disclosing the knowledge itself.  This involves cryptographic
        # commitments, challenges, and responses.

        return True  # Placeholder - always returns True

    def integrate_hsm(self, hsm_client):
        """
        Integrates with a Hardware Security Module (HSM).

        Args:
            hsm_client: An object representing the HSM client.  This object
                        would need to provide methods for key generation, signing,
                        and verification using the HSM.
        """
        self.hsm = hsm_client
        # Example HSM interaction (replace with actual HSM calls):
        # self.hsm.generate_key("my_key")
        # signature = self.hsm.sign("data", "my_key")
        logging.info("HSM integration initiated.")
        self._log_audit("hsm_integration", "HSM integration initiated")

    def run_performance_benchmarks(self, num_iterations=100):
        """Runs performance benchmarks for signature generation and verification."""
        data = "This is a test string for performance benchmarking."

        print(f"Running performance benchmarks with {num_iterations} iterations...")

        # HMAC-SHA256
        start_time = time.time()
        for _ in range(num_iterations):
            self.generate_signature(data, algorithm='hmac-sha256')
        hmac_gen_time = (time.time() - start_time) / num_iterations

        start_time = time.time()
        for _ in range(num_iterations):
            signature = self.generate_signature(data, algorithm='hmac-sha256')
            self.verify_signature(data, signature, algorithm='hmac-sha256')
        hmac_ver_time = (time.time() - start_time) / num_iterations

        # Ed25519
        start_time = time.time()
        for _ in range(num_iterations):
            self.generate_signature(data, algorithm='ed25519')
        ed25519_gen_time = (time.time() - start_time) / num_iterations

        start_time = time.time()
        for _ in range(num_iterations):
            signature = self.generate_signature(data, algorithm='ed25519')
            self.verify_signature(data, signature, algorithm='ed25519')
        ed25519_ver_time = (time.time() - start_time) / num_iterations

        # RSA
        start_time = time.time()
        for _ in range(num_iterations):
            self.generate_signature(data, algorithm='rsa')
        rsa_gen_time = (time.time() - start_time) / num_iterations

        start_time = time.time()
        for _ in range(num_iterations):
            signature = self.generate_signature(data, algorithm='rsa')
            self.verify_signature(data, signature, algorithm='rsa')
        rsa_ver_time = (time.time() - start_time) / num_iterations

        print("Performance Benchmarks (average time per operation in seconds):")
        print(f"  HMAC-SHA256: Generation: {hmac_gen_time:.6f}, Verification: {hmac_ver_time:.6f}")
        print(f"  Ed25519: Generation: {ed25519_gen_time:.6f}, Verification: {ed25519_ver_time:.6f}")
        print(f"  RSA: Generation: {rsa_gen_time:.6f}, Verification: {rsa_ver_time:.6f}")

# Example Usage (for testing):
if __name__ == '__main__':

    # Create an instance of the CryptographicValidationAdvanced skill
    crypto_validator = CryptographicValidationAdvanced(default_algorithm='hmac-sha256')

    # Example AI output
    ai_output = "The capital of France is Paris."

    # Sign the output
    signed_output = crypto_validator.sign_output(ai_output)
    print("Signed Output (HMAC):", signed_output)

    # Verify the output integrity
    is_valid = crypto_validator.verify_output_integrity(signed_output)
    print("Output Integrity Valid (HMAC):", is_valid)

    # Rotate the key
    new_key_version = crypto_validator.rotate_key()
    print(f"Key rotated to version: {new_key_version}")

    # Example of creating another output to chain together.
    ai_output2 = "The sky is blue."
    signed_output2 = crypto_validator.sign_output(ai_output2)
    print("Signed Output 2 (HMAC, new key):", signed_output2)

    # Verify the output integrity of the second output.
    is_valid2 = crypto_validator.verify_output_integrity(signed_output2)
    print("Output Integrity Valid 2 (HMAC, new key):", is_valid2)

    # Get the complete proof chain
    proof_chain = crypto_validator.get_proof_chain()
    print("Proof Chain:", proof_chain)

    # Detect tampering (example of a tampered chain)
    tampered_chain = proof_chain[:]
    if tampered_chain:
        tampered_chain[0]['data'] = "Tampered data"
        tampered = crypto_validator.detect_tampering(tampered_chain)
        print("Tampering Detected (Tampered Chain):", tampered)

    # Detect tampering (clean chain)
    tampered = crypto_validator.detect_tampering(proof_chain)
    print("Tampering Detected (Original Chain):", tampered)

    # Example using Ed25519
    crypto_validator_ed25519 = CryptographicValidationAdvanced(default_algorithm='ed25519')
    ai_output_ed25519 = "This is a message signed with Ed25519."
    signed_output_ed25519 = crypto_validator_ed25519.sign_output(ai_output_ed25519)
    print("Signed Output (Ed25519):", signed_output_ed25519)
    is_valid_ed25519 = crypto_validator_ed25519.verify_output_integrity(signed_output_ed25519)
    print("Output Integrity Valid (Ed25519):", is_valid_ed25519)

    # Example using RSA
    crypto_validator_rsa = CryptographicValidationAdvanced(default_algorithm='rsa')
    ai_output_rsa = "This is a message signed with RSA."
    signed_output_rsa = crypto_validator_rsa.sign_output(ai_output_rsa)
    print("Signed Output (RSA):", signed_output_rsa)
    is_valid_rsa = crypto_validator_rsa.verify_output_integrity(signed_output_rsa)
    print("Output Integrity Valid (RSA):", is_valid_rsa)

    # Example Merkle Tree
    data_list = ["data1", "data2", "data3", "data4", "data5"]
    merkle_root = crypto_validator.generate_merkle_root(data_list)
    print("Merkle Root:", merkle_root)

    # Generate and verify merkle proof
    target_data = "data3"
    merkle_proof = crypto_validator.generate_merkle_proof(data_list, target_data)
    print("Merkle Proof:", merkle_proof)
    is_proof_valid = crypto_validator.verify_merkle_proof(target_data, merkle_proof, merkle_root)
    print("Merkle Proof Valid:", is_proof_valid)

    # Zero-knowledge proof example
    crypto_validator.zero_knowledge_proof_example("my_secret", "proof_of_secret")

    # Run performance benchmarks
    crypto_validator.run_performance_benchmarks()

    # Access the audit log
    audit_log = crypto_validator.get_audit_log()
    print("Audit Log:", audit_log)
