"""
Genesis Metrics Collection
==========================
Prometheus-style metrics collection for Genesis memory system.

Features:
- Counter, Gauge, Histogram, Summary metrics
- Labels for dimensional data
- Persistence for recovery
- HTTP endpoint for scraping
- Memory operation tracking

Usage:
    from metrics import GenesisMetrics, Counter, Histogram

    # Use pre-configured metrics
    GenesisMetrics.memory_operations.inc(labels={"tier": "semantic", "op": "store"})
    GenesisMetrics.memory_latency.observe(0.05, labels={"tier": "working"})

    # Get all metrics
    snapshot = GenesisMetrics.snapshot()
"""

import json
import time
import threading
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, field, asdict
from collections import defaultdict
import math

# Default metrics directory
DEFAULT_METRICS_DIR = Path("E:/genesis-system/data/metrics")


@dataclass
class MetricValue:
    """A single metric value with labels."""
    value: float
    labels: Dict[str, str] = field(default_factory=dict)
    timestamp: float = field(default_factory=time.time)


class Counter:
    """
    Counter metric that only increments.

    Usage:
        requests = Counter("http_requests_total", "Total HTTP requests")
        requests.inc()
        requests.inc(labels={"method": "GET", "status": "200"})
    """

    def __init__(self, name: str, description: str = ""):
        self.name = name
        self.description = description
        self._values: Dict[str, float] = defaultdict(float)
        self._lock = threading.Lock()

    def inc(self, amount: float = 1.0, labels: Optional[Dict[str, str]] = None) -> None:
        """Increment the counter."""
        if amount < 0:
            raise ValueError("Counter can only be incremented")

        label_key = self._label_key(labels)
        with self._lock:
            self._values[label_key] += amount

    def get(self, labels: Optional[Dict[str, str]] = None) -> float:
        """Get current counter value."""
        label_key = self._label_key(labels)
        with self._lock:
            return self._values.get(label_key, 0.0)

    def _label_key(self, labels: Optional[Dict[str, str]]) -> str:
        """Create a hashable key from labels."""
        if not labels:
            return ""
        return "|".join(f"{k}={v}" for k, v in sorted(labels.items()))

    def to_dict(self) -> Dict:
        """Export counter as dict."""
        with self._lock:
            return {
                "type": "counter",
                "name": self.name,
                "description": self.description,
                "values": [
                    {"labels": self._parse_label_key(k), "value": v}
                    for k, v in self._values.items()
                ]
            }

    def _parse_label_key(self, key: str) -> Dict[str, str]:
        """Parse label key back to dict."""
        if not key:
            return {}
        labels = {}
        for part in key.split("|"):
            if "=" in part:
                k, v = part.split("=", 1)
                labels[k] = v
        return labels


class Gauge:
    """
    Gauge metric that can increase and decrease.

    Usage:
        active_connections = Gauge("active_connections", "Current connections")
        active_connections.set(10)
        active_connections.inc()
        active_connections.dec()
    """

    def __init__(self, name: str, description: str = ""):
        self.name = name
        self.description = description
        self._values: Dict[str, float] = defaultdict(float)
        self._lock = threading.Lock()

    def set(self, value: float, labels: Optional[Dict[str, str]] = None) -> None:
        """Set gauge to specific value."""
        label_key = self._label_key(labels)
        with self._lock:
            self._values[label_key] = value

    def inc(self, amount: float = 1.0, labels: Optional[Dict[str, str]] = None) -> None:
        """Increment gauge."""
        label_key = self._label_key(labels)
        with self._lock:
            self._values[label_key] += amount

    def dec(self, amount: float = 1.0, labels: Optional[Dict[str, str]] = None) -> None:
        """Decrement gauge."""
        label_key = self._label_key(labels)
        with self._lock:
            self._values[label_key] -= amount

    def get(self, labels: Optional[Dict[str, str]] = None) -> float:
        """Get current gauge value."""
        label_key = self._label_key(labels)
        with self._lock:
            return self._values.get(label_key, 0.0)

    def _label_key(self, labels: Optional[Dict[str, str]]) -> str:
        if not labels:
            return ""
        return "|".join(f"{k}={v}" for k, v in sorted(labels.items()))

    def to_dict(self) -> Dict:
        with self._lock:
            return {
                "type": "gauge",
                "name": self.name,
                "description": self.description,
                "values": [
                    {"labels": self._parse_label_key(k), "value": v}
                    for k, v in self._values.items()
                ]
            }

    def _parse_label_key(self, key: str) -> Dict[str, str]:
        if not key:
            return {}
        labels = {}
        for part in key.split("|"):
            if "=" in part:
                k, v = part.split("=", 1)
                labels[k] = v
        return labels


