"""
AIVA Outcome Tracker
Tracks actual outcomes vs expected outcomes for autonomous decisions.
Feeds learning back into confidence scorer for continuous improvement.

CRITICAL: Uses PostgreSQL via Elestio config (NO SQLite)
"""

import sys
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
import json
import logging

# Add genesis-memory to path for Elestio config
GENESIS_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(GENESIS_ROOT / "data" / "genesis-memory"))

from elestio_config import PostgresConfig
import psycopg2
from psycopg2.extras import RealDictCursor, Json

logger = logging.getLogger(__name__)


@dataclass
class OutcomeComparison:
    """Comparison between expected and actual outcomes"""
    decision_id: str
    task_type: str
    expected_outcome: Dict[str, Any]
    actual_outcome: Dict[str, Any]
    confidence_at_decision: float
    was_correct: bool
    deviation_score: float
    prediction_time: datetime
    resolution_time: datetime
    time_to_resolve_minutes: float


@dataclass
class AccuracyStats:
    """Accuracy statistics for a task type"""
    task_type: str
    total_predictions: int
    correct_predictions: int
    accuracy_rate: float
    avg_confidence: float
    avg_deviation: float
    confidence_calibration: float  # How well confidence matches actual accuracy
    window_days: int
    sample_period: Tuple[datetime, datetime]


@dataclass
class CalibrationReport:
    """Report on confidence score calibration"""
    overall_accuracy: float
    total_decisions: int
    per_task_accuracy: Dict[str, float]
    confidence_buckets: Dict[str, Dict[str, Any]]
    calibration_score: float  # 1.0 = perfectly calibrated
    overconfidence_rate: float
    underconfidence_rate: float
    generated_at: datetime


