#!/usr/bin/env python3
"""
GENESIS AUTONOMOUS QUEUE MANAGER
=================================
Self-managing task queue with intelligent prioritization.

Features:
    - Priority-based task scheduling
    - Dependency resolution
    - Automatic retry with backoff
    - Task aging (priority boost over time)
    - Load-based throttling
    - Work stealing between agents

Usage:
    queue = AutonomousQueue()
    queue.enqueue(task)
    queue.start_processing()
"""

"""
RULE 7 COMPLIANT: Uses Elestio PostgreSQL via genesis_db module.
"""
import json
import heapq
import threading
import time
import uuid
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Dict, List, Any, Optional, Callable, Set
import logging

# RULE 7: Use PostgreSQL via genesis_db (no sqlite3)
from core.genesis_db import connection, ensure_table

logger = logging.getLogger(__name__)


class TaskState(Enum):
    """Task execution states."""
    PENDING = "pending"
    READY = "ready"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    BLOCKED = "blocked"
    CANCELLED = "cancelled"


class TaskPriority(Enum):
    """Priority levels."""
    CRITICAL = 0
    HIGH = 1
    NORMAL = 2
    LOW = 3
    BACKGROUND = 4


@dataclass(order=True)
class QueuedTask:
    """A task in the queue."""
    priority_score: float = field(compare=True)
    task_id: str = field(compare=False)
    title: str = field(compare=False)
    description: str = field(compare=False)
    priority: TaskPriority = field(compare=False, default=TaskPriority.NORMAL)
    state: TaskState = field(compare=False, default=TaskState.PENDING)
    created_at: str = field(compare=False, default="")
    scheduled_at: Optional[str] = field(compare=False, default=None)
    started_at: Optional[str] = field(compare=False, default=None)
    completed_at: Optional[str] = field(compare=False, default=None)
    dependencies: List[str] = field(compare=False, default_factory=list)
    assigned_agent: Optional[str] = field(compare=False, default=None)
    retry_count: int = field(compare=False, default=0)
    max_retries: int = field(compare=False, default=3)
    timeout_seconds: int = field(compare=False, default=300)
    metadata: Dict[str, Any] = field(compare=False, default_factory=dict)
    result: Any = field(compare=False, default=None)
    error: Optional[str] = field(compare=False, default=None)

    def __post_init__(self):
        if not self.task_id:
            self.task_id = f"task_{uuid.uuid4().hex[:8]}"
        if not self.created_at:
            self.created_at = datetime.now().isoformat()
        if self.priority_score == 0:
            self.priority_score = self.priority.value

    def to_dict(self) -> Dict:
        return {
            "task_id": self.task_id,
            "title": self.title,
            "description": self.description,
            "priority": self.priority.value,
            "state": self.state.value,
            "created_at": self.created_at,
            "scheduled_at": self.scheduled_at,
            "started_at": self.started_at,
            "completed_at": self.completed_at,
            "dependencies": self.dependencies,
            "assigned_agent": self.assigned_agent,
            "retry_count": self.retry_count,
            "max_retries": self.max_retries,
            "result": self.result,
            "error": self.error
        }


class DependencyResolver:
    """
    Resolves task dependencies and determines execution order.
    """

    def __init__(self):
        self._graph: Dict[str, Set[str]] = defaultdict(set)
        self._reverse_graph: Dict[str, Set[str]] = defaultdict(set)
        self._completed: Set[str] = set()

    def add_task(self, task_id: str, dependencies: List[str]):
        """Add task with its dependencies."""
        for dep in dependencies:
            self._graph[task_id].add(dep)
            self._reverse_graph[dep].add(task_id)

    def mark_completed(self, task_id: str):
        """Mark a task as completed."""
        self._completed.add(task_id)

    def is_ready(self, task_id: str) -> bool:
        """Check if all dependencies are satisfied."""
        return all(dep in self._completed for dep in self._graph.get(task_id, set()))

    def get_ready_tasks(self, pending_tasks: List[str]) -> List[str]:
        """Get all tasks ready for execution."""
        return [t for t in pending_tasks if self.is_ready(t)]

    def get_blocked_by(self, task_id: str) -> List[str]:
        """Get incomplete dependencies blocking a task."""
        return [
            dep for dep in self._graph.get(task_id, set())
            if dep not in self._completed
        ]

    def get_dependents(self, task_id: str) -> List[str]:
        """Get tasks that depend on this task."""
        return list(self._reverse_graph.get(task_id, set()))

    def detect_cycle(self) -> Optional[List[str]]:
        """Detect circular dependencies."""
        visited = set()
        rec_stack = set()
        cycle_path = []

        def dfs(node: str) -> bool:
            visited.add(node)
            rec_stack.add(node)

            for dep in self._graph.get(node, set()):
                if dep not in visited:
                    if dfs(dep):
                        cycle_path.append(dep)
                        return True
                elif dep in rec_stack:
                    cycle_path.append(dep)
                    return True

            rec_stack.remove(node)
            return False

        for task in list(self._graph.keys()):
            if task not in visited:
                if dfs(task):
                    cycle_path.append(task)
                    return cycle_path[::-1]

        return None


