"""
batch_runner.py — Process leads in parallel with async concurrency control.

Usage:
    python -m scripts.heatmap_audit.batch_runner \\
        --csv data/LEADS/leads.csv \\
        --output data/heatmap_reports/ \\
        --concurrency 10

Supports the Zoho CSV format with columns:
    Last Name, Company, Email, Phone, Website, Street, City, State, ...

Also supports a simpler format with columns:
    business_name, email, phone, website_url, industry
"""

import argparse
import asyncio
import csv
import json
import logging
import os
import re
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Optional

import aiohttp

# Add project root to path for imports
PROJECT_ROOT = str(Path(__file__).resolve().parents[2])
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from scripts.heatmap_audit.scraper import scrape_all
from scripts.heatmap_audit.analyzer import analyze, calculate_overall_score
from scripts.heatmap_audit.pdf_generator import generate_pdf

logger = logging.getLogger(__name__)


def _sanitize_filename(name: str) -> str:
    """Convert business name to safe filename."""
    safe = re.sub(r'[^\w\s-]', '', name.strip())
    safe = re.sub(r'[\s]+', '_', safe)
    safe = safe[:80]  # Limit length
    return safe.lower()


def _parse_csv_row(row: dict) -> Optional[dict]:
    """
    Parse a CSV row into a standardized lead dict.
    Supports both Zoho format and simple format.
    """
    # Normalize keys to lowercase
    normalized = {k.strip().lower().replace(" ", "_"): v.strip() for k, v in row.items() if v}

    # Try Zoho format first
    business_name = (
        normalized.get("company")
        or normalized.get("business_name")
        or normalized.get("last_name")
        or normalized.get("name")
        or ""
    )

    website = (
        normalized.get("website")
        or normalized.get("website_url")
        or normalized.get("url")
        or ""
    )

    email = normalized.get("email", "")
    phone = normalized.get("phone", "")
    industry = normalized.get("industry", "")

    city = normalized.get("city", "")
    state = normalized.get("state", "")
    location = f"{city}, {state}".strip(", ") if (city or state) else "Australia"

    if not business_name:
        return None

    return {
        "business_name": business_name,
        "website": website,
        "email": email,
        "phone": phone,
        "industry": industry,
        "location": location or "Australia",
    }


async def process_lead(
    lead: dict,
    session: aiohttp.ClientSession,
    output_dir: str,
    phone_cta: str = "",
    semaphore: Optional[asyncio.Semaphore] = None,
) -> dict:
    """
    Process a single lead: scrape, analyze, generate PDF.
    Returns result dict with status.
    """
    sem = semaphore or asyncio.Semaphore(1)
    business_name = lead.get("business_name", "Unknown")
    website = lead.get("website", "")

    async with sem:
        result = {
            "business_name": business_name,
            "website": website,
            "email": lead.get("email", ""),
            "status": "pending",
            "overall_score": None,
            "pdf_path": None,
            "error": None,
        }

        try:
            # Skip leads without websites — generate a minimal report
            if not website:
                # Create a report showing "no website found" as the main issue
                scraped = {
                    "business_name": business_name,
                    "url": "",
                    "location": lead.get("location", "Australia"),
                    "pagespeed": {},
                    "brave": {},
                    "website": {"is_reachable": False, "has_ssl": False, "status_code": 0},
                }

                # Still do Brave search for online presence
                try:
                    from scripts.heatmap_audit.scraper import fetch_brave_search
                    brave_data = await fetch_brave_search(
                        session, business_name, lead.get("location", "Australia")
                    )
                    if brave_data:
                        scraped["brave"] = brave_data
                except Exception:
                    pass

            else:
                # Full scrape
                scraped = await scrape_all(
                    url=website,
                    business_name=business_name,
                    location=lead.get("location", "Australia"),
                    session=session,
                )

            # Analyze
            scores, recommendations = analyze(scraped)
            overall_score = scores.get("overall", {}).get("score", 0)

            # Generate PDF
            safe_name = _sanitize_filename(business_name)
            pdf_filename = f"{safe_name}_ai_score.pdf"
            pdf_path = os.path.join(output_dir, pdf_filename)

            generate_pdf(
                business_name=business_name,
                scores=scores,
                recommendations=recommendations,
                output_path=pdf_path,
                phone_cta=phone_cta,
                website_url=website,
            )

            result["status"] = "success"
            result["overall_score"] = overall_score
            result["pdf_path"] = pdf_path

        except Exception as e:
            result["status"] = "error"
            result["error"] = str(e)
            logger.error("Error processing %s: %s", business_name, e)

        return result


