"""
AIVA Swarm Liaison - Bidirectional Worker Communication Bridge
================================================================

Core liaison that enables bidirectional communication between AIVA's daemon
and external worker agents (Claude Code, Gemini Flash, Gemini Pro).

Architecture:
    - Redis pub/sub for real-time messaging
    - Redis lists for task queues (LPUSH/BRPOP reliable queuing)
    - Redis hashes with TTL for worker registry (auto-expire on missed heartbeat)
    - All connection params from elestio_config.RedisConfig

Channel Layout:
    aiva:swarm:tasks:{worker_type}     -- Task queues per worker type (list)
    aiva:swarm:results:{task_id}       -- Per-task result channel (list)
    aiva:swarm:workers                 -- Worker registry (hash)
    aiva:swarm:worker:{worker_id}      -- Worker detail + heartbeat (hash w/ TTL)
    aiva:swarm:messages:{worker_id}    -- Direct message inbox per worker (list)
    aiva:swarm:broadcast:{worker_type} -- Broadcast channel per type (pub/sub)
    aiva:swarm:broadcast:all           -- Global broadcast channel (pub/sub)

NO SQLITE. All storage uses Elestio Redis.

VERIFICATION_STAMP
Story: AIVA-SWARM-001
Verified By: Claude Opus 4.6
Verified At: 2026-02-11
Component: Swarm Liaison (bidirectional worker communication)
"""

import sys
import json
import time
import uuid
import logging
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum

# Elestio config
sys.path.append('/mnt/e/genesis-system/data/genesis-memory')
from elestio_config import RedisConfig

logger = logging.getLogger("AIVA.SwarmLiaison")


# =============================================================================
# DATA STRUCTURES
# =============================================================================

class WorkerType(Enum):
    """Types of worker agents in the Genesis swarm."""
    CLAUDE_CODE = "claude_code"
    GEMINI_FLASH = "gemini_flash"
    GEMINI_PRO = "gemini_pro"
    GENERIC = "generic"


class TaskStatus(Enum):
    """Task lifecycle states."""
    PENDING = "PENDING"
    DISPATCHED = "DISPATCHED"
    IN_PROGRESS = "IN_PROGRESS"
    COMPLETED = "COMPLETED"
    FAILED = "FAILED"
    TIMEOUT = "TIMEOUT"
    CANCELLED = "CANCELLED"


class TaskPriority(Enum):
    """Task priority levels."""
    HIGH = 1
    MEDIUM = 5
    LOW = 10


@dataclass
class SwarmTask:
    """A task dispatched to the swarm."""
    task_id: str
    description: str
    context: str = ""
    worker_type: str = "generic"
    priority: int = 5
    timeout_seconds: int = 300
    status: str = "PENDING"
    created_at: str = ""
    dispatched_at: Optional[str] = None
    metadata: Dict = field(default_factory=dict)

    def __post_init__(self):
        if not self.task_id:
            self.task_id = f"swarm-{uuid.uuid4().hex[:12]}"
        if not self.created_at:
            self.created_at = datetime.now().isoformat()

    def to_dict(self) -> Dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: Dict) -> "SwarmTask":
        return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})


@dataclass
class SwarmResult:
    """Result returned by a worker."""
    task_id: str
    status: str = "COMPLETED"
    result: Any = None
    error: Optional[str] = None
    worker_id: str = ""
    duration_ms: float = 0.0
    tokens_used: int = 0
    completed_at: str = ""
    metadata: Dict = field(default_factory=dict)

    def __post_init__(self):
        if not self.completed_at:
            self.completed_at = datetime.now().isoformat()

    def to_dict(self) -> Dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: Dict) -> "SwarmResult":
        return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})


