#!/usr/bin/env python3
"""
PRD GATE - 100-Question Enforcement Layer
==========================================

This module ENFORCES the 100-Question Gate mandate.

NO complex task can execute without:
1. An approved Meta-PRD (100+ questions answered by human)
2. OR an approved Sub-PRD (inherits from approved Meta-PRD)
3. OR classification as "simple task" (bypass allowed)

THERE IS NO BYPASS for complex tasks without PRD coverage.

Usage:
    from core.prd_gate import PRDGate, check_task_clearance

    # Before ANY task execution
    clearance = check_task_clearance(task_description)
    if not clearance.approved:
        raise PRDGateError(clearance.reason)

Authority: Kinan (Creator)
Effective: 2026-01-24
"""

import sys
import json
import hashlib
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Optional, List, Dict, Any
from enum import Enum

# Elestio PostgreSQL (NO SQLITE - per GLOBAL_GENESIS_RULES)
sys.path.insert(0, str(Path(__file__).parent.parent / "data" / "genesis-memory"))
try:
    from elestio_config import PostgresConfig
    import psycopg2
    from psycopg2.extras import RealDictCursor
    POSTGRES_AVAILABLE = True
except ImportError:
    POSTGRES_AVAILABLE = False


class TaskComplexity(Enum):
    """Task complexity classification."""
    SIMPLE = "simple"           # Single action, no PRD needed
    MODERATE = "moderate"       # 2-5 steps, Sub-PRD recommended
    COMPLEX = "complex"         # 5+ steps, Sub-PRD required
    META = "meta"               # System-level, Meta-PRD required
    CUSTOMER_FACING = "customer_facing"  # Touches customer, HITL required


class PRDStatus(Enum):
    """PRD approval status."""
    DRAFT = "draft"
    PENDING_QUESTIONS = "pending_questions"
    QUESTIONS_COMPLETE = "questions_complete"
    PENDING_APPROVAL = "pending_approval"
    APPROVED = "approved"
    REJECTED = "rejected"
    ARCHIVED = "archived"


@dataclass
class TaskClearance:
    """Result of PRD gate check."""
    approved: bool
    task_hash: str
    complexity: TaskComplexity
    prd_id: Optional[str] = None
    prd_title: Optional[str] = None
    reason: str = ""
    requires_hitl: bool = False
    hitl_checkpoint: Optional[str] = None
    checked_at: str = ""

    def __post_init__(self):
        if not self.checked_at:
            self.checked_at = datetime.now().isoformat()


@dataclass
class PRDRecord:
    """PRD stored in database."""
    id: str
    title: str
    status: PRDStatus
    prd_type: str  # "meta" or "sub"
    parent_prd_id: Optional[str]
    question_count: int
    questions_answered: int
    approved_by: Optional[str]
    approved_at: Optional[str]
    created_at: str
    file_path: str
    coverage_domains: List[str]  # What task domains this PRD covers


class PRDGateError(Exception):
    """Raised when task is blocked by PRD gate."""
    pass