class PriorityAger:
    """
    Boosts priority of aging tasks to prevent starvation.
    """

    def __init__(self, age_factor: float = 0.1, max_boost: float = 2.0):
        self.age_factor = age_factor  # Priority boost per minute
        self.max_boost = max_boost

    def calculate_aged_priority(self, task: QueuedTask) -> float:
        """Calculate priority with aging boost."""
        created = datetime.fromisoformat(task.created_at)
        age_minutes = (datetime.now() - created).total_seconds() / 60

        boost = min(age_minutes * self.age_factor, self.max_boost)
        aged_priority = task.priority.value - boost  # Lower is higher priority

        return max(0, aged_priority)


class RetryPolicy:
    """
    Manages retry logic with exponential backoff.
    """

    def __init__(
        self,
        max_retries: int = 3,
        base_delay: float = 1.0,
        max_delay: float = 60.0,
        exponential_base: float = 2.0
    ):
        self.max_retries = max_retries
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.exponential_base = exponential_base

    def should_retry(self, task: QueuedTask) -> bool:
        """Determine if task should be retried."""
        return task.retry_count < self.max_retries

    def get_delay(self, retry_count: int) -> float:
        """Get delay before next retry."""
        delay = self.base_delay * (self.exponential_base ** retry_count)
        return min(delay, self.max_delay)

    def prepare_retry(self, task: QueuedTask) -> QueuedTask:
        """Prepare task for retry."""
        task.retry_count += 1
        task.state = TaskState.PENDING
        task.error = None
        delay = self.get_delay(task.retry_count)
        task.scheduled_at = (datetime.now() + timedelta(seconds=delay)).isoformat()
        return task


