#!/usr/bin/env python3
"""
RLM OVERNIGHT WORKER — Runs until 4am AEST (18:00 UTC)
Processes session logs -> extracts learnings -> writes KG entities -> generates preference pairs
"""
import os
import sys
import json
import time
import datetime
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

# Load env
try:
    from dotenv import load_dotenv
    load_dotenv(Path(__file__).parent.parent / '.env')
except ImportError:
    pass

import google.generativeai as genai
genai.configure(api_key=os.getenv('GEMINI_API_KEY'))

GENESIS_ROOT = Path("E:/genesis-system")
KG_ENTITIES = GENESIS_ROOT / "KNOWLEDGE_GRAPH" / "entities"
KG_AXIOMS = GENESIS_ROOT / "KNOWLEDGE_GRAPH" / "axioms"
SESSION_LOGS = GENESIS_ROOT / "data" / "observability"
RESULTS_FILE = GENESIS_ROOT / "OVERNIGHT_RESULTS.md"
OVERNIGHT_LOGS = GENESIS_ROOT / "data" / "overnight_logs"

OVERNIGHT_LOGS.mkdir(parents=True, exist_ok=True)

model = genai.GenerativeModel('gemini-2.0-flash')

# Track last-processed line count to avoid reprocessing
_last_events_count = 0


def get_cutoff_time():
    """4am AEST = 18:00 UTC next day if before 18:00 UTC, else 18:00 UTC today."""
    now = datetime.datetime.utcnow()
    today_cutoff = now.replace(hour=18, minute=0, second=0, microsecond=0)
    if now < today_cutoff:
        return today_cutoff
    else:
        return today_cutoff + datetime.timedelta(days=1)


def extract_learnings(content: str, source: str) -> dict:
    """Use Gemini to extract RLM learnings from content."""
    prompt = f"""Analyze this Genesis system content and extract key learnings.

Content from {source}:
{content[:3000]}

Return a JSON object with:
{{
  "surprise_score": 0.0-1.0 (how surprising/novel is this?),
  "key_learning": "one sentence summary of the key learning",
  "axiom": "hardwired rule derived from this content",
  "preference_correct": "what should have happened or the right approach",
  "preference_incorrect": "what went wrong (if anything, else 'N/A')",
  "entity_type": "failure|success|insight|pattern|system_state"
}}

Return ONLY the JSON, no other text."""

    try:
        resp = model.generate_content(prompt)
        text = resp.text.strip()
        # Strip markdown code fences if present
        if text.startswith('```'):
            text = text.split('\n', 1)[1] if '\n' in text else text
            text = text.rsplit('```', 1)[0] if '```' in text else text
        text = text.strip().lstrip('json').strip()
        return json.loads(text)
    except Exception as e:
        print(f"  [extract_learnings] Error: {e}")
        return None


def write_kg_entity(learning: dict, source: str):
    """Write learning to KG axioms file."""
    entity = {
        "id": f"rlm_{datetime.datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{source[:20].replace('/', '_').replace(' ', '_')}",
        "date": datetime.datetime.utcnow().isoformat(),
        "source": source,
        "type": learning.get("entity_type", "insight"),
        "surprise_score": learning.get("surprise_score", 0.5),
        "key_learning": learning.get("key_learning", ""),
        "axiom": learning.get("axiom", ""),
        "preference_correct": learning.get("preference_correct", ""),
        "preference_incorrect": learning.get("preference_incorrect", "N/A"),
    }

    # Write to genesis_evolution_learnings.jsonl
    axioms_file = KG_AXIOMS / "genesis_evolution_learnings.jsonl"
    with open(axioms_file, 'a', encoding='utf-8') as f:
        f.write(json.dumps(entity) + '\n')

    print(f"  [KG] Entity written: {entity['id']} (surprise={entity['surprise_score']:.2f})")
    return entity


def process_session_logs() -> int:
    """Process new session log events since last cycle."""
    global _last_events_count
    processed = 0
    events_file = SESSION_LOGS / "events.jsonl"

    if not events_file.exists():
        print("  [events] events.jsonl not found, skipping")
        return 0

    with open(events_file, 'r', encoding='utf-8', errors='replace') as f:
        lines = f.readlines()

    total = len(lines)
    new_lines = lines[_last_events_count:]
    _last_events_count = total

    if not new_lines:
        print(f"  [events] No new events (total: {total})")
        return 0

    print(f"  [events] {len(new_lines)} new events (total: {total})")

    # Parse events and build a content block
    events = []
    for line in new_lines[-200:]:  # Cap at 200 most recent new events
        try:
            events.append(json.loads(line.strip()))
        except Exception:
            pass

    if not events:
        return 0

    content = json.dumps(events, indent=2)
    learning = extract_learnings(content, f"session_events ({len(events)} events)")

    if learning and learning.get("surprise_score", 0) >= 0.3:
        write_kg_entity(learning, "session_events")
        processed += 1
    else:
        score = learning.get("surprise_score", 0) if learning else "extract_failed"
        print(f"  [events] Learning skipped (surprise_score={score})")

    return processed


