#!/usr/bin/env python3
"""
GENESIS METRICS AGGREGATOR
===========================
Centralized metrics collection, aggregation, and analysis.

Features:
    - Multi-source metric collection
    - Time-series aggregation
    - Statistical analysis
    - Alerting thresholds
    - Export formats (JSON, Prometheus)

Usage:
    metrics = MetricsAggregator()
    metrics.record("task.completed", 1, labels={"agent": "claude"})
    summary = metrics.get_summary("1h")
"""

"""
RULE 7 COMPLIANT: Uses Elestio PostgreSQL via genesis_db module.
"""
import json
import threading
import time
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Any, Optional, Callable, Tuple
from enum import Enum
import statistics

# RULE 7: Use PostgreSQL via genesis_db (no sqlite3)
from core.genesis_db import connection, ensure_table

logger = logging.getLogger(__name__)


class MetricType(Enum):
    """Types of metrics."""
    COUNTER = "counter"        # Monotonically increasing
    GAUGE = "gauge"            # Point-in-time value
    HISTOGRAM = "histogram"    # Distribution of values
    SUMMARY = "summary"        # Quantiles


class AlertLevel(Enum):
    """Alert severity levels."""
    INFO = "info"
    WARNING = "warning"
    CRITICAL = "critical"


@dataclass
class MetricPoint:
    """A single metric data point."""
    name: str
    value: float
    timestamp: float
    metric_type: MetricType = MetricType.GAUGE
    labels: Dict[str, str] = field(default_factory=dict)

    def to_dict(self) -> Dict:
        return {
            "name": self.name,
            "value": self.value,
            "timestamp": self.timestamp,
            "type": self.metric_type.value,
            "labels": self.labels
        }


@dataclass
class AlertRule:
    """An alerting rule."""
    name: str
    metric: str
    condition: str  # "gt", "lt", "eq", "ne"
    threshold: float
    level: AlertLevel
    message: str
    cooldown_seconds: int = 300
    last_triggered: Optional[float] = None


@dataclass
class Alert:
    """A triggered alert."""
    rule_name: str
    metric: str
    value: float
    threshold: float
    level: AlertLevel
    message: str
    timestamp: float

    def to_dict(self) -> Dict:
        return {
            "rule": self.rule_name,
            "metric": self.metric,
            "value": self.value,
            "threshold": self.threshold,
            "level": self.level.value,
            "message": self.message,
            "timestamp": datetime.fromtimestamp(self.timestamp).isoformat()
        }


