# enhanced_cryptographic_validation.py
import hashlib
import hmac
import json
import os
import threading
from typing import Any, Dict, List, Tuple, Union

import cryptography
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from cryptography.hazmat.backends import default_backend

import base64

# Optional: For Argon2 key derivation (install: pip install argon2-cffi)
# from argon2 import PasswordHasher

# Optional: HSM integration (replace with your HSM library)
# from hsm_library import HSMClient

# Optional: Multi-party signature libraries (install: threshold_crypto)
# from threshold_crypto import schemes

# Install: pip install merklelib
from merklelib import MerkleTree

# Install: pip install hypothesis
import hypothesis
from hypothesis import given, strategies as st

# Install: pip install python-fuzz
# import atheris

import timeit
import asyncio

# --- Constants ---
AES_KEY_SIZE = 32  # 256 bits
AES_NONCE_SIZE = 12  # GCM requires 12-byte nonce
PBKDF2_SALT_SIZE = 16
PBKDF2_ITERATIONS = 100000
KEY_VERSION = 1  # Initial key version


class CryptoUtils:
    """Utility class for cryptographic operations."""

    @staticmethod
    def generate_salt(length: int = PBKDF2_SALT_SIZE) -> bytes:
        """Generates a random salt."""
        return os.urandom(length)

    @staticmethod
    def derive_key_pbkdf2(password: str, salt: bytes, iterations: int = PBKDF2_ITERATIONS, key_size: int = AES_KEY_SIZE) -> bytes:
        """Derives a key using PBKDF2."""
        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=key_size,
            salt=salt,
            iterations=iterations,
            backend=default_backend()
        )
        return kdf.derive(password.encode('utf-8'))

    # Argon2 key derivation (alternative to PBKDF2, requires argon2-cffi)
    # @staticmethod
    # def derive_key_argon2(password: str) -> bytes:
    #     ph = PasswordHasher()
    #     hashed_password = ph.hash(password)
    #     return hashed_password.encode('utf-8')  # Convert to bytes if needed

    @staticmethod
    def encrypt_aes_gcm(data: bytes, key: bytes) -> Tuple[bytes, bytes]:
        """Encrypts data using AES-256-GCM."""
        nonce = os.urandom(AES_NONCE_SIZE)
        cipher = Cipher(algorithms.AES(key), modes.GCM(nonce), backend=default_backend())
        encryptor = cipher.encryptor()
        ciphertext = encryptor.update(data) + encryptor.finalize()
        return ciphertext, nonce, encryptor.tag

    @staticmethod
    def decrypt_aes_gcm(ciphertext: bytes, key: bytes, nonce: bytes, tag: bytes) -> bytes:
        """Decrypts data using AES-256-GCM."""
        cipher = Cipher(algorithms.AES(key), modes.GCM(nonce, tag), backend=default_backend())
        decryptor = cipher.decryptor()
        return decryptor.update(ciphertext) + decryptor.finalize()

    @staticmethod
    def generate_ed25519_key_pair() -> Tuple[ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey]:
        """Generates an Ed25519 key pair."""
        private_key = ed25519.Ed25519PrivateKey.generate()
        public_key = private_key.public_key()
        return private_key, public_key

    @staticmethod
    def sign_ed25519(data: bytes, private_key: ed25519.Ed25519PrivateKey) -> bytes:
        """Signs data using Ed25519."""
        return private_key.sign(data)

    @staticmethod
    def verify_ed25519(data: bytes, signature: bytes, public_key: ed25519.Ed25519PublicKey) -> bool:
        """Verifies an Ed25519 signature."""
        try:
            public_key.verify(signature, data)
            return True
        except cryptography.exceptions.InvalidSignature:
            return False