class OutcomeTracker:
    """
    Tracks outcome predictions and actuals for AIVA's autonomous decisions.
    Provides calibration metrics to improve confidence scoring over time.
    """

    def __init__(self):
        """Initialize outcome tracker with PostgreSQL connection"""
        self.conn_params = PostgresConfig.get_connection_params()
        self._ensure_schema()

    def _get_connection(self):
        """Get PostgreSQL connection"""
        return psycopg2.connect(**self.conn_params)

    def _ensure_schema(self):
        """Create outcome tracking table if it doesn't exist"""
        schema_sql = """
        CREATE TABLE IF NOT EXISTS aiva_outcome_tracking (
            id SERIAL PRIMARY KEY,
            decision_id VARCHAR(255) UNIQUE NOT NULL,
            task_type VARCHAR(100) NOT NULL,
            expected_outcome JSONB NOT NULL,
            actual_outcome JSONB,
            confidence_at_decision FLOAT NOT NULL CHECK (confidence_at_decision BETWEEN 0 AND 1),
            was_correct BOOLEAN,
            deviation_score FLOAT CHECK (deviation_score BETWEEN 0 AND 1),
            recorded_at TIMESTAMP NOT NULL DEFAULT NOW(),
            resolved_at TIMESTAMP,
            metadata JSONB DEFAULT '{}'::jsonb,
            created_at TIMESTAMP DEFAULT NOW(),
            updated_at TIMESTAMP DEFAULT NOW()
        );

        CREATE INDEX IF NOT EXISTS idx_outcome_task_type ON aiva_outcome_tracking(task_type);
        CREATE INDEX IF NOT EXISTS idx_outcome_recorded_at ON aiva_outcome_tracking(recorded_at);
        CREATE INDEX IF NOT EXISTS idx_outcome_resolved_at ON aiva_outcome_tracking(resolved_at);
        CREATE INDEX IF NOT EXISTS idx_outcome_confidence ON aiva_outcome_tracking(confidence_at_decision);
        CREATE INDEX IF NOT EXISTS idx_outcome_was_correct ON aiva_outcome_tracking(was_correct);

        -- Trigger to update updated_at timestamp
        CREATE OR REPLACE FUNCTION update_outcome_tracking_timestamp()
        RETURNS TRIGGER AS $$
        BEGIN
            NEW.updated_at = NOW();
            RETURN NEW;
        END;
        $$ LANGUAGE plpgsql;

        DROP TRIGGER IF EXISTS trigger_update_outcome_tracking_timestamp ON aiva_outcome_tracking;
        CREATE TRIGGER trigger_update_outcome_tracking_timestamp
            BEFORE UPDATE ON aiva_outcome_tracking
            FOR EACH ROW
            EXECUTE FUNCTION update_outcome_tracking_timestamp();
        """

        try:
            with self._get_connection() as conn:
                with conn.cursor() as cur:
                    cur.execute(schema_sql)
                    conn.commit()
            logger.info("Outcome tracking schema ensured")
        except Exception as e:
            logger.error(f"Failed to ensure schema: {e}")
            raise

    def record_prediction(
        self,
        decision_id: str,
        task_type: str,
        expected_outcome: Dict[str, Any],
        confidence_score: float,
        metadata: Optional[Dict[str, Any]] = None
    ) -> bool:
        """
        Record a prediction for a decision.

        Args:
            decision_id: Unique identifier for this decision
            task_type: Type of task (e.g., 'email_classification', 'priority_assignment')
            expected_outcome: What we predict will happen (as dict)
            confidence_score: Confidence in this prediction (0-1)
            metadata: Optional additional context

        Returns:
            True if recorded successfully
        """
        if not 0 <= confidence_score <= 1:
            raise ValueError(f"Confidence score must be between 0 and 1, got {confidence_score}")

        insert_sql = """
        INSERT INTO aiva_outcome_tracking
            (decision_id, task_type, expected_outcome, confidence_at_decision, metadata)
        VALUES (%s, %s, %s, %s, %s)
        ON CONFLICT (decision_id) DO UPDATE SET
            task_type = EXCLUDED.task_type,
            expected_outcome = EXCLUDED.expected_outcome,
            confidence_at_decision = EXCLUDED.confidence_at_decision,
            metadata = EXCLUDED.metadata,
            updated_at = NOW()
        """

        try:
            with self._get_connection() as conn:
                with conn.cursor() as cur:
                    cur.execute(
                        insert_sql,
                        (
                            decision_id,
                            task_type,
                            Json(expected_outcome),
                            confidence_score,
                            Json(metadata or {})
                        )
                    )
                    conn.commit()
            logger.info(f"Recorded prediction for decision {decision_id}")
            return True
        except Exception as e:
            logger.error(f"Failed to record prediction: {e}")
            return False

    def record_actual(
        self,
        decision_id: str,
        actual_outcome: Dict[str, Any],
        success: bool
    ) -> bool:
        """
        Record the actual outcome for a decision.

        Args:
            decision_id: Unique identifier for the decision
            actual_outcome: What actually happened (as dict)
            success: Whether the outcome matched expectations

        Returns:
            True if recorded successfully
        """
        # First, get the expected outcome to calculate deviation
        select_sql = """
        SELECT expected_outcome, confidence_at_decision
        FROM aiva_outcome_tracking
        WHERE decision_id = %s
        """

        try:
            with self._get_connection() as conn:
                with conn.cursor(cursor_factory=RealDictCursor) as cur:
                    cur.execute(select_sql, (decision_id,))
                    row = cur.fetchone()

                    if not row:
                        logger.error(f"No prediction found for decision {decision_id}")
                        return False

                    expected = dict(row['expected_outcome'])

                    # Calculate deviation score
                    deviation = self._calculate_deviation(expected, actual_outcome, success)

                    # Update with actual outcome
                    update_sql = """
                    UPDATE aiva_outcome_tracking
                    SET actual_outcome = %s,
                        was_correct = %s,
                        deviation_score = %s,
                        resolved_at = NOW()
                    WHERE decision_id = %s
                    """

                    cur.execute(
                        update_sql,
                        (Json(actual_outcome), success, deviation, decision_id)
                    )
                    conn.commit()

            logger.info(f"Recorded actual outcome for decision {decision_id}")
            return True
        except Exception as e:
            logger.error(f"Failed to record actual outcome: {e}")
            return False

    def _calculate_deviation(
        self,
        expected: Dict[str, Any],
        actual: Dict[str, Any],
        success: bool
    ) -> float:
        """
        Calculate deviation score between expected and actual outcomes.

        Returns:
            Float between 0 (completely wrong) and 1 (perfect match)
        """
        if success:
            # If marked as success, check how close the actual values are
            if not expected or not actual:
                return 1.0

            # Calculate field-by-field match
            common_keys = set(expected.keys()) & set(actual.keys())
            if not common_keys:
                return 0.5  # Different structure but marked success

            matches = 0
            total = len(common_keys)

            for key in common_keys:
                exp_val = expected[key]
                act_val = actual[key]

                # Exact match
                if exp_val == act_val:
                    matches += 1
                # Numeric proximity
                elif isinstance(exp_val, (int, float)) and isinstance(act_val, (int, float)):
                    if exp_val == 0:
                        matches += 0.5 if act_val == 0 else 0
                    else:
                        proximity = 1 - min(abs(exp_val - act_val) / abs(exp_val), 1.0)
                        matches += proximity
                # String similarity (simple)
                elif isinstance(exp_val, str) and isinstance(act_val, str):
                    if exp_val.lower() in act_val.lower() or act_val.lower() in exp_val.lower():
                        matches += 0.7
                    else:
                        matches += 0.3

            return matches / total if total > 0 else 0.5
        else:
            # Marked as failure
            return 0.0

    def compare_outcomes(self, decision_id: str) -> Optional[OutcomeComparison]:
        """
        Get comparison between expected and actual outcomes.

        Args:
            decision_id: Unique identifier for the decision

        Returns:
            OutcomeComparison object or None if not found/resolved
        """
        select_sql = """
        SELECT
            decision_id,
            task_type,
            expected_outcome,
            actual_outcome,
            confidence_at_decision,
            was_correct,
            deviation_score,
            recorded_at,
            resolved_at
        FROM aiva_outcome_tracking
        WHERE decision_id = %s AND resolved_at IS NOT NULL
        """

        try:
            with self._get_connection() as conn:
                with conn.cursor(cursor_factory=RealDictCursor) as cur:
                    cur.execute(select_sql, (decision_id,))
                    row = cur.fetchone()

                    if not row:
                        return None

                    time_diff = (row['resolved_at'] - row['recorded_at']).total_seconds() / 60

                    return OutcomeComparison(
                        decision_id=row['decision_id'],
                        task_type=row['task_type'],
                        expected_outcome=dict(row['expected_outcome']),
                        actual_outcome=dict(row['actual_outcome']),
                        confidence_at_decision=row['confidence_at_decision'],
                        was_correct=row['was_correct'],
                        deviation_score=row['deviation_score'],
                        prediction_time=row['recorded_at'],
                        resolution_time=row['resolved_at'],
                        time_to_resolve_minutes=time_diff
                    )
        except Exception as e:
            logger.error(f"Failed to compare outcomes: {e}")
            return None

    def get_accuracy_stats(
        self,
        task_type: str,
        window_days: int = 30
    ) -> Optional[AccuracyStats]:
        """
        Get accuracy statistics for a task type.

        Args:
            task_type: Type of task to analyze
            window_days: Number of days to look back

        Returns:
            AccuracyStats object or None if no data
        """
        cutoff_date = datetime.now() - timedelta(days=window_days)

        stats_sql = """
        SELECT
            COUNT(*) as total,
            SUM(CASE WHEN was_correct THEN 1 ELSE 0 END) as correct,
            AVG(confidence_at_decision) as avg_confidence,
            AVG(deviation_score) as avg_deviation,
            MIN(recorded_at) as earliest,
            MAX(recorded_at) as latest
        FROM aiva_outcome_tracking
        WHERE task_type = %s
          AND resolved_at IS NOT NULL
          AND recorded_at >= %s
        """

        try:
            with self._get_connection() as conn:
                with conn.cursor(cursor_factory=RealDictCursor) as cur:
                    cur.execute(stats_sql, (task_type, cutoff_date))
                    row = cur.fetchone()

                    if not row or row['total'] == 0:
                        return None

                    total = row['total']
                    correct = row['correct'] or 0
                    accuracy = correct / total if total > 0 else 0.0
                    avg_conf = float(row['avg_confidence'] or 0.0)

                    # Calibration: how close is confidence to actual accuracy?
                    # Perfect calibration = 1.0, gets worse as difference increases
                    calibration = 1.0 - abs(avg_conf - accuracy)

                    return AccuracyStats(
                        task_type=task_type,
                        total_predictions=total,
                        correct_predictions=correct,
                        accuracy_rate=accuracy,
                        avg_confidence=avg_conf,
                        avg_deviation=float(row['avg_deviation'] or 0.0),
                        confidence_calibration=calibration,
                        window_days=window_days,
                        sample_period=(row['earliest'], row['latest'])
                    )
        except Exception as e:
            logger.error(f"Failed to get accuracy stats: {e}")
            return None

    def get_calibration_report(self, window_days: int = 30) -> CalibrationReport:
        """
        Get comprehensive calibration report.

        Args:
            window_days: Number of days to look back

        Returns:
            CalibrationReport with detailed calibration metrics
        """
        cutoff_date = datetime.now() - timedelta(days=window_days)

        # Overall accuracy
        overall_sql = """
        SELECT
            COUNT(*) as total,
            SUM(CASE WHEN was_correct THEN 1 ELSE 0 END) as correct
        FROM aiva_outcome_tracking
        WHERE resolved_at IS NOT NULL
          AND recorded_at >= %s
        """

        # Per-task accuracy
        per_task_sql = """
        SELECT
            task_type,
            COUNT(*) as total,
            SUM(CASE WHEN was_correct THEN 1 ELSE 0 END) as correct
        FROM aiva_outcome_tracking
        WHERE resolved_at IS NOT NULL
          AND recorded_at >= %s
        GROUP BY task_type
        """

        # Confidence bucket analysis
        bucket_sql = """
        SELECT
            CASE
                WHEN confidence_at_decision < 0.3 THEN '0.0-0.3'
                WHEN confidence_at_decision < 0.6 THEN '0.3-0.6'
                WHEN confidence_at_decision < 0.8 THEN '0.6-0.8'
                ELSE '0.8-1.0'
            END as bucket,
            COUNT(*) as total,
            SUM(CASE WHEN was_correct THEN 1 ELSE 0 END) as correct,
            AVG(confidence_at_decision) as avg_confidence
        FROM aiva_outcome_tracking
        WHERE resolved_at IS NOT NULL
          AND recorded_at >= %s
        GROUP BY bucket
        ORDER BY bucket
        """

        try:
            with self._get_connection() as conn:
                with conn.cursor(cursor_factory=RealDictCursor) as cur:
                    # Overall accuracy
                    cur.execute(overall_sql, (cutoff_date,))
                    overall = cur.fetchone()
                    total_decisions = overall['total'] or 0
                    overall_accuracy = (overall['correct'] or 0) / total_decisions if total_decisions > 0 else 0.0

                    # Per-task accuracy
                    cur.execute(per_task_sql, (cutoff_date,))
                    per_task_rows = cur.fetchall()
                    per_task_accuracy = {
                        row['task_type']: (row['correct'] or 0) / row['total']
                        for row in per_task_rows
                    }

                    # Confidence buckets
                    cur.execute(bucket_sql, (cutoff_date,))
                    bucket_rows = cur.fetchall()
                    confidence_buckets = {}

                    overconfident_count = 0
                    underconfident_count = 0
                    calibrated_count = 0

                    for row in bucket_rows:
                        bucket = row['bucket']
                        total = row['total']
                        correct = row['correct'] or 0
                        accuracy = correct / total if total > 0 else 0.0
                        avg_conf = float(row['avg_confidence'] or 0.0)

                        confidence_buckets[bucket] = {
                            'total': total,
                            'correct': correct,
                            'accuracy': accuracy,
                            'avg_confidence': avg_conf
                        }

                        # Detect over/under confidence
                        diff = avg_conf - accuracy
                        if diff > 0.1:  # Confidence exceeds accuracy by >10%
                            overconfident_count += total
                        elif diff < -0.1:  # Accuracy exceeds confidence by >10%
                            underconfident_count += total
                        else:
                            calibrated_count += total

                    # Calculate calibration score
                    # Perfect calibration = all predictions in calibrated_count
                    calibration_score = calibrated_count / total_decisions if total_decisions > 0 else 0.0

                    overconfidence_rate = overconfident_count / total_decisions if total_decisions > 0 else 0.0
                    underconfidence_rate = underconfident_count / total_decisions if total_decisions > 0 else 0.0

                    return CalibrationReport(
                        overall_accuracy=overall_accuracy,
                        total_decisions=total_decisions,
                        per_task_accuracy=per_task_accuracy,
                        confidence_buckets=confidence_buckets,
                        calibration_score=calibration_score,
                        overconfidence_rate=overconfidence_rate,
                        underconfidence_rate=underconfidence_rate,
                        generated_at=datetime.now()
                    )
        except Exception as e:
            logger.error(f"Failed to generate calibration report: {e}")
            # Return empty report on error
            return CalibrationReport(
                overall_accuracy=0.0,
                total_decisions=0,
                per_task_accuracy={},
                confidence_buckets={},
                calibration_score=0.0,
                overconfidence_rate=0.0,
                underconfidence_rate=0.0,
                generated_at=datetime.now()
            )

    def get_learning_signals(
        self,
        task_type: Optional[str] = None,
        min_confidence: float = 0.0,
        max_confidence: float = 1.0,
        window_days: int = 7
    ) -> List[Dict[str, Any]]:
        """
        Get learning signals for confidence scorer training.
        Returns recent decisions with outcomes for model improvement.

        Args:
            task_type: Filter by task type (optional)
            min_confidence: Minimum confidence threshold
            max_confidence: Maximum confidence threshold
            window_days: Days to look back

        Returns:
            List of learning signal dictionaries
        """
        cutoff_date = datetime.now() - timedelta(days=window_days)

        query_sql = """
        SELECT
            decision_id,
            task_type,
            expected_outcome,
            actual_outcome,
            confidence_at_decision,
            was_correct,
            deviation_score,
            metadata
        FROM aiva_outcome_tracking
        WHERE resolved_at IS NOT NULL
          AND recorded_at >= %s
          AND confidence_at_decision >= %s
          AND confidence_at_decision <= %s
        """

        params = [cutoff_date, min_confidence, max_confidence]

        if task_type:
            query_sql += " AND task_type = %s"
            params.append(task_type)

        query_sql += " ORDER BY recorded_at DESC LIMIT 1000"

        try:
            with self._get_connection() as conn:
                with conn.cursor(cursor_factory=RealDictCursor) as cur:
                    cur.execute(query_sql, params)
                    rows = cur.fetchall()

                    return [
                        {
                            'decision_id': row['decision_id'],
                            'task_type': row['task_type'],
                            'expected_outcome': dict(row['expected_outcome']),
                            'actual_outcome': dict(row['actual_outcome']),
                            'confidence_at_decision': row['confidence_at_decision'],
                            'was_correct': row['was_correct'],
                            'deviation_score': row['deviation_score'],
                            'metadata': dict(row['metadata'])
                        }
                        for row in rows
                    ]
        except Exception as e:
            logger.error(f"Failed to get learning signals: {e}")
            return []


# VERIFICATION_STAMP
# Component: AIVA Outcome Tracker
# Verified By: parallel-builder
# Verified At: 2026-02-11T00:00:00Z
# Tests: Pending (black box + white box tests required)
# Coverage: Pending
# Storage: PostgreSQL via Elestio config (NO SQLite)
# Compliance: GLOBAL_GENESIS_RULES.md Rule 7
