"""
AIVA Task Dispatcher - Intelligent ROI-Aware Task Routing
============================================================

Higher-level task management on top of the SwarmLiaison.
Routes tasks to the optimal worker based on complexity vs. cost,
manages task lifecycle with retries, and stores audit trails in PostgreSQL.

Routing Logic (ROI-aware):
    Simple tasks       -> Gemini Flash  (cheapest, fastest)
    Medium complexity  -> Gemini Pro    (balanced)
    Complex reasoning  -> Claude Code   (most capable, most expensive)

Task Lifecycle:
    PENDING -> DISPATCHED -> IN_PROGRESS -> COMPLETED / FAILED / TIMEOUT

Persistence:
    All task state changes are logged to PostgreSQL (aiva_swarm_tasks table)
    for audit trail and analytics.

NO SQLITE. All storage uses Elestio PostgreSQL/Redis.

VERIFICATION_STAMP
Story: AIVA-SWARM-002
Verified By: Claude Opus 4.6
Verified At: 2026-02-11
Component: Task Dispatcher (ROI-aware routing + lifecycle management)
"""

import sys
import json
import time
import uuid
import logging
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum

# Elestio config
sys.path.append('/mnt/e/genesis-system/data/genesis-memory')
from elestio_config import PostgresConfig

from .swarm_liaison import (
    SwarmLiaison,
    SwarmTask,
    SwarmResult,
    TaskStatus,
    TaskPriority,
    WorkerType,
    get_swarm_liaison
)

logger = logging.getLogger("AIVA.TaskDispatcher")


# =============================================================================
# COMPLEXITY CLASSIFICATION
# =============================================================================

class TaskComplexity(Enum):
    """Task complexity levels for routing."""
    SIMPLE = "simple"       # Quick lookups, formatting, simple code
    MEDIUM = "medium"       # Multi-step tasks, moderate reasoning
    COMPLEX = "complex"     # Architecture, debugging, deep reasoning


# Keyword-based complexity heuristics
COMPLEXITY_KEYWORDS = {
    TaskComplexity.COMPLEX: [
        "architect", "design", "refactor", "debug", "security",
        "optimize", "strategy", "patent", "complex", "analyze",
        "investigate", "root cause", "performance", "concurrent",
        "distributed", "migration"
    ],
    TaskComplexity.MEDIUM: [
        "implement", "create", "build", "test", "integrate",
        "api", "endpoint", "database", "query", "dashboard",
        "configure", "deploy", "script"
    ],
    TaskComplexity.SIMPLE: [
        "list", "check", "status", "format", "count",
        "summarize", "lookup", "fetch", "ping", "verify",
        "simple", "quick"
    ]
}

# Cost per 1K tokens (approximate) for routing decisions
WORKER_COST_MAP = {
    WorkerType.GEMINI_FLASH.value: 0.0001,   # Cheapest
    WorkerType.GEMINI_PRO.value: 0.001,       # 10x Flash
    WorkerType.CLAUDE_CODE.value: 0.015,      # Most expensive
    WorkerType.GENERIC.value: 0.001,          # Default
}

# Optimal worker for each complexity level
COMPLEXITY_ROUTING = {
    TaskComplexity.SIMPLE: WorkerType.GEMINI_FLASH.value,
    TaskComplexity.MEDIUM: WorkerType.GEMINI_PRO.value,
    TaskComplexity.COMPLEX: WorkerType.CLAUDE_CODE.value,
}


# =============================================================================
# TASK DISPATCHER
# =============================================================================

