"""
RLM Gateway - Integration Layer for Reinforcement Learning from Memory

This module connects the 5 RLM modules to AIVA's live decision-making loop:
1. Preference Learning (rlm_01)
2. Reward Model (rlm_02)
3. PPO Engine (rlm_03)
4. DPO Trainer (rlm_04)
5. Constitutional AI (rlm_05)

The gateway:
- Intercepts AIVA outputs before sending to users
- Scores outputs using the trained reward model
- Revises outputs using Constitutional AI
- Collects human feedback for training
- Triggers training when sufficient data accumulated
- Manages A/B testing for safe policy deployment

Author: Genesis-OS RLM Integration
Version: 1.0.0
Date: 2026-02-16
"""

import asyncio
import json
import logging
import os
import random
import sys
import time
from datetime import datetime
from typing import Dict, List, Optional, Tuple

# Add paths for Genesis modules
sys.path.append('/mnt/e/genesis-system')
sys.path.append('/mnt/e/genesis-system/data/genesis-memory')
sys.path.append('/mnt/e/genesis-system/AIVA/queen_outputs/rlm') # Add RLM modules path

# Elestio infrastructure
from elestio_config import PostgresConfig, RedisConfig

# Import RLM modules
from rlm_01_preference_learning import PreferenceDataset as PreferenceLearningDataset
from rlm_02_reward_model import RewardModel as RealRewardModel, PreferenceDataset as RewardModelPreferenceDataset, RewardInference
from rlm_05_constitutional_ai import Constitution, SelfCritique, RevisionLoop

try:
    import psycopg2
    from psycopg2.extras import Json, RealDictCursor
except ImportError:
    psycopg2 = None
    print("Warning: psycopg2 not installed. RLM Gateway will run in mock mode.")

try:
    import redis
except ImportError:
    redis = None
    print("Warning: redis not installed. RLM Gateway will run in mock mode.")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("RLMGateway")


