"""
AIVA Task Queue Consumer - PM-027

AIVA pulls tasks from Redis queue with priority ordering.
Supports task lifecycle: pending -> in_progress -> completed/failed.
"""

import os
import json
import logging
import time
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict, field
from enum import Enum

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class TaskPriority(Enum):
    """Task priority levels."""
    CRITICAL = 1    # Immediate execution
    HIGH = 2        # Execute within 1 minute
    NORMAL = 3      # Execute within 5 minutes
    LOW = 4         # Execute when idle
    BACKGROUND = 5  # Long-running background tasks


class TaskStatus(Enum):
    """Task execution status."""
    PENDING = "pending"
    IN_PROGRESS = "in_progress"
    COMPLETED = "completed"
    FAILED = "failed"
    TIMEOUT = "timeout"
    CANCELLED = "cancelled"


@dataclass
class Task:
    """Represents a task in the queue."""
    task_id: str
    task_type: str
    description: str
    params: Dict
    priority: TaskPriority = TaskPriority.NORMAL
    status: TaskStatus = TaskStatus.PENDING
    created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
    started_at: Optional[str] = None
    completed_at: Optional[str] = None
    result: Optional[Any] = None
    error: Optional[str] = None
    retry_count: int = 0
    max_retries: int = 3
    timeout_seconds: int = 300
    metadata: Dict = field(default_factory=dict)

    def to_dict(self) -> Dict:
        data = asdict(self)
        data["priority"] = self.priority.value
        data["status"] = self.status.value
        return data

    @classmethod
    def from_dict(cls, data: Dict) -> "Task":
        data["priority"] = TaskPriority(data["priority"])
        data["status"] = TaskStatus(data["status"])
        return cls(**data)

    def to_json(self) -> str:
        return json.dumps(self.to_dict())

    @classmethod
    def from_json(cls, json_str: str) -> "Task":
        return cls.from_dict(json.loads(json_str))


