"""
Retriever Orchestrator — fan-out to all registered retrievers, merge, rank.
PRD: _bmad-output/RLM_NERVOUS_SYSTEM_PRD.md (Story 2.5)
"""

from __future__ import annotations

import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List

from core.nervous_system.contracts import (
    IRetriever,
    RetrievalRequest,
    RetrievalResult,
    RetrievedChunk,
)

logger = logging.getLogger("nervous_system.orchestrator")


class RetrieverOrchestrator:
    """Fans out queries to registered retrievers, merges and ranks results."""

    def __init__(self, timeout_ms: int = 2000):
        self._retrievers: Dict[str, IRetriever] = {}
        self._timeout_s = timeout_ms / 1000.0

    def register_retriever(self, retriever: IRetriever) -> None:
        """Register a retriever backend."""
        self._retrievers[retriever.source_name] = retriever

    def query(self, request: RetrievalRequest) -> RetrievalResult:
        """Execute retrieval across all registered sources in parallel."""
        start = time.monotonic()
        all_chunks: List[RetrievedChunk] = []
        sources_queried: List[str] = []
        sources_failed: List[str] = []

        # Filter to requested sources
        active = {
            name: ret for name, ret in self._retrievers.items()
            if any(name.startswith(s) for s in request.sources)
        }

        if not active:
            return RetrievalResult(
                chunks=[], latency_ms=0,
                sources_queried=[], sources_failed=list(self._retrievers.keys()),
            )

        def _fetch(name: str, retriever: IRetriever) -> tuple:
            try:
                chunks = retriever.retrieve(request)
                return (name, chunks, None)
            except Exception as exc:
                return (name, [], exc)

        with ThreadPoolExecutor(max_workers=len(active)) as pool:
            futures = {
                pool.submit(_fetch, name, ret): name
                for name, ret in active.items()
            }
            try:
                for future in as_completed(futures, timeout=self._timeout_s):
                    name, chunks, error = future.result()
                    if error:
                        logger.warning(f"Retriever {name} failed: {error}")
                        sources_failed.append(name)
                    else:
                        sources_queried.append(name)
                        all_chunks.extend(chunks)
            except TimeoutError:
                logger.warning(f"Retrieval timed out after {self._timeout_s}s")

            # Check for timed-out futures
            for future, name in futures.items():
                if not future.done():
                    sources_failed.append(name)
                    future.cancel()

        ranked = self._merge_and_rank(all_chunks, request.top_k)
        elapsed_ms = (time.monotonic() - start) * 1000

        return RetrievalResult(
            chunks=ranked,
            latency_ms=round(elapsed_ms, 1),
            sources_queried=sources_queried,
            sources_failed=sources_failed,
        )

    def health(self) -> Dict[str, bool]:
        """Return health status of all registered retrievers."""
        return {name: ret.health_check() for name, ret in self._retrievers.items()}

    @staticmethod
    def _merge_and_rank(
        chunks: List[RetrievedChunk], top_k: int
    ) -> List[RetrievedChunk]:
        """Deduplicate by content hash, rank by relevance_score, take top_k."""
        seen: set = set()
        unique: List[RetrievedChunk] = []
        for chunk in chunks:
            key = hash(chunk.content[:200])
            if key in seen:
                continue
            seen.add(key)
            unique.append(chunk)
        unique.sort(key=lambda c: c.relevance_score, reverse=True)
        return unique[:top_k]
