import pytest
from rwl.orchestrator import Orchestrator
from rwl.models import Task, Dependency

# Define some dummy task implementations
class DummyTask(Task):
    def execute(self):
        self.status = "completed"
        return True


def test_circular_dependency_detection():
    """
    Test that the orchestrator detects and handles circular dependencies gracefully.
    """
    orchestrator = Orchestrator()

    task1 = DummyTask(name="Task1", dependencies=[Dependency(task_name="Task2")])
    task2 = DummyTask(name="Task2", dependencies=[Dependency(task_name="Task3")])
    task3 = DummyTask(name="Task3", dependencies=[Dependency(task_name="Task1")])  # Circular dependency

    orchestrator.register_task(task1)
    orchestrator.register_task(task2)
    orchestrator.register_task(task3)

    with pytest.raises(ValueError) as e:  # Expect a ValueError due to circular dependency
        orchestrator.resolve_dependencies()

    assert "Circular dependency detected" in str(e.value)


def test_self_dependency_detection():
    """
    Test that the orchestrator detects and handles self-dependencies gracefully.
    """
    orchestrator = Orchestrator()

    task1 = DummyTask(name="Task1", dependencies=[Dependency(task_name="Task1")])  # Self-dependency

    orchestrator.register_task(task1)

    with pytest.raises(ValueError) as e:  # Expect a ValueError due to self-dependency
        orchestrator.resolve_dependencies()

    assert "Self-dependency detected" in str(e.value)


def test_nonexistent_dependency_handling():
    """
    Test that the orchestrator handles dependencies on nonexistent tasks gracefully.
    """
    orchestrator = Orchestrator()

    task1 = DummyTask(name="Task1", dependencies=[Dependency(task_name="NonExistentTask")])

    orchestrator.register_task(task1)

    with pytest.raises(ValueError) as e:  # Expect a ValueError due to a dependency on a non-existent task.
        orchestrator.resolve_dependencies()

    assert "Task NonExistentTask not found" in str(e.value)


def test_diamond_dependency_resolution():
    """
    Test a more complex dependency graph (diamond shape) to ensure it resolves correctly without infinite loops.
    """
    orchestrator = Orchestrator()

    taskA = DummyTask(name="TaskA")
    taskB = DummyTask(name="TaskB", dependencies=[Dependency(task_name="TaskA")])
    taskC = DummyTask(name="TaskC", dependencies=[Dependency(task_name="TaskA")])
    taskD = DummyTask(name="TaskD", dependencies=[Dependency(task_name="TaskB"), Dependency(task_name="TaskC")])

    orchestrator.register_task(taskA)
    orchestrator.register_task(taskB)
    orchestrator.register_task(taskC)
    orchestrator.register_task(taskD)

    try:
        orchestrator.resolve_dependencies()  # Should not raise an exception in this case.
        # Check that dependencies are resolved correctly: D depends on B and C which depend on A
        assert orchestrator.task_dependencies['TaskD'] == {taskB, taskC}
        assert orchestrator.task_dependencies['TaskB'] == {taskA}
        assert orchestrator.task_dependencies['TaskC'] == {taskA}
        assert orchestrator.task_dependencies['TaskA'] == set()

        # Confirm no infinite loops occur by executing the tasks
        orchestrator.execute_tasks()
        assert taskA.status == "completed"
        assert taskB.status == "completed"
        assert taskC.status == "completed"
        assert taskD.status == "completed"
    except ValueError as e:
        pytest.fail(f"Dependency resolution failed unexpectedly: {e}")


def test_valid_linear_dependencies():
    """
    Test a valid linear dependency chain resolves without errors.
    """

    orchestrator = Orchestrator()

    task1 = DummyTask(name="Task1")
    task2 = DummyTask(name="Task2", dependencies=[Dependency(task_name="Task1")])
    task3 = DummyTask(name="Task3", dependencies=[Dependency(task_name="Task2")])

    orchestrator.register_task(task1)
    orchestrator.register_task(task2)
    orchestrator.register_task(task3)

    try:
        orchestrator.resolve_dependencies()
        assert orchestrator.task_dependencies['Task3'] == {task2}
        assert orchestrator.task_dependencies['Task2'] == {task1}
        assert orchestrator.task_dependencies['Task1'] == set()

        orchestrator.execute_tasks()
        assert task1.status == "completed"
        assert task2.status == "completed"
        assert task3.status == "completed"


    except ValueError as e:
        pytest.fail(f"Dependency resolution failed unexpectedly: {e}")