"""
core/coherence/swarm_worker_base.py

SwarmWorkerBase — ABC that claims tasks from Redis Stream via XREADGROUP
and sends XACK after successful completion.

Workers subclass this and implement `process(task: dict) -> Optional[dict]`.
The base class handles all stream mechanics:

  * XREADGROUP GROUP {group} {consumer_id} COUNT 1 BLOCK 5000
      — receives exactly one new message at a time (">")
  * XACK on success — guarantees at-least-once delivery with idempotent workers
  * No XACK on failure — task re-enters the Pending Entry List (PEL)
      and will be re-claimed after PEL_TIMEOUT_MS (60 s)
  * XAUTOCLAIM on startup — recovers orphaned pending entries from crashed workers

Usage::

    class MyWorker(SwarmWorkerBase):
        async def process(self, task: dict) -> Optional[dict]:
            # ... do work ...
            return {"result": "ok"}

    worker = MyWorker(redis_client, staging_area=staging)
    await worker.run_worker_loop(group="genesis_workers", consumer_id="worker-1")

# VERIFICATION_STAMP
# Story: 6.04
# Verified By: parallel-builder
# Verified At: 2026-02-25
# Tests: 15/15
# Coverage: 100%
"""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import Optional

from core.coherence.task_dag_pusher import STREAM_KEY, DEFAULT_GROUP

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

PEL_TIMEOUT_MS: int = 60000  # 60-second visibility timeout before re-claim


# ---------------------------------------------------------------------------
# Main ABC
# ---------------------------------------------------------------------------