async def run_batch(
    csv_path: str,
    output_dir: str,
    concurrency: int = 10,
    phone_cta: str = "",
    limit: Optional[int] = None,
) -> dict:
    """
    Process all leads from CSV in parallel.

    Args:
        csv_path: Path to leads CSV file
        output_dir: Directory for output PDFs
        concurrency: Max concurrent scraping tasks
        phone_cta: Phone number for CTA in PDFs
        limit: Max number of leads to process (None = all)

    Returns:
        Summary dict with counts and results.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Parse CSV
    leads = []
    with open(csv_path, "r", encoding="utf-8-sig") as f:
        reader = csv.DictReader(f)
        for row in reader:
            parsed = _parse_csv_row(row)
            if parsed:
                leads.append(parsed)

    if limit:
        leads = leads[:limit]

    total = len(leads)
    logger.info("Loaded %d leads from %s", total, csv_path)

    if total == 0:
        return {"total": 0, "success": 0, "error": 0, "skipped": 0, "results": []}

    # Set up async session and semaphore
    semaphore = asyncio.Semaphore(concurrency)
    connector = aiohttp.TCPConnector(ssl=False, limit=concurrency * 2)
    timeout = aiohttp.ClientTimeout(total=60)

    results = []
    success_count = 0
    error_count = 0
    start_time = time.time()

    error_log_path = os.path.join(output_dir, "errors.log")
    summary_path = os.path.join(output_dir, "batch_summary.json")

    async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
        # Process in batches to avoid overwhelming resources
        batch_size = concurrency * 3  # Queue up 3x concurrency
        for batch_start in range(0, total, batch_size):
            batch_end = min(batch_start + batch_size, total)
            batch_leads = leads[batch_start:batch_end]

            tasks = [
                process_lead(
                    lead=lead,
                    session=session,
                    output_dir=output_dir,
                    phone_cta=phone_cta,
                    semaphore=semaphore,
                )
                for lead in batch_leads
            ]

            batch_results = await asyncio.gather(*tasks, return_exceptions=True)

            for r in batch_results:
                if isinstance(r, Exception):
                    error_count += 1
                    results.append({
                        "status": "error",
                        "error": str(r),
                    })
                elif r.get("status") == "success":
                    success_count += 1
                    results.append(r)
                else:
                    error_count += 1
                    results.append(r)

            processed = batch_end
            elapsed = time.time() - start_time
            rate = processed / elapsed if elapsed > 0 else 0

            print(
                f"  [{processed}/{total}] "
                f"{success_count} success, {error_count} errors "
                f"({rate:.1f} leads/sec)",
                flush=True,
            )

    elapsed_total = time.time() - start_time

    # Write error log
    errors = [r for r in results if r.get("status") == "error"]
    if errors:
        with open(error_log_path, "w") as f:
            for err in errors:
                f.write(
                    f"{err.get('business_name', 'unknown')}: {err.get('error', 'unknown error')}\n"
                )
        logger.info("Error log written to %s", error_log_path)

    # Write summary JSON
    summary = {
        "timestamp": datetime.now().isoformat(),
        "csv_path": csv_path,
        "output_dir": output_dir,
        "total_leads": total,
        "success": success_count,
        "errors": error_count,
        "elapsed_seconds": round(elapsed_total, 1),
        "rate_per_second": round(total / elapsed_total, 2) if elapsed_total > 0 else 0,
        "score_distribution": _score_distribution(results),
        "results": [
            {
                "business_name": r.get("business_name"),
                "overall_score": r.get("overall_score"),
                "pdf_path": r.get("pdf_path"),
                "status": r.get("status"),
            }
            for r in results
        ],
    }

    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)
    logger.info("Batch summary written to %s", summary_path)

    print(f"\n{'='*60}")
    print(f"  BATCH COMPLETE")
    print(f"  Total: {total} | Success: {success_count} | Errors: {error_count}")
    print(f"  Time: {elapsed_total:.1f}s | Rate: {summary['rate_per_second']} leads/sec")
    print(f"  PDFs: {output_dir}")
    print(f"  Summary: {summary_path}")
    if errors:
        print(f"  Errors: {error_log_path}")
    print(f"{'='*60}")

    return summary


def _score_distribution(results: list) -> dict:
    """Calculate score distribution for summary."""
    dist = {"critical_0_20": 0, "poor_21_40": 0, "average_41_60": 0, "good_61_80": 0, "excellent_81_100": 0}
    for r in results:
        score = r.get("overall_score")
        if score is None:
            continue
        if score <= 20:
            dist["critical_0_20"] += 1
        elif score <= 40:
            dist["poor_21_40"] += 1
        elif score <= 60:
            dist["average_41_60"] += 1
        elif score <= 80:
            dist["good_61_80"] += 1
        else:
            dist["excellent_81_100"] += 1
    return dist


def main():
    parser = argparse.ArgumentParser(
        description="Heatmap Audit Generator — Batch process leads into AI Score PDFs"
    )
    parser.add_argument(
        "--csv", required=True,
        help="Path to leads CSV file"
    )
    parser.add_argument(
        "--output", default="data/heatmap_reports/",
        help="Output directory for PDFs (default: data/heatmap_reports/)"
    )
    parser.add_argument(
        "--concurrency", type=int, default=10,
        help="Max concurrent scraping tasks (default: 10)"
    )
    parser.add_argument(
        "--phone", default="",
        help="Phone number for CTA in PDFs"
    )
    parser.add_argument(
        "--limit", type=int, default=None,
        help="Max number of leads to process (default: all)"
    )
    parser.add_argument(
        "--verbose", "-v", action="store_true",
        help="Enable verbose logging"
    )

    args = parser.parse_args()

    # Setup logging
    level = logging.DEBUG if args.verbose else logging.INFO
    logging.basicConfig(
        level=level,
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
        datefmt="%H:%M:%S",
    )

    if not os.path.exists(args.csv):
        print(f"ERROR: CSV file not found: {args.csv}")
        sys.exit(1)

    print(f"\nHeatmap Audit Generator v1.0")
    print(f"Powered by Sunaiva Digital")
    print(f"{'='*60}")
    print(f"  CSV: {args.csv}")
    print(f"  Output: {args.output}")
    print(f"  Concurrency: {args.concurrency}")
    if args.phone:
        print(f"  CTA Phone: {args.phone}")
    if args.limit:
        print(f"  Limit: {args.limit} leads")
    print(f"{'='*60}\n")

    asyncio.run(
        run_batch(
            csv_path=args.csv,
            output_dir=args.output,
            concurrency=args.concurrency,
            phone_cta=args.phone,
            limit=args.limit,
        )
    )


if __name__ == "__main__":
    main()
