#!/usr/bin/env python3
"""
Turbo Transcript Engine
=========================
High-performance asynchronous YouTube transcript extraction engine.

Features:
- Async concurrency (asyncio.gather with semaphore)
- Exponential backoff retry logic
- Multi-tier fallback (youtube-transcript-api -> yt-dlp -> Supadata)
- Disk caching for instant re-reads
- Windows native paths

Usage:
    import asyncio
    from core.youtube.transcript_engine import TranscriptEngine
    
    engine = TranscriptEngine()
    results = asyncio.run(engine.extract_batch(["dQw4w9WgXcQ", "jNQXAC9IVRw"]))
    
Author: Genesis System
"""

import asyncio
import json
import logging
import os
import re
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Any

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("yt_turbo_engine")

# Windows-friendly paths
GENESIS_ROOT = Path("e:/genesis-system")
CACHE_DIR = GENESIS_ROOT / "data" / "youtube" / "transcript_cache"
CACHE_TTL_DAYS = 7

@dataclass
class TranscriptSegment:
    text: str
    start: float
    duration: float
    language: str = "en"

@dataclass
class TranscriptResult:
    video_id: str
    success: bool
    method: str = ""
    text: str = ""
    segments: List[TranscriptSegment] = field(default_factory=list)
    language: str = "en"
    duration_secs: float = 0
    extraction_time_ms: int = 0
    error: Optional[str] = None
    extracted_at: str = field(default_factory=lambda: datetime.utcnow().isoformat() + "Z")
    cached: bool = False

    @property
    def full_text(self) -> str:
        if self.text:
            return self.text
        return " ".join([seg.text for seg in self.segments])


