#!/usr/bin/env python3
"""
GENESIS STATE MACHINE ENGINE
==============================
Robust finite state machine implementation for complex workflows.

Features:
    - Declarative state definitions
    - Transition guards and actions
    - State entry/exit hooks
    - Hierarchical states (nested)
    - History states
    - Event-driven transitions
    - Persistence and restoration

Usage:
    machine = StateMachine("order_flow")
    machine.add_state("pending", initial=True)
    machine.add_state("processing")
    machine.add_state("complete", final=True)

    machine.add_transition("pending", "processing", "start")
    machine.add_transition("processing", "complete", "finish")

    machine.trigger("start")
    machine.trigger("finish")
"""

import asyncio
import json
import threading
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Dict, List, Any, Optional, Callable, Set, Union, Tuple


class StateType(Enum):
    """Types of states."""
    NORMAL = "normal"
    INITIAL = "initial"
    FINAL = "final"
    COMPOSITE = "composite"  # Has child states
    HISTORY = "history"      # Remembers last active child


@dataclass
class State:
    """Definition of a state."""
    name: str
    state_type: StateType = StateType.NORMAL
    on_enter: Optional[Callable[['StateMachine', 'State'], None]] = None
    on_exit: Optional[Callable[['StateMachine', 'State'], None]] = None
    parent: Optional[str] = None
    data: Dict[str, Any] = field(default_factory=dict)

    # For composite states
    children: List[str] = field(default_factory=list)
    initial_child: Optional[str] = None

    def __hash__(self):
        return hash(self.name)

    def __eq__(self, other):
        if isinstance(other, State):
            return self.name == other.name
        return self.name == other


@dataclass
class Transition:
    """Definition of a state transition."""
    source: str
    target: str
    event: str
    guard: Optional[Callable[['StateMachine', Dict], bool]] = None
    action: Optional[Callable[['StateMachine', Dict], None]] = None
    internal: bool = False  # Internal transition (no exit/enter)
    priority: int = 0

    def __hash__(self):
        return hash((self.source, self.target, self.event))


@dataclass
class TransitionEvent:
    """Record of a state transition."""
    timestamp: str
    source: str
    target: str
    event: str
    data: Dict[str, Any] = field(default_factory=dict)
    success: bool = True
    error: Optional[str] = None


class TransitionError(Exception):
    """Error during state transition."""
    pass


class InvalidStateError(Exception):
    """Invalid state reference."""
    pass


