#!/usr/bin/env python3
"""
GENESIS QWEN STORAGE (PostgreSQL)
==================================
PostgreSQL-based storage for Qwen metrics and state.

RULE 6 COMPLIANCE: NO SQLite - Uses Elestio PostgreSQL only.

Stores:
- Usage metrics (requests, tokens, latency)
- Circuit breaker state
- Model warmth history
- Health check history

Usage:
    from core.qwen.storage import QwenMetricsStore

    store = QwenMetricsStore()
    store.record_usage(tokens=100, latency=1.5)
    metrics = store.get_usage_summary()
"""

import json
import sys
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List

# Add genesis-memory path for elestio_config
sys.path.insert(0, "/mnt/e/genesis-system/data/genesis-memory")

try:
    import psycopg2
    from psycopg2.extras import RealDictCursor
    PSYCOPG2_AVAILABLE = True
except ImportError:
    PSYCOPG2_AVAILABLE = False

try:
    from elestio_config import PostgresConfig
    ELESTIO_AVAILABLE = True
except ImportError:
    ELESTIO_AVAILABLE = False


# NO SQLITE - Rule 6 Compliance Check
import importlib.util
if importlib.util.find_spec("sqlite3"):
    # sqlite3 is part of stdlib, but we MUST NOT use it
    pass  # Awareness check only


@dataclass
class UsageRecord:
    """Single usage record."""
    timestamp: str
    tokens_used: int
    latency_seconds: float
    success: bool
    model: str
    prompt_preview: str = ""