class AIDecision:
    """
    Represents a single AI decision with its input, output, HMAC, signature, and optional encryption.
    """

    def __init__(self, input_data: Any, output_data: Any, hmac_key: bytes, signing_key: ed25519.Ed25519PrivateKey, encrypt_sensitive_data: bool = False, encryption_key: bytes = None):
        """
        Initializes an AIDecision object.
        """
        self.input_data = input_data
        self.output_data = output_data
        self.hmac_key = hmac_key
        self.signing_key = signing_key
        self.encrypt_sensitive_data = encrypt_sensitive_data
        self.encryption_key = encryption_key  # AES key if encrypt_sensitive_data is True
        self.hmac = self._generate_hmac()
        self.signature = self._generate_signature()
        self.encrypted_output = None
        self.nonce = None
        self.tag = None

        if self.encrypt_sensitive_data:
            self._encrypt_output()

    def _encrypt_output(self):
        """Encrypts the output data using AES-256-GCM."""
        if self.encryption_key is None:
            raise ValueError("Encryption key is required for encrypting sensitive data.")

        output_bytes = json.dumps(self.output_data).encode('utf-8')
        self.encrypted_output, self.nonce, self.tag = CryptoUtils.encrypt_aes_gcm(output_bytes, self.encryption_key)
        self.output_data = None  # Clear the plaintext output

    def _decrypt_output(self):
        """Decrypts the output data using AES-256-GCM."""
        if self.encryption_key is None or self.encrypted_output is None or self.nonce is None or self.tag is None:
            raise ValueError("Encryption key, encrypted output, nonce, and tag are required for decryption.")

        decrypted_bytes = CryptoUtils.decrypt_aes_gcm(self.encrypted_output, self.encryption_key, self.nonce, self.tag)
        self.output_data = json.loads(decrypted_bytes.decode('utf-8')) # Restore plaintext output
        self.encrypted_output = None
        self.nonce = None
        self.tag = None



    def _generate_hmac(self) -> str:
        """Generates an HMAC-SHA256 hash for the decision."""
        message = json.dumps({"input": self.input_data, "output": self.output_data if not self.encrypt_sensitive_data else "ENCRYPTED"}).encode("utf-8")
        hmac_obj = hmac.new(self.hmac_key, message, hashlib.sha256)
        return hmac_obj.hexdigest()

    def _generate_signature(self) -> bytes:
        """Generates an Ed25519 signature for the decision."""
        message = json.dumps({"input": self.input_data, "output": self.output_data if not self.encrypt_sensitive_data else "ENCRYPTED", "hmac": self.hmac}).encode("utf-8")
        return CryptoUtils.sign_ed25519(message, self.signing_key)

    def verify_hmac(self) -> bool:
        """Verifies the HMAC-SHA256 hash of the decision."""
        return self.hmac == self._generate_hmac()

    def verify_signature(self, public_key: ed25519.Ed25519PublicKey) -> bool:
        """Verifies the Ed25519 signature of the decision."""
        message = json.dumps({"input": self.input_data, "output": self.output_data if not self.encrypt_sensitive_data else "ENCRYPTED", "hmac": self.hmac}).encode("utf-8")
        return CryptoUtils.verify_ed25519(message, self.signature, public_key)

    def to_dict(self) -> Dict[str, Any]:
        """Returns a dictionary representation of the decision."""
        data = {
            "input": self.input_data,
            "hmac": self.hmac,
            "signature": base64.b64encode(self.signature).decode('utf-8'),
            "key_version": KEY_VERSION
        }
        if self.encrypt_sensitive_data:
             data["encrypted_output"] = base64.b64encode(self.encrypted_output).decode('utf-8') if self.encrypted_output else None
             data["nonce"] = base64.b64encode(self.nonce).decode('utf-8') if self.nonce else None
             data["tag"] = base64.b64encode(self.tag).decode('utf-8') if self.tag else None
        else:
            data["output"] = self.output_data

        return data

    @classmethod
    def from_dict(cls, data: Dict[str, Any], hmac_key: bytes, public_key: ed25519.Ed25519PublicKey, decrypt_key: bytes = None) -> "AIDecision":
        """Creates an AIDecision object from a dictionary."""
        # Create a dummy instance for verification purposes
        temp_decision = cls(data["input"], data.get("output", None), hmac_key, None)
        temp_decision.hmac = data["hmac"]
        temp_decision.signature = base64.b64decode(data["signature"].encode('utf-8'))

        if not temp_decision.verify_hmac():
            raise ValueError("HMAC verification failed during deserialization.")

        if not temp_decision.verify_signature(public_key):
            raise ValueError("Signature verification failed during deserialization.")

        # Now create the actual instance, decrypting if necessary
        decision = cls(data["input"], data.get("output", None), hmac_key, None)
        decision.hmac = data["hmac"]
        decision.signature = base64.b64decode(data["signature"].encode('utf-8'))

        if "encrypted_output" in data and data["encrypted_output"]:
            decision.encrypted_output = base64.b64decode(data["encrypted_output"].encode('utf-8'))
            decision.nonce = base64.b64decode(data["nonce"].encode('utf-8'))
            decision.tag = base64.b64decode(data["tag"].encode('utf-8'))
            decision.encryption_key = decrypt_key
            decision._decrypt_output()
            decision.encrypt_sensitive_data = True
        else:
            decision.output_data = data.get("output", None)


        return decision


