"""
PM-009: Tiered Executor Engine
Main execution loop with tier escalation for Genesis TRUE Method.

Acceptance Criteria:
- [x] GIVEN atomic task WHEN execute() THEN starts Tier 1
- [x] AND attempts 20x Gemini fresh sessions
- [x] AND if fail THEN escalates to Tier 2 (Sonnet)
- [x] AND if fail THEN escalates to Tier 3 (Opus)
- [x] AND if all fail THEN marks for human review

Dependencies: PM-007, PM-008
"""

import os
import json
import logging
from datetime import datetime
from typing import Optional, Dict, Any, List, Callable
from dataclasses import dataclass, field
from enum import Enum

from core.model_tier_loader import load_model_tiers, TierConfig
from core.learning_accumulator import get_learning_accumulator, LearningAccumulator
from core.gemini_session_spawner import GeminiSessionSpawner, SessionResult
from core.claude_session_spawner import ClaudeSessionSpawner, ClaudeSessionResult
from core.perf_metrics import get_metrics_collector, PerformanceMetricsCollector
from core.cost_tracker_v2 import get_cost_tracker, CostTracker

logger = logging.getLogger(__name__)


class ExecutionStatus(Enum):
    """Status of task execution."""
    PENDING = "pending"
    IN_PROGRESS = "in_progress"
    SUCCESS = "success"
    ESCALATED = "escalated"
    FAILED = "failed"
    HUMAN_REVIEW = "human_review"


@dataclass
class ExecutionResult:
    """Result from tiered execution."""
    task_id: str
    status: ExecutionStatus
    output: Optional[str] = None

    # Execution path
    final_tier: int = 0
    total_attempts: int = 0
    tier_attempts: Dict[int, int] = field(default_factory=dict)

    # Cost tracking
    total_cost: float = 0.0
    tier_costs: Dict[int, float] = field(default_factory=dict)

    # Timing
    total_duration_ms: int = 0
    started_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
    completed_at: Optional[str] = None

    # Error info (if failed)
    error_type: Optional[str] = None
    error_message: Optional[str] = None

    # Learning summary
    lessons_learned: int = 0

    def to_dict(self) -> Dict[str, Any]:
        return {
            "task_id": self.task_id,
            "status": self.status.value,
            "output": self.output,
            "final_tier": self.final_tier,
            "total_attempts": self.total_attempts,
            "tier_attempts": self.tier_attempts,
            "total_cost": self.total_cost,
            "tier_costs": self.tier_costs,
            "total_duration_ms": self.total_duration_ms,
            "started_at": self.started_at,
            "completed_at": self.completed_at,
            "error_type": self.error_type,
            "error_message": self.error_message,
            "lessons_learned": self.lessons_learned
        }