class TranscriptEngine:
    """Async engine for high-speed YouTube transcript extraction."""

    def __init__(self, use_cache: bool = True, max_concurrent: int = 10, max_retries: int = 3):
        self.use_cache = use_cache
        self.max_concurrent = max_concurrent
        self.max_retries = max_retries
        self.semaphore = asyncio.Semaphore(max_concurrent)
        
        CACHE_DIR.mkdir(parents=True, exist_ok=True)
        
        self.supadata_api_key = os.environ.get("SUPADATA_API_KEY")
        if not self.supadata_api_key:
            secrets_file = GENESIS_ROOT / "config" / "secrets.env"
            if secrets_file.exists():
                with open(secrets_file) as f:
                    for line in f:
                        if line.startswith("SUPADATA_API_KEY="):
                            self.supadata_api_key = line.strip().split("=", 1)[1]

    def _get_cache_path(self, video_id: str) -> Path:
        return CACHE_DIR / f"{video_id}.json"

    def _read_cache(self, video_id: str) -> Optional[TranscriptResult]:
        if not self.use_cache:
            return None
            
        cache_path = self._get_cache_path(video_id)
        if not cache_path.exists():
            return None
            
        try:
            # Check TTL
            mtime = datetime.fromtimestamp(cache_path.stat().st_mtime)
            if datetime.now() - mtime > timedelta(days=CACHE_TTL_DAYS):
                logger.debug(f"Cache expired for {video_id}")
                return None
                
            with open(cache_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                
            segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])]
            
            return TranscriptResult(
                video_id=video_id,
                success=True,
                method=data.get("method", "cache"),
                text=data.get("text", ""),
                segments=segments,
                language=data.get("language", "en"),
                duration_secs=data.get("duration_secs", 0),
                extraction_time_ms=0,
                extracted_at=data.get("extracted_at", ""),
                cached=True
            )
        except Exception as e:
            logger.warning(f"Failed to read cache for {video_id}: {e}")
            return None

    def _write_cache(self, result: TranscriptResult):
        if not self.use_cache or not result.success:
            return
            
        try:
            cache_path = self._get_cache_path(result.video_id)
            data = asdict(result)
            with open(cache_path, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2)
        except Exception as e:
            logger.warning(f"Failed to write cache for {result.video_id}: {e}")

    async def _extract_tier1_api(self, video_id: str, languages: List[str]) -> TranscriptResult:
        """youtube-transcript-api (Free, Fast)"""
        start_time = time.time()
        try:
            from youtube_transcript_api import YouTubeTranscriptApi
            
            loop = asyncio.get_event_loop()
            
            # Try specified languages, fall back to any available
            try:
                transcript_data = await loop.run_in_executor(
                    None, lambda: YouTubeTranscriptApi.get_transcript(video_id, languages=languages)
                )
            except:
                try:
                    transcript_list = await loop.run_in_executor(
                        None, lambda: YouTubeTranscriptApi.list_transcripts(video_id)
                    )
                    transcript = transcript_list.find_transcript(languages) if hasattr(transcript_list, 'find_transcript') else None
                    if not transcript:
                        # Just grab the first available
                        for t in transcript_list:
                            transcript = t
                            break
                    if not transcript:
                        raise Exception("No transcripts available")
                    transcript_data = await loop.run_in_executor(None, lambda: transcript.fetch())
                except Exception as e:
                    raise e
                    
            segments = []
            for item in transcript_data:
                # Handle both object and dict formats based on lib version
                if hasattr(item, 'text'):
                    segments.append(TranscriptSegment(item.text, item.start, item.duration, languages[0]))
                else:
                    segments.append(TranscriptSegment(item.get('text',''), item.get('start',0), item.get('duration',0), languages[0]))
                    
            text = " ".join([s.text for s in segments])
            total_duration = sum([s.duration for s in segments])
            
            return TranscriptResult(
                video_id=video_id,
                success=True,
                method="youtube_transcript_api",
                text=text,
                segments=segments,
                language=languages[0],
                duration_secs=total_duration,
                extraction_time_ms=int((time.time() - start_time) * 1000)
            )
        except Exception as e:
            raise Exception(f"Tier 1 API failed: {e}")

    async def _extract_tier2_ytdlp(self, video_id: str, languages: List[str]) -> TranscriptResult:
        """yt-dlp subtitle download (Free, Slower)"""
        start_time = time.time()
        try:
            import yt_dlp
            
            download_dir = CACHE_DIR / "ytdlp_subs"
            download_dir.mkdir(exist_ok=True)
            output_template = str(download_dir / f"{video_id}.%(ext)s")
            
            ydl_opts = {
                "writesubtitles": True,
                "writeautomaticsub": True,
                "subtitleslangs": languages + ["en"],
                "skip_download": True,
                "outtmpl": output_template,
                "quiet": True,
                "no_warnings": True
            }
            
            loop = asyncio.get_event_loop()
            url = f"https://www.youtube.com/watch?v={video_id}"
            await loop.run_in_executor(None, lambda: yt_dlp.YoutubeDL(ydl_opts).download([url]))
            
            # Find the downloaded sub
            sub_file = None
            for lang in languages + ["en"]:
                for ext in [".vtt", ".srt"]:
                    candidate = download_dir / f"{video_id}.{lang}{ext}"
                    if candidate.exists():
                        sub_file = candidate
                        break
                if sub_file: break
                
            if not sub_file:
                raise Exception("yt-dlp downloaded no subtitles")
                
            # Basic parse of VTT/SRT text
            content = sub_file.read_text(encoding="utf-8")
            # Strip timestamps and html/vtt tags
            text_lines = []
            for line in content.split("\n"):
                line = line.strip()
                if not line or line.isdigit() or "-->" in line or line.startswith("WEBVTT") or line.startswith("Kind:") or line.startswith("Language:"):
                    continue
                clean_line = re.sub(r'<[^>]+>', '', line).strip()
                if clean_line and clean_line not in text_lines[-3:]: # basic dedupe
                    text_lines.append(clean_line)
                    
            text = " ".join(text_lines)
            
            return TranscriptResult(
                video_id=video_id,
                success=True,
                method="yt_dlp",
                text=text,
                language=languages[0],
                extraction_time_ms=int((time.time() - start_time) * 1000)
            )
        except Exception as e:
            raise Exception(f"Tier 2 yt-dlp failed: {e}")

    async def _extract_tier3_supadata(self, video_id: str) -> TranscriptResult:
        """Supadata.ai (Paid fallback)"""
        if not self.supadata_api_key or self.supadata_api_key.startswith("sd_"):
            raise Exception("No valid Supadata API key")
            
        start_time = time.time()
        try:
            import aiohttp
            headers = {"x-api-key": self.supadata_api_key}
            url = "https://api.supadata.ai/v1/youtube/transcript"
            params = {"videoId": video_id}
            
            async with aiohttp.ClientSession() as session:
                async with session.get(url, headers=headers, params=params, timeout=30) as resp:
                    if resp.status != 200:
                        raise Exception(f"Supadata HTTP {resp.status}")
                    data = await resp.json()
                    
            content = data.get("content", [])
            if isinstance(content, list) and content:
                text = " ".join([c.get("text", "") for c in content if isinstance(c, dict)])
            elif isinstance(data.get("transcript"), str):
                text = data["transcript"]
            else:
                raise Exception(f"Unexpected Supadata format")
                
            return TranscriptResult(
                video_id=video_id,
                success=True,
                method="supadata",
                text=text,
                extraction_time_ms=int((time.time() - start_time) * 1000)
            )
        except Exception as e:
            raise Exception(f"Tier 3 Supadata failed: {e}")

    async def extract_single(self, video_id: str, languages: List[str] = None) -> TranscriptResult:
        """Extract a single transcript with retries and fallback."""
        languages = languages or ["en", "en-US", "en-GB"]
        
        async with self.semaphore:
            # Check cache
            cached = self._read_cache(video_id)
            if cached:
                return cached
                
            last_error = ""
            start_time = time.time()
            
            # Retry loop with exponential backoff
            for attempt in range(self.max_retries):
                try:
                    # Strategy 1: Fast API
                    return await self._extract_tier1_api(video_id, languages)
                except Exception as e1:
                    logger.debug(str(e1))
                    try:
                        # Strategy 2: yt-dlp
                        return await self._extract_tier2_ytdlp(video_id, languages)
                    except Exception as e2:
                        logger.debug(str(e2))
                        try:
                            # Strategy 3: Supadata
                            return await self._extract_tier3_supadata(video_id)
                        except Exception as e3:
                            last_error = str(e3)
                            logger.debug(str(e3))
                
                # If we get here, all tiers failed on this attempt
                if attempt < self.max_retries - 1:
                    delay = 2 ** attempt
                    logger.debug(f"[{video_id}] Attempt {attempt+1} failed. Retrying in {delay}s...")
                    await asyncio.sleep(delay)
            
            return TranscriptResult(
                video_id=video_id,
                success=False,
                error=f"All methods exhausted after {self.max_retries} attempts. Last error: {last_error}",
                extraction_time_ms=int((time.time() - start_time) * 1000)
            )

    async def extract_batch(self, video_ids: List[str], languages: List[str] = None) -> Dict[str, TranscriptResult]:
        """Extract multiple transcripts concurrently."""
        tasks = [self.extract_single(vid, languages) for vid in video_ids]
        results_list = await asyncio.gather(*tasks)
        
        results_dict = {}
        for res in results_list:
            results_dict[res.video_id] = res
            if res.success and not res.cached:
                self._write_cache(res)
                
        return results_dict

if __name__ == "__main__":
    # Test smoke
    import sys
    ids = sys.argv[1:] if len(sys.argv) > 1 else ["dQw4w9WgXcQ", "jNQXAC9IVRw"]
    
    engine = TranscriptEngine(max_concurrent=5)
    r = asyncio.run(engine.extract_batch(ids))
    success = [vid for vid, res in r.items() if res.success]
    print(f"Extracted {len(success)}/{len(ids)}")
    for vid, res in r.items():
        if res.success:
            print(f"[{vid}] OK ({res.method}) - {len(res.full_text)} chars")
        else:
            print(f"[{vid}] FAIL - {res.error}")