class RedisTaskQueue:
    """Redis-backed task queue."""

    def __init__(
        self,
        host: str = "localhost",
        port: int = 6379,
        db: int = 0,
        queue_prefix: str = "aiva:tasks"
    ):
        """Initialize Redis connection."""
        self.host = host
        self.port = port
        self.db = db
        self.queue_prefix = queue_prefix
        self._client = None
        self._connect()

    def _connect(self) -> bool:
        """Establish Redis connection."""
        try:
            import redis
            self._client = redis.Redis(
                host=self.host,
                port=self.port,
                db=self.db,
                decode_responses=True
            )
            self._client.ping()
            logger.info(f"Connected to Redis at {self.host}:{self.port}")
            return True
        except ImportError:
            logger.warning("redis package not installed, using in-memory queue")
            self._client = None
            return False
        except Exception as e:
            logger.error(f"Redis connection failed: {e}")
            self._client = None
            return False

    def _get_queue_key(self, priority: TaskPriority) -> str:
        """Get queue key for priority level."""
        return f"{self.queue_prefix}:p{priority.value}"

    def _get_processing_key(self) -> str:
        """Get key for processing set."""
        return f"{self.queue_prefix}:processing"

    def push(self, task: Task) -> bool:
        """
        Push a task to the queue.

        Args:
            task: Task to add

        Returns:
            True if successful
        """
        queue_key = self._get_queue_key(task.priority)

        if self._client:
            try:
                # Use sorted set with timestamp as score for ordering
                score = time.time()
                self._client.zadd(queue_key, {task.to_json(): score})
                logger.debug(f"Task {task.task_id} pushed to {queue_key}")
                return True
            except Exception as e:
                logger.error(f"Failed to push task: {e}")
                return False

        return False

    def pop(self, priority: Optional[TaskPriority] = None) -> Optional[Task]:
        """
        Pop the next task from the queue.

        Args:
            priority: Specific priority to pop from (None = highest available)

        Returns:
            Next task or None
        """
        if not self._client:
            return None

        # Check queues in priority order
        priorities = [priority] if priority else list(TaskPriority)

        for p in priorities:
            queue_key = self._get_queue_key(p)
            try:
                # Pop from sorted set (ZPOPMIN gets lowest score = oldest)
                result = self._client.zpopmin(queue_key, count=1)
                if result:
                    task_json, _ = result[0]
                    task = Task.from_json(task_json)

                    # Mark as in progress
                    task.status = TaskStatus.IN_PROGRESS
                    task.started_at = datetime.utcnow().isoformat()

                    # Add to processing set
                    self._client.sadd(self._get_processing_key(), task.to_json())

                    logger.debug(f"Task {task.task_id} popped from {queue_key}")
                    return task
            except Exception as e:
                logger.error(f"Failed to pop task: {e}")

        return None

    def complete(self, task: Task, result: Any = None) -> bool:
        """
        Mark task as completed.

        Args:
            task: Task to complete
            result: Execution result

        Returns:
            True if successful
        """
        if not self._client:
            return False

        task.status = TaskStatus.COMPLETED
        task.completed_at = datetime.utcnow().isoformat()
        task.result = result

        try:
            # Remove from processing
            processing_key = self._get_processing_key()
            # Find and remove from processing set
            for member in self._client.smembers(processing_key):
                member_task = Task.from_json(member)
                if member_task.task_id == task.task_id:
                    self._client.srem(processing_key, member)
                    break

            # Store completed task
            completed_key = f"{self.queue_prefix}:completed"
            self._client.lpush(completed_key, task.to_json())
            self._client.ltrim(completed_key, 0, 999)  # Keep last 1000

            logger.info(f"Task {task.task_id} completed")
            return True
        except Exception as e:
            logger.error(f"Failed to complete task: {e}")
            return False

    def fail(self, task: Task, error: str) -> bool:
        """
        Mark task as failed.

        Args:
            task: Task that failed
            error: Error message

        Returns:
            True if successful
        """
        task.status = TaskStatus.FAILED
        task.error = error
        task.completed_at = datetime.utcnow().isoformat()

        if not self._client:
            return False

        try:
            # Remove from processing
            processing_key = self._get_processing_key()
            for member in self._client.smembers(processing_key):
                member_task = Task.from_json(member)
                if member_task.task_id == task.task_id:
                    self._client.srem(processing_key, member)
                    break

            # Check retry
            if task.retry_count < task.max_retries:
                task.retry_count += 1
                task.status = TaskStatus.PENDING
                task.started_at = None
                task.error = None
                self.push(task)
                logger.info(f"Task {task.task_id} queued for retry ({task.retry_count}/{task.max_retries})")
            else:
                # Store in failed queue
                failed_key = f"{self.queue_prefix}:failed"
                self._client.lpush(failed_key, task.to_json())
                self._client.ltrim(failed_key, 0, 999)
                logger.warning(f"Task {task.task_id} failed after {task.max_retries} retries")

            return True
        except Exception as e:
            logger.error(f"Failed to mark task as failed: {e}")
            return False

    def get_queue_depth(self) -> Dict[str, int]:
        """Get count of tasks per priority."""
        depths = {}
        if self._client:
            for p in TaskPriority:
                queue_key = self._get_queue_key(p)
                depths[p.name] = self._client.zcard(queue_key)
        return depths

    def get_processing_count(self) -> int:
        """Get count of tasks being processed."""
        if self._client:
            return self._client.scard(self._get_processing_key())
        return 0