class SwarmWorkerBase(ABC):
    """
    Base class for all Gemini swarm workers.

    Claims tasks from Redis Stream via XREADGROUP with Consumer Groups.
    Guarantees exactly-once delivery: XACK only on success.

    Subclasses must implement :meth:`process`.

    Consumer Group mechanics
    ------------------------
    * Each worker instance calls ``run_worker_loop`` with a unique
      ``consumer_id`` (e.g. ``"worker-1"``, ``"worker-2"``).
    * Redis delivers each message to exactly ONE consumer in the group.
    * Un-ACK'd messages stay in the Pending Entry List (PEL) and are
      reclaimed by the next available worker after ``PEL_TIMEOUT_MS``.
    * On startup, :meth:`_reclaim_pending` uses ``XAUTOCLAIM`` to absorb
      orphaned entries from previously crashed workers.

    Args:
        redis_client: Async Redis client supporting:
                      xreadgroup, xack, xautoclaim
        staging_area: Optional StagingArea — when provided and ``process``
                      returns a non-None result, the result is forwarded via
                      ``staging_area.submit_delta(result)``.
    """

    def __init__(self, redis_client, staging_area=None) -> None:
        self.redis = redis_client
        self.staging = staging_area
        self._running: bool = False

    # ------------------------------------------------------------------
    # Abstract interface — subclasses implement this
    # ------------------------------------------------------------------

    @abstractmethod
    async def process(self, task: dict) -> Optional[dict]:
        """
        Process a single task claimed from the stream.

        Args:
            task: Dict of task fields decoded from the Redis Stream entry.
                  All keys and values are Python ``str`` (bytes decoded).

        Returns:
            Result dict on success, or ``None``.
            When a non-None result is returned AND a ``staging_area`` was
            provided at construction time, the result is forwarded to
            ``staging_area.submit_delta(result)``.

        Raises:
            Any exception signals processing failure.
            The task entry will NOT be ACK'd and will re-queue after
            ``PEL_TIMEOUT_MS`` milliseconds.
        """
        ...  # pragma: no cover

    # ------------------------------------------------------------------
    # Main worker loop
    # ------------------------------------------------------------------

    async def run_worker_loop(
        self,
        group: str = DEFAULT_GROUP,
        consumer_id: str = "worker-1",
    ) -> None:
        """
        Main loop: continuously claims and processes tasks from the stream.

        Flow per iteration::

            XREADGROUP GROUP {group} {consumer_id}
                       COUNT 1 BLOCK 5000
                       STREAMS genesis:swarm:tasks >
              → process(task)
              → on success:  XACK genesis:swarm:tasks {group} {entry_id}
              → on failure:  log error, no XACK (re-queues after PEL timeout)

        On startup, calls :meth:`_reclaim_pending` to recover orphaned entries.

        Args:
            group:       Consumer Group name (default: ``"genesis_workers"``).
            consumer_id: Unique name for this consumer instance
                         (default: ``"worker-1"``).
        """
        self._running = True

        # Recover orphaned pending entries from crashed consumers
        await self._reclaim_pending(group, consumer_id)

        while self._running:
            # ----------------------------------------------------------------
            # Claim one new message (">") — block up to 5 s for new entries
            # ----------------------------------------------------------------
            try:
                entries = await self.redis.xreadgroup(
                    group,
                    consumer_id,
                    {STREAM_KEY: ">"},
                    count=1,
                    block=5000,
                )
            except Exception as exc:
                logger.error("SwarmWorkerBase: xreadgroup failed: %s", exc)
                continue

            if not entries:
                # Timeout with no new messages — loop again
                continue

            stream_name, messages = entries[0]

            for entry_id, fields in messages:
                # ------------------------------------------------------------
                # Decode bytes → str (Redis may return bytes)
                # ------------------------------------------------------------
                task: dict = {}
                for k, v in fields.items():
                    key = k.decode() if isinstance(k, bytes) else k
                    val = v.decode() if isinstance(v, bytes) else v
                    task[key] = val

                # Normalise entry_id to str once
                eid = entry_id.decode() if isinstance(entry_id, bytes) else entry_id

                # ------------------------------------------------------------
                # Process — ACK only on success
                # ------------------------------------------------------------
                try:
                    result = await self.process(task)

                    # Forward result to staging area if configured
                    if self.staging is not None and result is not None:
                        await self.staging.submit_delta(result)

                    # Acknowledge: task is done, remove from PEL
                    await self.redis.xack(STREAM_KEY, group, eid)
                    logger.debug(
                        "SwarmWorkerBase: processed and ACK'd entry %s", eid
                    )

                except Exception as exc:
                    # Do NOT ACK — task stays in PEL and re-queues after timeout
                    logger.error(
                        "SwarmWorkerBase: process() failed for entry %s: %s",
                        eid,
                        exc,
                    )

    # ------------------------------------------------------------------
    # Pending Entry List reclaim (XAUTOCLAIM)
    # ------------------------------------------------------------------

    async def _reclaim_pending(self, group: str, consumer_id: str) -> int:
        """
        Reclaim orphaned pending entries older than ``PEL_TIMEOUT_MS``.

        Called once on :meth:`run_worker_loop` startup.  Uses ``XAUTOCLAIM``
        to transfer stale pending entries (from crashed workers) to this
        consumer so they are reprocessed.

        XAUTOCLAIM semantics::

            XAUTOCLAIM genesis:swarm:tasks {group} {consumer_id}
                       {PEL_TIMEOUT_MS} 0-0

        Args:
            group:       Consumer Group name.
            consumer_id: Name of this consumer (entries transferred to it).

        Returns:
            Number of reclaimed entries (0 on failure or nothing to claim).
        """
        try:
            result = await self.redis.xautoclaim(
                STREAM_KEY,
                group,
                consumer_id,
                min_idle_time=PEL_TIMEOUT_MS,
                start_id="0-0",
            )
            # xautoclaim returns (next_start_id, [(entry_id, fields), ...], [deleted_ids])
            # Some client versions return only (next_start_id, entries).
            if result and len(result) >= 2:
                reclaimed = result[1] if result[1] is not None else []
                count = len(reclaimed)
                if count > 0:
                    logger.info(
                        "SwarmWorkerBase: reclaimed %d orphaned pending entries",
                        count,
                    )
                return count
            return 0

        except Exception as exc:
            logger.warning("SwarmWorkerBase: xautoclaim failed: %s", exc)
            return 0

    # ------------------------------------------------------------------
    # Graceful shutdown
    # ------------------------------------------------------------------

    def stop(self) -> None:
        """
        Signal graceful shutdown.

        Sets ``_running = False``.  The worker loop exits after the current
        iteration completes (including the current BLOCK wait).
        """
        self._running = False
        logger.info("SwarmWorkerBase: stop() called — will exit after current iteration")