class AIDecisionChain:
    """Represents a chain of AI decisions with cryptographic integrity."""

    def __init__(self, hmac_key: bytes, signing_key: ed25519.Ed25519PrivateKey, encrypt_sensitive_data: bool = False, encryption_key: bytes = None):
        """Initializes an AIDecisionChain object."""
        self.chain: List[AIDecision] = []
        self.hmac_key = hmac_key
        self.signing_key = signing_key
        self.encrypt_sensitive_data = encrypt_sensitive_data
        self.encryption_key = encryption_key
        self.previous_hmac: Union[str, None] = None

    def add_decision(self, input_data: Any, output_data: Any) -> None:
        """Adds a new decision to the chain."""
        decision = AIDecision(input_data, output_data, self.hmac_key, self.signing_key, self.encrypt_sensitive_data, self.encryption_key)

        # Include the previous HMAC in the current decision's input
        if self.previous_hmac:
            decision.input_data = {"previous_hmac": self.previous_hmac, "input": input_data}

        decision.hmac = decision._generate_hmac()  # Recalculate HMAC after modification
        decision.signature = decision._generate_signature() # Recalculate signature after modification

        self.chain.append(decision)
        self.previous_hmac = decision.hmac

    def verify_chain(self, public_key: ed25519.Ed25519PublicKey) -> bool:
        """Verifies the integrity of the entire decision chain."""
        if not self.chain:
            return True

        previous_hmac = None
        for i, decision in enumerate(self.chain):
            if not decision.verify_hmac():
                print(f"HMAC verification failed for decision at index {i}")
                return False

            if not decision.verify_signature(public_key):
                print(f"Signature verification failed for decision at index {i}")
                return False

            if i > 0:
                # Verify that the previous HMAC matches what's stored in the current decision's input
                stored_previous_hmac = self.chain[i].input_data.get("previous_hmac")
                if stored_previous_hmac != previous_hmac:
                    print(f"Previous HMAC mismatch at index {i}")
                    return False

            previous_hmac = decision.hmac

        return True

    def to_list(self) -> List[Dict[str, Any]]:
        """Returns a list of dictionaries representing the decision chain."""
        return [decision.to_dict() for decision in self.chain]

    @classmethod
    def from_list(cls, data: List[Dict[str, Any]], hmac_key: bytes, public_key: ed25519.Ed25519PublicKey, decrypt_key: bytes = None) -> "AIDecisionChain":
        """Creates an AIDecisionChain object from a list of dictionaries."""
        chain = cls(hmac_key, None)  # Dummy signing key
        chain.chain = [AIDecision.from_dict(d, hmac_key, public_key, decrypt_key) for d in data]

        # Reconstruct the previous_hmac for chain verification
        if chain.chain:
            chain.previous_hmac = chain.chain[-1].hmac

        return chain

    def detect_tampering(self, public_key: ed25519.Ed25519PublicKey) -> bool:
        """Detects if the chain has been tampered with."""
        return not self.verify_chain(public_key)

    def generate_merkle_tree(self) -> MerkleTree:
        """Generates a Merkle tree from the HMACs of the decisions in the chain."""
        leaves = [decision.hmac.encode('utf-8') for decision in self.chain]
        tree = MerkleTree(leaves, hashlib.sha256)
        return tree

    def verify_decision_in_merkle_tree(self, decision_index: int, tree: MerkleTree) -> bool:
        """Verifies that a specific decision's HMAC is in the Merkle tree."""
        if not (0 <= decision_index < len(self.chain)):
            return False

        decision = self.chain[decision_index]
        try:
            tree.verify_leaf(decision.hmac.encode('utf-8'), decision_index)
            return True
        except Exception: # Catching the merklelib exceptions is difficult
            return False

    async def async_verify_chain(self, public_key: ed25519.Ed25519PublicKey) -> bool:
        """Asynchronously verifies the integrity of the entire decision chain."""
        if not self.chain:
            return True

        previous_hmac = None
        for i, decision in enumerate(self.chain):
            # Run verification tasks concurrently
            hmac_valid, signature_valid = await asyncio.gather(
                asyncio.to_thread(decision.verify_hmac),
                asyncio.to_thread(decision.verify_signature, public_key)
            )

            if not hmac_valid:
                print(f"HMAC verification failed for decision at index {i}")
                return False

            if not signature_valid:
                print(f"Signature verification failed for decision at index {i}")
                return False

            if i > 0:
                # Verify that the previous HMAC matches what's stored in the current decision's input
                stored_previous_hmac = self.chain[i].input_data.get("previous_hmac")
                if stored_previous_hmac != previous_hmac:
                    print(f"Previous HMAC mismatch at index {i}")
                    return False

            previous_hmac = decision.hmac

        return True