class StateMachine:
    """
    Finite State Machine with advanced features.
    """

    def __init__(
        self,
        name: str,
        persist_path: Path = None,
        auto_persist: bool = False
    ):
        self.name = name
        self.persist_path = persist_path
        self.auto_persist = auto_persist

        self._states: Dict[str, State] = {}
        self._transitions: Dict[str, List[Transition]] = {}  # Keyed by event
        self._current_state: Optional[str] = None
        self._initial_state: Optional[str] = None
        self._history: Dict[str, str] = {}  # Parent -> last active child

        self._context: Dict[str, Any] = {}
        self._event_log: List[TransitionEvent] = []
        self._lock = threading.RLock()

        # Callbacks
        self._on_transition: List[Callable[[TransitionEvent], None]] = []
        self._on_state_change: List[Callable[[str, str], None]] = []

        # Restore if persist path exists
        if persist_path and persist_path.exists():
            self._restore()

    @property
    def current_state(self) -> Optional[str]:
        """Get current state name."""
        return self._current_state

    @property
    def current_state_obj(self) -> Optional[State]:
        """Get current state object."""
        if self._current_state:
            return self._states.get(self._current_state)
        return None

    @property
    def context(self) -> Dict[str, Any]:
        """Get machine context data."""
        return self._context

    def add_state(
        self,
        name: str,
        initial: bool = False,
        final: bool = False,
        on_enter: Callable = None,
        on_exit: Callable = None,
        parent: str = None,
        data: Dict = None
    ) -> 'StateMachine':
        """Add a state to the machine."""
        with self._lock:
            state_type = StateType.NORMAL
            if initial:
                state_type = StateType.INITIAL
                self._initial_state = name
            elif final:
                state_type = StateType.FINAL

            state = State(
                name=name,
                state_type=state_type,
                on_enter=on_enter,
                on_exit=on_exit,
                parent=parent,
                data=data or {}
            )

            self._states[name] = state

            # Register with parent if composite
            if parent and parent in self._states:
                parent_state = self._states[parent]
                parent_state.children.append(name)
                parent_state.state_type = StateType.COMPOSITE
                if initial:
                    parent_state.initial_child = name

            # Auto-set initial state if first
            if not self._current_state and initial:
                self._current_state = name
                if on_enter:
                    on_enter(self, state)

            return self

    def add_transition(
        self,
        source: str,
        target: str,
        event: str,
        guard: Callable = None,
        action: Callable = None,
        internal: bool = False,
        priority: int = 0
    ) -> 'StateMachine':
        """Add a transition between states."""
        with self._lock:
            transition = Transition(
                source=source,
                target=target,
                event=event,
                guard=guard,
                action=action,
                internal=internal,
                priority=priority
            )

            if event not in self._transitions:
                self._transitions[event] = []

            self._transitions[event].append(transition)

            # Sort by priority (higher first)
            self._transitions[event].sort(key=lambda t: -t.priority)

            return self

    def can_trigger(self, event: str, data: Dict = None) -> bool:
        """Check if an event can be triggered."""
        with self._lock:
            transition = self._find_transition(event, data or {})
            return transition is not None

    def trigger(self, event: str, data: Dict = None) -> bool:
        """
        Trigger an event, potentially causing a state transition.

        Returns True if transition occurred.
        """
        with self._lock:
            data = data or {}

            # Find valid transition
            transition = self._find_transition(event, data)

            if not transition:
                event_record = TransitionEvent(
                    timestamp=datetime.now().isoformat(),
                    source=self._current_state or "",
                    target="",
                    event=event,
                    data=data,
                    success=False,
                    error=f"No valid transition for event '{event}' from state '{self._current_state}'"
                )
                self._event_log.append(event_record)
                return False

            try:
                self._execute_transition(transition, data)

                event_record = TransitionEvent(
                    timestamp=datetime.now().isoformat(),
                    source=transition.source,
                    target=transition.target,
                    event=event,
                    data=data,
                    success=True
                )
                self._event_log.append(event_record)

                # Notify callbacks
                for callback in self._on_transition:
                    try:
                        callback(event_record)
                    except Exception:
                        pass

                # Auto-persist
                if self.auto_persist and self.persist_path:
                    self._persist()

                return True

            except Exception as e:
                event_record = TransitionEvent(
                    timestamp=datetime.now().isoformat(),
                    source=transition.source,
                    target=transition.target,
                    event=event,
                    data=data,
                    success=False,
                    error=str(e)
                )
                self._event_log.append(event_record)
                raise TransitionError(f"Transition failed: {e}") from e

    def _find_transition(self, event: str, data: Dict) -> Optional[Transition]:
        """Find a valid transition for the event."""
        if event not in self._transitions:
            return None

        for transition in self._transitions[event]:
            # Check source matches current state (or is a parent)
            if not self._state_matches(transition.source):
                continue

            # Check guard condition
            if transition.guard:
                try:
                    if not transition.guard(self, data):
                        continue
                except Exception:
                    continue

            return transition

        return None

    def _state_matches(self, state_name: str) -> bool:
        """Check if state_name matches current state or a parent."""
        if state_name == self._current_state:
            return True

        # Check if current state is a child of state_name
        current = self._states.get(self._current_state)
        while current and current.parent:
            if current.parent == state_name:
                return True
            current = self._states.get(current.parent)

        return False

    def _execute_transition(self, transition: Transition, data: Dict):
        """Execute a state transition."""
        source_state = self._states.get(transition.source)
        target_state = self._states.get(transition.target)

        if not target_state:
            raise InvalidStateError(f"Target state not found: {transition.target}")

        old_state = self._current_state

        # Internal transitions don't trigger exit/enter
        if transition.internal:
            if transition.action:
                transition.action(self, data)
            return

        # Exit current state (and parents if needed)
        if source_state and source_state.on_exit:
            source_state.on_exit(self, source_state)

        # Update history for parent states
        if source_state and source_state.parent:
            self._history[source_state.parent] = source_state.name

        # Execute action
        if transition.action:
            transition.action(self, data)

        # Enter target state
        self._current_state = transition.target

        # Handle composite state - enter initial child
        if target_state.state_type == StateType.COMPOSITE:
            initial_child = target_state.initial_child or (
                target_state.children[0] if target_state.children else None
            )
            if initial_child:
                self._current_state = initial_child
                target_state = self._states.get(initial_child)

        # Handle history state
        if target_state and target_state.state_type == StateType.HISTORY:
            parent = target_state.parent
            if parent and parent in self._history:
                self._current_state = self._history[parent]
                target_state = self._states.get(self._current_state)

        if target_state and target_state.on_enter:
            target_state.on_enter(self, target_state)

        # Notify state change
        for callback in self._on_state_change:
            try:
                callback(old_state or "", self._current_state or "")
            except Exception:
                pass

    def reset(self):
        """Reset machine to initial state."""
        with self._lock:
            if self._initial_state:
                old_state = self._current_state

                # Exit current state
                if old_state:
                    current = self._states.get(old_state)
                    if current and current.on_exit:
                        current.on_exit(self, current)

                self._current_state = self._initial_state
                self._history.clear()

                # Enter initial state
                initial = self._states.get(self._initial_state)
                if initial and initial.on_enter:
                    initial.on_enter(self, initial)

    def set_context(self, key: str, value: Any):
        """Set context data."""
        self._context[key] = value

    def get_context(self, key: str, default: Any = None) -> Any:
        """Get context data."""
        return self._context.get(key, default)

    def is_in_state(self, state_name: str) -> bool:
        """Check if currently in a state (or child of composite state)."""
        return self._state_matches(state_name)

    def is_final(self) -> bool:
        """Check if machine is in a final state."""
        current = self._states.get(self._current_state)
        return current and current.state_type == StateType.FINAL

    def get_available_events(self) -> List[str]:
        """Get events that can be triggered from current state."""
        available = []
        for event, transitions in self._transitions.items():
            for t in transitions:
                if self._state_matches(t.source):
                    if t.guard is None or t.guard(self, {}):
                        available.append(event)
                        break
        return available

    def get_event_log(self, limit: int = None) -> List[TransitionEvent]:
        """Get transition event log."""
        log = self._event_log.copy()
        if limit:
            log = log[-limit:]
        return log

    def on_transition(self, callback: Callable[[TransitionEvent], None]):
        """Register transition callback."""
        self._on_transition.append(callback)

    def on_state_change(self, callback: Callable[[str, str], None]):
        """Register state change callback (old_state, new_state)."""
        self._on_state_change.append(callback)

    def _persist(self):
        """Save machine state to disk."""
        if not self.persist_path:
            return

        self.persist_path.parent.mkdir(parents=True, exist_ok=True)

        data = {
            "name": self.name,
            "current_state": self._current_state,
            "context": self._context,
            "history": self._history,
            "event_log": [
                {
                    "timestamp": e.timestamp,
                    "source": e.source,
                    "target": e.target,
                    "event": e.event,
                    "data": e.data,
                    "success": e.success,
                    "error": e.error
                }
                for e in self._event_log[-100:]  # Keep last 100
            ]
        }

        with open(self.persist_path, 'w') as f:
            json.dump(data, f, indent=2)

    def _restore(self):
        """Restore machine state from disk."""
        if not self.persist_path or not self.persist_path.exists():
            return

        try:
            with open(self.persist_path, 'r') as f:
                data = json.load(f)

            if data.get("name") == self.name:
                self._current_state = data.get("current_state")
                self._context = data.get("context", {})
                self._history = data.get("history", {})
                self._event_log = [
                    TransitionEvent(**e)
                    for e in data.get("event_log", [])
                ]
        except Exception:
            pass

    def save(self):
        """Manually save state."""
        self._persist()

    def get_status(self) -> Dict:
        """Get machine status."""
        return {
            "name": self.name,
            "current_state": self._current_state,
            "is_final": self.is_final(),
            "available_events": self.get_available_events(),
            "states": list(self._states.keys()),
            "event_count": len(self._event_log)
        }

    def visualize(self) -> str:
        """Generate ASCII visualization of the state machine."""
        lines = [
            f"State Machine: {self.name}",
            "=" * 40,
            "",
            "States:"
        ]

        for name, state in self._states.items():
            indicator = ""
            if state.state_type == StateType.INITIAL:
                indicator = " [INITIAL]"
            elif state.state_type == StateType.FINAL:
                indicator = " [FINAL]"
            elif state.state_type == StateType.COMPOSITE:
                indicator = f" [COMPOSITE: {', '.join(state.children)}]"

            current = " <-- CURRENT" if name == self._current_state else ""
            lines.append(f"  - {name}{indicator}{current}")

        lines.extend(["", "Transitions:"])

        for event, transitions in self._transitions.items():
            for t in transitions:
                guard = " [guarded]" if t.guard else ""
                action = " [has action]" if t.action else ""
                lines.append(f"  {t.source} --({event})--> {t.target}{guard}{action}")

        return "\n".join(lines)


