#!/usr/bin/env python3
"""
YouTube Transcript Extractor
=============================
Extracts transcripts for videos identified by the watch history fetcher.

Two-tier extraction:
    1. PRIMARY: youtube-transcript-api (free, no API key needed)
    2. FALLBACK: Supadata.ai API (paid, handles edge cases, AI-generated transcripts)

Stores results in:
    - PostgreSQL: yt_transcripts table (structured storage)
    - Qdrant: Semantic vector embeddings for similarity search
    - Local JSON: backup copies for offline access

Usage:
    # Extract transcripts for specific video IDs
    python youtube_transcript_extractor.py --video-ids dQw4w9WgXcQ abc123 xyz789

    # Extract all videos from today's watch history (reads from PostgreSQL)
    python youtube_transcript_extractor.py --today

    # Extract from a date range
    python youtube_transcript_extractor.py --date 2026-02-15

    # Dry run (don't store, just print)
    python youtube_transcript_extractor.py --today --dry-run

Author: Genesis System
Version: 1.0.0
"""

import argparse
import hashlib
import json
import logging
import os
import re
import sys
import time
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple

# Add Genesis paths
sys.path.insert(0, "/mnt/e/genesis-system/data/genesis-memory")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger("yt_transcript_extractor")

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

SUPADATA_API_KEY = os.environ.get(
    "SUPADATA_API_KEY",
    "sd_4b8009caa1fd18698793e2a86117b07c"
)
SUPADATA_BASE_URL = "https://api.supadata.ai/v1"

TRANSCRIPT_BACKUP_DIR = Path("/mnt/e/genesis-system/data/youtube_transcripts")
QDRANT_COLLECTION = "yt_transcripts"
EMBEDDING_MODEL = "text-embedding-3-small"  # OpenAI or use Gemini


# ---------------------------------------------------------------------------
# Database layer
# ---------------------------------------------------------------------------

def get_db_connection():
    """Get PostgreSQL connection using Elestio config."""
    import psycopg2
    from elestio_config import PostgresConfig
    return psycopg2.connect(**PostgresConfig.get_connection_params())


def ensure_transcript_schema(conn):
    """Create the transcripts table if it doesn't exist."""
    with conn.cursor() as cur:
        cur.execute("""
            CREATE TABLE IF NOT EXISTS yt_transcripts (
                id SERIAL PRIMARY KEY,
                video_id VARCHAR(20) NOT NULL,
                transcript TEXT,
                language VARCHAR(10) DEFAULT 'en',
                word_count INTEGER,
                extraction_method VARCHAR(20) DEFAULT 'youtube_api',
                extracted_topics TEXT[],
                extracted_insights JSONB,
                created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
                UNIQUE(video_id)
            );

            CREATE INDEX IF NOT EXISTS idx_yt_transcript_video
                ON yt_transcripts(video_id);
            CREATE INDEX IF NOT EXISTS idx_yt_transcript_date
                ON yt_transcripts(created_at);
        """)
        conn.commit()
    logger.info("Transcript schema verified/created.")


def get_existing_transcripts(conn, video_ids: List[str]) -> set:
    """Check which video IDs already have transcripts."""
    if not video_ids:
        return set()
    with conn.cursor() as cur:
        cur.execute(
            "SELECT video_id FROM yt_transcripts WHERE video_id = ANY(%s)",
            (video_ids,)
        )
        return {row[0] for row in cur.fetchall()}


def store_transcript(conn, record: Dict[str, Any]) -> bool:
    """Store a transcript record in PostgreSQL."""
    try:
        with conn.cursor() as cur:
            cur.execute("""
                INSERT INTO yt_transcripts
                    (video_id, transcript, language, word_count,
                     extraction_method, extracted_topics, extracted_insights)
                VALUES (%s, %s, %s, %s, %s, %s, %s)
                ON CONFLICT (video_id) DO UPDATE SET
                    transcript = EXCLUDED.transcript,
                    language = EXCLUDED.language,
                    word_count = EXCLUDED.word_count,
                    extraction_method = EXCLUDED.extraction_method,
                    extracted_topics = EXCLUDED.extracted_topics,
                    extracted_insights = EXCLUDED.extracted_insights
            """, (
                record["video_id"],
                record["transcript"],
                record.get("language", "en"),
                record.get("word_count", 0),
                record.get("extraction_method", "unknown"),
                record.get("extracted_topics", []),
                json.dumps(record.get("extracted_insights", {})),
            ))
            conn.commit()
            return True
    except Exception as e:
        logger.error(f"Failed to store transcript for {record['video_id']}: {e}")
        conn.rollback()
        return False