class MetricStore:
    """
    Time-series metric storage with PostgreSQL backend (RULE 7).
    """

    def __init__(self, db_path: Path = None, max_age_hours: int = 168):
        # RULE 7: db_path is ignored - uses PostgreSQL via genesis_db
        self.max_age_hours = max_age_hours
        self._init_db()

    def _init_db(self):
        """Initialize database (RULE 7: PostgreSQL)."""
        ensure_table('metrics_store', '''
            id SERIAL PRIMARY KEY,
            name TEXT NOT NULL,
            value REAL NOT NULL,
            timestamp REAL NOT NULL,
            metric_type TEXT NOT NULL,
            labels JSONB
        ''')
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("CREATE INDEX IF NOT EXISTS idx_metrics_name ON metrics_store(name)")
                cursor.execute("CREATE INDEX IF NOT EXISTS idx_metrics_timestamp ON metrics_store(timestamp)")
        except Exception as e:
            logger.warning(f"Index creation warning: {e}")

    def store(self, point: MetricPoint):
        """Store a metric point (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    INSERT INTO metrics_store (name, value, timestamp, metric_type, labels)
                    VALUES (%s, %s, %s, %s, %s)
                """, (
                    point.name,
                    point.value,
                    point.timestamp,
                    point.metric_type.value,
                    json.dumps(point.labels)
                ))
        except Exception as e:
            logger.warning(f"Failed to store metric: {e}")

    def query(
        self,
        name: str,
        since_hours: float = 1.0,
        labels: Dict[str, str] = None
    ) -> List[MetricPoint]:
        """Query metrics by name and time range (RULE 7: PostgreSQL)."""
        cutoff = time.time() - (since_hours * 3600)

        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    SELECT name, value, timestamp, metric_type, labels
                    FROM metrics_store
                    WHERE name = %s AND timestamp >= %s
                    ORDER BY timestamp ASC
                """, (name, cutoff))

                points = []
                for row in cursor.fetchall():
                    point_labels = row[4] if row[4] else {}

                    # Filter by labels if specified
                    if labels:
                        if not all(point_labels.get(k) == v for k, v in labels.items()):
                            continue

                    points.append(MetricPoint(
                        name=row[0],
                        value=row[1],
                        timestamp=row[2],
                        metric_type=MetricType(row[3]),
                        labels=point_labels
                    ))

                return points

        except Exception as e:
            logger.warning(f"Failed to query metrics: {e}")
            return []

    def get_latest(self, name: str) -> Optional[MetricPoint]:
        """Get the latest value for a metric (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    SELECT name, value, timestamp, metric_type, labels
                    FROM metrics_store
                    WHERE name = %s
                    ORDER BY timestamp DESC
                    LIMIT 1
                """, (name,))

                row = cursor.fetchone()
                if row:
                    return MetricPoint(
                        name=row[0],
                        value=row[1],
                        timestamp=row[2],
                        metric_type=MetricType(row[3]),
                        labels=row[4] if row[4] else {}
                    )
                return None

        except Exception as e:
            logger.warning(f"Failed to get latest metric: {e}")
            return None

    def get_all_names(self) -> List[str]:
        """Get all unique metric names (RULE 7: PostgreSQL)."""
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("SELECT DISTINCT name FROM metrics_store")
                return [row[0] for row in cursor.fetchall()]
        except Exception as e:
            logger.warning(f"Failed to get metric names: {e}")
            return []

    def cleanup(self):
        """Remove old metrics (RULE 7: PostgreSQL)."""
        cutoff = time.time() - (self.max_age_hours * 3600)
        try:
            with connection() as conn:
                cursor = conn.cursor()
                cursor.execute("DELETE FROM metrics_store WHERE timestamp < %s", (cutoff,))
        except Exception as e:
            logger.warning(f"Failed to cleanup metrics: {e}")


class MetricAggregations:
    """
    Statistical aggregations for metrics.
    """

    @staticmethod
    def sum(values: List[float]) -> float:
        return sum(values)

    @staticmethod
    def avg(values: List[float]) -> float:
        return statistics.mean(values) if values else 0

    @staticmethod
    def min(values: List[float]) -> float:
        return min(values) if values else 0

    @staticmethod
    def max(values: List[float]) -> float:
        return max(values) if values else 0

    @staticmethod
    def count(values: List[float]) -> int:
        return len(values)

    @staticmethod
    def stddev(values: List[float]) -> float:
        return statistics.stdev(values) if len(values) > 1 else 0

    @staticmethod
    def percentile(values: List[float], p: float) -> float:
        if not values:
            return 0
        sorted_vals = sorted(values)
        idx = int(len(sorted_vals) * p / 100)
        return sorted_vals[min(idx, len(sorted_vals) - 1)]

    @staticmethod
    def rate(values: List[float], time_seconds: float) -> float:
        """Calculate rate (change per second)."""
        if len(values) < 2 or time_seconds <= 0:
            return 0
        return (values[-1] - values[0]) / time_seconds


class MetricsAggregator:
    """
    Central metrics aggregation system.
    """

    def __init__(self, persist: bool = True):
        self.persist = persist
        self.store = MetricStore() if persist else None

        # In-memory buffers for fast access
        self._latest: Dict[str, MetricPoint] = {}
        self._counters: Dict[str, float] = defaultdict(float)
        self._histograms: Dict[str, List[float]] = defaultdict(list)
        self._lock = threading.RLock()

        # Alerting
        self._alert_rules: List[AlertRule] = []
        self._alerts: List[Alert] = []
        self._alert_handlers: List[Callable[[Alert], None]] = []

        # Background cleanup
        self._running = False
        self._cleanup_thread: Optional[threading.Thread] = None

    def record(
        self,
        name: str,
        value: float,
        metric_type: MetricType = MetricType.GAUGE,
        labels: Dict[str, str] = None
    ):
        """Record a metric value."""
        point = MetricPoint(
            name=name,
            value=value,
            timestamp=time.time(),
            metric_type=metric_type,
            labels=labels or {}
        )

        with self._lock:
            # Update in-memory state
            if metric_type == MetricType.COUNTER:
                self._counters[name] += value
                point.value = self._counters[name]
            elif metric_type == MetricType.HISTOGRAM:
                self._histograms[name].append(value)
                # Limit histogram size
                if len(self._histograms[name]) > 1000:
                    self._histograms[name] = self._histograms[name][-1000:]

            self._latest[name] = point

        # Persist
        if self.store:
            self.store.store(point)

        # Check alerts
        self._check_alerts(name, point.value)

    def increment(self, name: str, value: float = 1.0, labels: Dict[str, str] = None):
        """Increment a counter."""
        self.record(name, value, MetricType.COUNTER, labels)

    def gauge(self, name: str, value: float, labels: Dict[str, str] = None):
        """Set a gauge value."""
        self.record(name, value, MetricType.GAUGE, labels)

    def histogram(self, name: str, value: float, labels: Dict[str, str] = None):
        """Record a histogram value."""
        self.record(name, value, MetricType.HISTOGRAM, labels)

    def timing(self, name: str):
        """Context manager for timing operations."""
        return TimingContext(self, name)

    def get_latest(self, name: str) -> Optional[float]:
        """Get latest value for a metric."""
        with self._lock:
            point = self._latest.get(name)
            return point.value if point else None

    def get_summary(self, name: str, hours: float = 1.0) -> Dict:
        """Get statistical summary for a metric."""
        if not self.store:
            # Use in-memory only
            with self._lock:
                if name in self._histograms:
                    values = self._histograms[name]
                elif name in self._latest:
                    values = [self._latest[name].value]
                else:
                    values = []
        else:
            points = self.store.query(name, since_hours=hours)
            values = [p.value for p in points]

        if not values:
            return {"error": "no data"}

        return {
            "metric": name,
            "period_hours": hours,
            "count": len(values),
            "sum": MetricAggregations.sum(values),
            "avg": round(MetricAggregations.avg(values), 4),
            "min": MetricAggregations.min(values),
            "max": MetricAggregations.max(values),
            "stddev": round(MetricAggregations.stddev(values), 4),
            "p50": MetricAggregations.percentile(values, 50),
            "p95": MetricAggregations.percentile(values, 95),
            "p99": MetricAggregations.percentile(values, 99)
        }

    def get_all_summaries(self, hours: float = 1.0) -> Dict[str, Dict]:
        """Get summaries for all metrics."""
        if self.store:
            names = self.store.get_all_names()
        else:
            with self._lock:
                names = list(set(list(self._latest.keys()) + list(self._histograms.keys())))

        return {name: self.get_summary(name, hours) for name in names}

    def add_alert_rule(
        self,
        name: str,
        metric: str,
        condition: str,
        threshold: float,
        level: AlertLevel = AlertLevel.WARNING,
        message: str = ""
    ):
        """Add an alerting rule."""
        rule = AlertRule(
            name=name,
            metric=metric,
            condition=condition,
            threshold=threshold,
            level=level,
            message=message or f"{metric} {condition} {threshold}"
        )
        self._alert_rules.append(rule)

    def on_alert(self, handler: Callable[[Alert], None]):
        """Register an alert handler."""
        self._alert_handlers.append(handler)

    def _check_alerts(self, metric: str, value: float):
        """Check if any alert rules are triggered."""
        now = time.time()

        for rule in self._alert_rules:
            if rule.metric != metric:
                continue

            # Check cooldown
            if rule.last_triggered:
                if now - rule.last_triggered < rule.cooldown_seconds:
                    continue

            # Check condition
            triggered = False
            if rule.condition == "gt" and value > rule.threshold:
                triggered = True
            elif rule.condition == "lt" and value < rule.threshold:
                triggered = True
            elif rule.condition == "eq" and value == rule.threshold:
                triggered = True
            elif rule.condition == "ne" and value != rule.threshold:
                triggered = True
            elif rule.condition == "gte" and value >= rule.threshold:
                triggered = True
            elif rule.condition == "lte" and value <= rule.threshold:
                triggered = True

            if triggered:
                rule.last_triggered = now
                alert = Alert(
                    rule_name=rule.name,
                    metric=metric,
                    value=value,
                    threshold=rule.threshold,
                    level=rule.level,
                    message=rule.message,
                    timestamp=now
                )
                self._alerts.append(alert)

                # Notify handlers
                for handler in self._alert_handlers:
                    try:
                        handler(alert)
                    except Exception:
                        pass

    def get_alerts(self, since_hours: float = 24.0) -> List[Alert]:
        """Get recent alerts."""
        cutoff = time.time() - (since_hours * 3600)
        return [a for a in self._alerts if a.timestamp >= cutoff]

    def export_prometheus(self) -> str:
        """Export metrics in Prometheus format."""
        lines = []

        with self._lock:
            for name, point in self._latest.items():
                # Sanitize name for Prometheus
                prom_name = name.replace(".", "_").replace("-", "_")

                # Format labels
                label_str = ""
                if point.labels:
                    label_parts = [f'{k}="{v}"' for k, v in point.labels.items()]
                    label_str = "{" + ",".join(label_parts) + "}"

                lines.append(f"# TYPE {prom_name} {point.metric_type.value}")
                lines.append(f"{prom_name}{label_str} {point.value}")

        return "\n".join(lines)

    def export_json(self) -> str:
        """Export metrics as JSON."""
        with self._lock:
            data = {
                "timestamp": datetime.now().isoformat(),
                "metrics": {
                    name: point.to_dict()
                    for name, point in self._latest.items()
                },
                "counters": dict(self._counters),
                "alerts": [a.to_dict() for a in self._alerts[-100:]]
            }
        return json.dumps(data, indent=2)

    def start_cleanup(self, interval_hours: int = 1):
        """Start background cleanup thread."""
        if self._running:
            return

        self._running = True
        self._cleanup_thread = threading.Thread(
            target=self._cleanup_loop,
            args=(interval_hours,),
            daemon=True
        )
        self._cleanup_thread.start()

    def stop_cleanup(self):
        """Stop background cleanup."""
        self._running = False
        if self._cleanup_thread:
            self._cleanup_thread.join(timeout=5)

    def _cleanup_loop(self, interval_hours: int):
        """Background cleanup loop."""
        while self._running:
            time.sleep(interval_hours * 3600)
            if self.store:
                self.store.cleanup()

            # Trim alerts
            cutoff = time.time() - (7 * 24 * 3600)  # 7 days
            self._alerts = [a for a in self._alerts if a.timestamp >= cutoff]


class TimingContext:
    """Context manager for timing operations."""

    def __init__(self, aggregator: MetricsAggregator, name: str):
        self.aggregator = aggregator
        self.name = name
        self.start_time = 0.0

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        duration = time.time() - self.start_time
        self.aggregator.histogram(self.name, duration * 1000)  # ms
        return False


# Global instance
_metrics: Optional[MetricsAggregator] = None


def get_metrics() -> MetricsAggregator:
    """Get global metrics aggregator."""
    global _metrics
    if _metrics is None:
        _metrics = MetricsAggregator()
    return _metrics


# Convenience functions
def record(name: str, value: float, metric_type: MetricType = MetricType.GAUGE, labels: Dict[str, str] = None):
    get_metrics().record(name, value, metric_type, labels)


def increment(name: str, value: float = 1.0, labels: Dict[str, str] = None):
    get_metrics().increment(name, value, labels)


def gauge(name: str, value: float, labels: Dict[str, str] = None):
    get_metrics().gauge(name, value, labels)


def timing(name: str):
    return get_metrics().timing(name)


def main():
    """CLI for metrics aggregator."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Metrics Aggregator")
    parser.add_argument("command", choices=["demo", "summary", "export", "alerts"])
    parser.add_argument("--metric", help="Metric name")
    parser.add_argument("--hours", type=float, default=1.0)
    parser.add_argument("--format", choices=["json", "prometheus"], default="json")
    args = parser.parse_args()

    metrics = MetricsAggregator()

    if args.command == "demo":
        print("Metrics Aggregator Demo")
        print("=" * 40)

        # Record some test metrics
        for i in range(10):
            metrics.gauge("system.cpu", 20 + i * 5)
            metrics.increment("tasks.completed")
            metrics.histogram("task.duration", 100 + i * 10)
            time.sleep(0.1)

        # Add alert rule
        metrics.add_alert_rule(
            name="high_cpu",
            metric="system.cpu",
            condition="gt",
            threshold=50,
            level=AlertLevel.WARNING,
            message="CPU usage is high"
        )

        # Trigger alert
        metrics.gauge("system.cpu", 75)

        print("\nSummary:")
        print(json.dumps(metrics.get_all_summaries(1), indent=2))

        print("\nAlerts:")
        for alert in metrics.get_alerts():
            print(f"  [{alert.level.value}] {alert.message}: {alert.value}")

    elif args.command == "summary":
        if args.metric:
            summary = metrics.get_summary(args.metric, args.hours)
        else:
            summary = metrics.get_all_summaries(args.hours)
        print(json.dumps(summary, indent=2))

    elif args.command == "export":
        if args.format == "prometheus":
            print(metrics.export_prometheus())
        else:
            print(metrics.export_json())

    elif args.command == "alerts":
        alerts = metrics.get_alerts(args.hours)
        for alert in alerts:
            print(f"[{alert.level.value}] {alert.rule_name}: {alert.message}")
            print(f"  Value: {alert.value}, Threshold: {alert.threshold}")


if __name__ == "__main__":
    main()