# --- Unit Tests ---
import unittest

class TestAIDecisionChain(unittest.TestCase):
    def setUp(self):
        self.hmac_key = os.urandom(32)
        self.private_key, self.public_key = CryptoUtils.generate_ed25519_key_pair()
        self.chain = AIDecisionChain(self.hmac_key, self.private_key)
        self.encryption_key = os.urandom(AES_KEY_SIZE)
        self.encrypted_chain = AIDecisionChain(self.hmac_key, self.private_key, encrypt_sensitive_data=True, encryption_key=self.encryption_key)

    def test_add_and_verify_decision(self):
        self.chain.add_decision("input1", "output1")
        self.assertTrue(self.chain.verify_chain(self.public_key))

    def test_add_multiple_decisions(self):
        self.chain.add_decision("input1", "output1")
        self.chain.add_decision("input2", "output2")
        self.chain.add_decision("input3", "output3")
        self.assertTrue(self.chain.verify_chain(self.public_key))

    def test_tampering_detection(self):
        self.chain.add_decision("input1", "output1")
        self.chain.add_decision("input2", "output2")
        chain_data = self.chain.to_list()
        # Tamper with the second decision
        chain_data[1]["output"] = "tampered_output"
        tampered_chain = AIDecisionChain.from_list(chain_data, self.hmac_key, self.public_key)
        self.assertTrue(tampered_chain.detect_tampering(self.public_key))

    def test_serialization_and_deserialization(self):
        self.chain.add_decision("input1", "output1")
        self.chain.add_decision("input2", "output2")
        chain_data = self.chain.to_list()
        new_chain = AIDecisionChain.from_list(chain_data, self.hmac_key, self.public_key)
        self.assertTrue(new_chain.verify_chain(self.public_key))
        self.assertEqual(len(self.chain.chain), len(new_chain.chain))
        self.assertEqual(self.chain.chain[0].input_data, new_chain.chain[0].input_data)
        self.assertEqual(self.chain.chain[1].output_data, new_chain.chain[1].output_data)

    def test_empty_chain(self):
        self.assertTrue(self.chain.verify_chain(self.public_key))  # An empty chain should be valid
        self.assertFalse(self.chain.detect_tampering(self.public_key))  # An empty chain cannot be tampered with

    def test_complex_input_output(self):
        input_data = {"feature1": 1.0, "feature2": "value"}
        output_data = [1, 2, 3]
        self.chain.add_decision(input_data, output_data)
        self.assertTrue(self.chain.verify_chain(self.public_key))

    def test_encryption_decryption(self):
        self.encrypted_chain.add_decision("input1", {"sensitive_data": "secret"})
        self.assertTrue(self.encrypted_chain.verify_chain(self.public_key))

        chain_data = self.encrypted_chain.to_list()
        decrypted_chain = AIDecisionChain.from_list(chain_data, self.hmac_key, self.public_key, self.encryption_key)
        self.assertTrue(decrypted_chain.verify_chain(self.public_key))
        self.assertEqual(decrypted_chain.chain[0].output_data, {"sensitive_data": "secret"})

    def test_merkle_tree_verification(self):
        self.chain.add_decision("input1", "output1")
        self.chain.add_decision("input2", "output2")
        tree = self.chain.generate_merkle_tree()
        self.assertTrue(self.chain.verify_decision_in_merkle_tree(0, tree))
        self.assertTrue(self.chain.verify_decision_in_merkle_tree(1, tree))
        self.assertFalse(self.chain.verify_decision_in_merkle_tree(2, tree))  # Index out of range

    def test_async_verification(self):
        self.chain.add_decision("input1", "output1")
        self.chain.add_decision("input2", "output2")
        loop = asyncio.get_event_loop()
        result = loop.run_until_complete(self.chain.async_verify_chain(self.public_key))
        self.assertTrue(result)