class TaskDispatcher:
    """
    Intelligent task routing and lifecycle management.

    Features:
    - ROI-aware routing: selects cheapest worker capable of the task
    - Priority queue: HIGH tasks processed before MEDIUM/LOW
    - Retry with exponential backoff (max 2 retries)
    - Batch dispatch for parallel execution
    - PostgreSQL audit trail for all task state changes
    """

    MAX_RETRIES = 2
    INITIAL_RETRY_DELAY = 5    # seconds
    RETRY_BACKOFF_FACTOR = 2   # exponential

    def __init__(
        self,
        liaison: Optional[SwarmLiaison] = None,
    ):
        """
        Initialize the Task Dispatcher.

        Args:
            liaison: SwarmLiaison instance (uses singleton if not provided)
        """
        self.liaison = liaison or get_swarm_liaison()
        self._pg_conn = None
        self._init_postgres()

    def _init_postgres(self):
        """Initialize PostgreSQL connection and ensure tables exist."""
        try:
            import psycopg2
            params = PostgresConfig.get_connection_params()
            self._pg_conn = psycopg2.connect(**params, connect_timeout=10)
            self._pg_conn.autocommit = True
            self._ensure_tables()
            logger.info("TaskDispatcher PostgreSQL connected")
        except Exception as e:
            self._pg_conn = None
            logger.warning(f"TaskDispatcher PostgreSQL unavailable (non-fatal): {e}")

    def _ensure_tables(self):
        """Create the swarm tasks table if it doesn't exist."""
        if not self._pg_conn:
            return

        try:
            with self._pg_conn.cursor() as cur:
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS aiva_swarm_tasks (
                        id SERIAL PRIMARY KEY,
                        task_id VARCHAR(64) UNIQUE NOT NULL,
                        description TEXT NOT NULL,
                        worker_type VARCHAR(32) NOT NULL,
                        complexity VARCHAR(16),
                        priority INTEGER DEFAULT 5,
                        status VARCHAR(20) DEFAULT 'PENDING',
                        retry_count INTEGER DEFAULT 0,
                        result_summary TEXT,
                        error_message TEXT,
                        tokens_used INTEGER DEFAULT 0,
                        duration_ms FLOAT DEFAULT 0,
                        dispatched_at TIMESTAMP,
                        completed_at TIMESTAMP,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                        metadata JSONB DEFAULT '{}'
                    )
                """)
                # Index for status lookups
                cur.execute("""
                    CREATE INDEX IF NOT EXISTS idx_swarm_tasks_status
                    ON aiva_swarm_tasks(status)
                """)
                # Index for task_id lookups
                cur.execute("""
                    CREATE INDEX IF NOT EXISTS idx_swarm_tasks_task_id
                    ON aiva_swarm_tasks(task_id)
                """)
            logger.info("TaskDispatcher tables verified")
        except Exception as e:
            logger.warning(f"Table creation warning: {e}")

    # =========================================================================
    # COMPLEXITY CLASSIFICATION
    # =========================================================================

    def classify_complexity(self, description: str) -> TaskComplexity:
        """
        Classify task complexity based on description keywords.

        Args:
            description: Task description

        Returns:
            TaskComplexity level
        """
        desc_lower = description.lower()

        # Check from most complex to least
        for complexity in [TaskComplexity.COMPLEX, TaskComplexity.MEDIUM, TaskComplexity.SIMPLE]:
            keywords = COMPLEXITY_KEYWORDS[complexity]
            if any(kw in desc_lower for kw in keywords):
                return complexity

        # Default to medium if no keywords match
        return TaskComplexity.MEDIUM

    def select_worker_type(
        self,
        complexity: TaskComplexity,
        preferred_type: Optional[str] = None
    ) -> str:
        """
        Select the optimal worker type based on complexity and availability.

        Falls back to available workers if the optimal type has none registered.

        Args:
            complexity: Task complexity
            preferred_type: Override the automatic selection

        Returns:
            Worker type string
        """
        if preferred_type:
            return preferred_type

        optimal = COMPLEXITY_ROUTING.get(complexity, WorkerType.GENERIC.value)

        # Check if optimal worker type has available workers
        available = self.liaison.get_worker_count(optimal)
        if available > 0:
            return optimal

        # Fall back: try each type from cheapest to most expensive
        fallback_order = [
            WorkerType.GEMINI_FLASH.value,
            WorkerType.GEMINI_PRO.value,
            WorkerType.CLAUDE_CODE.value,
            WorkerType.GENERIC.value,
        ]

        for wtype in fallback_order:
            if self.liaison.get_worker_count(wtype) > 0:
                logger.info(
                    f"Worker fallback: {optimal} unavailable, using {wtype}"
                )
                return wtype

        # No workers available, dispatch to optimal queue anyway
        # (worker may come online later)
        logger.warning(f"No workers available, queuing for {optimal}")
        return optimal

    # =========================================================================
    # DISPATCH
    # =========================================================================

    def dispatch(
        self,
        description: str,
        complexity: Optional[str] = None,
        worker_type: Optional[str] = None,
        priority: int = 5,
        timeout: int = 300,
        context: str = "",
        metadata: Optional[Dict] = None
    ) -> Optional[SwarmTask]:
        """
        Dispatch a task with intelligent routing.

        Classifies complexity, selects worker, dispatches via liaison,
        and logs to PostgreSQL.

        Args:
            description: What needs to be done
            complexity: Override complexity ("simple", "medium", "complex")
            worker_type: Override worker type
            priority: 1=HIGH, 5=MEDIUM, 10=LOW
            timeout: Max seconds for completion
            context: Additional context
            metadata: Optional metadata

        Returns:
            SwarmTask if dispatched, None on failure
        """
        # Classify complexity
        if complexity:
            task_complexity = TaskComplexity(complexity)
        else:
            task_complexity = self.classify_complexity(description)

        # Select worker
        selected_worker = self.select_worker_type(task_complexity, worker_type)

        # Dispatch via liaison
        task = self.liaison.dispatch_task(
            task_description=description,
            worker_type=selected_worker,
            priority=priority,
            timeout=timeout,
            context=context,
            metadata={
                **(metadata or {}),
                "complexity": task_complexity.value,
                "routed_by": "task_dispatcher"
            }
        )

        if task:
            self._log_task_to_pg(task, task_complexity)
            logger.info(
                f"Dispatched: {task.task_id} -> {selected_worker} "
                f"(complexity={task_complexity.value}, priority={priority})"
            )

        return task

    def dispatch_batch(
        self,
        tasks: List[Dict]
    ) -> List[SwarmTask]:
        """
        Dispatch multiple tasks with per-task routing.

        Each task dict should have:
          - description (required)
          - priority (optional, default 5)
          - timeout (optional, default 300)
          - context (optional)
          - complexity (optional, auto-detected)
          - worker_type (optional, auto-selected)

        Args:
            tasks: List of task specification dicts

        Returns:
            List of dispatched SwarmTask objects
        """
        dispatched = []

        for spec in tasks:
            task = self.dispatch(
                description=spec.get("description", ""),
                complexity=spec.get("complexity"),
                worker_type=spec.get("worker_type"),
                priority=spec.get("priority", 5),
                timeout=spec.get("timeout", 300),
                context=spec.get("context", ""),
                metadata=spec.get("metadata")
            )
            if task:
                dispatched.append(task)

        logger.info(f"Batch dispatched: {len(dispatched)}/{len(tasks)} tasks")
        return dispatched

    # =========================================================================
    # RESULT COLLECTION WITH RETRY
    # =========================================================================

    def await_result(
        self,
        task: SwarmTask,
        timeout: Optional[int] = None
    ) -> Optional[SwarmResult]:
        """
        Wait for a task result with automatic retry on failure.

        If the task fails or times out, retries up to MAX_RETRIES times
        with exponential backoff.

        Args:
            task: The dispatched SwarmTask
            timeout: Override timeout (uses task's timeout by default)

        Returns:
            SwarmResult on success, None if all retries exhausted
        """
        effective_timeout = timeout or task.timeout_seconds

        # Try collecting the result
        result = self.liaison.collect_result(task.task_id, timeout=effective_timeout)

        if result and result.status == TaskStatus.COMPLETED.value:
            self._update_task_status(
                task.task_id,
                TaskStatus.COMPLETED.value,
                result_summary=str(result.result)[:500] if result.result else None,
                tokens_used=result.tokens_used,
                duration_ms=result.duration_ms
            )
            return result

        # Handle failure / timeout with retries
        retry_count = 0
        while retry_count < self.MAX_RETRIES:
            retry_count += 1
            delay = self.INITIAL_RETRY_DELAY * (self.RETRY_BACKOFF_FACTOR ** (retry_count - 1))

            logger.warning(
                f"Task {task.task_id} failed/timed out. "
                f"Retry {retry_count}/{self.MAX_RETRIES} after {delay}s"
            )

            time.sleep(delay)

            # Re-dispatch with same parameters
            retry_task = self.liaison.dispatch_task(
                task_description=task.description,
                worker_type=task.worker_type,
                priority=task.priority,
                timeout=task.timeout_seconds,
                context=task.context,
                metadata={**task.metadata, "retry_of": task.task_id, "retry_count": retry_count}
            )

            if not retry_task:
                continue

            self._update_task_status(
                task.task_id,
                "RETRYING",
                error_message=f"Retry {retry_count}, new task: {retry_task.task_id}"
            )

            result = self.liaison.collect_result(
                retry_task.task_id,
                timeout=effective_timeout
            )

            if result and result.status == TaskStatus.COMPLETED.value:
                self._update_task_status(
                    task.task_id,
                    TaskStatus.COMPLETED.value,
                    result_summary=str(result.result)[:500] if result.result else None,
                    tokens_used=result.tokens_used,
                    duration_ms=result.duration_ms
                )
                return result

        # All retries exhausted
        error_msg = (
            f"Task failed after {self.MAX_RETRIES} retries. "
            f"Last status: {result.status if result else 'TIMEOUT'}"
        )
        self._update_task_status(
            task.task_id,
            TaskStatus.FAILED.value,
            error_message=error_msg
        )
        logger.error(f"Task {task.task_id}: {error_msg}")

        return result  # Return last result even if failed

    def await_batch(
        self,
        tasks: List[SwarmTask],
        timeout: int = 600
    ) -> Dict[str, Optional[SwarmResult]]:
        """
        Collect results for a batch of dispatched tasks.

        Args:
            tasks: List of dispatched SwarmTask objects
            timeout: Total timeout for all results

        Returns:
            Dict mapping task_id -> SwarmResult
        """
        task_ids = [t.task_id for t in tasks]
        results = self.liaison.collect_results_batch(task_ids, timeout=timeout)

        # Update PostgreSQL status for each
        for task_id, result in results.items():
            if result and result.status == TaskStatus.COMPLETED.value:
                self._update_task_status(
                    task_id,
                    TaskStatus.COMPLETED.value,
                    result_summary=str(result.result)[:500] if result.result else None,
                    tokens_used=result.tokens_used,
                    duration_ms=result.duration_ms
                )
            elif result:
                self._update_task_status(
                    task_id,
                    result.status,
                    error_message=result.error
                )
            else:
                self._update_task_status(
                    task_id,
                    TaskStatus.TIMEOUT.value,
                    error_message="Result collection timed out"
                )

        return results

    # =========================================================================
    # POSTGRESQL LOGGING
    # =========================================================================

    def _log_task_to_pg(self, task: SwarmTask, complexity: TaskComplexity):
        """Log a newly dispatched task to PostgreSQL."""
        if not self._pg_conn:
            return

        try:
            with self._pg_conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO aiva_swarm_tasks
                    (task_id, description, worker_type, complexity, priority,
                     status, dispatched_at, metadata)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                    ON CONFLICT (task_id) DO UPDATE SET
                        status = EXCLUDED.status,
                        dispatched_at = EXCLUDED.dispatched_at
                """, (
                    task.task_id,
                    task.description[:1000],
                    task.worker_type,
                    complexity.value,
                    task.priority,
                    task.status,
                    task.dispatched_at,
                    json.dumps(task.metadata)
                ))
        except Exception as e:
            logger.warning(f"PG log failed (non-fatal): {e}")

    def _update_task_status(
        self,
        task_id: str,
        status: str,
        result_summary: Optional[str] = None,
        error_message: Optional[str] = None,
        tokens_used: int = 0,
        duration_ms: float = 0
    ):
        """Update task status in PostgreSQL."""
        if not self._pg_conn:
            return

        try:
            with self._pg_conn.cursor() as cur:
                cur.execute("""
                    UPDATE aiva_swarm_tasks
                    SET status = %s,
                        result_summary = COALESCE(%s, result_summary),
                        error_message = COALESCE(%s, error_message),
                        tokens_used = GREATEST(tokens_used, %s),
                        duration_ms = GREATEST(duration_ms, %s),
                        completed_at = CASE
                            WHEN %s IN ('COMPLETED', 'FAILED', 'TIMEOUT')
                            THEN NOW()
                            ELSE completed_at
                        END
                    WHERE task_id = %s
                """, (
                    status,
                    result_summary,
                    error_message,
                    tokens_used,
                    duration_ms,
                    status,
                    task_id
                ))
        except Exception as e:
            logger.warning(f"PG status update failed (non-fatal): {e}")

    # =========================================================================
    # ANALYTICS
    # =========================================================================

    def get_task_stats(self) -> Dict[str, Any]:
        """
        Get task dispatch statistics from PostgreSQL.

        Returns:
            Dict with counts by status, worker type, and complexity
        """
        if not self._pg_conn:
            return {"error": "PostgreSQL not connected"}

        try:
            with self._pg_conn.cursor() as cur:
                # By status
                cur.execute("""
                    SELECT status, COUNT(*) as cnt
                    FROM aiva_swarm_tasks
                    GROUP BY status
                """)
                by_status = {row[0]: row[1] for row in cur.fetchall()}

                # By worker type
                cur.execute("""
                    SELECT worker_type, COUNT(*) as cnt
                    FROM aiva_swarm_tasks
                    GROUP BY worker_type
                """)
                by_worker = {row[0]: row[1] for row in cur.fetchall()}

                # By complexity
                cur.execute("""
                    SELECT complexity, COUNT(*) as cnt
                    FROM aiva_swarm_tasks
                    WHERE complexity IS NOT NULL
                    GROUP BY complexity
                """)
                by_complexity = {row[0]: row[1] for row in cur.fetchall()}

                # Average duration for completed tasks
                cur.execute("""
                    SELECT AVG(duration_ms), SUM(tokens_used)
                    FROM aiva_swarm_tasks
                    WHERE status = 'COMPLETED'
                """)
                avg_row = cur.fetchone()
                avg_duration = round(avg_row[0] or 0, 2)
                total_tokens = avg_row[1] or 0

                return {
                    "by_status": by_status,
                    "by_worker_type": by_worker,
                    "by_complexity": by_complexity,
                    "avg_duration_ms": avg_duration,
                    "total_tokens": total_tokens,
                    "timestamp": datetime.now().isoformat()
                }

        except Exception as e:
            return {"error": str(e)}

    def get_status(self) -> Dict[str, Any]:
        """
        Get dispatcher status including liaison health.

        Returns:
            Combined status dict
        """
        return {
            "dispatcher": {
                "postgres_connected": self._pg_conn is not None,
                "max_retries": self.MAX_RETRIES,
                "routing": {k.value: v for k, v in COMPLEXITY_ROUTING.items()}
            },
            "liaison": self.liaison.get_status(),
            "stats": self.get_task_stats()
        }

    # =========================================================================
    # CLEANUP
    # =========================================================================

    def close(self):
        """Close all connections."""
        if self._pg_conn:
            try:
                self._pg_conn.close()
            except Exception:
                pass
        self.liaison.close()
        logger.info("TaskDispatcher closed")


# Module-level singleton
_dispatcher_instance: Optional[TaskDispatcher] = None


def get_task_dispatcher() -> TaskDispatcher:
    """Get or create the singleton TaskDispatcher instance."""
    global _dispatcher_instance
    if _dispatcher_instance is None:
        _dispatcher_instance = TaskDispatcher()
    return _dispatcher_instance