@dataclass
class WorkerInfo:
    """Worker registration info."""
    worker_id: str
    worker_type: str
    capabilities: List[str] = field(default_factory=list)
    registered_at: str = ""
    last_heartbeat: str = ""
    status: str = "active"
    metadata: Dict = field(default_factory=dict)

    def __post_init__(self):
        if not self.worker_id:
            self.worker_id = f"worker-{uuid.uuid4().hex[:8]}"
        if not self.registered_at:
            self.registered_at = datetime.now().isoformat()
        if not self.last_heartbeat:
            self.last_heartbeat = self.registered_at

    def to_dict(self) -> Dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: Dict) -> "WorkerInfo":
        if isinstance(data.get("capabilities"), str):
            data["capabilities"] = json.loads(data["capabilities"])
        if isinstance(data.get("metadata"), str):
            data["metadata"] = json.loads(data["metadata"])
        return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})


# =============================================================================
# SWARM LIAISON
# =============================================================================

class SwarmLiaison:
    """
    Bidirectional communication bridge between AIVA and worker agents.

    Manages:
    - Task dispatch to workers via Redis queues
    - Result collection from workers
    - Worker registry with heartbeat-based health
    - Direct messaging and broadcast
    """

    # Redis key prefixes
    TASK_QUEUE_PREFIX = "aiva:swarm:tasks"
    RESULT_PREFIX = "aiva:swarm:results"
    WORKER_REGISTRY = "aiva:swarm:workers"
    WORKER_DETAIL_PREFIX = "aiva:swarm:worker"
    MESSAGE_PREFIX = "aiva:swarm:messages"
    BROADCAST_PREFIX = "aiva:swarm:broadcast"

    # Worker heartbeat TTL (seconds) - if no heartbeat in this time, worker expires
    WORKER_TTL = 120  # 2 minutes

    # Default timeouts
    DEFAULT_TASK_TIMEOUT = 300  # 5 minutes
    RESULT_POLL_INTERVAL = 0.5  # seconds between polls

    def __init__(self):
        """Initialize the Swarm Liaison with Elestio Redis."""
        self._redis = None
        self._connected = False
        self._connect()

    def _connect(self):
        """Establish Redis connection using Elestio config."""
        try:
            import redis
            params = RedisConfig.get_connection_params()
            self._redis = redis.Redis(**params)
            self._redis.ping()
            self._connected = True
            logger.info(
                f"Swarm Liaison connected to Redis: "
                f"{params['host']}:{params['port']}"
            )
        except Exception as e:
            self._connected = False
            logger.warning(f"Swarm Liaison Redis connection failed: {e}")

    @property
    def is_connected(self) -> bool:
        """Check if Redis connection is healthy."""
        if not self._redis or not self._connected:
            return False
        try:
            self._redis.ping()
            return True
        except Exception:
            self._connected = False
            return False

    # =========================================================================
    # TASK DISPATCH
    # =========================================================================

    def dispatch_task(
        self,
        task_description: str,
        worker_type: str = "generic",
        priority: int = 5,
        timeout: int = 300,
        context: str = "",
        metadata: Optional[Dict] = None
    ) -> Optional[SwarmTask]:
        """
        Dispatch a task to a worker queue.

        Uses Redis LPUSH to add task to the appropriate worker-type queue.
        Workers consume tasks via BRPOP for reliable FIFO ordering.

        Args:
            task_description: What the worker should do
            worker_type: Target worker type (claude_code, gemini_flash, etc.)
            priority: 1=HIGH, 5=MEDIUM, 10=LOW
            timeout: Max seconds for task completion
            context: Additional context for the task
            metadata: Optional metadata dict

        Returns:
            SwarmTask if dispatched successfully, None on failure
        """
        if not self.is_connected:
            logger.error("Cannot dispatch task: Redis not connected")
            return None

        task = SwarmTask(
            task_id="",  # auto-generated
            description=task_description,
            context=context,
            worker_type=worker_type,
            priority=priority,
            timeout_seconds=timeout,
            status=TaskStatus.DISPATCHED.value,
            dispatched_at=datetime.now().isoformat(),
            metadata=metadata or {}
        )

        try:
            queue_key = f"{self.TASK_QUEUE_PREFIX}:{worker_type}"
            task_json = json.dumps(task.to_dict())

            # LPUSH for FIFO when consumed with BRPOP
            self._redis.lpush(queue_key, task_json)

            # Also store task metadata for status tracking
            task_meta_key = f"{self.RESULT_PREFIX}:{task.task_id}:meta"
            self._redis.set(
                task_meta_key,
                json.dumps({"status": TaskStatus.DISPATCHED.value, "task": task.to_dict()}),
                ex=timeout + 60  # TTL slightly longer than task timeout
            )

            logger.info(
                f"Task dispatched: {task.task_id} -> {worker_type} "
                f"(priority={priority}, timeout={timeout}s)"
            )
            return task

        except Exception as e:
            logger.error(f"Failed to dispatch task: {e}")
            return None

    def dispatch_batch(
        self,
        tasks: List[Dict],
        worker_type: str = "generic"
    ) -> List[SwarmTask]:
        """
        Dispatch multiple tasks in a pipeline for efficiency.

        Args:
            tasks: List of dicts with 'description', 'priority', 'timeout', 'context'
            worker_type: Default worker type for all tasks

        Returns:
            List of dispatched SwarmTask objects
        """
        if not self.is_connected:
            logger.error("Cannot dispatch batch: Redis not connected")
            return []

        dispatched = []
        try:
            pipe = self._redis.pipeline()
            queue_key = f"{self.TASK_QUEUE_PREFIX}:{worker_type}"

            for task_spec in tasks:
                task = SwarmTask(
                    task_id="",
                    description=task_spec.get("description", ""),
                    context=task_spec.get("context", ""),
                    worker_type=task_spec.get("worker_type", worker_type),
                    priority=task_spec.get("priority", 5),
                    timeout_seconds=task_spec.get("timeout", self.DEFAULT_TASK_TIMEOUT),
                    status=TaskStatus.DISPATCHED.value,
                    dispatched_at=datetime.now().isoformat(),
                    metadata=task_spec.get("metadata", {})
                )

                actual_queue = f"{self.TASK_QUEUE_PREFIX}:{task.worker_type}"
                pipe.lpush(actual_queue, json.dumps(task.to_dict()))

                task_meta_key = f"{self.RESULT_PREFIX}:{task.task_id}:meta"
                pipe.set(
                    task_meta_key,
                    json.dumps({"status": TaskStatus.DISPATCHED.value, "task": task.to_dict()}),
                    ex=task.timeout_seconds + 60
                )

                dispatched.append(task)

            pipe.execute()
            logger.info(f"Batch dispatched: {len(dispatched)} tasks")

        except Exception as e:
            logger.error(f"Batch dispatch failed: {e}")

        return dispatched

    # =========================================================================
    # RESULT COLLECTION
    # =========================================================================

    def collect_result(
        self,
        task_id: str,
        timeout: int = 300
    ) -> Optional[SwarmResult]:
        """
        Wait for and collect the result of a specific task.

        Uses Redis BRPOP on the task-specific result channel for
        blocking wait with timeout.

        Args:
            task_id: Task ID to wait for
            timeout: Max seconds to wait

        Returns:
            SwarmResult if received, None on timeout/error
        """
        if not self.is_connected:
            logger.error("Cannot collect result: Redis not connected")
            return None

        result_key = f"{self.RESULT_PREFIX}:{task_id}"

        try:
            # BRPOP blocks until result is available or timeout
            raw = self._redis.brpop(result_key, timeout=timeout)

            if raw is None:
                logger.warning(f"Result timeout for task {task_id} after {timeout}s")
                return None

            # raw is a tuple (key, value)
            result_data = json.loads(raw[1])
            result = SwarmResult.from_dict(result_data)

            logger.info(
                f"Result collected: {task_id} -> {result.status} "
                f"(duration={result.duration_ms:.0f}ms)"
            )
            return result

        except Exception as e:
            logger.error(f"Failed to collect result for {task_id}: {e}")
            return None

    def collect_results_batch(
        self,
        task_ids: List[str],
        timeout: int = 300
    ) -> Dict[str, Optional[SwarmResult]]:
        """
        Collect results for multiple tasks.

        Polls all task result channels with a shared timeout budget.

        Args:
            task_ids: List of task IDs
            timeout: Total timeout budget for all results

        Returns:
            Dict mapping task_id -> SwarmResult (None if not received)
        """
        if not self.is_connected:
            return {tid: None for tid in task_ids}

        results: Dict[str, Optional[SwarmResult]] = {}
        pending = set(task_ids)
        start = time.time()

        while pending and (time.time() - start) < timeout:
            for task_id in list(pending):
                result_key = f"{self.RESULT_PREFIX}:{task_id}"
                try:
                    raw = self._redis.rpop(result_key)
                    if raw:
                        result_data = json.loads(raw)
                        results[task_id] = SwarmResult.from_dict(result_data)
                        pending.discard(task_id)
                except Exception as e:
                    logger.debug(f"Poll error for {task_id}: {e}")

            if pending:
                time.sleep(self.RESULT_POLL_INTERVAL)

        # Mark remaining as not received
        for task_id in pending:
            results[task_id] = None
            logger.warning(f"Result not received for task {task_id}")

        return results

    def post_result(self, result: SwarmResult) -> bool:
        """
        Post a result back for a task (called by workers).

        Args:
            result: The SwarmResult to post

        Returns:
            True if posted successfully
        """
        if not self.is_connected:
            return False

        try:
            result_key = f"{self.RESULT_PREFIX}:{result.task_id}"
            result_json = json.dumps(result.to_dict())

            # LPUSH so the collector's BRPOP gets it
            self._redis.lpush(result_key, result_json)

            # Set TTL on the result key so it doesn't persist forever
            self._redis.expire(result_key, 3600)  # 1 hour TTL

            # Update task metadata
            task_meta_key = f"{self.RESULT_PREFIX}:{result.task_id}:meta"
            self._redis.set(
                task_meta_key,
                json.dumps({"status": result.status, "result": result.to_dict()}),
                ex=3600
            )

            logger.info(f"Result posted for task {result.task_id}: {result.status}")
            return True

        except Exception as e:
            logger.error(f"Failed to post result: {e}")
            return False

    def consume_task(
        self,
        worker_type: str,
        timeout: int = 30
    ) -> Optional[SwarmTask]:
        """
        Consume a task from the queue (called by workers).

        Uses BRPOP for reliable FIFO consumption.

        Args:
            worker_type: Which queue to consume from
            timeout: Max seconds to wait for a task

        Returns:
            SwarmTask if available, None on timeout
        """
        if not self.is_connected:
            return None

        queue_key = f"{self.TASK_QUEUE_PREFIX}:{worker_type}"

        try:
            raw = self._redis.brpop(queue_key, timeout=timeout)
            if raw is None:
                return None

            task_data = json.loads(raw[1])
            task = SwarmTask.from_dict(task_data)
            task.status = TaskStatus.IN_PROGRESS.value

            logger.info(f"Task consumed: {task.task_id} by {worker_type}")
            return task

        except Exception as e:
            logger.error(f"Failed to consume task: {e}")
            return None

    # =========================================================================
    # WORKER REGISTRY
    # =========================================================================

    def register_worker(
        self,
        worker_id: str,
        worker_type: str,
        capabilities: Optional[List[str]] = None,
        metadata: Optional[Dict] = None
    ) -> bool:
        """
        Register a worker in the swarm registry.

        Workers are stored in a Redis hash (registry) plus a per-worker
        detail key with TTL that expires if heartbeat stops.

        Args:
            worker_id: Unique worker identifier
            worker_type: Worker type (claude_code, gemini_flash, etc.)
            capabilities: List of capability strings
            metadata: Optional metadata

        Returns:
            True if registered successfully
        """
        if not self.is_connected:
            return False

        worker = WorkerInfo(
            worker_id=worker_id,
            worker_type=worker_type,
            capabilities=capabilities or [],
            metadata=metadata or {}
        )

        try:
            # Add to registry hash (worker_id -> type mapping)
            self._redis.hset(
                self.WORKER_REGISTRY,
                worker_id,
                worker_type
            )

            # Store detailed worker info with TTL
            detail_key = f"{self.WORKER_DETAIL_PREFIX}:{worker_id}"
            worker_data = worker.to_dict()
            worker_data["capabilities"] = json.dumps(worker_data["capabilities"])
            worker_data["metadata"] = json.dumps(worker_data["metadata"])

            self._redis.hset(detail_key, mapping=worker_data)
            self._redis.expire(detail_key, self.WORKER_TTL)

            logger.info(
                f"Worker registered: {worker_id} (type={worker_type}, "
                f"capabilities={capabilities})"
            )
            return True

        except Exception as e:
            logger.error(f"Failed to register worker: {e}")
            return False

    def heartbeat(self, worker_id: str) -> bool:
        """
        Worker heartbeat - refreshes TTL to keep worker alive.

        Workers should call this every WORKER_TTL/2 seconds.

        Args:
            worker_id: Worker identifier

        Returns:
            True if heartbeat acknowledged
        """
        if not self.is_connected:
            return False

        try:
            detail_key = f"{self.WORKER_DETAIL_PREFIX}:{worker_id}"

            # Check worker exists
            if not self._redis.exists(detail_key):
                logger.warning(f"Heartbeat for unknown worker: {worker_id}")
                return False

            # Update heartbeat timestamp and refresh TTL
            now = datetime.now().isoformat()
            self._redis.hset(detail_key, "last_heartbeat", now)
            self._redis.expire(detail_key, self.WORKER_TTL)

            return True

        except Exception as e:
            logger.error(f"Heartbeat failed for {worker_id}: {e}")
            return False

    def deregister_worker(self, worker_id: str) -> bool:
        """
        Remove a worker from the registry.

        Args:
            worker_id: Worker to remove

        Returns:
            True if removed
        """
        if not self.is_connected:
            return False

        try:
            self._redis.hdel(self.WORKER_REGISTRY, worker_id)
            detail_key = f"{self.WORKER_DETAIL_PREFIX}:{worker_id}"
            self._redis.delete(detail_key)
            logger.info(f"Worker deregistered: {worker_id}")
            return True
        except Exception as e:
            logger.error(f"Failed to deregister worker: {e}")
            return False

    def get_available_workers(
        self,
        worker_type: Optional[str] = None
    ) -> List[WorkerInfo]:
        """
        Get all available workers, optionally filtered by type.

        Only returns workers whose detail keys haven't expired
        (i.e., they've sent a heartbeat recently).

        Args:
            worker_type: Filter by type, or None for all

        Returns:
            List of WorkerInfo for available workers
        """
        if not self.is_connected:
            return []

        workers = []
        try:
            # Get all registered workers
            registry = self._redis.hgetall(self.WORKER_REGISTRY)

            for wid, wtype in registry.items():
                # Filter by type if specified
                if worker_type and wtype != worker_type:
                    continue

                # Check if worker detail key still exists (heartbeat alive)
                detail_key = f"{self.WORKER_DETAIL_PREFIX}:{wid}"
                detail = self._redis.hgetall(detail_key)

                if detail:
                    workers.append(WorkerInfo.from_dict(detail))
                else:
                    # Worker expired, clean up registry
                    self._redis.hdel(self.WORKER_REGISTRY, wid)

        except Exception as e:
            logger.error(f"Failed to get workers: {e}")

        return workers

    def get_worker_count(self, worker_type: Optional[str] = None) -> int:
        """Get count of available workers."""
        return len(self.get_available_workers(worker_type))

    # =========================================================================
    # MESSAGING
    # =========================================================================

    def send_message(
        self,
        worker_id: str,
        message: Dict
    ) -> bool:
        """
        Send a direct message to a specific worker.

        Args:
            worker_id: Target worker
            message: Message dict (must include 'type' and 'content')

        Returns:
            True if sent
        """
        if not self.is_connected:
            return False

        try:
            msg_key = f"{self.MESSAGE_PREFIX}:{worker_id}"
            envelope = {
                "from": "aiva_mother",
                "to": worker_id,
                "timestamp": datetime.now().isoformat(),
                "message": message
            }
            self._redis.lpush(msg_key, json.dumps(envelope))
            # Set reasonable TTL on the message queue
            self._redis.expire(msg_key, 3600)

            logger.debug(f"Message sent to {worker_id}")
            return True

        except Exception as e:
            logger.error(f"Failed to send message to {worker_id}: {e}")
            return False

    def broadcast(
        self,
        message: Dict,
        worker_type: Optional[str] = None
    ) -> int:
        """
        Broadcast a message to all workers of a given type (or all workers).

        Uses Redis pub/sub for real-time delivery.

        Args:
            message: Message dict
            worker_type: Target type, or None for all workers

        Returns:
            Number of subscribers that received the message
        """
        if not self.is_connected:
            return 0

        try:
            envelope = {
                "from": "aiva_mother",
                "timestamp": datetime.now().isoformat(),
                "target_type": worker_type or "all",
                "message": message
            }

            if worker_type:
                channel = f"{self.BROADCAST_PREFIX}:{worker_type}"
            else:
                channel = f"{self.BROADCAST_PREFIX}:all"

            recipients = self._redis.publish(channel, json.dumps(envelope))
            logger.info(f"Broadcast to {channel}: {recipients} recipients")
            return recipients

        except Exception as e:
            logger.error(f"Broadcast failed: {e}")
            return 0

    def receive_messages(
        self,
        worker_id: str,
        timeout: int = 5,
        max_messages: int = 10
    ) -> List[Dict]:
        """
        Receive pending messages for a worker (called by workers).

        Args:
            worker_id: Worker's ID
            timeout: Max wait time if no messages
            max_messages: Max messages to retrieve

        Returns:
            List of message dicts
        """
        if not self.is_connected:
            return []

        messages = []
        msg_key = f"{self.MESSAGE_PREFIX}:{worker_id}"

        try:
            for _ in range(max_messages):
                raw = self._redis.rpop(msg_key)
                if raw is None:
                    break
                messages.append(json.loads(raw))

        except Exception as e:
            logger.error(f"Failed to receive messages for {worker_id}: {e}")

        return messages

    # =========================================================================
    # QUEUE STATUS
    # =========================================================================

    def get_queue_depths(self) -> Dict[str, int]:
        """
        Get the depth of each worker-type task queue.

        Returns:
            Dict mapping worker_type -> queue depth
        """
        if not self.is_connected:
            return {}

        depths = {}
        for wtype in WorkerType:
            queue_key = f"{self.TASK_QUEUE_PREFIX}:{wtype.value}"
            try:
                depth = self._redis.llen(queue_key)
                if depth > 0:
                    depths[wtype.value] = depth
            except Exception:
                pass

        return depths

    def get_status(self) -> Dict[str, Any]:
        """
        Get overall swarm liaison status.

        Returns:
            Dict with connection status, queue depths, worker counts
        """
        status = {
            "connected": self.is_connected,
            "timestamp": datetime.now().isoformat()
        }

        if self.is_connected:
            status["queue_depths"] = self.get_queue_depths()
            status["workers"] = {}
            for wtype in WorkerType:
                count = self.get_worker_count(wtype.value)
                if count > 0:
                    status["workers"][wtype.value] = count
            status["total_workers"] = sum(status["workers"].values())

        return status

    # =========================================================================
    # CLEANUP
    # =========================================================================

    def close(self):
        """Close Redis connection."""
        if self._redis:
            try:
                self._redis.close()
            except Exception:
                pass
            self._connected = False
            logger.info("Swarm Liaison closed")


# Module-level singleton
_liaison_instance: Optional[SwarmLiaison] = None


def get_swarm_liaison() -> SwarmLiaison:
    """Get or create the singleton SwarmLiaison instance."""
    global _liaison_instance
    if _liaison_instance is None:
        _liaison_instance = SwarmLiaison()
    return _liaison_instance
