#!/usr/bin/env python3
"""Unit tests for cryptographic operations within the Patent OS."""

import os
import sys
import logging
import unittest
from typing import List, Tuple, Optional

# Add the project root to the Python path for module discovery
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

try:
    from patent_os.crypto import ed25519_key_pair, sign_message, verify_signature, rotate_key, create_merkle_tree, verify_merkle_path, ProofChain, InvalidProofChainError
except ImportError as e:
    print(f"Error importing modules: {e}")
    print("Please ensure that the 'patent_os' package is installed and available in your Python environment.")
    print("You may need to run 'pip install -e .' from the project's root directory if you are running tests locally.")
    sys.exit(1)


# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')


class TestCrypto(unittest.TestCase):
    """Test suite for cryptographic operations."""

    def setUp(self):
        """Set up test environment before each test."""
        self.message = b"This is a test message."
        self.public_key, self.private_key = ed25519_key_pair()
        self.data_blocks = [b"data1", b"data2", b"data3", b"data4"]
        self.merkle_tree = create_merkle_tree(self.data_blocks)
        self.rotated_public_key, self.rotated_private_key = ed25519_key_pair()


    def test_ed25519_signing_verification(self):
        """Test Ed25519 signing and verification."""
        signature = sign_message(self.message, self.private_key)
        self.assertTrue(verify_signature(self.message, signature, self.public_key))
        self.assertFalse(verify_signature(b"tampered message", signature, self.public_key))
        self.assertFalse(verify_signature(self.message, b"invalid signature", self.public_key))  # Invalid signature


    def test_key_rotation(self):
        """Test key rotation functionality."""
        rotated_data = rotate_key(self.message, self.private_key, self.rotated_public_key)
        self.assertNotEqual(rotated_data, self.message) # Make sure the data changed

        # Simulate the 'reverse' rotate operation to verify the key rotation process.
        unrotated_data = rotate_key(rotated_data, self.rotated_private_key, self.public_key)
        self.assertEqual(unrotated_data, self.message)


    def test_merkle_tree_creation(self):
        """Test Merkle tree creation."""
        self.assertIsNotNone(self.merkle_tree)
        self.assertIsInstance(self.merkle_tree, list)


    def test_merkle_path_verification(self):
        """Test Merkle path verification."""
        for i, block in enumerate(self.data_blocks):
            path = self.merkle_tree[1][i]
            self.assertTrue(verify_merkle_path(block, path, self.merkle_tree[0]))


    def test_proof_chain_creation_and_verification(self):
        """Test ProofChain creation and verification."""
        chain = ProofChain()
        original_data = b"genesis block"
        chain.add_data(original_data, self.private_key)
        data1 = b"data block 1"
        chain.add_data(data1, self.private_key)
        data2 = b"data block 2"
        chain.add_data(data2, self.private_key)

        self.assertTrue(chain.verify_chain())
        self.assertEqual(chain.get_data(-1), data2)
        self.assertEqual(chain.get_data(0), original_data)


    def test_proof_chain_tampering(self):
        """Test tampering with a ProofChain."""
        chain = ProofChain()
        original_data = b"genesis block"
        chain.add_data(original_data, self.private_key)
        data1 = b"data block 1"
        chain.add_data(data1, self.private_key)

        # Tamper with a data block
        chain.chain[1]['data'] = b"tampered data"
        self.assertFalse(chain.verify_chain())

        # Tamper with a signature
        chain = ProofChain() # Reset
        chain.add_data(original_data, self.private_key)
        chain.add_data(data1, self.private_key)
        chain.chain[1]['signature'] = b"invalid signature"

        with self.assertRaises(InvalidProofChainError): # Expect an error upon verification
            chain.verify_chain()


    def test_proof_chain_empty(self):
        """Test an empty ProofChain."""
        chain = ProofChain()
        self.assertTrue(chain.verify_chain()) # An empty chain should be considered valid.
        self.assertIsNone(chain.get_data(-1)) # No data exists in an empty chain


    def test_proof_chain_out_of_bounds_access(self):
        """Test accessing ProofChain data out of bounds."""
        chain = ProofChain()
        chain.add_data(b"some data", self.private_key)
        with self.assertRaises(IndexError):
            chain.get_data(1) # Index out of range


    def test_rotate_key_invalid_key(self):
        """Test rotate_key with invalid keys."""
        with self.assertRaises(ValueError):
            rotate_key(self.message, b"invalid private key", self.rotated_public_key)
        with self.assertRaises(ValueError):
            rotate_key(self.message, self.private_key, b"invalid public key")


    def test_sign_message_invalid_key(self):
        """Test sign_message with an invalid key."""
        with self.assertRaises(ValueError):
            sign_message(self.message, b"invalid private key")


    def test_verify_signature_invalid_key(self):
        """Test verify_signature with an invalid key."""
        signature = sign_message(self.message, self.private_key)
        with self.assertRaises(ValueError):
            verify_signature(self.message, signature, b"invalid public key")

    def test_merkle_tree_empty_data(self):
        """Test Merkle tree creation with empty data."""
        empty_tree = create_merkle_tree([])
        self.assertEqual(len(empty_tree), 0) # Verify tree structure or handle appropriately


if __name__ == '__main__':
    unittest.main()