class Histogram:
    """
    Histogram metric for measuring distributions.

    Usage:
        latency = Histogram("request_latency", "Request latency",
                           buckets=[0.01, 0.05, 0.1, 0.5, 1.0, 5.0])
        latency.observe(0.05)
    """

    DEFAULT_BUCKETS = [0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0]

    def __init__(
        self,
        name: str,
        description: str = "",
        buckets: Optional[List[float]] = None
    ):
        self.name = name
        self.description = description
        self.buckets = sorted(buckets or self.DEFAULT_BUCKETS)
        self._observations: Dict[str, List[float]] = defaultdict(list)
        self._lock = threading.Lock()

    def observe(self, value: float, labels: Optional[Dict[str, str]] = None) -> None:
        """Record an observation."""
        label_key = self._label_key(labels)
        with self._lock:
            self._observations[label_key].append(value)

    def get_stats(self, labels: Optional[Dict[str, str]] = None) -> Dict:
        """Get histogram statistics."""
        label_key = self._label_key(labels)
        with self._lock:
            values = self._observations.get(label_key, [])
            if not values:
                return {"count": 0, "sum": 0, "buckets": {}}

            # Calculate bucket counts
            bucket_counts = {}
            for bucket in self.buckets:
                bucket_counts[bucket] = sum(1 for v in values if v <= bucket)
            bucket_counts["+Inf"] = len(values)

            return {
                "count": len(values),
                "sum": sum(values),
                "mean": sum(values) / len(values),
                "min": min(values),
                "max": max(values),
                "buckets": bucket_counts
            }

    def _label_key(self, labels: Optional[Dict[str, str]]) -> str:
        if not labels:
            return ""
        return "|".join(f"{k}={v}" for k, v in sorted(labels.items()))

    def to_dict(self) -> Dict:
        with self._lock:
            return {
                "type": "histogram",
                "name": self.name,
                "description": self.description,
                "buckets": self.buckets,
                "values": [
                    {"labels": self._parse_label_key(k), "stats": self._compute_stats(v)}
                    for k, v in self._observations.items()
                ]
            }

    def _compute_stats(self, values: List[float]) -> Dict:
        if not values:
            return {"count": 0, "sum": 0}

        bucket_counts = {}
        for bucket in self.buckets:
            bucket_counts[str(bucket)] = sum(1 for v in values if v <= bucket)
        bucket_counts["+Inf"] = len(values)

        return {
            "count": len(values),
            "sum": sum(values),
            "mean": sum(values) / len(values),
            "buckets": bucket_counts
        }

    def _parse_label_key(self, key: str) -> Dict[str, str]:
        if not key:
            return {}
        labels = {}
        for part in key.split("|"):
            if "=" in part:
                k, v = part.split("=", 1)
                labels[k] = v
        return labels


