"""RLMCaptureAgent — async agent loop for capturing call transcripts.

Story 3.05 — AIVA RLM Nexus PRD v2 — Track A
Story 3.06 — Transcript Fetch + Speaker Labeling (extends 3.05)
"""
import asyncio
import json
import logging
import os
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

import httpx

logger = logging.getLogger(__name__)

TELNYX_API_BASE = "https://api.telnyx.com/v2/ai/assistants"

EVENTS_DIR = Path("/mnt/e/genesis-system/data/observability")
MAX_RUNTIME_SECONDS = 3600  # Safety guard against zombie agents
POLL_INTERVAL_SECONDS = 30  # Check every 30s


class RLMCaptureAgent:
    """
    Async capture loop that polls for call-ended signal.
    Designed to run as a background asyncio task during active calls.

    The agent polls Redis every POLL_INTERVAL_SECONDS seconds looking for
    the key ``aiva:state:{session_id}`` to contain ``{"status": "ended"}``.
    When detected — or when MAX_RUNTIME_SECONDS elapses — the loop exits
    cleanly and logs a stop event.  Any Redis failure is treated as
    fail-safe: the agent keeps running (returns False from _is_call_ended).
    """

    def __init__(
        self,
        session_id: str,
        call_control_id: str,
        redis_client=None,
        call_direction: str = "inbound",
    ) -> None:
        self.session_id = session_id
        self.call_control_id = call_control_id
        self._redis = redis_client
        self.call_direction = call_direction
        self._running: bool = True
        self._started_at: Optional[float] = None
        self._consecutive_failures: int = 0
        self._chunk_index: int = 0
        # Track the last raw transcript text to detect new content
        self._last_transcript_text: str = ""

    async def run(self) -> None:
        """
        Main loop.  Runs until call ended or max runtime exceeded.

        - Polls every POLL_INTERVAL_SECONDS (30 s)
        - Checks Redis ``aiva:state:{session_id}`` for ``{"status": "ended"}``
          to stop
        - Max runtime: MAX_RUNTIME_SECONDS (3600 s) — zombie guard
        - Returns normally (no exception) after stopping
        """
        self._started_at = time.monotonic()
        self._log_event("agent_started")

        while self._running:
            # Zombie guard: hard cap on total runtime
            elapsed = time.monotonic() - self._started_at
            if elapsed >= MAX_RUNTIME_SECONDS:
                self._log_event("zombie_guard_triggered", {"elapsed_seconds": elapsed})
                break

            # Poll Redis for call-ended signal
            if await self._is_call_ended():
                self._log_event("call_ended_detected")
                break

            # Capture transcript chunk for this poll cycle
            await self._capture_cycle()

            # Wait for the next poll cycle
            await asyncio.sleep(POLL_INTERVAL_SECONDS)

        self._running = False
        self._log_event(
            "agent_stopped",
            {"total_seconds": time.monotonic() - self._started_at},
        )

    async def _is_call_ended(self) -> bool:
        """
        Returns True if Redis state shows 'ended'.
        Returns False if key is missing (call still active) or on any error.
        """
        if not self._redis:
            return False
        try:
            key = f"aiva:state:{self.session_id}"
            raw = await self._redis.get(key)
            if not raw:
                return False  # Key absent → call still active
            state = json.loads(raw)
            return state.get("status") == "ended"
        except Exception:
            # Redis failure → fail-safe: keep running
            return False

    async def stop(self) -> None:
        """Graceful shutdown.  Sets _running = False.  Idempotent."""
        self._running = False

    async def _fetch_and_label_chunk(self) -> Optional[dict]:
        """
        Fetches the latest transcript from the Telnyx AI Assistant transcript API
        and labels the speaker based on the call direction.

        Speaker labels:
            "AIVA"     — AI assistant turns (role == "assistant")
            "KINAN"    — human turns when call_direction == "outbound"
            "CUSTOMER" — human turns when call_direction != "outbound" (inbound)

        Returns:
            {
                "t": float,           # unix timestamp of fetch
                "speaker": str,       # "AIVA" | "KINAN" | "CUSTOMER"
                "text": str,          # transcript text
                "chunk_index": int,   # monotonically increasing counter
            }
            Or None if no new content was available since the last fetch.

        This method NEVER raises — all errors are caught and logged internally.
        """
        api_key = os.environ.get("TELNYX_API_KEY", "")
        url = (
            f"{TELNYX_API_BASE}/{self.session_id}"
            f"/transcript?call_control_id={self.call_control_id}"
        )
        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        }

        try:
            async with httpx.AsyncClient(timeout=10.0) as client:
                response = await client.get(url, headers=headers)
                response.raise_for_status()
                payload = response.json()
        except httpx.TimeoutException:
            self._log_event("fetch_timeout", {"url": url})
            return None
        except httpx.HTTPStatusError as exc:
            self._log_event(
                "fetch_http_error",
                {"status_code": exc.response.status_code, "url": url},
            )
            return None
        except Exception:
            self._log_event("fetch_error", {"url": url})
            return None

        # Extract transcript text from the Telnyx response.
        # Telnyx returns: {"data": {"transcript": "...", "messages": [...]}}
        # We support both shapes gracefully.
        try:
            data = payload.get("data", payload)

            # Prefer the structured messages array (role + content)
            messages = data.get("messages") if isinstance(data, dict) else None
            if messages and isinstance(messages, list) and len(messages) > 0:
                # Take the most recent message
                latest = messages[-1]
                role = latest.get("role", "user")
                text = latest.get("content", "")
            else:
                # Fall back to flat transcript string
                transcript_text = (
                    data.get("transcript", "") if isinstance(data, dict) else ""
                )
                if not transcript_text:
                    return None
                # Determine role from transcript text heuristic or default
                role = "user"
                text = transcript_text

            if not text:
                return None

            # Deduplicate: skip if text hasn't changed since last fetch
            if text == self._last_transcript_text:
                return None

            self._last_transcript_text = text

            # Label the speaker
            if role == "assistant":
                speaker = "AIVA"
            elif self.call_direction == "outbound":
                speaker = "KINAN"
            else:
                speaker = "CUSTOMER"

            chunk = {
                "t": time.time(),
                "speaker": speaker,
                "text": text,
                "chunk_index": self._chunk_index,
            }
            self._chunk_index += 1
            return chunk

        except Exception:
            self._log_event("fetch_parse_error")
            return None

    async def _capture_cycle(self) -> None:
        """
        One capture cycle called every 30 s from the run() loop.

        1. Call _fetch_and_label_chunk() to get the latest transcript chunk.
        2. If a chunk is returned: RPUSH it (as JSON) to
           ``aiva:transcript:{session_id}`` in Redis.
        3. If _fetch_and_label_chunk() returns None: no-op (normal — no new content).
        4. Any exception (from fetch OR from Redis RPUSH) is caught, logged as a
           warning, and the agent continues running — NEVER raises.
        5. After 3 consecutive failed cycles: log at ERROR level (escalation signal).
        6. A successful cycle resets the consecutive-failure counter to 0.
        """
        try:
            chunk = await self._fetch_and_label_chunk()
        except Exception as exc:
            # _fetch_and_label_chunk should never raise, but guard anyway
            self._consecutive_failures += 1
            logger.warning(
                "rlm_capture: _fetch_and_label_chunk raised unexpectedly "
                "(session=%s, failures=%d): %s",
                self.session_id,
                self._consecutive_failures,
                exc,
            )
            if self._consecutive_failures >= 3:
                logger.error(
                    "rlm_capture: 3+ consecutive failures (session=%s, failures=%d) — "
                    "escalation signal",
                    self.session_id,
                    self._consecutive_failures,
                )
            return

        if chunk is None:
            # Normal: no new transcript content since last poll
            return

        # Attempt Redis RPUSH
        redis_key = f"aiva:transcript:{self.session_id}"
        try:
            await self._redis.rpush(redis_key, json.dumps(chunk))
            # Success — reset failure counter
            self._consecutive_failures = 0
        except Exception as exc:
            self._consecutive_failures += 1
            logger.warning(
                "rlm_capture: Redis RPUSH failed for key=%s "
                "(session=%s, failures=%d): %s",
                redis_key,
                self.session_id,
                self._consecutive_failures,
                exc,
            )
            if self._consecutive_failures >= 3:
                logger.error(
                    "rlm_capture: 3+ consecutive failures (session=%s, failures=%d) — "
                    "escalation signal",
                    self.session_id,
                    self._consecutive_failures,
                )

    def _log_event(self, event_type: str, extra: Optional[dict] = None) -> None:
        """Append a JSONL event to the observability log.  Never raises."""
        try:
            EVENTS_DIR.mkdir(parents=True, exist_ok=True)
            event: dict = {
                "timestamp": datetime.now(timezone.utc).isoformat(),
                "event_type": f"rlm_capture_{event_type}",
                "session_id": self.session_id,
                "call_control_id": self.call_control_id,
            }
            if extra:
                event.update(extra)
            with open(EVENTS_DIR / "events.jsonl", "a") as f:
                f.write(json.dumps(event) + "\n")
        except Exception:
            pass  # Observability must never break the agent


# VERIFICATION_STAMP
# Story: 3.07
# Verified By: parallel-builder
# Verified At: 2026-02-25
# Tests: 26/26 (test_story_3_07.py) + 53/53 regression (3.05 + 3.06)
# Coverage: 100%
# Prior stamps — 3.06: 26/26; 3.05: 11/11