def get_video_ids_for_date(conn, date_str: Optional[str] = None) -> List[str]:
    """Get video IDs from watch history for a given date."""
    if date_str:
        target = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc)
    else:
        target = datetime.now(timezone.utc)

    start = target.replace(hour=0, minute=0, second=0, microsecond=0)
    end = start + timedelta(days=1)

    with conn.cursor() as cur:
        cur.execute("""
            SELECT DISTINCT video_id FROM yt_watch_history
            WHERE watched_at >= %s AND watched_at < %s
            ORDER BY watched_at
        """, (start, end))
        return [row[0] for row in cur.fetchall()]


# ---------------------------------------------------------------------------
# Tier 1: youtube-transcript-api (free)
# ---------------------------------------------------------------------------

def extract_via_youtube_api(video_id: str) -> Optional[Dict[str, Any]]:
    """
    Extract transcript using the youtube-transcript-api package.
    Free, no API key needed, works for most videos with captions.
    """
    try:
        from youtube_transcript_api import YouTubeTranscriptApi
        from youtube_transcript_api.formatters import TextFormatter
    except ImportError:
        logger.error(
            "youtube-transcript-api required. Install with:\n"
            "  pip install youtube-transcript-api"
        )
        return None

    try:
        ytt_api = YouTubeTranscriptApi()

        # Try English first, then any available language
        transcript_data = None
        language = "en"

        try:
            transcript_data = ytt_api.fetch(video_id, languages=["en", "en-US", "en-AU", "en-GB"])
        except Exception:
            try:
                # Get list of available transcripts and pick the first one
                transcript_list = ytt_api.list(video_id)
                if transcript_list:
                    # Try to find any transcript
                    for t in transcript_list:
                        try:
                            transcript_data = ytt_api.fetch(video_id, languages=[t.language_code])
                            language = t.language_code
                            break
                        except Exception:
                            continue
            except Exception:
                pass

        if not transcript_data:
            logger.warning(f"No transcript available for {video_id} via youtube-transcript-api")
            return None

        # Format as plain text
        formatter = TextFormatter()
        full_text = formatter.format_transcript(transcript_data)

        # Also build timestamped version
        segments = []
        for entry in transcript_data:
            segments.append({
                "text": entry.text if hasattr(entry, 'text') else str(entry.get("text", "")),
                "start": entry.start if hasattr(entry, 'start') else entry.get("start", 0),
                "duration": entry.duration if hasattr(entry, 'duration') else entry.get("duration", 0),
            })

        word_count = len(full_text.split())

        return {
            "video_id": video_id,
            "transcript": full_text,
            "segments": segments,
            "language": language,
            "word_count": word_count,
            "extraction_method": "youtube_transcript_api",
        }

    except Exception as e:
        logger.warning(f"youtube-transcript-api failed for {video_id}: {e}")
        return None


# ---------------------------------------------------------------------------
# Tier 2: Supadata API (paid fallback)
# ---------------------------------------------------------------------------