class Summary:
    """
    Summary metric for calculating quantiles.

    Usage:
        latency = Summary("request_latency", "Request latency",
                         quantiles=[0.5, 0.9, 0.99])
        latency.observe(0.05)
    """

    DEFAULT_QUANTILES = [0.5, 0.9, 0.95, 0.99]

    def __init__(
        self,
        name: str,
        description: str = "",
        quantiles: Optional[List[float]] = None,
        max_observations: int = 1000
    ):
        self.name = name
        self.description = description
        self.quantiles = quantiles or self.DEFAULT_QUANTILES
        self.max_observations = max_observations
        self._observations: Dict[str, List[float]] = defaultdict(list)
        self._lock = threading.Lock()

    def observe(self, value: float, labels: Optional[Dict[str, str]] = None) -> None:
        """Record an observation."""
        label_key = self._label_key(labels)
        with self._lock:
            obs = self._observations[label_key]
            obs.append(value)
            # Keep only recent observations
            if len(obs) > self.max_observations:
                self._observations[label_key] = obs[-self.max_observations:]

    def get_quantiles(self, labels: Optional[Dict[str, str]] = None) -> Dict[float, float]:
        """Get quantile values."""
        label_key = self._label_key(labels)
        with self._lock:
            values = self._observations.get(label_key, [])
            if not values:
                return {}

            sorted_values = sorted(values)
            n = len(sorted_values)

            result = {}
            for q in self.quantiles:
                idx = int(q * (n - 1))
                result[q] = sorted_values[idx]

            return result

    def _label_key(self, labels: Optional[Dict[str, str]]) -> str:
        if not labels:
            return ""
        return "|".join(f"{k}={v}" for k, v in sorted(labels.items()))

    def to_dict(self) -> Dict:
        with self._lock:
            return {
                "type": "summary",
                "name": self.name,
                "description": self.description,
                "quantiles": self.quantiles,
                "values": [
                    {"labels": self._parse_label_key(k), "stats": self._compute_stats(v)}
                    for k, v in self._observations.items()
                ]
            }

    def _compute_stats(self, values: List[float]) -> Dict:
        if not values:
            return {"count": 0, "sum": 0}

        sorted_values = sorted(values)
        n = len(sorted_values)

        quantile_values = {}
        for q in self.quantiles:
            idx = int(q * (n - 1))
            quantile_values[str(q)] = sorted_values[idx]

        return {
            "count": n,
            "sum": sum(values),
            "mean": sum(values) / n,
            "quantiles": quantile_values
        }

    def _parse_label_key(self, key: str) -> Dict[str, str]:
        if not key:
            return {}
        labels = {}
        for part in key.split("|"):
            if "=" in part:
                k, v = part.split("=", 1)
                labels[k] = v
        return labels


class MetricsRegistry:
    """Central registry for all metrics."""

    def __init__(self, persist_path: Optional[Path] = None):
        self.persist_path = persist_path or DEFAULT_METRICS_DIR / "metrics_snapshot.json"
        self._metrics: Dict[str, Any] = {}
        self._lock = threading.Lock()
        self._start_time = time.time()

    def register(self, metric: Any) -> Any:
        """Register a metric."""
        with self._lock:
            self._metrics[metric.name] = metric
        return metric

    def counter(self, name: str, description: str = "") -> Counter:
        """Create and register a counter."""
        metric = Counter(name, description)
        return self.register(metric)

    def gauge(self, name: str, description: str = "") -> Gauge:
        """Create and register a gauge."""
        metric = Gauge(name, description)
        return self.register(metric)

    def histogram(
        self,
        name: str,
        description: str = "",
        buckets: Optional[List[float]] = None
    ) -> Histogram:
        """Create and register a histogram."""
        metric = Histogram(name, description, buckets)
        return self.register(metric)

    def summary(
        self,
        name: str,
        description: str = "",
        quantiles: Optional[List[float]] = None
    ) -> Summary:
        """Create and register a summary."""
        metric = Summary(name, description, quantiles)
        return self.register(metric)

    def snapshot(self) -> Dict:
        """Get snapshot of all metrics."""
        with self._lock:
            return {
                "timestamp": datetime.utcnow().isoformat() + "Z",
                "uptime_seconds": time.time() - self._start_time,
                "metrics": {
                    name: metric.to_dict()
                    for name, metric in self._metrics.items()
                }
            }

    def persist(self) -> None:
        """Save metrics to disk."""
        try:
            self.persist_path.parent.mkdir(parents=True, exist_ok=True)
            snapshot = self.snapshot()
            with open(self.persist_path, 'w') as f:
                json.dump(snapshot, f, indent=2)
        except Exception as e:
            print(f"[!] Metrics persist error: {e}")


# Global registry
_registry = MetricsRegistry()