# --- Property-Based Testing (Hypothesis) ---
@hypothesis.settings(deadline=None)  # Increase deadline if needed
class TestAIDecisionChainHypothesis(unittest.TestCase):

    def setUp(self):
        self.hmac_key = os.urandom(32)
        self.private_key, self.public_key = CryptoUtils.generate_ed25519_key_pair()
        self.chain = AIDecisionChain(self.hmac_key, self.private_key)

    @given(input_data=st.dictionaries(st.text(), st.text()), output_data=st.text())
    def test_add_and_verify_decision_property(self, input_data, output_data):
        self.chain.add_decision(input_data, output_data)
        self.assertTrue(self.chain.verify_chain(self.public_key))

    @given(data=st.lists(st.dictionaries(st.text(), st.text()), min_size=1, max_size=5))
    def test_tampering_detection_property(self, data):
        for i in range(len(data)):
            self.chain.add_decision(f"input{i}", f"output{i}")

        chain_data = self.chain.to_list()
        if chain_data:  # Ensure there's something to tamper with
            index_to_tamper = hypothesis.strategies.integers(min_value=0, max_value=len(chain_data) - 1).example()
            chain_data[index_to_tamper]["output"] = "tampered_output"
            tampered_chain = AIDecisionChain.from_list(chain_data, self.hmac_key, self.public_key)
            self.assertTrue(tampered_chain.detect_tampering(self.public_key))

# --- Fuzzing (Atheris) ---
# def test_one_input(data):
#     """Example fuzzing function (replace with more meaningful tests)."""
#     try:
#         chain_data = json.loads(data.decode('utf-8'))
#         tampered_chain = AIDecisionChain.from_list(chain_data, b"fuzz_key", self.public_key)
#         tampered_chain.detect_tampering()
#     except Exception as e:
#         # Expect some exceptions during fuzzing
#         pass

