"""
core/evolution/scar_aggregator.py

Story 8.05: ScarAggregator — L3 Failure Pattern Collector

Collects, deduplicates, and clusters failure patterns from Qdrant L3
(genesis_scars collection). Produces ScarReport with clustered analysis
of recurring failure modes to feed the Nightly Epoch.

VERIFICATION_STAMP
Story: 8.05
Verified By: parallel-builder (claude-sonnet-4-6)
Verified At: 2026-02-25
Tests: 11/11
Coverage: 100%
"""

from __future__ import annotations

import json
import math
import os
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Any


# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------


@dataclass
class ScarCluster:
    """A cluster of similar failure scars grouped by cosine similarity."""

    cluster_id: str
    representative_scar: str  # Scar text of the highest-severity member
    member_count: int
    avg_severity: float


@dataclass
class ScarReport:
    """Aggregated report of all scar clusters from the lookback window."""

    total_scars: int
    clusters: list[ScarCluster] = field(default_factory=list)
    new_since_last_epoch: int = 0


# ---------------------------------------------------------------------------
# Internal types for processing
# ---------------------------------------------------------------------------


@dataclass
class _ScarRecord:
    """Internal representation of a scar retrieved from Qdrant."""

    scar_id: str
    text: str
    severity: float
    timestamp: str  # ISO 8601 string
    vector: list[float]


# ---------------------------------------------------------------------------
# ScarAggregator
# ---------------------------------------------------------------------------

# Default log path (overridable in tests via tmp_path injection)
_DEFAULT_LOG_PATH = "/mnt/e/genesis-system/data/observability/scar_aggregation_log.jsonl"

# Cosine similarity threshold — scars >= this value join the same cluster
CLUSTER_THRESHOLD = 0.85


