#!/usr/bin/env python3
"""
Ingest ALL Knowledge Graph entities and axioms into SubAIVA KnowledgeStore.
Batches of 20 documents per POST to swarm/ingest endpoint.
"""

import json
import os
import sys
import time
import requests
from pathlib import Path

ENDPOINT = "https://subaiva.kinan-ae7.workers.dev/api/swarm/ingest"
STATS_ENDPOINT = "https://subaiva.kinan-ae7.workers.dev/api/swarm/stats"
TOKEN = "dev-token-genesis-2026"
HEADERS = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {TOKEN}"
}
BATCH_SIZE = 20
CONCURRENCY = 50
# Max content size per document to avoid timeouts (100KB)
MAX_CONTENT_SIZE = 100_000

KG_ROOT = Path("/mnt/e/genesis-system/KNOWLEDGE_GRAPH")


def collect_files():
    """Collect all files to ingest."""
    files = []

    # Entity JSONL files
    entities_dir = KG_ROOT / "entities"
    for f in sorted(entities_dir.glob("*.jsonl")):
        files.append(("entity", f))

    # Axiom JSONL files
    axioms_dir = KG_ROOT / "axioms"
    for f in sorted(axioms_dir.glob("*.jsonl")):
        files.append(("axiom", f))

    # MD files everywhere in KG
    for f in sorted(KG_ROOT.rglob("*.md")):
        files.append(("markdown", f))

    return files


def file_to_document(category_label: str, filepath: Path) -> list[dict]:
    """Convert a file to one or more document dicts for the API."""
    content = filepath.read_text(encoding="utf-8", errors="replace")
    rel_path = str(filepath.relative_to(KG_ROOT))

    # Determine source type
    if filepath.suffix == ".jsonl":
        source_type = "json"
    else:
        source_type = "markdown"

    # Determine category
    if "axiom" in str(filepath):
        category = "axiom"
    elif "entity" in str(filepath) or "entities" in str(filepath):
        category = "entity"
    elif "research" in str(filepath):
        category = "research"
    elif "gap" in str(filepath):
        category = "gap"
    else:
        category = "genesis"

    # Build title from filename
    title = filepath.stem.replace("_", " ").title()

    # If content is too large, split into chunks
    if len(content) > MAX_CONTENT_SIZE:
        docs = []
        chunks = []
        current_chunk = ""
        for line in content.splitlines(keepends=True):
            if len(current_chunk) + len(line) > MAX_CONTENT_SIZE:
                chunks.append(current_chunk)
                current_chunk = line
            else:
                current_chunk += line
        if current_chunk:
            chunks.append(current_chunk)

        for i, chunk in enumerate(chunks):
            docs.append({
                "source": f"KNOWLEDGE_GRAPH/{rel_path}",
                "sourceType": source_type,
                "title": f"{title} (Part {i+1}/{len(chunks)})",
                "content": chunk,
                "category": category
            })
        return docs

    return [{
        "source": f"KNOWLEDGE_GRAPH/{rel_path}",
        "sourceType": source_type,
        "title": title,
        "content": content,
        "category": category
    }]


def post_batch(batch: list[dict], batch_num: int, total_batches: int) -> dict:
    """POST a batch of documents to the ingest endpoint."""
    payload = {
        "documents": batch,
        "concurrency": CONCURRENCY
    }

    try:
        resp = requests.post(ENDPOINT, headers=HEADERS, json=payload, timeout=300)
        resp.raise_for_status()
        result = resp.json()
        succeeded = result.get("succeeded", 0)
        failed = result.get("failed", 0)
        tokens = result.get("totalTokens", 0)
        print(f"  Batch {batch_num}/{total_batches}: {succeeded} ok, {failed} fail, {tokens} tokens")
        return result
    except requests.exceptions.Timeout:
        print(f"  Batch {batch_num}/{total_batches}: TIMEOUT (300s)")
        return {"succeeded": 0, "failed": len(batch), "totalTokens": 0, "error": "timeout"}
    except Exception as e:
        print(f"  Batch {batch_num}/{total_batches}: ERROR - {e}")
        return {"succeeded": 0, "failed": len(batch), "totalTokens": 0, "error": str(e)}


def get_stats():
    """GET current swarm stats."""
    try:
        resp = requests.get(STATS_ENDPOINT, headers=HEADERS, timeout=30)
        resp.raise_for_status()
        return resp.json()
    except Exception as e:
        print(f"Stats error: {e}")
        return None


def main():
    print("=" * 60)
    print("GENESIS KG -> SubAIVA KnowledgeStore Ingestion")
    print("=" * 60)

    # Collect all files
    files = collect_files()
    print(f"\nFiles found: {len(files)}")
    print(f"  Entity JSONL: {sum(1 for c, f in files if c == 'entity')}")
    print(f"  Axiom JSONL:  {sum(1 for c, f in files if c == 'axiom')}")
    print(f"  Markdown:     {sum(1 for c, f in files if c == 'markdown')}")

    # Convert all files to documents (handles chunking of large files)
    all_docs = []
    for cat, filepath in files:
        try:
            docs = file_to_document(cat, filepath)
            all_docs.extend(docs)
        except Exception as e:
            print(f"  ERROR reading {filepath.name}: {e}")

    print(f"\nTotal documents (after chunking): {len(all_docs)}")
    total_content_size = sum(len(d["content"]) for d in all_docs)
    print(f"Total content size: {total_content_size / 1024 / 1024:.2f} MB")

    # Batch into groups of BATCH_SIZE
    batches = []
    for i in range(0, len(all_docs), BATCH_SIZE):
        batches.append(all_docs[i:i + BATCH_SIZE])

    print(f"Batches: {len(batches)} (batch size: {BATCH_SIZE})")
    print()

    # Execute batches
    total_succeeded = 0
    total_failed = 0
    total_tokens = 0
    failed_docs = []

    start_time = time.time()

    for i, batch in enumerate(batches, 1):
        batch_titles = [d["title"][:50] for d in batch[:3]]
        print(f"Batch {i}/{len(batches)}: {len(batch)} docs [{', '.join(batch_titles)}...]")

        result = post_batch(batch, i, len(batches))
        total_succeeded += result.get("succeeded", 0)
        total_failed += result.get("failed", 0)
        total_tokens += result.get("totalTokens", 0)

        # Track individual failures
        if "results" in result:
            for r in result["results"]:
                if not r.get("success", False):
                    failed_docs.append(r.get("id", "unknown"))

        # Small delay between batches to avoid overwhelming the worker
        if i < len(batches):
            time.sleep(1)

    elapsed = time.time() - start_time

    print()
    print("=" * 60)
    print("INGESTION COMPLETE")
    print("=" * 60)
    print(f"Time elapsed:    {elapsed:.1f}s")
    print(f"Total documents: {len(all_docs)}")
    print(f"Succeeded:       {total_succeeded}")
    print(f"Failed:          {total_failed}")
    print(f"Total tokens:    {total_tokens}")
    if failed_docs:
        print(f"Failed doc IDs:  {failed_docs[:10]}")

    # Get final stats
    print()
    print("Fetching final stats...")
    stats = get_stats()
    if stats:
        print(json.dumps(stats, indent=2))

    return total_failed == 0


if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)