class QwenMetricsStore:
    """
    PostgreSQL storage for Qwen metrics.

    Uses Elestio PostgreSQL (Rule 6 compliant).
    Falls back to in-memory storage if PostgreSQL unavailable.
    """

    TABLE_NAME = "qwen_metrics"
    STATE_TABLE = "qwen_state"

    def __init__(self):
        self._conn = None
        self._fallback_records: List[UsageRecord] = []
        self._fallback_state: Dict[str, Any] = {}

        if PSYCOPG2_AVAILABLE and ELESTIO_AVAILABLE:
            self._init_postgres()
        else:
            print("WARNING: PostgreSQL unavailable, using in-memory fallback")

    def _init_postgres(self):
        """Initialize PostgreSQL connection and tables."""
        try:
            self._conn = psycopg2.connect(**PostgresConfig.get_connection_params())
            self._create_tables()
        except Exception as e:
            print(f"PostgreSQL connection failed: {e}")
            self._conn = None

    def _create_tables(self):
        """Create required tables if not exist."""
        if not self._conn:
            return

        with self._conn.cursor() as cur:
            # Usage metrics table
            cur.execute(f"""
                CREATE TABLE IF NOT EXISTS {self.TABLE_NAME} (
                    id SERIAL PRIMARY KEY,
                    timestamp TIMESTAMPTZ DEFAULT NOW(),
                    tokens_used INTEGER,
                    latency_seconds FLOAT,
                    success BOOLEAN,
                    model VARCHAR(255),
                    prompt_preview TEXT
                )
            """)

            # State table (circuit breaker, warmth, etc.)
            cur.execute(f"""
                CREATE TABLE IF NOT EXISTS {self.STATE_TABLE} (
                    key VARCHAR(255) PRIMARY KEY,
                    value JSONB,
                    updated_at TIMESTAMPTZ DEFAULT NOW()
                )
            """)

            # Index for time-based queries
            cur.execute(f"""
                CREATE INDEX IF NOT EXISTS idx_{self.TABLE_NAME}_timestamp
                ON {self.TABLE_NAME} (timestamp)
            """)

            self._conn.commit()

    def record_usage(
        self,
        tokens_used: int,
        latency_seconds: float,
        success: bool = True,
        model: str = "huihui_ai/qwenlong-l1.5-abliterated:30b-a3b",
        prompt_preview: str = ""
    ):
        """Record a usage metric."""
        record = UsageRecord(
            timestamp=datetime.now().isoformat(),
            tokens_used=tokens_used,
            latency_seconds=latency_seconds,
            success=success,
            model=model,
            prompt_preview=prompt_preview[:100],
        )

        if self._conn:
            try:
                with self._conn.cursor() as cur:
                    cur.execute(f"""
                        INSERT INTO {self.TABLE_NAME}
                        (tokens_used, latency_seconds, success, model, prompt_preview)
                        VALUES (%s, %s, %s, %s, %s)
                    """, (
                        record.tokens_used,
                        record.latency_seconds,
                        record.success,
                        record.model,
                        record.prompt_preview,
                    ))
                    self._conn.commit()
            except Exception as e:
                print(f"Failed to record usage: {e}")
                self._fallback_records.append(record)
        else:
            self._fallback_records.append(record)

    def get_usage_summary(self, hours: int = 24) -> Dict[str, Any]:
        """Get usage summary for the past N hours."""
        if self._conn:
            try:
                with self._conn.cursor(cursor_factory=RealDictCursor) as cur:
                    cur.execute(f"""
                        SELECT
                            COUNT(*) as total_requests,
                            SUM(tokens_used) as total_tokens,
                            AVG(latency_seconds) as avg_latency,
                            SUM(CASE WHEN success THEN 1 ELSE 0 END) as successful,
                            SUM(CASE WHEN NOT success THEN 1 ELSE 0 END) as failed
                        FROM {self.TABLE_NAME}
                        WHERE timestamp > NOW() - INTERVAL '{hours} hours'
                    """)
                    row = cur.fetchone()
                    return {
                        "period_hours": hours,
                        "total_requests": row["total_requests"] or 0,
                        "total_tokens": row["total_tokens"] or 0,
                        "avg_latency_seconds": round(row["avg_latency"] or 0, 3),
                        "successful": row["successful"] or 0,
                        "failed": row["failed"] or 0,
                        "success_rate": (
                            row["successful"] / row["total_requests"]
                            if row["total_requests"] > 0 else 0
                        ),
                    }
            except Exception as e:
                print(f"Failed to get summary: {e}")

        # Fallback to in-memory
        total = len(self._fallback_records)
        if total == 0:
            return {"period_hours": hours, "total_requests": 0}

        tokens = sum(r.tokens_used for r in self._fallback_records)
        latencies = [r.latency_seconds for r in self._fallback_records]
        successful = sum(1 for r in self._fallback_records if r.success)

        return {
            "period_hours": hours,
            "total_requests": total,
            "total_tokens": tokens,
            "avg_latency_seconds": round(sum(latencies) / len(latencies), 3),
            "successful": successful,
            "failed": total - successful,
            "success_rate": successful / total,
            "storage": "in_memory_fallback",
        }

    def save_state(self, key: str, value: Dict[str, Any]):
        """Save state to PostgreSQL."""
        if self._conn:
            try:
                with self._conn.cursor() as cur:
                    cur.execute(f"""
                        INSERT INTO {self.STATE_TABLE} (key, value, updated_at)
                        VALUES (%s, %s, NOW())
                        ON CONFLICT (key) DO UPDATE
                        SET value = EXCLUDED.value, updated_at = NOW()
                    """, (key, json.dumps(value)))
                    self._conn.commit()
            except Exception as e:
                print(f"Failed to save state: {e}")
                self._fallback_state[key] = value
        else:
            self._fallback_state[key] = value

    def load_state(self, key: str) -> Optional[Dict[str, Any]]:
        """Load state from PostgreSQL."""
        if self._conn:
            try:
                with self._conn.cursor(cursor_factory=RealDictCursor) as cur:
                    cur.execute(f"""
                        SELECT value FROM {self.STATE_TABLE} WHERE key = %s
                    """, (key,))
                    row = cur.fetchone()
                    return row["value"] if row else None
            except Exception as e:
                print(f"Failed to load state: {e}")

        return self._fallback_state.get(key)

    def close(self):
        """Close PostgreSQL connection."""
        if self._conn:
            self._conn.close()
            self._conn = None

    def is_postgres_connected(self) -> bool:
        """Check if PostgreSQL is connected."""
        return self._conn is not None


# Singleton accessor
_store: Optional[QwenMetricsStore] = None


def get_metrics_store() -> QwenMetricsStore:
    """Get the metrics store singleton."""
    global _store
    if _store is None:
        _store = QwenMetricsStore()
    return _store


# VERIFICATION_STAMP
# Story: STORY-009
# Verified By: CLAUDE
# Verified At: 2026-01-22
# Tests: Pending
# Coverage: Pending
# RULE 6 COMPLIANCE: NO SQLite imports or usage