class _PlaceholderConstitutionalAI:
    """Constitutional AI module for self-critique and revision."""

    def __init__(self, constitution_path: str = "/mnt/e/genesis-system/AIVA/constitution.json"):
        self.constitution_path = constitution_path
        self.constitution = self._load_constitution()
        self.enabled = True

    def _load_constitution(self) -> Dict:
        """Load constitutional principles from JSON file."""
        try:
            with open(self.constitution_path, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            logger.warning(f"Constitution file not found: {self.constitution_path}")
            return {"principles": [], "enforcement": {}}

    async def critique_output(self, output: str) -> Tuple[str, List[str], str]:
        """
        Critique output against constitutional principles.

        Returns:
            (critique_text, violated_principle_ids, severity)
        """
        violated = []
        critiques = []
        max_severity = "low"

        # Check each principle
        for principle in self.constitution.get("principles", []):
            violation = self._check_principle(output, principle)
            if violation:
                violated.append(principle["id"])
                critiques.append(f"Violates {principle['name']}: {violation}")
                if self._severity_rank(principle["severity"]) > self._severity_rank(max_severity):
                    max_severity = principle["severity"]

        critique_text = "\n".join(critiques) if critiques else "No violations detected"
        return critique_text, violated, max_severity

    def _check_principle(self, output: str, principle: Dict) -> Optional[str]:
        """
        Check if output violates a specific principle.

        Returns violation description if violated, None otherwise.
        """
        principle_id = principle["id"]

        # P003: Safety - C: drive and AIVA server protection
        if principle_id == "p003":
            if "C:" in output or "C:\\" in output:
                return "Output references C: drive which is forbidden"
            if "152.53.201.152" in output or "aiva-mini" in output:
                return "Output references AIVA server which is untouchable"

        # P002: Honesty - Detect fabrication keywords
        if principle_id == "p002":
            fabrication_keywords = [
                "probably costs", "typically around", "usually priced at",
                "based on industry standards", "most competitors", "I estimate"
            ]
            for keyword in fabrication_keywords:
                if keyword.lower() in output.lower():
                    return f"Output contains fabrication indicator: '{keyword}'"

        # P005: Revenue Focus - Detect revenue-negative patterns
        if principle_id == "p005":
            revenue_negative = [
                "let's refactor for elegance", "internal tool", "1% efficiency gain",
                "no customer value", "purely for maintenance"
            ]
            for pattern in revenue_negative:
                if pattern.lower() in output.lower():
                    return f"Output suggests revenue-negative action: '{pattern}'"

        # P004: Efficiency - Detect token waste
        if principle_id == "p004":
            if len(output) > 5000 and "dispatching agent" not in output.lower():
                return "Output is >5000 chars without delegating to agents"

        return None

    def _severity_rank(self, severity: str) -> int:
        """Convert severity string to numeric rank."""
        return {"low": 1, "medium": 2, "high": 3, "critical": 4}.get(severity, 1)

    async def revise_output(self, output: str, critique: str, violated_principles: List[str]) -> str:
        """
        Revise output to comply with constitutional principles.

        In production, this would call an LLM (Claude/Gemini) with:
        - Original output
        - Critique
        - Violated principles
        - Examples of compliant outputs

        For now, returns a revision instruction.
        """
        if not violated_principles:
            return output

        revision_prompt = f"""
CONSTITUTIONAL REVISION REQUIRED

Original output:
{output}

Critique:
{critique}

Violated principles: {', '.join(violated_principles)}

Please revise to comply with Genesis Constitution.
"""

        # TODO: Call LLM for actual revision
        # For now, return output with warning prepended
        return f"[CONSTITUTIONAL WARNING: {critique}]\n\n{output}"

class ConstitutionalAI:
    """Wrapper for Constitutional AI modules for self-critique and revision."""

    def __init__(self, constitution_path: str = "/mnt/e/genesis-system/AIVA/constitution.json"):
        self.constitution = Constitution(config_path=constitution_path)
        self.critique_engine = SelfCritique(constitution=self.constitution)
        self.revision_loop = RevisionLoop(critique_engine=self.critique_engine)

    async def critique_output(self, output: str, prompt: str = "") -> Tuple[str, List[str], str]:
        """
        Critique output against constitutional principles using the actual RLM module.

        Returns:
            (critique_text, violated_principle_ids, severity)
        """
        critique_results = self.critique_engine.critique(prompt=prompt, response=output)
        
        violated = []
        critiques_text = []
        max_severity_enum = None

        for c_result in critique_results:
            if c_result.violated:
                violated.append(c_result.principle_id)
                critiques_text.append(f"Violates {c_result.principle_name}: {c_result.explanation}")
                if max_severity_enum is None or c_result.severity.value > max_severity_enum.value:
                    max_severity_enum = c_result.severity

        critique_summary = "\n".join(critiques_text) if critiques_text else "No violations detected"
        
        severity_str = max_severity_enum.name.lower() if max_severity_enum else "none"
        if severity_str == "none" and violated: # If there are violations but max_severity is none, set to low.
            severity_str = "low"

        return critique_summary, violated, severity_str

    async def revise_output(self, output: str, critique: str, violated_principles: List[str], prompt: str = "") -> str:
        """
        Revise output to comply with constitutional principles using the actual RLM module.
        """
        # The revision_loop.revise method takes the original prompt and response.
        # It internally calls the critique_engine.
        revision_result = self.revision_loop.revise(prompt=prompt, response=output)
        
        if revision_result.is_compliant:
            return revision_result.revised_response
        else:
            # If not fully compliant, still return the revised response with a warning
            warning = f"[CONSTITUTIONAL WARNING: Remaining issues detected: {len(revision_result.remaining_issues)}. Final verdict: {revision_result.final_verdict}]\n\n"
            return warning + revision_result.revised_response


class _PlaceholderRewardModel:
    """Reward model inference for scoring outputs."""

    def __init__(self, checkpoint_path: Optional[str] = None):
        self.checkpoint_path = checkpoint_path
        self.model = self._load_model()

    def _load_model(self):
        """Load trained reward model from checkpoint."""
        if not self.checkpoint_path or not os.path.exists(self.checkpoint_path):
            logger.warning("No reward model checkpoint found. Using mock scoring.")
            return None

        # TODO: Load actual model weights
        return None

    async def score_output(self, output: str, context: Optional[str] = None) -> float:
        """
        Score output quality using trained reward model.

        Returns:
            Float score 0.0 to 1.0 (higher = better)
        """
        if self.model is None:
            # Mock scoring based on simple heuristics
            score = 0.5

            # Boost for action verbs
            action_verbs = ["dispatching", "creating", "executing", "deploying", "building"]
            if any(verb in output.lower() for verb in action_verbs):
                score += 0.2

            # Boost for specificity
            if any(char.isdigit() for char in output):
                score += 0.1

            # Penalize vagueness
            vague_phrases = ["might", "could", "maybe", "perhaps", "I think"]
            if any(phrase in output.lower() for phrase in vague_phrases):
                score -= 0.2

            return max(0.0, min(1.0, score))

        # TODO: Real model inference
        return 0.5

class RewardModel:
    """Wrapper for RewardInference for scoring outputs."""

    def __init__(self, checkpoint_path: Optional[str] = None):
        # In a real scenario, you'd load the actual trained RewardModel
        # For now, we'll initialize RewardInference with a dummy model
        # The RewardInference expects a RewardModel (from rlm_02_reward_model) during init
        # For initial setup, we might need a mock RewardModel or ensure a proper one is passed.
        # Let's assume a basic `RealRewardModel` can be instantiated for RewardInference.
        try:
            # Placeholder for actual model loading, passing minimal args
            dummy_model = RealRewardModel(embedding_dim=768) # Default embedding_dim
            self.inference_engine = RewardInference(model=dummy_model)
            self.enabled = True
            logger.info("Real RewardModel (via RewardInference) initialized.")
        except Exception as e:
            logger.error(f"Failed to initialize Real RewardModel, falling back to mock: {e}")
            self.inference_engine = _PlaceholderRewardModel(checkpoint_path=checkpoint_path)
            self.enabled = False

    async def score_output(self, output: str, context: Optional[str] = None) -> float:
        """Score output quality using the wrapped RewardInference."""
        if self.enabled:
            # The score method of RewardInference expects a prompt and response.
            # We can use the context as the prompt for now, or an empty string if no context.
            score_result = await self.inference_engine.score_async(prompt=context or "", response=output)
            return score_result.score
        else:
            return await self.inference_engine.score_output(output, context)


class RLMGateway:
    """
    Central gateway for all RLM functionality.

    Responsibilities:
    1. Intercept AIVA outputs before sending
    2. Score outputs using reward model
    3. Revise outputs using Constitutional AI
    4. Collect human feedback
    5. Trigger training when thresholds met
    6. Manage A/B testing for policy deployment
    """

    def __init__(self):
        # Constitutional AI (always enabled)
        self.constitutional_ai = ConstitutionalAI()

        # Reward model (enabled once trained)
        self.reward_model = RewardModel()

        # Database connections
        self.db_conn = self._init_db()
        self.redis_conn = self._init_redis()

        # Configuration
        self.training_threshold = 100  # New preferences before training
        self.ab_test_enabled = False
        self.ab_test_new_policy_ratio = 0.2  # 20% traffic to new policy

        # Policy versions
        self.current_policy_version = "v1_baseline"
        self.new_policy_version = None

        logger.info("RLM Gateway initialized")

    def _init_db(self):
        """Initialize PostgreSQL connection."""
        if psycopg2 is None:
            logger.warning("PostgreSQL not available. Running in mock mode.")
            return None

        try:
            conn = psycopg2.connect(**PostgresConfig.get_connection_params())
            logger.info("Connected to PostgreSQL")
            return conn
        except Exception as e:
            logger.error(f"Failed to connect to PostgreSQL: {e}")
            return None

    def _init_redis(self):
        """Initialize Redis connection."""
        if redis is None:
            logger.warning("Redis not available. Running in mock mode.")
            return None

        try:
            redis_config = RedisConfig.get_connection_params()
            conn = redis.Redis(**redis_config)
            conn.ping()
            logger.info("Connected to Redis")
            return conn
        except Exception as e:
            logger.error(f"Failed to connect to Redis: {e}")
            return None

    async def process_output(self, output: str, context: Optional[str] = None) -> Tuple[str, Dict]:
        """
        Main entry point: process AIVA output before sending.

        Steps:
        1. Score output with reward model
        2. Check constitutional compliance
        3. Revise if necessary
        4. Log processing metadata

        Args:
            output: AIVA's raw output
            context: Optional context (input that led to this output)

        Returns:
            (final_output, metadata)
        """
        start_time = time.time()
        metadata = {
            "original_length": len(output),
            "processing_time": 0,
            "reward_score": 0.0,
            "constitutional_check": {},
            "revised": False,
            "policy_version": self.current_policy_version
        }

        # Step 1: Score output
        reward_score = await self.reward_model.score_output(output, context)
        metadata["reward_score"] = reward_score

        # Step 2: Constitutional check
        critique, violated_principles, severity = await self.constitutional_ai.critique_output(output, prompt=context or "")
        metadata["constitutional_check"] = {
            "violated_principles": violated_principles,
            "severity": severity,
            "critique": critique
        }

        # Step 3: Revise if necessary
        final_output = output
        if violated_principles and self.constitutional_ai.critique_engine.enabled:
            final_output = await self.constitutional_ai.revise_output(output, critique, violated_principles, prompt=context or "")
            metadata["revised"] = True

            # Log constitutional violation
            await self._log_constitutional_violation(
                original_output=output,
                critique=critique,
                revised_output=final_output,
                violated_principles=violated_principles,
                severity=severity
            )

        # Step 4: A/B test policy selection
        if self.ab_test_enabled and random.random() < self.ab_test_new_policy_ratio:
            metadata["policy_version"] = self.new_policy_version

        metadata["processing_time"] = time.time() - start_time
        metadata["final_length"] = len(final_output)

        return final_output, metadata

    async def collect_feedback(self, output: str, feedback: str, context: Optional[str] = None):
        """
        Collect human feedback on an output.

        Args:
            output: The output that was shown to the user
            feedback: User feedback ('good', 'bad', 'regenerate', or free text)
            context: Optional context (what led to this output)
        """
        if self.db_conn is None:
            logger.warning("No DB connection. Feedback not stored.")
            return

        try:
            with self.db_conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO rlm_ab_test_results
                    (policy_version, input_text, output_text, user_feedback, created_at)
                    VALUES (%s, %s, %s, %s, NOW())
                """, (
                    self.current_policy_version,
                    context or "",
                    output,
                    feedback
                ))
            self.db_conn.commit()
            logger.info(f"Feedback collected: {feedback}")

            # Check if training threshold met
            await self._check_training_trigger()

        except Exception as e:
            logger.error(f"Failed to collect feedback: {e}")

    async def collect_preference(self, output_a: str, output_b: str, choice: int,
                                   input_text: Optional[str] = None, confidence: float = 1.0,
                                   annotator_id: str = "auto", metadata: Optional[Dict] = None):
        """
        Collect preference comparison: which output is better?

        Args:
            output_a: First output option
            output_b: Second output option
            choice: 1 (A preferred), -1 (B preferred), 0 (tie)
            input_text: Optional input that generated these outputs
            confidence: Confidence in preference (0.0 to 1.0)
            annotator_id: ID of the annotator (default: "auto")
            metadata: Optional dictionary for additional metadata
        """
        if self.db_conn is None:
            logger.warning("No DB connection. Preference not stored.")
            return

        try:
            with self.db_conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO pl_preference_pairs
                    (input_text, output_a, output_b, preference, confidence, annotator_id, metadata, created_at)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, NOW())
                """, (
                    input_text or "",
                    output_a,
                    output_b,
                    choice,
                    confidence,
                    annotator_id,
                    json.dumps(metadata) if metadata else None # Store metadata as JSONB
                ))
            self.db_conn.commit()
            logger.info(f"Preference collected: {choice} (confidence: {confidence})")

            # Check if training threshold met
            await self._check_training_trigger()

        except Exception as e:
            logger.error(f"Failed to collect preference: {e}")

    async def _log_constitutional_violation(self, original_output: str, critique: str,
                                             revised_output: str, violated_principles: List[str],
                                             severity: str):
        """Log constitutional violation to database."""
        if self.db_conn is None:
            return

        try:
            with self.db_conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO cai_critique_log
                    (original_output, critique, revised_output, violated_principles, severity, created_at)
                    VALUES (%s, %s, %s, %s, %s, NOW())
                """, (
                    original_output,
                    critique,
                    revised_output,
                    violated_principles,
                    severity
                ))

                # Update daily violation counts
                for principle_id in violated_principles:
                    cur.execute("""
                        INSERT INTO cai_principle_violations (principle_id, violation_date, count)
                        VALUES (%s, CURRENT_DATE, 1)
                        ON CONFLICT (principle_id, violation_date)
                        DO UPDATE SET count = cai_principle_violations.count + 1
                    """, (principle_id,))

            self.db_conn.commit()
            logger.warning(f"Constitutional violation logged: {severity} - {', '.join(violated_principles)}")

            # Alert on critical violations
            if severity == "critical":
                logger.critical(f"CRITICAL CONSTITUTIONAL VIOLATION: {critique}")
                # TODO: Send Telegram alert to Kinan

        except Exception as e:
            logger.error(f"Failed to log constitutional violation: {e}")

    async def _check_training_trigger(self):
        """Check if training should be triggered based on new data."""
        if self.db_conn is None:
            return

        try:
            with self.db_conn.cursor() as cur:
                # Count preferences collected since last training
                cur.execute("""
                    SELECT COUNT(*) FROM pl_preference_pairs
                    WHERE created_at > (
                        SELECT COALESCE(MAX(created_at), '1970-01-01')
                        FROM rlm_training_triggers
                        WHERE training_module = 'reward_model' AND status = 'completed'
                    )
                """)
                new_prefs = cur.fetchone()[0]

            if new_prefs >= self.training_threshold:
                logger.info(f"Training threshold met: {new_prefs} new preferences. Triggering training.")
                await self.trigger_training()

        except Exception as e:
            logger.error(f"Failed to check training trigger: {e}")

    async def trigger_training(self, module: str = "reward_model", manual: bool = False):
        """
        Trigger RLM training run.

        Args:
            module: Which module to train ('reward_model', 'ppo', 'dpo', 'cai')
            manual: Whether this is a manual trigger (vs automatic threshold)
        """
        if self.db_conn is None:
            logger.warning("No DB connection. Cannot trigger training.")
            return

        trigger_type = "manual" if manual else "threshold"

        try:
            with self.db_conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO rlm_training_triggers
                    (trigger_type, trigger_reason, training_module, status, created_at)
                    VALUES (%s, %s, %s, 'queued', NOW())
                    RETURNING id
                """, (
                    trigger_type,
                    f"Training triggered: {trigger_type}",
                    module
                ))
                trigger_id = cur.fetchone()[0]
            self.db_conn.commit()

            logger.info(f"Training trigger created: ID={trigger_id}, module={module}")

            # TODO: Actually kick off training job
            # This would spawn a background agent to run the training
            # For now, just log the trigger

        except Exception as e:
            logger.error(f"Failed to trigger training: {e}")

    async def process_interaction(
        self,
        call_id: str,
        transcript: str,
        caller_number: str = "",
        call_duration_seconds: int = 0,
        outcome: str = "completed"
    ) -> Dict:
        """
        Process a completed voice call interaction through the full RLM pipeline.

        Called by the Telnyx webhook handler on call.hangup events.
        Runs in SHADOW MODE: logs data only, does not alter call behaviour.

        Pipeline:
        1. Store interaction record in aiva_interactions
        2. Run Constitutional AI check on transcript
        3. Compute heuristic reward score
        4. Infer outcome label (positive/negative/neutral)
        5. Store scored record in aiva_feedback_scores
        6. Check training trigger threshold

        Args:
            call_id: Telnyx call session ID
            transcript: Full concatenated transcript text
            caller_number: Caller phone number (for context)
            call_duration_seconds: Total call length
            outcome: 'completed', 'transferred', 'voicemail', 'error'

        Returns:
            Dict with rlm_processed=True and metadata
        """
        result = {
            "rlm_processed": True,
            "call_id": call_id,
            "shadow_mode": True,
            "interaction_id": None,
            "reward_score": None,
            "cai_violations": [],
            "outcome_label": None,
            "error": None,
        }

        try:
            # --- Step 1: Constitutional AI check ---
            critique, violated_principles, severity = await self.constitutional_ai.critique_output(transcript)
            result["cai_violations"] = violated_principles
            result["cai_severity"] = severity

            if violated_principles:
                await self._log_constitutional_violation(
                    original_output=transcript,
                    critique=critique,
                    revised_output=transcript,  # shadow mode: no revision
                    violated_principles=violated_principles,
                    severity=severity
                )

            # --- Step 2: Reward scoring ---
            reward_score = await self.reward_model.score_output(transcript, context=f"call_id={call_id}")
            result["reward_score"] = reward_score

            # --- Step 3: Infer outcome label ---
            outcome_label = self._infer_outcome_label(transcript, outcome, call_duration_seconds)
            result["outcome_label"] = outcome_label

            # --- Step 4: Persist to PostgreSQL (shadow mode) ---
            if self.db_conn is not None:
                interaction_id = await self._store_aiva_interaction(
                    call_id=call_id,
                    transcript=transcript,
                    caller_number=caller_number,
                    call_duration_seconds=call_duration_seconds,
                    outcome=outcome,
                    reward_score=reward_score,
                    outcome_label=outcome_label,
                    cai_violations=violated_principles,
                    cai_severity=severity,
                )
                result["interaction_id"] = interaction_id

                # --- Step 5: Check training threshold ---
                await self._check_training_trigger()

                logger.info(
                    f"RLM shadow processed call {call_id}: "
                    f"reward={reward_score:.3f}, outcome={outcome_label}, "
                    f"cai_violations={len(violated_principles)}"
                )
            else:
                logger.warning(f"RLM shadow mode (no DB): call {call_id} scored {reward_score:.3f}")

        except Exception as e:
            result["error"] = str(e)
            result["rlm_processed"] = False
            logger.error(f"RLM process_interaction failed for call {call_id}: {e}", exc_info=True)

        return result

    def _infer_outcome_label(self, transcript: str, outcome: str, duration_seconds: int) -> str:
        """
        Infer a positive/negative/neutral label from call signals.

        Positive signals: caller says thank you, booking confirmed, long call
        Negative signals: caller frustrated, hung up quickly, escalation keywords
        """
        text = transcript.lower()

        # Explicit positive signals
        positive_keywords = [
            "thank you", "thanks", "perfect", "great", "excellent",
            "book", "appointment", "schedule", "confirm", "yes please",
            "that's helpful", "sounds good", "brilliant", "amazing"
        ]
        # Explicit negative signals
        negative_keywords = [
            "forget it", "never mind", "not helpful", "terrible",
            "useless", "waste", "frustrated", "angry", "cancel",
            "wrong number", "do not call"
        ]

        positive_count = sum(1 for kw in positive_keywords if kw in text)
        negative_count = sum(1 for kw in negative_keywords if kw in text)

        # Duration heuristic: calls < 30s likely abandoned
        if duration_seconds < 30:
            negative_count += 1

        # Calls > 2 minutes with positive signals = strong positive
        if duration_seconds > 120 and positive_count > 0:
            positive_count += 1

        if outcome == "error":
            return "negative"
        elif positive_count > negative_count:
            return "positive"
        elif negative_count > positive_count:
            return "negative"
        else:
            return "neutral"

    async def _store_aiva_interaction(
        self,
        call_id: str,
        transcript: str,
        caller_number: str,
        call_duration_seconds: int,
        outcome: str,
        reward_score: float,
        outcome_label: str,
        cai_violations: List[str],
        cai_severity: str,
    ) -> Optional[int]:
        """Persist interaction to aiva_interactions and aiva_feedback_scores tables."""
        interaction_id = None
        try:
            with self.db_conn.cursor() as cur:
                # Ensure schema exists (idempotent)
                cur.execute("CREATE SCHEMA IF NOT EXISTS aiva_rlm")

                # aiva_interactions: canonical record of each processed call
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS aiva_rlm.aiva_interactions (
                        id SERIAL PRIMARY KEY,
                        call_id VARCHAR(255) UNIQUE NOT NULL,
                        caller_number VARCHAR(50),
                        transcript TEXT,
                        call_duration_seconds INTEGER DEFAULT 0,
                        outcome VARCHAR(50),
                        outcome_label VARCHAR(20),
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                    )
                """)

                # aiva_feedback_scores: RLM scoring result per interaction
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS aiva_rlm.aiva_feedback_scores (
                        id SERIAL PRIMARY KEY,
                        interaction_id INTEGER REFERENCES aiva_rlm.aiva_interactions(id),
                        call_id VARCHAR(255) NOT NULL,
                        reward_score DECIMAL(6,4),
                        cai_violations TEXT[],
                        cai_severity VARCHAR(20),
                        policy_version VARCHAR(50),
                        shadow_mode BOOLEAN DEFAULT TRUE,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                    )
                """)

                # aiva_preference_pairs: for future A/B comparison and Bradley-Terry
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS aiva_rlm.aiva_preference_pairs (
                        id SERIAL PRIMARY KEY,
                        call_id_a VARCHAR(255),
                        call_id_b VARCHAR(255),
                        preferred VARCHAR(10),
                        confidence DECIMAL(4,3) DEFAULT 1.0,
                        annotator VARCHAR(50) DEFAULT 'auto',
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                    )
                """)

                # Insert interaction
                cur.execute("""
                    INSERT INTO aiva_rlm.aiva_interactions
                    (call_id, caller_number, transcript, call_duration_seconds, outcome, outcome_label)
                    VALUES (%s, %s, %s, %s, %s, %s)
                    ON CONFLICT (call_id) DO UPDATE SET
                        outcome = EXCLUDED.outcome,
                        outcome_label = EXCLUDED.outcome_label
                    RETURNING id
                """, (call_id, caller_number, transcript, call_duration_seconds, outcome, outcome_label))

                row = cur.fetchone()
                interaction_id = row[0] if row else None

                # Insert score record
                cur.execute("""
                    INSERT INTO aiva_rlm.aiva_feedback_scores
                    (interaction_id, call_id, reward_score, cai_violations, cai_severity, policy_version)
                    VALUES (%s, %s, %s, %s, %s, %s)
                """, (
                    interaction_id,
                    call_id,
                    reward_score,
                    cai_violations,
                    cai_severity,
                    self.current_policy_version
                ))

            self.db_conn.commit()

        except Exception as e:
            logger.error(f"Failed to store AIVA interaction {call_id}: {e}")
            try:
                self.db_conn.rollback()
            except Exception:
                pass

        return interaction_id

    async def enable_ab_test(self, new_policy_version: str, traffic_ratio: float = 0.2):
        """
        Enable A/B testing with a new policy version.

        Args:
            new_policy_version: Name of the new policy to test
            traffic_ratio: Fraction of traffic to send to new policy (0.0 to 1.0)
        """
        self.ab_test_enabled = True
        self.new_policy_version = new_policy_version
        self.ab_test_new_policy_ratio = traffic_ratio
        logger.info(f"A/B test enabled: {traffic_ratio*100}% traffic to {new_policy_version}")

    async def get_ab_test_results(self, hours: int = 24) -> Dict:
        """Get A/B test results for the last N hours."""
        if self.db_conn is None:
            return {}

        try:
            with self.db_conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT * FROM v_ab_test_comparison
                """)
                results = cur.fetchall()

            return {
                "results": [dict(row) for row in results],
                "winner": self._determine_ab_winner(results)
            }

        except Exception as e:
            logger.error(f"Failed to get A/B test results: {e}")
            return {}

    def _determine_ab_winner(self, results: List) -> str:
        """Determine which policy won the A/B test."""
        if len(results) < 2:
            return "insufficient_data"

        # Compare average reward scores
        old_policy = next((r for r in results if r["policy_version"] == self.current_policy_version), None)
        new_policy = next((r for r in results if r["policy_version"] == self.new_policy_version), None)

        if not old_policy or not new_policy:
            return "insufficient_data"

        old_score = old_policy.get("avg_reward", 0)
        new_score = new_policy.get("avg_reward", 0)

        if new_score > old_score + 0.05:  # 5% improvement threshold
            return "new_policy"
        elif old_score > new_score + 0.05:
            return "old_policy"
        else:
            return "inconclusive"

    async def promote_policy(self, policy_version: str):
        """Promote a policy to 100% traffic."""
        self.current_policy_version = policy_version
        self.ab_test_enabled = False
        logger.info(f"Policy promoted to production: {policy_version}")

    def close(self):
        """Clean up database connections."""
        if self.db_conn:
            self.db_conn.close()
        if self.redis_conn:
            self.redis_conn.close()