class TaskConsumer:
    """
    Consumes tasks from the queue.

    Usage:
        consumer = TaskConsumer()
        while True:
            task = consumer.pop_task()
            if task:
                result = process_task(task)
                consumer.complete_task(task, result)
    """

    def __init__(
        self,
        queue: Optional[RedisTaskQueue] = None,
        skill_registry=None,
        permission_manager=None
    ):
        """
        Initialize task consumer.

        Args:
            queue: RedisTaskQueue instance
            skill_registry: SkillRegistry for task execution
            permission_manager: PermissionManager for validation
        """
        self.queue = queue or RedisTaskQueue()
        self.skill_registry = skill_registry
        self.permission_manager = permission_manager
        self.current_task: Optional[Task] = None
        self._running = False
        logger.info("TaskConsumer initialized")

    def pop_task(self) -> Optional[Task]:
        """
        Pop the next task to process.

        Returns:
            Next task or None if queue empty
        """
        task = self.queue.pop()
        if task:
            self.current_task = task
            logger.info(f"Popped task: {task.task_id} ({task.task_type})")
        return task

    def complete_task(self, task: Task, result: Any = None) -> bool:
        """
        Mark current task as completed.

        Args:
            task: Task to complete
            result: Execution result

        Returns:
            True if successful
        """
        success = self.queue.complete(task, result)
        if success and self.current_task and self.current_task.task_id == task.task_id:
            self.current_task = None
        return success

    def fail_task(self, task: Task, error: str) -> bool:
        """
        Mark current task as failed.

        Args:
            task: Task that failed
            error: Error message

        Returns:
            True if successful
        """
        success = self.queue.fail(task, error)
        if success and self.current_task and self.current_task.task_id == task.task_id:
            self.current_task = None
        return success

    def submit_task(
        self,
        task_type: str,
        description: str,
        params: Dict,
        priority: TaskPriority = TaskPriority.NORMAL,
        **kwargs
    ) -> Task:
        """
        Submit a new task to the queue.

        Args:
            task_type: Type of task
            description: Human-readable description
            params: Task parameters
            priority: Task priority
            **kwargs: Additional task fields

        Returns:
            Created Task
        """
        task = Task(
            task_id=f"task_{int(time.time() * 1000)}_{hash(description) % 10000}",
            task_type=task_type,
            description=description,
            params=params,
            priority=priority,
            **kwargs
        )

        self.queue.push(task)
        logger.info(f"Submitted task: {task.task_id}")
        return task

    def get_status(self) -> Dict:
        """Get consumer status."""
        return {
            "running": self._running,
            "current_task": self.current_task.to_dict() if self.current_task else None,
            "queue_depth": self.queue.get_queue_depth(),
            "processing_count": self.queue.get_processing_count()
        }

    def process_one(self) -> Optional[Dict]:
        """
        Process one task from the queue.

        Returns:
            Result dict or None if no task
        """
        task = self.pop_task()
        if not task:
            return None

        try:
            # Execute based on task type
            if task.task_type == "skill" and self.skill_registry:
                skill_id = task.params.get("skill_id")
                skill_params = task.params.get("skill_params", {})
                result = self.skill_registry.execute_skill(skill_id, skill_params)
                if result.success:
                    self.complete_task(task, result.result)
                else:
                    self.fail_task(task, result.error)
                return {"task_id": task.task_id, "success": result.success}

            # Generic task handling
            self.complete_task(task, {"processed": True})
            return {"task_id": task.task_id, "success": True}

        except Exception as e:
            self.fail_task(task, str(e))
            return {"task_id": task.task_id, "success": False, "error": str(e)}


# Singleton instance
_consumer: Optional[TaskConsumer] = None


def get_task_consumer() -> TaskConsumer:
    """Get or create singleton TaskConsumer."""
    global _consumer
    if _consumer is None:
        _consumer = TaskConsumer()
    return _consumer


if __name__ == "__main__":
    # Example usage
    consumer = TaskConsumer()

    # Submit some tasks
    task1 = consumer.submit_task(
        task_type="test",
        description="Test task 1",
        params={"action": "hello"},
        priority=TaskPriority.HIGH
    )

    task2 = consumer.submit_task(
        task_type="test",
        description="Test task 2",
        params={"action": "world"},
        priority=TaskPriority.NORMAL
    )

    print(f"\nQueue status: {consumer.get_status()}")

    # Process tasks
    while True:
        result = consumer.process_one()
        if not result:
            break
        print(f"Processed: {result}")

    print(f"\nFinal status: {consumer.get_status()}")