class StateMachineBuilder:
    """Fluent builder for state machines."""

    def __init__(self, name: str):
        self._machine = StateMachine(name)
        self._current_state: Optional[str] = None

    def state(
        self,
        name: str,
        initial: bool = False,
        final: bool = False
    ) -> 'StateMachineBuilder':
        """Add a state."""
        self._machine.add_state(name, initial=initial, final=final)
        self._current_state = name
        return self

    def on_enter(self, callback: Callable) -> 'StateMachineBuilder':
        """Set on_enter for current state."""
        if self._current_state:
            state = self._machine._states[self._current_state]
            state.on_enter = callback
        return self

    def on_exit(self, callback: Callable) -> 'StateMachineBuilder':
        """Set on_exit for current state."""
        if self._current_state:
            state = self._machine._states[self._current_state]
            state.on_exit = callback
        return self

    def transition(
        self,
        event: str,
        target: str,
        guard: Callable = None,
        action: Callable = None
    ) -> 'StateMachineBuilder':
        """Add transition from current state."""
        if self._current_state:
            self._machine.add_transition(
                self._current_state, target, event,
                guard=guard, action=action
            )
        return self

    def build(self) -> StateMachine:
        """Build and return the state machine."""
        return self._machine


class WorkflowEngine:
    """
    Manages multiple state machines for complex workflows.
    """

    def __init__(self):
        self._machines: Dict[str, StateMachine] = {}
        self._lock = threading.RLock()

    def create(
        self,
        name: str,
        persist_path: Path = None,
        auto_persist: bool = True
    ) -> StateMachine:
        """Create and register a new state machine."""
        with self._lock:
            machine = StateMachine(name, persist_path, auto_persist)
            self._machines[name] = machine
            return machine

    def get(self, name: str) -> Optional[StateMachine]:
        """Get a state machine by name."""
        return self._machines.get(name)

    def remove(self, name: str) -> bool:
        """Remove a state machine."""
        with self._lock:
            if name in self._machines:
                del self._machines[name]
                return True
            return False

    def list_machines(self) -> List[str]:
        """List all machine names."""
        return list(self._machines.keys())

    def get_all_status(self) -> Dict[str, Dict]:
        """Get status of all machines."""
        return {
            name: machine.get_status()
            for name, machine in self._machines.items()
        }