class GenesisMetrics:
    """Pre-configured metrics for Genesis memory system."""

    # Memory operations
    memory_operations = _registry.counter(
        "genesis_memory_operations_total",
        "Total memory operations by tier and operation type"
    )

    memory_latency = _registry.histogram(
        "genesis_memory_latency_seconds",
        "Memory operation latency",
        buckets=[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0]
    )

    # Memory state
    memory_count = _registry.gauge(
        "genesis_memory_count",
        "Current memory count by tier"
    )

    memory_size_bytes = _registry.gauge(
        "genesis_memory_size_bytes",
        "Memory storage size in bytes by tier"
    )

    # Surprise detection
    surprise_scores = _registry.summary(
        "genesis_surprise_score",
        "Surprise score distribution by domain"
    )

    # MCP sync
    mcp_sync_operations = _registry.counter(
        "genesis_mcp_sync_operations_total",
        "MCP sync operations by status"
    )

    mcp_sync_latency = _registry.histogram(
        "genesis_mcp_sync_latency_seconds",
        "MCP sync latency"
    )

    # Circuit breakers
    circuit_breaker_state = _registry.gauge(
        "genesis_circuit_breaker_state",
        "Circuit breaker state (0=closed, 1=half-open, 2=open)"
    )

    circuit_breaker_failures = _registry.counter(
        "genesis_circuit_breaker_failures_total",
        "Circuit breaker failure count"
    )

    # Cache
    cache_hits = _registry.counter(
        "genesis_cache_hits_total",
        "Cache hit count by tier"
    )

    cache_misses = _registry.counter(
        "genesis_cache_misses_total",
        "Cache miss count by tier"
    )

    @classmethod
    def snapshot(cls) -> Dict:
        """Get all metrics snapshot."""
        return _registry.snapshot()

    @classmethod
    def persist(cls) -> None:
        """Persist metrics to disk."""
        _registry.persist()


# Context manager for timing operations
class TimedOperation:
    """Context manager for timing and recording operations."""

    def __init__(
        self,
        histogram: Histogram,
        labels: Optional[Dict[str, str]] = None
    ):
        self.histogram = histogram
        self.labels = labels
        self.start_time = None

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        duration = time.time() - self.start_time
        self.histogram.observe(duration, self.labels)
        return False


# CLI interface
if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        cmd = sys.argv[1]

        if cmd == "demo":
            print("=== Genesis Metrics Demo ===\n")

            # Simulate some memory operations
            for tier in ["working", "episodic", "semantic"]:
                for op in ["store", "get", "delete"]:
                    GenesisMetrics.memory_operations.inc(
                        labels={"tier": tier, "op": op}
                    )

                # Record latencies
                for _ in range(10):
                    latency = 0.001 + (0.1 * (hash(tier) % 10) / 10)
                    GenesisMetrics.memory_latency.observe(
                        latency,
                        labels={"tier": tier}
                    )

            # Record memory counts
            GenesisMetrics.memory_count.set(150, labels={"tier": "working"})
            GenesisMetrics.memory_count.set(5000, labels={"tier": "episodic"})
            GenesisMetrics.memory_count.set(500, labels={"tier": "semantic"})

            # Record surprise scores
            for domain in ["tech", "security", "general"]:
                for _ in range(20):
                    score = 0.2 + (hash(domain) % 80) / 100
                    GenesisMetrics.surprise_scores.observe(
                        score,
                        labels={"domain": domain}
                    )

            # Print snapshot
            snapshot = GenesisMetrics.snapshot()
            print(json.dumps(snapshot, indent=2, default=str))

            # Persist
            GenesisMetrics.persist()
            print(f"\n[OK] Metrics saved to {_registry.persist_path}")

        elif cmd == "snapshot":
            snapshot = GenesisMetrics.snapshot()
            print(json.dumps(snapshot, indent=2, default=str))

        elif cmd == "persist":
            GenesisMetrics.persist()
            print(f"Metrics saved to {_registry.persist_path}")

        else:
            print(f"Unknown command: {cmd}")
            print("Usage: python metrics.py [demo|snapshot|persist]")
    else:
        print("Genesis Metrics Collection")
        print("Usage: python metrics.py [demo|snapshot|persist]")