# =============================================================================
# CONVENIENCE FUNCTIONS
# =============================================================================

_gateway_instance = None

def get_gateway() -> RLMGateway:
    """Get or create singleton RLM Gateway instance."""
    global _gateway_instance
    if _gateway_instance is None:
        _gateway_instance = RLMGateway()
    return _gateway_instance


async def process_aiva_output(output: str, context: Optional[str] = None) -> Tuple[str, Dict]:
    """
    Convenience function: process an AIVA output through RLM Gateway.

    Usage:
        final_output, metadata = await process_aiva_output(aiva_response)
        print(final_output)  # Send this to user instead of raw response
    """
    gateway = get_gateway()
    return await gateway.process_output(output, context)


async def collect_feedback(output: str, feedback: str, context: Optional[str] = None):
    """
    Convenience function: collect user feedback on an output.

    Usage:
        await collect_feedback(output, "good")  # User liked it
        await collect_feedback(output, "bad")   # User disliked it
    """
    gateway = get_gateway()
    await gateway.collect_feedback(output, feedback, context)


# =============================================================================
# MAIN (for testing)
# =============================================================================

async def main():
    """Test RLM Gateway functionality."""
    gateway = RLMGateway()

    # Test 1: Process output with constitutional check
    test_output = "Let me write some files to C:\\Users\\P3\\.claude-worktrees\\ to help with this task."
    final_output, metadata = await gateway.process_output(test_output)

    print("=" * 80)
    print("TEST 1: Constitutional Check")
    print("=" * 80)
    print(f"Original: {test_output}")
    print(f"\nFinal: {final_output}")
    print(f"\nMetadata: {json.dumps(metadata, indent=2)}")

    # Test 2: Collect preference
    await gateway.collect_preference(
        output_a="I'll dispatch 3 agents to build this (2 hours)",
        output_b="I can help you with that if you'd like",
        choice=1,  # A is better
        input_text="Build the RLM Gateway",
        confidence=0.9
    )

    print("\n" + "=" * 80)
    print("TEST 2: Preference Collection")
    print("=" * 80)
    print("Preference stored in database")

    # Test 3: Reward scoring
    good_output = "Dispatching 3 Opus agents to build RLM Gateway (2-3 hours). Estimated cost: $0.50. Will integrate with PostgreSQL, Redis, and Constitutional AI."
    bad_output = "I think I could probably help with that maybe."

    score_good = await gateway.reward_model.score_output(good_output)
    score_bad = await gateway.reward_model.score_output(bad_output)

    print("\n" + "=" * 80)
    print("TEST 3: Reward Scoring")
    print("=" * 80)
    print(f"Good output score: {score_good:.2f}")
    print(f"Bad output score: {score_bad:.2f}")

    gateway.close()


if __name__ == "__main__":
    asyncio.run(main())