def process_sessions_log() -> int:
    """Process sessions.jsonl for high-level session summaries."""
    processed = 0
    sessions_file = SESSION_LOGS / "sessions.jsonl"

    if not sessions_file.exists():
        return 0

    with open(sessions_file, 'r', encoding='utf-8', errors='replace') as f:
        lines = f.readlines()

    # Process last 5 sessions only
    recent = lines[-5:]
    sessions = []
    for line in recent:
        try:
            sessions.append(json.loads(line.strip()))
        except Exception:
            pass

    if not sessions:
        return 0

    content = json.dumps(sessions, indent=2)
    learning = extract_learnings(content, "session_summaries")

    if learning and learning.get("surprise_score", 0) >= 0.4:
        write_kg_entity(learning, "session_summaries")
        processed += 1

    return processed


def write_status_report(results: list):
    """Write overnight results to OVERNIGHT_RESULTS.md."""
    total_entities = sum(r.get('entities', 0) for r in results)
    total_cycles = len(results)
    errors = [r for r in results if 'error' in r]

    with open(RESULTS_FILE, 'w', encoding='utf-8') as f:
        f.write(f"# OVERNIGHT SPRINT RESULTS\n")
        f.write(f"Generated: {datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M UTC')} (approx {(datetime.datetime.utcnow().hour + 10) % 24:02d}:00 AEST)\n\n")
        f.write("## RLM Processing Summary\n")
        f.write(f"- Total cycles: {total_cycles}\n")
        f.write(f"- KG entities written: {total_entities}\n")
        f.write(f"- Errors encountered: {len(errors)}\n")
        f.write(f"- Target cutoff: 4am AEST (18:00 UTC)\n\n")
        f.write("## Cycle Log\n")
        for r in results[-30:]:
            cycle_num = r.get('cycle', '?')
            t = r.get('time', '')[:19]
            entities = r.get('entities', 0)
            err = f" ERROR: {r['error'][:60]}" if 'error' in r else ""
            f.write(f"- Cycle {cycle_num} | {t} UTC | {entities} entities{err}\n")

    print(f"  [report] OVERNIGHT_RESULTS.md updated ({total_cycles} cycles, {total_entities} entities total)")


def main():
    print("=" * 60)
    print("RLM OVERNIGHT WORKER STARTED")
    print(f"Time: {datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')}")
    print(f"Will run until 4am AEST (18:00 UTC)")
    print("=" * 60)

    cutoff = get_cutoff_time()
    print(f"Cutoff time: {cutoff.strftime('%Y-%m-%d %H:%M UTC')}")
    print(f"GEMINI_API_KEY: {'SET' if os.getenv('GEMINI_API_KEY') else 'MISSING'}")
    print()

    if not os.getenv('GEMINI_API_KEY'):
        print("ERROR: GEMINI_API_KEY not set. Cannot proceed.")
        print("Set via: export GEMINI_API_KEY=your_key")
        sys.exit(1)

    results = []
    cycle = 0

    while datetime.datetime.utcnow() < cutoff:
        cycle += 1
        now_str = datetime.datetime.utcnow().strftime('%H:%M UTC')
        remaining = cutoff - datetime.datetime.utcnow()
        hours_left = remaining.total_seconds() / 3600
        print(f"\n{'='*50}")
        print(f"RLM Cycle {cycle} | {now_str} | {hours_left:.1f}h remaining")
        print(f"{'='*50}")

        cycle_result = {
            "cycle": cycle,
            "time": datetime.datetime.utcnow().isoformat(),
            "entities": 0
        }

        try:
            # Process session event log
            n1 = process_session_logs()

            # Every 6th cycle, also process sessions summary
            n2 = 0
            if cycle % 6 == 0:
                n2 = process_sessions_log()

            total = n1 + n2
            cycle_result["entities"] = total
            print(f"  [cycle] Total entities this cycle: {total}")

            results.append(cycle_result)
            write_status_report(results)

        except Exception as e:
            print(f"  [ERROR] {e}")
            import traceback
            traceback.print_exc()
            cycle_result["error"] = str(e)
            results.append(cycle_result)

        # Sleep 5 minutes between cycles
        sleep_until = datetime.datetime.utcnow() + datetime.timedelta(minutes=5)
        if sleep_until >= cutoff:
            print("  [sleep] Near cutoff — exiting now")
            break

        print(f"  [sleep] Sleeping 5 min... Next cycle at {sleep_until.strftime('%H:%M UTC')}")
        time.sleep(300)

    print()
    print("=" * 60)
    print("4am AEST reached — RLM worker shutting down")
    write_status_report(results)
    print(f"Final: {sum(r.get('entities', 0) for r in results)} KG entities written over {len(results)} cycles")
    print("Results written to OVERNIGHT_RESULTS.md")
    print("=" * 60)


if __name__ == "__main__":
    main()