class AutonomousQueue:
    """
    Self-managing task queue with intelligent processing.
    """

    def __init__(
        self,
        max_concurrent: int = 5,
        persist: bool = True
    ):
        self.max_concurrent = max_concurrent
        self.persist = persist

        self._heap: List[QueuedTask] = []
        self._tasks: Dict[str, QueuedTask] = {}
        self._running: Dict[str, QueuedTask] = {}
        self._lock = threading.RLock()
        self._condition = threading.Condition(self._lock)

        self.dependency_resolver = DependencyResolver()
        self.priority_ager = PriorityAger()
        self.retry_policy = RetryPolicy()

        self._processing = False
        self._processor_threads: List[threading.Thread] = []
        self._task_handlers: Dict[str, Callable] = {}

        # Persistence via PostgreSQL (RULE 7)
        if persist:
            self._init_db()

        # Metrics
        self._metrics = {
            "enqueued": 0,
            "completed": 0,
            "failed": 0,
            "retried": 0
        }

    def _init_db(self):
        """Initialize persistence database via PostgreSQL (RULE 7)."""
        ensure_table('queue_tasks', '''
            task_id TEXT PRIMARY KEY,
            data JSONB NOT NULL,
            state TEXT NOT NULL,
            created_at TIMESTAMPTZ NOT NULL,
            updated_at TIMESTAMPTZ NOT NULL
        ''')

    def _persist_task(self, task: QueuedTask):
        """Persist task to database via PostgreSQL (RULE 7)."""
        if not self.persist:
            return
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    INSERT INTO queue_tasks (task_id, data, state, created_at, updated_at)
                    VALUES (%s, %s, %s, %s, %s)
                    ON CONFLICT (task_id) DO UPDATE SET
                        data = EXCLUDED.data,
                        state = EXCLUDED.state,
                        updated_at = EXCLUDED.updated_at
                """, (
                    task.task_id,
                    json.dumps(task.to_dict()),
                    task.state.value,
                    task.created_at,
                    datetime.now().isoformat()
                ))
        except Exception as e:
            logger.warning(f"Failed to persist task: {e}")

    def register_handler(self, task_type: str, handler: Callable):
        """Register a handler for a task type."""
        self._task_handlers[task_type] = handler

    def enqueue(
        self,
        title: str,
        description: str = "",
        priority: TaskPriority = TaskPriority.NORMAL,
        dependencies: List[str] = None,
        metadata: Dict = None,
        task_id: str = None
    ) -> str:
        """
        Add task to queue.

        Returns:
            Task ID
        """
        with self._lock:
            task = QueuedTask(
                priority_score=priority.value,
                task_id=task_id or f"task_{uuid.uuid4().hex[:8]}",
                title=title,
                description=description,
                priority=priority,
                state=TaskState.PENDING,
                dependencies=dependencies or [],
                metadata=metadata or {}
            )

            self._tasks[task.task_id] = task
            self.dependency_resolver.add_task(task.task_id, task.dependencies)

            if self.dependency_resolver.is_ready(task.task_id):
                task.state = TaskState.READY
                heapq.heappush(self._heap, task)

            self._persist_task(task)
            self._metrics["enqueued"] += 1

            self._condition.notify()

            return task.task_id

    def dequeue(self, agent_id: str = None) -> Optional[QueuedTask]:
        """
        Get next task from queue.
        """
        with self._lock:
            # Wait if queue is empty
            while not self._heap and self._processing:
                self._condition.wait(timeout=1.0)
                self._update_priorities()

            if not self._heap:
                return None

            # Check concurrent limit
            if len(self._running) >= self.max_concurrent:
                return None

            task = heapq.heappop(self._heap)
            task.state = TaskState.RUNNING
            task.started_at = datetime.now().isoformat()
            task.assigned_agent = agent_id

            self._running[task.task_id] = task
            self._persist_task(task)

            return task

    def complete(self, task_id: str, result: Any = None):
        """Mark task as completed."""
        with self._lock:
            if task_id not in self._running:
                return

            task = self._running.pop(task_id)
            task.state = TaskState.COMPLETED
            task.completed_at = datetime.now().isoformat()
            task.result = result

            self._tasks[task_id] = task
            self.dependency_resolver.mark_completed(task_id)

            # Check if any blocked tasks can now run
            self._unblock_dependents(task_id)

            self._persist_task(task)
            self._metrics["completed"] += 1

            self._condition.notify()

    def fail(self, task_id: str, error: str):
        """Mark task as failed."""
        with self._lock:
            if task_id not in self._running:
                return

            task = self._running.pop(task_id)
            task.error = error

            if self.retry_policy.should_retry(task):
                task = self.retry_policy.prepare_retry(task)
                self._tasks[task_id] = task
                self._metrics["retried"] += 1
                # Will be re-added to heap when scheduled time arrives
            else:
                task.state = TaskState.FAILED
                task.completed_at = datetime.now().isoformat()
                self._tasks[task_id] = task
                self._metrics["failed"] += 1

            self._persist_task(task)
            self._condition.notify()

    def cancel(self, task_id: str) -> bool:
        """Cancel a task."""
        with self._lock:
            if task_id in self._running:
                return False  # Can't cancel running task

            if task_id in self._tasks:
                task = self._tasks[task_id]
                task.state = TaskState.CANCELLED
                self._tasks[task_id] = task
                self._persist_task(task)
                return True

            return False

    def _unblock_dependents(self, completed_task_id: str):
        """Check and unblock tasks that depended on completed task."""
        dependents = self.dependency_resolver.get_dependents(completed_task_id)

        for dep_id in dependents:
            if dep_id in self._tasks:
                task = self._tasks[dep_id]
                # Check PENDING or BLOCKED - both indicate waiting for dependencies
                if task.state in (TaskState.BLOCKED, TaskState.PENDING):
                    if self.dependency_resolver.is_ready(dep_id):
                        task.state = TaskState.READY
                        heapq.heappush(self._heap, task)
                        self._persist_task(task)

    def _update_priorities(self):
        """Update priorities based on aging."""
        if not self._heap:
            return

        # Rebuild heap with updated priorities
        updated = []
        for task in self._heap:
            task.priority_score = self.priority_ager.calculate_aged_priority(task)
            updated.append(task)

        heapq.heapify(updated)
        self._heap = updated

    def get_task(self, task_id: str) -> Optional[QueuedTask]:
        """Get task by ID."""
        with self._lock:
            return self._tasks.get(task_id)

    def get_queue_status(self) -> Dict:
        """Get queue status."""
        with self._lock:
            by_state = defaultdict(int)
            by_priority = defaultdict(int)

            for task in self._tasks.values():
                by_state[task.state.value] += 1
                by_priority[task.priority.value] += 1

            return {
                "total_tasks": len(self._tasks),
                "queued": len(self._heap),
                "running": len(self._running),
                "by_state": dict(by_state),
                "by_priority": dict(by_priority),
                "metrics": self._metrics.copy()
            }

    def get_pending_tasks(self, limit: int = 50) -> List[Dict]:
        """Get pending/ready tasks."""
        with self._lock:
            pending = [
                t for t in self._tasks.values()
                if t.state in (TaskState.PENDING, TaskState.READY)
            ]
            pending.sort(key=lambda t: t.priority_score)
            return [t.to_dict() for t in pending[:limit]]

    def start_processing(self, num_workers: int = None):
        """Start processing tasks."""
        if self._processing:
            return

        self._processing = True
        num_workers = num_workers or self.max_concurrent

        for i in range(num_workers):
            thread = threading.Thread(
                target=self._worker_loop,
                args=(f"worker-{i}",),
                daemon=True
            )
            thread.start()
            self._processor_threads.append(thread)

    def stop_processing(self):
        """Stop processing tasks."""
        self._processing = False
        with self._lock:
            self._condition.notify_all()

        for thread in self._processor_threads:
            thread.join(timeout=5)
        self._processor_threads.clear()

    def _worker_loop(self, worker_id: str):
        """Worker thread loop."""
        while self._processing:
            task = self.dequeue(agent_id=worker_id)

            if not task:
                time.sleep(0.5)
                continue

            try:
                # Execute task
                task_type = task.metadata.get("type", "default")
                handler = self._task_handlers.get(task_type)

                if handler:
                    result = handler(task)
                    self.complete(task.task_id, result)
                else:
                    # No handler - just complete
                    self.complete(task.task_id, {"status": "no_handler"})

            except Exception as e:
                self.fail(task.task_id, str(e))

    def batch_enqueue(self, tasks: List[Dict]) -> List[str]:
        """Enqueue multiple tasks."""
        task_ids = []
        for task_spec in tasks:
            task_id = self.enqueue(
                title=task_spec.get("title", ""),
                description=task_spec.get("description", ""),
                priority=TaskPriority(task_spec.get("priority", 2)),
                dependencies=task_spec.get("dependencies", []),
                metadata=task_spec.get("metadata", {})
            )
            task_ids.append(task_id)
        return task_ids

    def clear_completed(self):
        """Remove completed tasks from memory."""
        with self._lock:
            to_remove = [
                t_id for t_id, t in self._tasks.items()
                if t.state in (TaskState.COMPLETED, TaskState.CANCELLED)
            ]
            for t_id in to_remove:
                del self._tasks[t_id]


# Global instance
_queue: Optional[AutonomousQueue] = None


def get_queue() -> AutonomousQueue:
    """Get global queue instance."""
    global _queue
    if _queue is None:
        _queue = AutonomousQueue()
    return _queue


def main():
    """CLI for autonomous queue."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Autonomous Queue")
    parser.add_argument("command", choices=["add", "status", "list", "demo", "start"])
    parser.add_argument("--title", help="Task title")
    parser.add_argument("--priority", choices=["critical", "high", "normal", "low", "background"], default="normal")
    parser.add_argument("--workers", type=int, default=3)
    args = parser.parse_args()

    queue = AutonomousQueue()

    if args.command == "add":
        if not args.title:
            print("--title required")
            return

        priority_map = {
            "critical": TaskPriority.CRITICAL,
            "high": TaskPriority.HIGH,
            "normal": TaskPriority.NORMAL,
            "low": TaskPriority.LOW,
            "background": TaskPriority.BACKGROUND
        }

        task_id = queue.enqueue(
            title=args.title,
            priority=priority_map[args.priority]
        )
        print(f"Enqueued: {task_id}")

    elif args.command == "status":
        status = queue.get_queue_status()
        print(json.dumps(status, indent=2))

    elif args.command == "list":
        tasks = queue.get_pending_tasks()
        print(f"Pending tasks ({len(tasks)}):")
        for t in tasks:
            print(f"  [{t['priority']}] {t['task_id']}: {t['title']}")

    elif args.command == "demo":
        print("Queue Demo")
        print("=" * 40)

        # Add some tasks with dependencies
        t1 = queue.enqueue("Setup database", priority=TaskPriority.HIGH)
        t2 = queue.enqueue("Create tables", dependencies=[t1])
        t3 = queue.enqueue("Load data", dependencies=[t2])
        t4 = queue.enqueue("Run tests", dependencies=[t3], priority=TaskPriority.NORMAL)
        queue.enqueue("Background cleanup", priority=TaskPriority.BACKGROUND)

        print(f"Added {queue.get_queue_status()['total_tasks']} tasks")

        # Simulate processing
        def handler(task):
            print(f"  Executing: {task.title}")
            time.sleep(0.5)
            return {"status": "done"}

        queue.register_handler("default", handler)
        queue.start_processing(num_workers=2)

        time.sleep(5)
        queue.stop_processing()

        print(f"\nFinal status: {json.dumps(queue.get_queue_status(), indent=2)}")

    elif args.command == "start":
        print(f"Starting queue processor with {args.workers} workers...")
        queue.start_processing(num_workers=args.workers)

        try:
            while True:
                time.sleep(1)
        except KeyboardInterrupt:
            print("\nStopping...")
            queue.stop_processing()


if __name__ == "__main__":
    main()
