import unittest
from rwl.hardening.context_hardened import ContextHardenedObject, ContextHardenedKey, generate_key

class TestContextHardenedObject(unittest.TestCase):

    def test_creation_and_verification(self):
        data = "sensitive_data"
        hardened_object = ContextHardenedObject(data)
        self.assertTrue(hardened_object.verify(data))
        self.assertFalse(hardened_object.verify("wrong_data"))

    def test_hashing_with_salt(self):
        data = "sensitive_data"
        salt = "fixed_salt"
        hardened_object = ContextHardenedObject(data, salt=salt)
        expected_hash = "7e322f19055c9995708d87391c19c44156d05e47c172990651540b8644d50006"
        self.assertEqual(hardened_object.get_hashed_representation()[1], expected_hash)
        self.assertTrue(hardened_object.verify(data))

    def test_from_hashed_representation(self):
        data = "sensitive_data"
        hardened_object = ContextHardenedObject(data)
        salt, hashed_data = hardened_object.get_hashed_representation()

        reconstructed_object = ContextHardenedObject.from_hashed_representation(salt, hashed_data)
        self.assertIsNone(reconstructed_object.data) # original data should not be available
        self.assertEqual(reconstructed_object.salt, salt)
        self.assertEqual(reconstructed_object.hashed_data, hashed_data)
        self.assertTrue(reconstructed_object.verify(data))

    def test_different_data_types(self):
        # Test with integer
        hardened_int = ContextHardenedObject(12345)
        self.assertTrue(hardened_int.verify(12345))
        self.assertFalse(hardened_int.verify("123456"))

        # Test with float
        hardened_float = ContextHardenedObject(3.14159)
        self.assertTrue(hardened_float.verify(3.14159))
        self.assertFalse(hardened_float.verify("3.1415"))

        # Test with list
        hardened_list = ContextHardenedObject([1, 2, 3])
        self.assertTrue(hardened_list.verify([1, 2, 3]))
        self.assertFalse(hardened_list.verify([1,2])) #different list

    def test_empty_string(self):
        hardened_object = ContextHardenedObject("")
        self.assertTrue(hardened_object.verify(""))

    def test_unicode_characters(self):
        data = "你好世界"  # Chinese characters
        hardened_object = ContextHardenedObject(data)
        self.assertTrue(hardened_object.verify(data))

    def test_long_string(self):
        long_string = "A" * 2048 # some long string
        hardened_object = ContextHardenedObject(long_string)
        self.assertTrue(hardened_object.verify(long_string))



class TestContextHardenedKey(unittest.TestCase):

    def test_key_generation_and_verification(self):
        key = generate_key()
        hardened_key = ContextHardenedKey(key)
        self.assertTrue(hardened_key.verify(key))
        self.assertFalse(hardened_key.verify("wrong_key"))

    def test_key_generation_without_input(self):
        hardened_key = ContextHardenedKey()
        self.assertIsNotNone(hardened_key.key)
        self.assertTrue(hardened_key.verify(hardened_key.key))

    def test_from_hardened_representation(self):
        key = generate_key()
        hardened_key = ContextHardenedKey(key)
        salt, hashed_key = hardened_key.get_hardened_representation()

        reconstructed_key = ContextHardenedKey.from_hardened_representation(salt, hashed_key)
        self.assertIsNone(reconstructed_key.key) # Original key should be unrecoverable
        hardened_representation = reconstructed_key.get_hardened_representation()
        self.assertEqual(hardened_representation[0], salt)
        self.assertEqual(hardened_representation[1], hashed_key)
        # Original key not recoverable, so verification is not possible in a meaningful way
        # other than checking for consistency of the hardened representation.

    def test_key_reuse(self):
        key1 = generate_key()
        key2 = generate_key()
        self.assertNotEqual(key1, key2)

    def test_multiple_hardened_objects(self):
        key1 = ContextHardenedKey()
        key2 = ContextHardenedKey()
        self.assertNotEqual(key1.get_hardened_representation(), key2.get_hardened_representation())


if __name__ == '__main__':
    unittest.main()