import unittest
from unittest.mock import MagicMock, patch
import sys
import os

# Add the project root to the Python path to ensure imports work correctly
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))

from core.rwl import orchestrator
from core.rwl import task

class OrchestratorUnitTests(unittest.TestCase):

    @patch('core.rwl.orchestrator.Task')
    def test_add_task(self, MockTask):
        """Test adding a task to the orchestrator."""
        orc = orchestrator.Orchestrator()
        mock_task_instance = MockTask.return_value
        task_name = "test_task"
        orc.add_task(task_name, mock_task_instance)
        self.assertIn(task_name, orc.tasks)
        self.assertEqual(orc.tasks[task_name], mock_task_instance)

    def test_execute_empty_graph(self):
        """Test execution with an empty dependency graph."""
        orc = orchestrator.Orchestrator()
        result = orc.execute()
        self.assertEqual(result, {})

    @patch('core.rwl.orchestrator.Task')
    def test_execute_single_task(self, MockTask):
        """Test execution with a single task and no dependencies."""
        orc = orchestrator.Orchestrator()
        mock_task_instance = MockTask.return_value
        mock_task_instance.execute.return_value = "Task executed successfully"
        task_name = "task1"
        orc.add_task(task_name, mock_task_instance)
        orc.graph[task_name] = []  # No dependencies
        result = orc.execute()
        self.assertEqual(result[task_name], "Task executed successfully")
        mock_task_instance.execute.assert_called_once()

    @patch('core.rwl.orchestrator.Task')
    def test_execute_task_with_dependency(self, MockTask):
        """Test execution with a task that has a dependency."""
        orc = orchestrator.Orchestrator()

        mock_task1 = MockTask.return_value
        mock_task1.execute.return_value = "Task 1 executed"
        task1_name = "task1"
        orc.add_task(task1_name, mock_task1)

        mock_task2 = MagicMock()  # Use MagicMock for more flexibility
        mock_task2.execute.return_value = "Task 2 executed"
        task2_name = "task2"
        orc.add_task(task2_name, mock_task2)

        orc.graph[task2_name] = [task1_name]
        orc.graph[task1_name] = []

        result = orc.execute()
        self.assertEqual(result[task1_name], "Task 1 executed")
        self.assertEqual(result[task2_name], "Task 2 executed")

        # Ensure task1 is executed before task2
        self.assertTrue(mock_task1.execute.called)
        self.assertTrue(mock_task2.execute.called)

    @patch('core.rwl.orchestrator.Task')
    def test_execute_task_failure(self, MockTask):
        """Test handling a task that raises an exception during execution."""
        orc = orchestrator.Orchestrator()
        mock_task_instance = MockTask.return_value
        mock_task_instance.execute.side_effect = Exception("Task failed")
        task_name = "task1"
        orc.add_task(task_name, mock_task_instance)
        orc.graph[task_name] = []

        with self.assertRaises(Exception) as context:
            orc.execute()
        self.assertEqual(str(context.exception), "Task failed")

    @patch('core.rwl.orchestrator.Task')
    def test_add_dependency(self, MockTask):
         """Test adding dependencies between tasks."""
         orc = orchestrator.Orchestrator()
         mock_task1 = MockTask.return_value
         mock_task2 = MockTask.return_value

         task1_name = "task1"
         task2_name = "task2"

         orc.add_task(task1_name, mock_task1)
         orc.add_task(task2_name, mock_task2)

         orc.add_dependency(task2_name, task1_name)
         self.assertIn(task1_name, orc.graph[task2_name])

    def test_invalid_dependency(self):
        orc = orchestrator.Orchestrator()
        with self.assertRaises(ValueError) as context:
            orc.add_dependency("task2", "task1")  # task1 and task2 not yet added
        self.assertEqual(str(context.exception), "task1 is not a registered task.")

    @patch('core.rwl.orchestrator.Task')
    def test_execute_with_circular_dependency(self, MockTask):
        orc = orchestrator.Orchestrator()

        mock_task1 = MockTask.return_value
        mock_task1.execute.return_value = "Task 1 executed"
        task1_name = "task1"
        orc.add_task(task1_name, mock_task1)

        mock_task2 = MagicMock()
        mock_task2.execute.return_value = "Task 2 executed"
        task2_name = "task2"
        orc.add_task(task2_name, mock_task2)

        orc.graph[task1_name] = [task2_name]  # task1 depends on task2
        orc.graph[task2_name] = [task1_name]  # task2 depends on task1

        with self.assertRaises(ValueError) as context:
           orc.execute()
        self.assertEqual(str(context.exception), "Circular dependency detected.")

    def test_reset(self):
        """Test the reset method to clear task results."""
        orc = orchestrator.Orchestrator()
        orc.results["task1"] = "some result"
        orc.reset()
        self.assertEqual(orc.results, {})


if __name__ == '__main__':
    unittest.main()