class PRDGate:
    """
    The 100-Question PRD Gate Enforcer.

    This class is the SINGLE POINT OF ENFORCEMENT for PRD requirements.
    All task execution MUST pass through here.
    """

    # Complexity thresholds
    SIMPLE_TASK_MAX_WORDS = 50
    SIMPLE_TASK_INDICATORS = [
        "what is", "how do i", "show me", "list", "check", "status",
        "help", "explain", "describe", "get", "fetch", "read"
    ]

    COMPLEX_TASK_INDICATORS = [
        "implement", "build", "create", "develop", "design", "architect",
        "refactor", "migrate", "integrate", "deploy", "launch", "scale",
        "automate", "pipeline", "workflow", "system", "feature", "module"
    ]

    CUSTOMER_FACING_INDICATORS = [
        "customer", "client", "user", "outreach", "email", "call",
        "pitch", "proposal", "demo", "lead", "prospect", "campaign",
        "send", "publish", "post", "advertise", "market"
    ]

    META_PRD_INDICATORS = [
        "revenue", "strategy", "architecture", "platform", "infrastructure",
        "core", "foundation", "framework", "system-wide", "global"
    ]

    def __init__(self):
        self.db_conn = None
        self._ensure_database()

    def _ensure_database(self):
        """Ensure PRD tables exist in PostgreSQL."""
        if not POSTGRES_AVAILABLE:
            print("WARNING: PostgreSQL not available. PRD Gate running in memory-only mode.")
            return

        try:
            self.db_conn = psycopg2.connect(**PostgresConfig.get_connection_params())
            cursor = self.db_conn.cursor()

            # Create PRD registry table
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS prd_registry (
                    id VARCHAR(64) PRIMARY KEY,
                    title VARCHAR(500) NOT NULL,
                    status VARCHAR(50) NOT NULL,
                    prd_type VARCHAR(20) NOT NULL,
                    parent_prd_id VARCHAR(64),
                    question_count INTEGER DEFAULT 0,
                    questions_answered INTEGER DEFAULT 0,
                    approved_by VARCHAR(100),
                    approved_at TIMESTAMP,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    file_path VARCHAR(500),
                    coverage_domains JSONB DEFAULT '[]'
                )
            """)

            # Create task clearance audit table
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS task_clearance_audit (
                    id SERIAL PRIMARY KEY,
                    task_hash VARCHAR(64) NOT NULL,
                    task_preview VARCHAR(500),
                    complexity VARCHAR(50) NOT NULL,
                    approved BOOLEAN NOT NULL,
                    prd_id VARCHAR(64),
                    reason TEXT,
                    requires_hitl BOOLEAN DEFAULT FALSE,
                    checked_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """)

            # Create HITL approval queue table
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS hitl_approval_queue (
                    id SERIAL PRIMARY KEY,
                    task_hash VARCHAR(64) NOT NULL,
                    task_description TEXT NOT NULL,
                    output_preview TEXT,
                    prd_id VARCHAR(64),
                    status VARCHAR(50) DEFAULT 'pending',
                    submitted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    reviewed_at TIMESTAMP,
                    approved_by VARCHAR(100),
                    rejection_reason TEXT
                )
            """)

            self.db_conn.commit()
            cursor.close()

        except Exception as e:
            print(f"Database initialization error: {e}")
            self.db_conn = None

    def classify_task(self, task: str) -> TaskComplexity:
        """
        Classify task complexity based on content analysis.

        This is the FIRST LINE of the gate - determining what level
        of PRD coverage is required.
        """
        task_lower = task.lower()
        word_count = len(task.split())

        # Check for customer-facing indicators (highest priority)
        if any(ind in task_lower for ind in self.CUSTOMER_FACING_INDICATORS):
            return TaskComplexity.CUSTOMER_FACING

        # Check for meta-level indicators
        if any(ind in task_lower for ind in self.META_PRD_INDICATORS):
            return TaskComplexity.META

        # Check for simple task patterns
        if word_count <= self.SIMPLE_TASK_MAX_WORDS:
            if any(task_lower.startswith(ind) for ind in self.SIMPLE_TASK_INDICATORS):
                return TaskComplexity.SIMPLE

        # Check for complex indicators
        complex_count = sum(1 for ind in self.COMPLEX_TASK_INDICATORS if ind in task_lower)

        if complex_count >= 3 or word_count > 200:
            return TaskComplexity.COMPLEX
        elif complex_count >= 1 or word_count > 100:
            return TaskComplexity.MODERATE
        else:
            return TaskComplexity.SIMPLE

    def find_covering_prd(self, task: str) -> Optional[PRDRecord]:
        """
        Find an approved PRD that covers this task's domain.

        Returns None if no approved PRD covers this task.
        """
        if not self.db_conn:
            return None

        try:
            cursor = self.db_conn.cursor(cursor_factory=RealDictCursor)

            # Get all approved PRDs
            cursor.execute("""
                SELECT * FROM prd_registry
                WHERE status = 'approved'
                ORDER BY created_at DESC
            """)

            prds = cursor.fetchall()
            cursor.close()

            task_lower = task.lower()

            for prd in prds:
                domains = prd.get('coverage_domains', [])
                if isinstance(domains, str):
                    domains = json.loads(domains)

                # Check if any domain keyword matches the task
                for domain in domains:
                    if domain.lower() in task_lower:
                        return PRDRecord(
                            id=prd['id'],
                            title=prd['title'],
                            status=PRDStatus(prd['status']),
                            prd_type=prd['prd_type'],
                            parent_prd_id=prd.get('parent_prd_id'),
                            question_count=prd.get('question_count', 0),
                            questions_answered=prd.get('questions_answered', 0),
                            approved_by=prd.get('approved_by'),
                            approved_at=prd.get('approved_at'),
                            created_at=str(prd.get('created_at', '')),
                            file_path=prd.get('file_path', ''),
                            coverage_domains=domains
                        )

            return None

        except Exception as e:
            print(f"PRD lookup error: {e}")
            return None

    def check_clearance(self, task: str) -> TaskClearance:
        """
        THE MAIN GATE CHECK.

        This method MUST be called before ANY task execution.
        Returns TaskClearance with approval status.
        """
        task_hash = hashlib.sha256(task.encode()).hexdigest()[:16]
        complexity = self.classify_task(task)

        # SIMPLE tasks always pass (no PRD needed)
        if complexity == TaskComplexity.SIMPLE:
            clearance = TaskClearance(
                approved=True,
                task_hash=task_hash,
                complexity=complexity,
                reason="Simple task - no PRD required"
            )
            self._audit_clearance(clearance, task[:500])
            return clearance

        # CUSTOMER_FACING always requires HITL, even with PRD
        if complexity == TaskComplexity.CUSTOMER_FACING:
            prd = self.find_covering_prd(task)
            clearance = TaskClearance(
                approved=prd is not None,
                task_hash=task_hash,
                complexity=complexity,
                prd_id=prd.id if prd else None,
                prd_title=prd.title if prd else None,
                reason="Customer-facing output requires PRD + HITL approval" if not prd else "PRD approved - HITL checkpoint required before customer contact",
                requires_hitl=True,
                hitl_checkpoint="pre_customer_contact"
            )
            self._audit_clearance(clearance, task[:500])
            return clearance

        # MODERATE/COMPLEX/META require PRD coverage
        prd = self.find_covering_prd(task)

        if prd:
            clearance = TaskClearance(
                approved=True,
                task_hash=task_hash,
                complexity=complexity,
                prd_id=prd.id,
                prd_title=prd.title,
                reason=f"Covered by approved PRD: {prd.title}"
            )
        else:
            clearance = TaskClearance(
                approved=False,
                task_hash=task_hash,
                complexity=complexity,
                reason=f"BLOCKED: {complexity.value} task requires approved PRD coverage. No matching PRD found."
            )

        self._audit_clearance(clearance, task[:500])
        return clearance

    def _audit_clearance(self, clearance: TaskClearance, task_preview: str):
        """Log clearance decision for audit trail."""
        if not self.db_conn:
            return

        try:
            cursor = self.db_conn.cursor()
            cursor.execute("""
                INSERT INTO task_clearance_audit
                (task_hash, task_preview, complexity, approved, prd_id, reason, requires_hitl)
                VALUES (%s, %s, %s, %s, %s, %s, %s)
            """, (
                clearance.task_hash,
                task_preview,
                clearance.complexity.value,
                clearance.approved,
                clearance.prd_id,
                clearance.reason,
                clearance.requires_hitl
            ))
            self.db_conn.commit()
            cursor.close()
        except Exception as e:
            print(f"Audit logging error: {e}")

    def register_prd(
        self,
        prd_id: str,
        title: str,
        prd_type: str,
        file_path: str,
        coverage_domains: List[str],
        parent_prd_id: Optional[str] = None,
        question_count: int = 0,
        questions_answered: int = 0
    ) -> bool:
        """Register a new PRD in the system."""
        if not self.db_conn:
            return False

        try:
            cursor = self.db_conn.cursor()
            cursor.execute("""
                INSERT INTO prd_registry
                (id, title, status, prd_type, parent_prd_id, question_count,
                 questions_answered, file_path, coverage_domains)
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
                ON CONFLICT (id) DO UPDATE SET
                    title = EXCLUDED.title,
                    question_count = EXCLUDED.question_count,
                    questions_answered = EXCLUDED.questions_answered,
                    file_path = EXCLUDED.file_path,
                    coverage_domains = EXCLUDED.coverage_domains
            """, (
                prd_id,
                title,
                PRDStatus.DRAFT.value,
                prd_type,
                parent_prd_id,
                question_count,
                questions_answered,
                file_path,
                json.dumps(coverage_domains)
            ))
            self.db_conn.commit()
            cursor.close()
            return True
        except Exception as e:
            print(f"PRD registration error: {e}")
            return False

    def approve_prd(self, prd_id: str, approved_by: str = "kinan") -> bool:
        """Mark a PRD as approved (only after questions complete)."""
        if not self.db_conn:
            return False

        try:
            cursor = self.db_conn.cursor()
            cursor.execute("""
                UPDATE prd_registry
                SET status = %s, approved_by = %s, approved_at = CURRENT_TIMESTAMP
                WHERE id = %s AND questions_answered >= question_count
            """, (PRDStatus.APPROVED.value, approved_by, prd_id))

            updated = cursor.rowcount > 0
            self.db_conn.commit()
            cursor.close()
            return updated
        except Exception as e:
            print(f"PRD approval error: {e}")
            return False

    def submit_for_hitl(self, task: str, output_preview: str, prd_id: Optional[str] = None) -> int:
        """Submit output for human-in-the-loop approval before customer contact."""
        if not self.db_conn:
            return -1

        try:
            task_hash = hashlib.sha256(task.encode()).hexdigest()[:16]
            cursor = self.db_conn.cursor()
            cursor.execute("""
                INSERT INTO hitl_approval_queue
                (task_hash, task_description, output_preview, prd_id)
                VALUES (%s, %s, %s, %s)
                RETURNING id
            """, (task_hash, task, output_preview, prd_id))

            queue_id = cursor.fetchone()[0]
            self.db_conn.commit()
            cursor.close()
            return queue_id
        except Exception as e:
            print(f"HITL submission error: {e}")
            return -1

    def check_hitl_approval(self, queue_id: int) -> Optional[Dict[str, Any]]:
        """Check if HITL approval has been granted."""
        if not self.db_conn:
            return None

        try:
            cursor = self.db_conn.cursor(cursor_factory=RealDictCursor)
            cursor.execute("""
                SELECT * FROM hitl_approval_queue WHERE id = %s
            """, (queue_id,))

            result = cursor.fetchone()
            cursor.close()
            return dict(result) if result else None
        except Exception as e:
            print(f"HITL check error: {e}")
            return None

    def get_pending_hitl(self) -> List[Dict[str, Any]]:
        """Get all items pending HITL approval."""
        if not self.db_conn:
            return []

        try:
            cursor = self.db_conn.cursor(cursor_factory=RealDictCursor)
            cursor.execute("""
                SELECT * FROM hitl_approval_queue
                WHERE status = 'pending'
                ORDER BY submitted_at ASC
            """)

            results = cursor.fetchall()
            cursor.close()
            return [dict(r) for r in results]
        except Exception as e:
            print(f"Pending HITL fetch error: {e}")
            return []


# Singleton instance
_prd_gate = None


def get_prd_gate() -> PRDGate:
    """Get singleton PRD gate instance."""
    global _prd_gate
    if _prd_gate is None:
        _prd_gate = PRDGate()
    return _prd_gate


def check_task_clearance(task: str) -> TaskClearance:
    """
    Convenience function to check task clearance.

    Use this before ANY task execution:

        clearance = check_task_clearance("Build the payment system")
        if not clearance.approved:
            raise PRDGateError(clearance.reason)
    """
    gate = get_prd_gate()
    return gate.check_clearance(task)


def require_prd_clearance(task: str) -> TaskClearance:
    """
    Check clearance and RAISE if not approved.

    This is the ENFORCING version - use in execution layer.
    """
    clearance = check_task_clearance(task)
    if not clearance.approved:
        raise PRDGateError(clearance.reason)
    return clearance


# CLI for manual operations
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="PRD Gate Management")
    parser.add_argument("command", choices=["check", "register", "approve", "pending-hitl"])
    parser.add_argument("--task", type=str, help="Task to check")
    parser.add_argument("--prd-id", type=str, help="PRD ID")
    parser.add_argument("--title", type=str, help="PRD title")
    parser.add_argument("--domains", type=str, help="Coverage domains (comma-separated)")
    args = parser.parse_args()

    gate = get_prd_gate()

    if args.command == "check":
        if not args.task:
            print("Error: --task required")
            sys.exit(1)
        clearance = gate.check_clearance(args.task)
        print(json.dumps(asdict(clearance), indent=2, default=str))

    elif args.command == "register":
        if not all([args.prd_id, args.title, args.domains]):
            print("Error: --prd-id, --title, --domains required")
            sys.exit(1)
        domains = [d.strip() for d in args.domains.split(",")]
        success = gate.register_prd(
            prd_id=args.prd_id,
            title=args.title,
            prd_type="meta",
            file_path="",
            coverage_domains=domains,
            question_count=100,
            questions_answered=100
        )
        print(f"Registration: {'success' if success else 'failed'}")

    elif args.command == "approve":
        if not args.prd_id:
            print("Error: --prd-id required")
            sys.exit(1)
        success = gate.approve_prd(args.prd_id)
        print(f"Approval: {'success' if success else 'failed'}")

    elif args.command == "pending-hitl":
        pending = gate.get_pending_hitl()
        print(f"Pending HITL approvals: {len(pending)}")
        for item in pending:
            print(f"  [{item['id']}] {item['task_description'][:100]}...")