def extract_via_supadata(video_id: str) -> Optional[Dict[str, Any]]:
    """
    Extract transcript using Supadata.ai API.
    Paid service, handles edge cases, can generate AI transcripts.
    """
    try:
        import requests
    except ImportError:
        logger.error("requests library required: pip install requests")
        return None

    if not SUPADATA_API_KEY or SUPADATA_API_KEY.startswith("sd_"):
        # Check if it's the real key
        pass

    headers = {
        "x-api-key": SUPADATA_API_KEY,
        "Content-Type": "application/json",
    }

    params = {
        "url": f"https://www.youtube.com/watch?v={video_id}",
        "mode": "auto",  # Try native first, fallback to AI
        "text": "true",  # Get plain text
    }

    try:
        response = requests.get(
            f"{SUPADATA_BASE_URL}/transcript",
            headers=headers,
            params=params,
            timeout=60,
        )

        if response.status_code == 202:
            # Async processing - poll for result
            data = response.json()
            job_id = data.get("jobId")
            if job_id:
                return _poll_supadata_job(job_id, video_id, headers)

        if response.status_code == 206:
            logger.warning(f"Supadata: transcript unavailable for {video_id}")
            return None

        if response.status_code == 404:
            logger.warning(f"Supadata: video not found {video_id}")
            return None

        if response.status_code == 429:
            logger.warning("Supadata: rate limited, waiting...")
            time.sleep(5)
            return None

        if response.status_code != 200:
            logger.warning(f"Supadata error {response.status_code}: {response.text[:200]}")
            return None

        data = response.json()
        content = data.get("content", "")

        if not content:
            logger.warning(f"Supadata returned empty content for {video_id}")
            return None

        # content can be a string (text mode) or list (segment mode)
        if isinstance(content, list):
            full_text = " ".join(
                seg.get("text", "") for seg in content
            )
            segments = content
        else:
            full_text = content
            segments = []

        word_count = len(full_text.split())
        language = data.get("lang", "en")

        return {
            "video_id": video_id,
            "transcript": full_text,
            "segments": segments,
            "language": language,
            "word_count": word_count,
            "extraction_method": "supadata",
        }

    except Exception as e:
        logger.error(f"Supadata failed for {video_id}: {e}")
        return None


def _poll_supadata_job(
    job_id: str, video_id: str, headers: dict, max_attempts: int = 60
) -> Optional[Dict[str, Any]]:
    """Poll Supadata async job until complete."""
    import requests

    for attempt in range(max_attempts):
        time.sleep(1)
        try:
            resp = requests.get(
                f"{SUPADATA_BASE_URL}/transcript/{job_id}",
                headers=headers,
                timeout=30,
            )
            if resp.status_code == 200:
                data = resp.json()
                status = data.get("status")
                if status in ("queued", "active"):
                    continue
                if status == "failed":
                    logger.warning(f"Supadata job failed for {video_id}")
                    return None

                content = data.get("content", "")
                if isinstance(content, list):
                    full_text = " ".join(s.get("text", "") for s in content)
                else:
                    full_text = content

                return {
                    "video_id": video_id,
                    "transcript": full_text,
                    "segments": content if isinstance(content, list) else [],
                    "language": data.get("lang", "en"),
                    "word_count": len(full_text.split()),
                    "extraction_method": "supadata",
                }
            elif resp.status_code == 202:
                continue
        except Exception as e:
            logger.warning(f"Job poll error: {e}")

    logger.warning(f"Supadata job timed out for {video_id}")
    return None


# ---------------------------------------------------------------------------
# Qdrant vector embedding
# ---------------------------------------------------------------------------

def embed_transcript_to_qdrant(video_id: str, transcript: str, metadata: Dict[str, Any]):
    """
    Embed transcript chunks into Qdrant for semantic search.
    Uses Gemini embedding or falls back to simple hashing.
    """
    try:
        from qdrant_client import QdrantClient
        from qdrant_client.models import (
            VectorParams, Distance, PointStruct,
            CollectionStatus
        )
        from elestio_config import QdrantConfig
    except ImportError:
        logger.warning("Qdrant client not available. Skipping vector embedding.")
        return False

    try:
        config = QdrantConfig()
        client = QdrantClient(url=config.url, api_key=config.api_key)

        # Ensure collection exists
        collections = [c.name for c in client.get_collections().collections]
        if QDRANT_COLLECTION not in collections:
            client.create_collection(
                collection_name=QDRANT_COLLECTION,
                vectors_config=VectorParams(
                    size=768,  # Gemini embedding dimension
                    distance=Distance.COSINE,
                ),
            )
            logger.info(f"Created Qdrant collection: {QDRANT_COLLECTION}")

        # Chunk the transcript (max ~500 words per chunk)
        chunks = _chunk_text(transcript, max_words=500)

        # Generate embeddings
        vectors = _generate_embeddings(chunks)
        if not vectors:
            logger.warning(f"No embeddings generated for {video_id}. Skipping Qdrant.")
            return False

        # Upsert points
        points = []
        for i, (chunk, vector) in enumerate(zip(chunks, vectors)):
            point_id = _generate_point_id(video_id, i)
            points.append(PointStruct(
                id=point_id,
                vector=vector,
                payload={
                    "video_id": video_id,
                    "chunk_index": i,
                    "text": chunk,
                    "title": metadata.get("title", ""),
                    "channel": metadata.get("channel_name", ""),
                    "watched_at": metadata.get("watched_at", ""),
                    "source": "youtube_pipeline",
                },
            ))

        client.upsert(collection_name=QDRANT_COLLECTION, points=points)
        logger.info(f"Embedded {len(points)} chunks for {video_id} into Qdrant.")
        return True

    except Exception as e:
        logger.error(f"Qdrant embedding failed for {video_id}: {e}")
        return False


