"""
core/coherence/staging_area.py

StagingArea — Redis Hash-backed collection hub for StateDelta proposals.

Each worker agent submits their proposed StateDelta here before the
OCCCommitEngine (Opus review) collects and processes them.

Key pattern:  genesis:staging:{session_id}
Hash fields:  {agent_id: json(delta)}
TTL:          600s (auto-cleanup if no collector arrives)

# VERIFICATION_STAMP
# Story: 6.05
# Verified By: parallel-builder
# Verified At: 2026-02-25
# Tests: 10/10
# Coverage: 100%
"""

from __future__ import annotations

import asyncio
import json
import logging
import time
from typing import Optional

logger = logging.getLogger(__name__)

STAGING_KEY_PREFIX = "genesis:staging:"
STAGING_TTL_SECONDS = 600  # 10-minute auto-cleanup


class StagingArea:
    """
    Redis Hash-backed staging area for StateDelta proposals.

    Each agent submits their proposed delta here. The OCCCommitEngine
    collects all deltas when ready for review.

    Key pattern:  genesis:staging:{session_id}
    Hash fields:  {agent_id: json(delta)}
    TTL:          600s (auto-cleanup if no one collects)
    """

    def __init__(self, redis_client) -> None:
        self.redis = redis_client

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    async def submit_delta(self, delta) -> None:
        """
        Store a StateDelta in the staging hash for its session.

        HSET genesis:staging:{session_id} {agent_id} {json(delta)}
        Then EXPIRE the hash key to STAGING_TTL_SECONDS.
        TTL is refreshed on every submission to keep the window alive
        until the slowest writer finishes.

        Args:
            delta: A StateDelta instance (or any object with agent_id,
                   session_id, version_at_read, patch, submitted_at).
        """
        key = f"{STAGING_KEY_PREFIX}{delta.session_id}"

        # Serialise submitted_at — datetime objects have isoformat();
        # anything else is coerced to str so we never crash.
        submitted_at_str = (
            delta.submitted_at.isoformat()
            if hasattr(delta.submitted_at, "isoformat")
            else str(delta.submitted_at)
        )

        serialized = json.dumps(
            {
                "agent_id": delta.agent_id,
                "session_id": delta.session_id,
                "version_at_read": delta.version_at_read,
                "patch": list(delta.patch),  # tuple → list for JSON
                "submitted_at": submitted_at_str,
            }
        )

        await self.redis.hset(key, delta.agent_id, serialized)
        await self.redis.expire(key, STAGING_TTL_SECONDS)

        logger.debug(
            "StagingArea: stored delta from %s for session %s",
            delta.agent_id,
            delta.session_id,
        )

    async def collect_all(self, session_id: str) -> list:
        """
        Return all staged deltas for the given session.

        HGETALL genesis:staging:{session_id}
        Deserializes each field value from JSON.
        Invalid/malformed entries are skipped with a warning.

        Returns:
            List of delta dicts (may be empty if nothing staged yet).
        """
        key = f"{STAGING_KEY_PREFIX}{session_id}"
        raw = await self.redis.hgetall(key)
        if not raw:
            return []

        deltas: list = []
        for agent_id, serialized in raw.items():
            # Redis may return bytes depending on the client configuration.
            if isinstance(agent_id, bytes):
                agent_id = agent_id.decode()
            if isinstance(serialized, bytes):
                serialized = serialized.decode()

            try:
                delta_dict = json.loads(serialized)
                deltas.append(delta_dict)
            except (json.JSONDecodeError, Exception) as exc:
                logger.warning(
                    "StagingArea: failed to deserialize delta from %s: %s",
                    agent_id,
                    exc,
                )

        return deltas

    async def clear(self, session_id: str) -> None:
        """
        Remove the entire staging hash for the given session.

        DEL genesis:staging:{session_id}
        Called by the OCCCommitEngine after a successful commit cycle.
        """
        key = f"{STAGING_KEY_PREFIX}{session_id}"
        await self.redis.delete(key)
        logger.debug(
            "StagingArea: cleared staging for session %s", session_id
        )

    async def wait_for_all(
        self,
        session_id: str,
        expected_count: int,
        timeout_ms: int = 60_000,
    ) -> list:
        """
        Poll the staging hash until ``expected_count`` deltas have arrived
        or ``timeout_ms`` milliseconds have elapsed.

        If the timeout is reached before all deltas arrive, whatever is
        currently in the hash is returned (partial result — never raises).

        Args:
            session_id:     Session to watch.
            expected_count: Number of deltas to wait for.
            timeout_ms:     Maximum wait time in milliseconds (default 60 s).

        Returns:
            List of deserialized delta dicts (may be partial on timeout).
        """
        key = f"{STAGING_KEY_PREFIX}{session_id}"
        deadline = time.monotonic() + (timeout_ms / 1000.0)
        poll_interval = 0.1  # 100 ms between polls

        while time.monotonic() < deadline:
            count = await self.redis.hlen(key)
            if count >= expected_count:
                return await self.collect_all(session_id)
            await asyncio.sleep(poll_interval)

        # Timeout — log and return whatever we have (partial is valid)
        final_count = await self.redis.hlen(key)
        logger.warning(
            "StagingArea: wait_for_all timed out for session %s "
            "(got %d/%d deltas)",
            session_id,
            final_count,
            expected_count,
        )
        return await self.collect_all(session_id)