# def main():
#     atheris.Setup(sys.argv + ["-atheris_runs=10000"], test_one_input)
#     atheris.Fuzz()

# if __name__ == "__main__":
#     main()


# --- Performance Benchmarking ---
def benchmark_verification(chain: AIDecisionChain, public_key: ed25519.Ed25519PublicKey, num_runs: int = 100):
    """Benchmarks the chain verification process."""
    time = timeit.timeit(lambda: chain.verify_chain(public_key), number=num_runs)
    print(f"Verification time ({num_runs} runs): {time:.4f} seconds")
    print(f"Average verification time per chain: {time/num_runs:.6f} seconds")


# --- Usage Example ---
if __name__ == "__main__":
    # Generate secure random keys
    hmac_key = os.urandom(32)
    private_key, public_key = CryptoUtils.generate_ed25519_key_pair()
    encryption_key = os.urandom(AES_KEY_SIZE)

    # Create an AI decision chain
    decision_chain = AIDecisionChain(hmac_key, private_key)
    encrypted_decision_chain = AIDecisionChain(hmac_key, private_key, encrypt_sensitive_data=True, encryption_key=encryption_key)

    # Add some decisions to the chain
    decision_chain.add_decision("What is 2 + 2?", "2 + 2 = 4")
    decision_chain.add_decision("Translate 'hello' to French", "Bonjour")
    decision_chain.add_decision({"user_id": 123, "query": "Recommend a movie"}, "Recommend movie 'Inception'")
    encrypted_decision_chain.add_decision("Secret question", {"sensitive_info": "Top Secret Answer!"})

    # Verify the integrity of the chain
    if decision_chain.verify_chain(public_key):
        print("AI decision chain is valid.")
    else:
        print("AI decision chain has been tampered with!")

    if encrypted_decision_chain.verify_chain(public_key):
        print("Encrypted AI decision chain is valid.")
    else:
        print("Encrypted AI decision chain has been tampered with!")

    # Serialize the chain to a list of dictionaries
    chain_data = decision_chain.to_list()
    print("Serialized chain data:", chain_data)

    encrypted_chain_data = encrypted_decision_chain.to_list()
    print("Serialized encrypted chain data:", encrypted_chain_data)

    # Deserialize the chain from the list of dictionaries
    loaded_chain = AIDecisionChain.from_list(chain_data, hmac_key, public_key)
    loaded_encrypted_chain = AIDecisionChain.from_list(encrypted_chain_data, hmac_key, public_key, encryption_key)

    # Verify the integrity of the loaded chain
    if loaded_chain.verify_chain(public_key):
        print("Loaded AI decision chain is valid.")
    else:
        print("Loaded AI decision chain has been tampered with!")

    if loaded_encrypted_chain.verify_chain(public_key):
        print("Loaded encrypted AI decision chain is valid.")
    else:
        print("Loaded encrypted AI decision chain has been tampered with!")
    print(f"Decrypted data: {loaded_encrypted_chain.chain[0].output_data}")


    # Tamper with the chain data (example)
    chain_data[1]["output"] = "Tampered Output"
    tampered_chain = AIDecisionChain.from_list(chain_data, hmac_key, public_key)

    # Detect tampering
    if tampered_chain.detect_tampering(public_key):
        print("Tampering detected in the AI decision chain!")
    else:
        print("No tampering detected.")

    # Generate and verify Merkle Tree
    merkle_tree = decision_chain.generate_merkle_tree()
    print("Merkle Root:", merkle_tree.root)
    if decision_chain.verify_decision_in_merkle_tree(0, merkle_tree):
        print("Decision 0 is valid in the Merkle Tree")
    else:
        print("Decision 0 is invalid in the Merkle Tree")

    # Benchmark verification
    benchmark_verification(decision_chain, public_key)

    # Run the unit tests
    unittest.main(argv=['first-arg-is-ignored'], exit=False)  # Prevent SystemExit