def _chunk_text(text: str, max_words: int = 500) -> List[str]:
    """Split text into chunks of approximately max_words."""
    words = text.split()
    chunks = []
    for i in range(0, len(words), max_words):
        chunk = " ".join(words[i:i + max_words])
        if chunk.strip():
            chunks.append(chunk)
    return chunks if chunks else [text[:2000]]  # At least one chunk


def _generate_embeddings(chunks: List[str]) -> List[List[float]]:
    """
    Generate vector embeddings for text chunks.
    Tries Gemini embedding API first, falls back to simple hash-based vectors.
    """
    try:
        import google.generativeai as genai
        api_key = os.environ.get("GEMINI_API_KEY", "AIzaSyCT_rx0NusUJWoqtT7uxHAKEfHo129SJb8")
        genai.configure(api_key=api_key)

        vectors = []
        for chunk in chunks:
            result = genai.embed_content(
                model="models/text-embedding-004",
                content=chunk,
                task_type="retrieval_document",
            )
            vectors.append(result["embedding"])
        return vectors

    except Exception as e:
        logger.warning(f"Gemini embedding failed: {e}. Using fallback hash vectors.")
        # Fallback: deterministic hash-based vectors (not ideal but functional)
        return [_hash_vector(chunk, dim=768) for chunk in chunks]