# Global workflow engine
_workflow_engine: Optional[WorkflowEngine] = None


def get_workflow_engine() -> WorkflowEngine:
    """Get global workflow engine."""
    global _workflow_engine
    if _workflow_engine is None:
        _workflow_engine = WorkflowEngine()
    return _workflow_engine


def main():
    """CLI and demo for state machine."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis State Machine")
    parser.add_argument("command", choices=["demo", "order", "task"])
    args = parser.parse_args()

    if args.command == "demo":
        print("State Machine Demo")
        print("=" * 40)

        # Simple traffic light
        print("\n1. Traffic Light State Machine:")

        light = StateMachine("traffic_light")
        light.add_state("red", initial=True, on_enter=lambda m, s: print(f"    >> Entering {s.name}"))
        light.add_state("green", on_enter=lambda m, s: print(f"    >> Entering {s.name}"))
        light.add_state("yellow", on_enter=lambda m, s: print(f"    >> Entering {s.name}"))

        light.add_transition("red", "green", "timer")
        light.add_transition("green", "yellow", "timer")
        light.add_transition("yellow", "red", "timer")

        print(f"  Initial state: {light.current_state}")
        for _ in range(4):
            light.trigger("timer")
            print(f"  Current state: {light.current_state}")

        # Builder pattern
        print("\n2. Builder Pattern:")

        door = (StateMachineBuilder("door")
            .state("closed", initial=True)
            .transition("open", "open")
            .state("open")
            .transition("close", "closed")
            .transition("lock", "locked")
            .state("locked")
            .transition("unlock", "closed")
            .build())

        print(f"  Initial: {door.current_state}")
        door.trigger("open")
        print(f"  After 'open': {door.current_state}")
        door.trigger("lock")
        print(f"  After 'lock': {door.current_state}")
        door.trigger("unlock")
        print(f"  After 'unlock': {door.current_state}")

        # Guards
        print("\n3. Guarded Transitions:")

        counter = StateMachine("counter")
        counter.set_context("count", 0)
        counter.add_state("counting", initial=True)
        counter.add_state("done", final=True)

        def increment(machine, data):
            machine.set_context("count", machine.get_context("count") + 1)

        def check_limit(machine, data):
            return machine.get_context("count") < 3

        counter.add_transition("counting", "counting", "increment",
                              guard=check_limit, action=increment)
        counter.add_transition("counting", "done", "finish")

        for i in range(5):
            success = counter.trigger("increment")
            print(f"  Increment {i+1}: {'OK' if success else 'blocked'}, count={counter.get_context('count')}")

        counter.trigger("finish")
        print(f"  Final state: {counter.current_state}, is_final: {counter.is_final()}")

        # Visualization
        print("\n4. Visualization:")
        print(light.visualize())

    elif args.command == "order":
        print("Order Workflow Demo")
        print("=" * 40)

        order = (StateMachineBuilder("order")
            .state("created", initial=True)
            .transition("submit", "pending")
            .state("pending")
            .transition("pay", "paid")
            .transition("cancel", "cancelled")
            .state("paid")
            .transition("ship", "shipped")
            .transition("refund", "refunded")
            .state("shipped")
            .transition("deliver", "delivered")
            .state("delivered", final=True)
            .state("cancelled", final=True)
            .state("refunded", final=True)
            .build())

        order.on_transition(lambda e: print(f"  [{e.event}] {e.source} -> {e.target}"))

        print(f"\nStarting state: {order.current_state}")
        print(f"Available events: {order.get_available_events()}")

        order.trigger("submit")
        order.trigger("pay")
        order.trigger("ship")
        order.trigger("deliver")

        print(f"\nFinal state: {order.current_state}")
        print(f"Is complete: {order.is_final()}")

        print("\n" + order.visualize())

    elif args.command == "task":
        print("Task State Machine Demo")
        print("=" * 40)

        task = StateMachine("task")
        task.add_state("todo", initial=True)
        task.add_state("in_progress")
        task.add_state("review")
        task.add_state("done", final=True)
        task.add_state("blocked")

        task.add_transition("todo", "in_progress", "start")
        task.add_transition("in_progress", "review", "submit")
        task.add_transition("in_progress", "blocked", "block")
        task.add_transition("blocked", "in_progress", "unblock")
        task.add_transition("review", "in_progress", "reject")
        task.add_transition("review", "done", "approve")

        print(task.visualize())

        print("\nSimulating workflow:")
        events = ["start", "block", "unblock", "submit", "reject", "submit", "approve"]
        for event in events:
            old = task.current_state
            task.trigger(event)
            print(f"  {old} --({event})--> {task.current_state}")

        print(f"\nEvent log:")
        for e in task.get_event_log():
            print(f"  [{e.timestamp[:19]}] {e.event}: {e.source} -> {e.target}")


if __name__ == "__main__":
    main()