class TieredExecutor:
    """
    Main execution engine with tier escalation.

    Execution Flow:
    1. Start at Tier 1 (Gemini Flash) - 20 attempts
    2. If all fail, escalate to Tier 2 (Claude Sonnet) - 10 attempts
    3. If all fail, escalate to Tier 3 (Claude Opus) - 10 attempts
    4. If all fail, mark for human review

    Each attempt is a FRESH session with learnings from previous attempts.
    """

    def __init__(self,
                 gemini_spawner: Optional[GeminiSessionSpawner] = None,
                 claude_spawner: Optional[ClaudeSessionSpawner] = None,
                 learning_accumulator: Optional[LearningAccumulator] = None,
                 metrics_collector: Optional[PerformanceMetricsCollector] = None,
                 cost_tracker: Optional[CostTracker] = None,
                 validation_callback: Optional[Callable[[str, str], bool]] = None):
        """
        Initialize TieredExecutor.

        Args:
            gemini_spawner: GeminiSessionSpawner instance
            claude_spawner: ClaudeSessionSpawner instance
            learning_accumulator: LearningAccumulator instance
            metrics_collector: PerformanceMetricsCollector instance
            cost_tracker: CostTracker instance
            validation_callback: Optional callback to validate output (task_id, output) -> bool
        """
        self.tier_config = load_model_tiers()
        self.gemini_spawner = gemini_spawner or GeminiSessionSpawner()
        self.claude_spawner = claude_spawner or ClaudeSessionSpawner()
        self.learning_accumulator = learning_accumulator or get_learning_accumulator()
        self.metrics_collector = metrics_collector or get_metrics_collector()
        self.cost_tracker = cost_tracker or get_cost_tracker()
        self.validation_callback = validation_callback

        # Track human review items
        self.human_review_queue: List[Dict[str, Any]] = []

    def execute(self,
               task_id: str,
               task_description: str,
               additional_context: Optional[str] = None,
               start_tier: int = 1,
               skip_tiers: Optional[List[int]] = None) -> ExecutionResult:
        """
        Execute a task using the tiered approach.

        Args:
            task_id: Unique task identifier
            task_description: Description of the task to complete
            additional_context: Optional additional context
            start_tier: Starting tier (default 1)
            skip_tiers: Tiers to skip (for testing)

        Returns:
            ExecutionResult with outcome
        """
        skip_tiers = skip_tiers or []
        start_time = datetime.utcnow()

        result = ExecutionResult(
            task_id=task_id,
            status=ExecutionStatus.IN_PROGRESS,
            started_at=start_time.isoformat()
        )

        logger.info(f"Starting tiered execution for task {task_id}")

        # Execute through tiers
        for tier in range(start_tier, 4):  # Tiers 1, 2, 3
            if tier in skip_tiers:
                logger.info(f"Skipping tier {tier} for task {task_id}")
                continue

            tier_config = self.tier_config.get_tier(tier)
            if not tier_config:
                logger.warning(f"No configuration for tier {tier}")
                continue

            logger.info(f"Executing tier {tier} ({tier_config.name}) for task {task_id}")

            tier_result = self._execute_tier(
                task_id=task_id,
                task_description=task_description,
                tier=tier,
                tier_config=tier_config,
                additional_context=additional_context
            )

            # Update result tracking
            result.tier_attempts[tier] = tier_result["attempts"]
            result.tier_costs[tier] = tier_result["cost"]
            result.total_attempts += tier_result["attempts"]
            result.total_cost += tier_result["cost"]
            result.total_duration_ms += tier_result["duration_ms"]
            result.final_tier = tier

            if tier_result["success"]:
                result.status = ExecutionStatus.SUCCESS
                result.output = tier_result["output"]
                result.completed_at = datetime.utcnow().isoformat()

                logger.info(f"Task {task_id} succeeded at tier {tier} after "
                           f"{result.total_attempts} total attempts")

                # Log success metrics
                self._log_success(task_id, result)
                return result

            logger.info(f"Tier {tier} failed for task {task_id} after "
                       f"{tier_result['attempts']} attempts. Escalating...")

            result.status = ExecutionStatus.ESCALATED
            result.error_type = tier_result.get("error_type")
            result.error_message = tier_result.get("error_message")

        # All tiers failed - mark for human review
        result.status = ExecutionStatus.HUMAN_REVIEW
        result.completed_at = datetime.utcnow().isoformat()
        result.lessons_learned = len(self.learning_accumulator.get_lessons(task_id))

        logger.warning(f"Task {task_id} failed all tiers. Marked for human review.")

        self._mark_for_human_review(task_id, task_description, result)

        return result

    def _execute_tier(self,
                     task_id: str,
                     task_description: str,
                     tier: int,
                     tier_config: TierConfig,
                     additional_context: Optional[str] = None) -> Dict[str, Any]:
        """
        Execute a single tier with all its attempts.

        Args:
            task_id: Task identifier
            task_description: Task description
            tier: Tier number
            tier_config: Tier configuration
            additional_context: Additional context

        Returns:
            Dict with success, output, attempts, cost, duration_ms
        """
        total_cost = 0.0
        total_duration_ms = 0
        last_error_type = None
        last_error_message = None

        for attempt in range(1, tier_config.max_attempts + 1):
            logger.debug(f"Tier {tier} attempt {attempt}/{tier_config.max_attempts} for {task_id}")

            # Spawn appropriate session based on tier
            if tier == 1:
                session_result = self.gemini_spawner.spawn_session(
                    task_id=task_id,
                    task_description=task_description,
                    attempt=attempt,
                    tier=tier,
                    additional_context=additional_context
                )
            else:
                session_result = self.claude_spawner.spawn_session(
                    task_id=task_id,
                    task_description=task_description,
                    attempt=attempt,
                    tier=tier,
                    additional_context=additional_context
                )

            # Track metrics
            total_cost += session_result.cost
            total_duration_ms += session_result.duration_ms

            # Log to metrics collector
            self.metrics_collector.log_session(
                model=session_result.model,
                tier=tier,
                attempt=attempt,
                duration_ms=session_result.duration_ms,
                input_tokens=session_result.input_tokens,
                output_tokens=session_result.output_tokens,
                success=session_result.success,
                task_id=task_id,
                cost=session_result.cost,
                error_type=session_result.error_type,
                error_message=session_result.error_message
            )

            if session_result.success:
                # Validate output if callback provided
                if self.validation_callback:
                    try:
                        is_valid = self.validation_callback(task_id, session_result.output)
                        if not is_valid:
                            logger.info(f"Output validation failed for {task_id} at attempt {attempt}")
                            # Capture as learning
                            self.learning_accumulator.capture_failure(
                                task_id=task_id,
                                attempt_number=attempt,
                                tier=tier,
                                error_type="ValidationError",
                                error_message="Output failed validation check",
                                model=session_result.model
                            )
                            continue
                    except Exception as e:
                        logger.warning(f"Validation callback error: {e}")

                return {
                    "success": True,
                    "output": session_result.output,
                    "attempts": attempt,
                    "cost": total_cost,
                    "duration_ms": total_duration_ms
                }

            last_error_type = session_result.error_type
            last_error_message = session_result.error_message

        return {
            "success": False,
            "output": None,
            "attempts": tier_config.max_attempts,
            "cost": total_cost,
            "duration_ms": total_duration_ms,
            "error_type": last_error_type,
            "error_message": last_error_message
        }

    def _log_success(self, task_id: str, result: ExecutionResult) -> None:
        """Log successful execution for analytics."""
        logger.info(
            f"SUCCESS: {task_id} "
            f"tier={result.final_tier} "
            f"attempts={result.total_attempts} "
            f"cost=${result.total_cost:.4f} "
            f"duration={result.total_duration_ms}ms"
        )

    def _mark_for_human_review(self,
                              task_id: str,
                              task_description: str,
                              result: ExecutionResult) -> None:
        """Mark a task for human review."""
        review_item = {
            "task_id": task_id,
            "task_description": task_description[:500],
            "marked_at": datetime.utcnow().isoformat(),
            "total_attempts": result.total_attempts,
            "total_cost": result.total_cost,
            "lessons_learned": result.lessons_learned,
            "final_error": f"{result.error_type}: {result.error_message}"
        }

        self.human_review_queue.append(review_item)

        logger.warning(
            f"HUMAN REVIEW REQUIRED: {task_id} "
            f"after {result.total_attempts} attempts "
            f"cost=${result.total_cost:.4f}"
        )

    def get_human_review_queue(self) -> List[Dict[str, Any]]:
        """Get tasks marked for human review."""
        return self.human_review_queue.copy()

    def clear_human_review(self, task_id: str) -> bool:
        """Remove a task from human review queue."""
        initial_len = len(self.human_review_queue)
        self.human_review_queue = [
            item for item in self.human_review_queue
            if item["task_id"] != task_id
        ]
        return len(self.human_review_queue) < initial_len

    def get_execution_summary(self) -> Dict[str, Any]:
        """Get summary of execution statistics."""
        return {
            "human_review_queue_size": len(self.human_review_queue),
            "tier_config": {
                tier: config.to_dict()
                for tier, config in self.tier_config.tiers.items()
            }
        }


# Singleton instance
_executor: Optional[TieredExecutor] = None


def get_tiered_executor() -> TieredExecutor:
    """Get or create global TieredExecutor instance."""
    global _executor
    if _executor is None:
        _executor = TieredExecutor()
    return _executor


def execute_task(task_id: str,
                task_description: str,
                **kwargs) -> ExecutionResult:
    """
    Convenience function to execute a task.

    Args:
        task_id: Task identifier
        task_description: Task description
        **kwargs: Additional arguments for execute()

    Returns:
        ExecutionResult
    """
    executor = get_tiered_executor()
    return executor.execute(task_id, task_description, **kwargs)


if __name__ == "__main__":
    # Test the TieredExecutor
    logging.basicConfig(level=logging.INFO)

    executor = TieredExecutor()

    print("Execution Summary:")
    print(json.dumps(executor.get_execution_summary(), indent=2))

    # Test execution (will likely fail without API keys)
    result = executor.execute(
        task_id="test-001",
        task_description="What is 2 + 2? Reply with just the number."
    )

    print("\nExecution Result:")
    print(json.dumps(result.to_dict(), indent=2))