def _hash_vector(text: str, dim: int = 768) -> List[float]:
    """Generate a deterministic pseudo-vector from text hash. Fallback only."""
    h = hashlib.sha512(text.encode()).digest()
    # Expand hash to fill vector dimension
    expanded = h * ((dim * 4 // len(h)) + 1)
    import struct
    values = struct.unpack(f">{dim}f", expanded[:dim * 4])
    # Normalize
    norm = sum(v * v for v in values) ** 0.5
    if norm == 0:
        return [0.0] * dim
    return [v / norm for v in values]


def _generate_point_id(video_id: str, chunk_index: int) -> int:
    """Generate a deterministic integer ID for a Qdrant point."""
    raw = f"{video_id}_{chunk_index}".encode()
    h = hashlib.md5(raw).hexdigest()
    return int(h[:16], 16) & 0x7FFFFFFFFFFFFFFF  # Positive 64-bit int


# ---------------------------------------------------------------------------
# Local backup
# ---------------------------------------------------------------------------

def save_transcript_backup(record: Dict[str, Any]):
    """Save a JSON backup of the transcript to local filesystem."""
    TRANSCRIPT_BACKUP_DIR.mkdir(parents=True, exist_ok=True)
    filepath = TRANSCRIPT_BACKUP_DIR / f"{record['video_id']}.json"
    with open(filepath, "w", encoding="utf-8") as f:
        json.dump(record, f, indent=2, ensure_ascii=False)
    logger.debug(f"Backup saved: {filepath}")


# ---------------------------------------------------------------------------
# Main extraction pipeline
# ---------------------------------------------------------------------------

def extract_transcripts(
    video_ids: List[str],
    skip_existing: bool = True,
    embed_vectors: bool = True,
    dry_run: bool = False,
) -> Dict[str, Any]:
    """
    Extract transcripts for a list of video IDs.

    Returns:
        Summary dict with counts and results.
    """
    conn = get_db_connection()
    try:
        ensure_transcript_schema(conn)

        # Skip videos that already have transcripts
        if skip_existing:
            existing = get_existing_transcripts(conn, video_ids)
            pending = [vid for vid in video_ids if vid not in existing]
            logger.info(
                f"Videos: {len(video_ids)} total, {len(existing)} already have transcripts, "
                f"{len(pending)} to extract."
            )
        else:
            pending = video_ids

        # Also load metadata from watch history
        metadata_map = _load_video_metadata(conn, video_ids)

        results = {
            "total": len(video_ids),
            "skipped_existing": len(video_ids) - len(pending),
            "extracted": 0,
            "failed": 0,
            "methods": {"youtube_transcript_api": 0, "supadata": 0},
            "video_results": [],
        }

        for i, video_id in enumerate(pending, 1):
            logger.info(f"[{i}/{len(pending)}] Extracting transcript for {video_id}...")

            # Tier 1: youtube-transcript-api (free)
            record = extract_via_youtube_api(video_id)

            # Tier 2: Supadata (paid fallback)
            if record is None:
                logger.info(f"  Falling back to Supadata for {video_id}...")
                record = extract_via_supadata(video_id)

            if record is None:
                logger.warning(f"  FAILED: No transcript available for {video_id}")
                results["failed"] += 1
                results["video_results"].append({
                    "video_id": video_id,
                    "status": "failed",
                    "reason": "no_transcript_available",
                })
                continue

            results["extracted"] += 1
            results["methods"][record["extraction_method"]] = (
                results["methods"].get(record["extraction_method"], 0) + 1
            )

            if dry_run:
                logger.info(
                    f"  OK: {record['word_count']} words via {record['extraction_method']}"
                )
                results["video_results"].append({
                    "video_id": video_id,
                    "status": "extracted_dry_run",
                    "word_count": record["word_count"],
                    "method": record["extraction_method"],
                })
                continue

            # Store in PostgreSQL
            store_transcript(conn, record)

            # Save local backup
            save_transcript_backup(record)

            # Embed in Qdrant
            if embed_vectors:
                meta = metadata_map.get(video_id, {})
                embed_transcript_to_qdrant(video_id, record["transcript"], meta)

            results["video_results"].append({
                "video_id": video_id,
                "status": "success",
                "word_count": record["word_count"],
                "method": record["extraction_method"],
            })

            # Rate limiting between requests
            if i < len(pending):
                time.sleep(0.5)

        return results

    finally:
        conn.close()


def _load_video_metadata(conn, video_ids: List[str]) -> Dict[str, Dict]:
    """Load video metadata from watch history table."""
    if not video_ids:
        return {}

    metadata = {}
    try:
        with conn.cursor() as cur:
            cur.execute("""
                SELECT DISTINCT ON (video_id)
                    video_id, title, channel_name, channel_id, watched_at
                FROM yt_watch_history
                WHERE video_id = ANY(%s)
                ORDER BY video_id, watched_at DESC
            """, (video_ids,))
            for row in cur.fetchall():
                metadata[row[0]] = {
                    "title": row[1],
                    "channel_name": row[2],
                    "channel_id": row[3],
                    "watched_at": row[4].isoformat() if row[4] else "",
                }
    except Exception as e:
        logger.warning(f"Failed to load video metadata: {e}")

    return metadata


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(
        description="Extract YouTube transcripts for Genesis memory pipeline"
    )
    parser.add_argument(
        "--video-ids",
        nargs="+",
        help="Specific video IDs to extract"
    )
    parser.add_argument(
        "--today",
        action="store_true",
        help="Extract transcripts for all videos from today's watch history"
    )
    parser.add_argument(
        "--date",
        help="Extract transcripts for videos watched on a specific date (YYYY-MM-DD)"
    )
    parser.add_argument(
        "--skip-existing",
        action="store_true",
        default=True,
        help="Skip videos that already have transcripts (default: True)"
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Re-extract even if transcript exists"
    )
    parser.add_argument(
        "--no-vectors",
        action="store_true",
        help="Skip Qdrant vector embedding"
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Extract but don't store anything"
    )

    args = parser.parse_args()

    # Determine video IDs
    video_ids = []
    if args.video_ids:
        video_ids = args.video_ids
    elif args.today or args.date:
        conn = get_db_connection()
        try:
            video_ids = get_video_ids_for_date(conn, args.date)
        finally:
            conn.close()
    else:
        parser.error("Must specify --video-ids, --today, or --date")

    if not video_ids:
        logger.warning("No video IDs to process.")
        print(json.dumps({"total": 0, "extracted": 0, "failed": 0}))
        return

    logger.info(f"Processing {len(video_ids)} videos...")

    results = extract_transcripts(
        video_ids=video_ids,
        skip_existing=not args.force,
        embed_vectors=not args.no_vectors,
        dry_run=args.dry_run,
    )

    print(json.dumps(results, indent=2, default=str))


if __name__ == "__main__":
    main()