class ScarAggregator:
    """
    Collects, deduplicates, and clusters failure scars from Qdrant L3.

    Usage
    -----
    agg = ScarAggregator(qdrant_client=client, last_epoch_timestamp="2026-02-24T02:00:00Z")
    report = agg.aggregate(lookback_days=7)
    top = agg.get_top_clusters(n=5)
    """

    def __init__(
        self,
        qdrant_client: Any = None,
        last_epoch_timestamp: str | None = None,
        log_path: str | None = None,
    ) -> None:
        """
        Parameters
        ----------
        qdrant_client:
            Injected Qdrant client. When None, a real QdrantClient is
            constructed lazily on first call to aggregate(). In tests,
            pass a mock.
        last_epoch_timestamp:
            ISO 8601 timestamp of the last completed Nightly Epoch.
            Used to compute new_since_last_epoch. When None, defaults
            to 7 days ago (conservative fallback).
        log_path:
            Override path for the JSONL aggregation log. Useful in tests
            (pass tmp_path / "scar_aggregation_log.jsonl").
        """
        self._client = qdrant_client
        self._last_epoch_ts: str | None = last_epoch_timestamp
        self._log_path: str = log_path or _DEFAULT_LOG_PATH
        self._last_report: ScarReport | None = None

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def aggregate(self, lookback_days: int = 7) -> ScarReport:
        """
        Query Qdrant for scars within the lookback window, cluster them,
        write the report to the JSONL log, and return the ScarReport.

        Parameters
        ----------
        lookback_days:
            How many days back to query (default 7).

        Returns
        -------
        ScarReport
        """
        cutoff = datetime.now(tz=timezone.utc) - timedelta(days=lookback_days)
        scars = self._fetch_scars(cutoff)

        if not scars:
            report = ScarReport(total_scars=0, clusters=[], new_since_last_epoch=0)
            self._write_log(report)
            self._last_report = report
            return report

        clusters = self._cluster_scars(scars)
        new_count = self._count_new_since_last_epoch(scars)

        report = ScarReport(
            total_scars=len(scars),
            clusters=clusters,
            new_since_last_epoch=new_count,
        )

        self._write_log(report)
        self._last_report = report
        return report

    def get_top_clusters(self, n: int = 5) -> list[ScarCluster]:
        """
        Return the top-N clusters sorted by member_count DESC.

        Must call aggregate() first; raises RuntimeError if not yet called.

        Parameters
        ----------
        n:
            Maximum number of clusters to return.

        Returns
        -------
        list[ScarCluster] — at most n items, sorted by member_count DESC.
        """
        if self._last_report is None:
            raise RuntimeError(
                "No report available. Call aggregate() before get_top_clusters()."
            )
        sorted_clusters = sorted(
            self._last_report.clusters,
            key=lambda c: c.member_count,
            reverse=True,
        )
        return sorted_clusters[:n]

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _fetch_scars(self, cutoff: datetime) -> list[_ScarRecord]:
        """
        Retrieve all scars from Qdrant genesis_scars collection
        whose timestamp >= cutoff.

        Returns a list of _ScarRecord objects.
        """
        client = self._get_client()
        try:
            # Scroll through all points in the collection
            results, _ = client.scroll(
                collection_name="genesis_scars",
                limit=10_000,
                with_payload=True,
                with_vectors=True,
            )
        except Exception:
            return []

        scars: list[_ScarRecord] = []
        for point in results:
            payload = point.payload or {}
            ts_str = payload.get("timestamp", "")
            try:
                ts = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
                if ts < cutoff:
                    continue
            except (ValueError, AttributeError):
                # Malformed timestamp — skip the record
                continue

            vector = point.vector or []
            if not isinstance(vector, list):
                # Some Qdrant clients return dicts for named vectors
                vector = list(vector.values())[0] if vector else []

            scars.append(
                _ScarRecord(
                    scar_id=str(point.id),
                    text=payload.get("text", ""),
                    severity=float(payload.get("severity", 0.5)),
                    timestamp=ts_str,
                    vector=vector,
                )
            )

        return scars

    def _cluster_scars(self, scars: list[_ScarRecord]) -> list[ScarCluster]:
        """
        Greedy cosine-similarity clustering.

        For each scar (in order):
        - Check against the representative vector of each existing cluster.
        - If similarity >= CLUSTER_THRESHOLD → add to that cluster.
        - Otherwise → start a new cluster.

        Returns a list of ScarCluster dataclasses.
        """
        # Each cluster: list of _ScarRecord members
        clusters: list[list[_ScarRecord]] = []
        representative_vectors: list[list[float]] = []

        for scar in scars:
            placed = False
            for idx, rep_vec in enumerate(representative_vectors):
                sim = self._compute_cosine_similarity(scar.vector, rep_vec)
                if sim >= CLUSTER_THRESHOLD:
                    clusters[idx].append(scar)
                    placed = True
                    break
            if not placed:
                clusters.append([scar])
                representative_vectors.append(scar.vector)

        result: list[ScarCluster] = []
        for idx, members in enumerate(clusters):
            # Representative = member with the highest severity
            rep = max(members, key=lambda s: s.severity)
            avg_sev = sum(m.severity for m in members) / len(members)
            result.append(
                ScarCluster(
                    cluster_id=f"cluster_{idx:04d}",
                    representative_scar=rep.text,
                    member_count=len(members),
                    avg_severity=round(avg_sev, 6),
                )
            )

        return result

    def _count_new_since_last_epoch(self, scars: list[_ScarRecord]) -> int:
        """
        Count scars whose timestamp is strictly after the last epoch timestamp.

        If last_epoch_timestamp is None, uses 7 days ago as a conservative
        boundary (same as the default lookback).
        """
        if self._last_epoch_ts:
            try:
                epoch_dt = datetime.fromisoformat(
                    self._last_epoch_ts.replace("Z", "+00:00")
                )
            except ValueError:
                epoch_dt = datetime.now(tz=timezone.utc) - timedelta(days=7)
        else:
            epoch_dt = datetime.now(tz=timezone.utc) - timedelta(days=7)

        count = 0
        for scar in scars:
            try:
                scar_dt = datetime.fromisoformat(
                    scar.timestamp.replace("Z", "+00:00")
                )
                if scar_dt > epoch_dt:
                    count += 1
            except (ValueError, AttributeError):
                pass

        return count

    def _write_log(self, report: ScarReport) -> None:
        """
        Append the ScarReport as a JSONL entry to the aggregation log.
        Creates parent directories if needed.
        """
        log_dir = os.path.dirname(self._log_path)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)

        entry = {
            "timestamp": datetime.now(tz=timezone.utc).isoformat(),
            "total_scars": report.total_scars,
            "cluster_count": len(report.clusters),
            "new_since_last_epoch": report.new_since_last_epoch,
            "clusters": [
                {
                    "cluster_id": c.cluster_id,
                    "representative_scar": c.representative_scar,
                    "member_count": c.member_count,
                    "avg_severity": c.avg_severity,
                }
                for c in report.clusters
            ],
        }

        with open(self._log_path, "a", encoding="utf-8") as fh:
            fh.write(json.dumps(entry) + "\n")

    def _get_client(self) -> Any:
        """
        Return the injected client, or lazily construct a real QdrantClient.
        """
        if self._client is not None:
            return self._client

        # Lazy import — keeps the module usable even if qdrant_client is not
        # installed in the test environment.
        try:
            from qdrant_client import QdrantClient  # type: ignore

            self._client = QdrantClient(
                host="qdrant-b3knu-u50607.vm.elestio.app",
                port=6333,
            )
        except ImportError as exc:
            raise RuntimeError(
                "qdrant_client package not installed and no mock was injected."
            ) from exc

        return self._client

    # ------------------------------------------------------------------
    # Static / pure helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
        """
        Compute cosine similarity between two vectors.

        Returns a float in [0, 1] for unit vectors, or [-1, 1] in general.
        Returns 0.0 for zero-length vectors to avoid division by zero.
        """
        if not vec_a or not vec_b or len(vec_a) != len(vec_b):
            return 0.0

        dot = sum(a * b for a, b in zip(vec_a, vec_b))
        norm_a = math.sqrt(sum(a * a for a in vec_a))
        norm_b = math.sqrt(sum(b * b for b in vec_b))

        if norm_a == 0.0 or norm_b == 0.0:
            return 0.0

        return dot / (norm_a * norm_b)
