import unittest
import threading
import time
from /mnt/e/genesis-system/core/rwl/hardening/concurrency_safe import ThreadSafeCounter, ThreadSafeList, ThreadSafeDict

class TestThreadSafeCounter(unittest.TestCase):

    def test_increment(self):
        counter = ThreadSafeCounter()
        num_threads = 10
        increments_per_thread = 1000

        def increment_task():
            for _ in range(increments_per_thread):
                counter.increment()

        threads = []
        for _ in range(num_threads):
            thread = threading.Thread(target=increment_task)
            threads.append(thread)
            thread.start()

        for thread in threads:
            thread.join()

        self.assertEqual(counter.value(), num_threads * increments_per_thread)

    def test_decrement(self):
        counter = ThreadSafeCounter()
        counter._value = 10000 # Initialize to avoid negative values during decrement test
        num_threads = 10
        decrements_per_thread = 1000

        def decrement_task():
            for _ in range(decrements_per_thread):
                counter.decrement()

        threads = []
        for _ in range(num_threads):
            thread = threading.Thread(target=decrement_task)
            threads.append(thread)
            thread.start()

        for thread in threads:
            thread.join()

        self.assertEqual(counter.value(), 10000 - (num_threads * decrements_per_thread))

    def test_reset(self):
        counter = ThreadSafeCounter()
        counter.increment(100)
        counter.reset()
        self.assertEqual(counter.value(), 0)

class TestThreadSafeList(unittest.TestCase):

    def test_append(self):
        safe_list = ThreadSafeList()
        num_threads = 10
        appends_per_thread = 100

        def append_task():
            for i in range(appends_per_thread):
                safe_list.append(i)

        threads = []
        for _ in range(num_threads):
            thread = threading.Thread(target=append_task)
            threads.append(thread)
            thread.start()

        for thread in threads:
            thread.join()

        self.assertEqual(len(safe_list), num_threads * appends_per_thread)

    def test_remove(self):
        safe_list = ThreadSafeList()
        initial_list = list(range(100))
        safe_list._list = initial_list  # Directly populate the list for testing. Avoid append since it is tested above.
        num_threads = 5
        removes_per_thread = 20

        def remove_task():
            for i in range(removes_per_thread):
                safe_list.remove(i)

        threads = []
        for _ in range(num_threads):
            thread = threading.Thread(target=remove_task)
            threads.append(thread)
            thread.start()

        for thread in threads:
            thread.join()

        expected_length = len(initial_list) - (num_threads * removes_per_thread) if (num_threads * removes_per_thread) < len(initial_list) else 0
        self.assertEqual(len(safe_list), max(0, expected_length))


    def test_clear(self):
        safe_list = ThreadSafeList()
        safe_list.append(1)
        safe_list.append(2)
        safe_list.clear()
        self.assertEqual(len(safe_list), 0)

class TestThreadSafeDict(unittest.TestCase):

    def test_set_and_get(self):
        safe_dict = ThreadSafeDict()
        num_threads = 10
        sets_per_thread = 100
        key_prefix = "key_"

        def set_task():
            for i in range(sets_per_thread):
                safe_dict.set(key_prefix + str(i), i)

        threads = []
        for _ in range(num_threads):
            thread = threading.Thread(target=set_task)
            threads.append(thread)
            thread.start()

        for thread in threads:
            thread.join()

        expected_size = sets_per_thread
        self.assertEqual(len(safe_dict), expected_size)
        for i in range(sets_per_thread):
            self.assertEqual(safe_dict.get(key_prefix + str(i)), i)

    def test_delete(self):
        safe_dict = ThreadSafeDict()
        safe_dict.set("key1", "value1")
        safe_dict.set("key2", "value2")
        safe_dict.delete("key1")
        self.assertIsNone(safe_dict.get("key1"))
        self.assertEqual(safe_dict.get("key2"), "value2")
        self.assertEqual(len(safe_dict), 1)

    def test_clear(self):
        safe_dict = ThreadSafeDict()
        safe_dict.set("key1", "value1")
        safe_dict.set("key2", "value2")
        safe_dict.clear()
        self.assertEqual(len(safe_dict), 0)

if __name__ == '__main__':
    unittest.main()