#!/usr/bin/env python3
"""
RLM Bloodstream - Memory Digestion Cron
========================================
Deduplicates KG entities, generates bloodstream index, and maintains memory health.

Runs as:
- Cron job (daily maintenance)
- Manual trigger (python memory_digestion_cron.py)
- Post-session cleanup

Functions:
1. Deduplicate entities across all JSONL files
2. Generate bloodstream index (entity catalog + stats)
3. Archive old/stale entities
4. Detect orphaned entities (no relationships)
5. Generate health report

Output:
- data/bloodstream_index.json (full catalog + metadata)
- KNOWLEDGE_GRAPH/indexes/bloodstream_index.json (backup)
- logs/bloodstream_digestion_{date}.log

Usage:
    python memory_digestion_cron.py
    python memory_digestion_cron.py --deep-clean
    python memory_digestion_cron.py --dry-run
"""

import os
import sys
import json
import logging
from pathlib import Path
from typing import List, Dict, Any, Set
from datetime import datetime, timedelta
from collections import defaultdict, Counter
import argparse
import shutil

# =============================================================================
# Configuration
# =============================================================================

GENESIS_ROOT = Path(__file__).parent.parent.parent
KG_DIR = GENESIS_ROOT / "KNOWLEDGE_GRAPH"
ENTITIES_DIR = KG_DIR / "entities"
AXIOMS_DIR = KG_DIR / "axioms"
RELATIONSHIPS_DIR = KG_DIR / "relationships"
INDEXES_DIR = KG_DIR / "indexes"
DATA_DIR = GENESIS_ROOT / "data"
LOG_DIR = GENESIS_ROOT / "logs" / "bloodstream"

# Output files
OUTPUT_INDEX = DATA_DIR / "bloodstream_index.json"
BACKUP_INDEX = INDEXES_DIR / "bloodstream_index.json"

# Deduplication settings
STALE_ENTITY_DAYS = 180  # Archive entities older than 6 months with no updates
BACKUP_RETENTION_DAYS = 30

# Setup logging
LOG_DIR.mkdir(parents=True, exist_ok=True)
log_file = LOG_DIR / f"digestion_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("bloodstream-digestion")

# =============================================================================
# Helper Functions
# =============================================================================

def load_all_jsonl(directory: Path) -> List[Dict[str, Any]]:
    """Load all JSONL files from a directory."""
    items = []
    if not directory.exists():
        logger.warning(f"Directory does not exist: {directory}")
        return items

    for jsonl_file in directory.glob("*.jsonl"):
        with open(jsonl_file, "r", encoding="utf-8") as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if line:
                    try:
                        item = json.loads(line)
                        item["_source_file"] = jsonl_file.name
                        item["_source_line"] = line_num
                        items.append(item)
                    except json.JSONDecodeError as e:
                        logger.warning(f"JSON decode error in {jsonl_file.name}:{line_num} - {e}")

    logger.info(f"Loaded {len(items)} items from {directory.name}/")
    return items

def get_entity_key(entity: Dict[str, Any]) -> str:
    """Generate a unique key for deduplication."""
    # Priority: id > entityId > name > title
    for field in ["id", "entityId", "name", "title"]:
        if field in entity and entity[field]:
            return str(entity[field]).strip().lower()

    # Fallback: generate from content
    content = str(entity.get("description", "")) + str(entity.get("content", ""))
    return f"generated_{hash(content)}"

def deduplicate_entities(entities: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Dict[str, int]]:
    """
    Deduplicate entities, keeping the most recent version.

    Returns:
        (unique_entities, duplicate_counts)
    """
    entity_map = {}  # key -> entity
    duplicate_counts = defaultdict(int)

    for entity in entities:
        key = get_entity_key(entity)

        if key in entity_map:
            # Duplicate found - keep the newer one
            duplicate_counts[key] += 1
            existing = entity_map[key]

            # Compare timestamps
            existing_ts = existing.get("timestamp") or existing.get("created_at") or ""
            new_ts = entity.get("timestamp") or entity.get("created_at") or ""

            if new_ts > existing_ts:
                entity_map[key] = entity  # Replace with newer
                logger.debug(f"Duplicate: {key} - kept newer version")
        else:
            entity_map[key] = entity

    unique_entities = list(entity_map.values())
    logger.info(f"Deduplicated: {len(entities)} → {len(unique_entities)} (removed {len(duplicate_counts)} duplicates)")

    return unique_entities, dict(duplicate_counts)

def detect_orphaned_entities(
    entities: List[Dict[str, Any]],
    relationships: List[Dict[str, Any]]
) -> List[str]:
    """Detect entities with no relationships (potential candidates for archiving)."""
    # Build set of entities referenced in relationships
    referenced_ids = set()
    for rel in relationships:
        referenced_ids.add(rel.get("from", ""))
        referenced_ids.add(rel.get("to", ""))

    # Find orphans
    orphans = []
    for entity in entities:
        entity_id = get_entity_key(entity)
        if entity_id not in referenced_ids:
            orphans.append(entity_id)

    logger.info(f"Orphaned entities: {len(orphans)}/{len(entities)}")
    return orphans

def detect_stale_entities(entities: List[Dict[str, Any]], days: int) -> List[Dict[str, Any]]:
    """Detect entities older than N days with no recent updates."""
    cutoff_date = datetime.now() - timedelta(days=days)
    stale = []

    for entity in entities:
        timestamp = entity.get("timestamp") or entity.get("created_at") or entity.get("createdAt")
        if timestamp:
            try:
                created = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
                if created.replace(tzinfo=None) < cutoff_date:
                    stale.append(entity)
            except Exception as e:
                logger.debug(f"Could not parse timestamp: {timestamp}")

    logger.info(f"Stale entities (>{days} days): {len(stale)}/{len(entities)}")
    return stale

def generate_statistics(
    entities: List[Dict[str, Any]],
    axioms: List[Dict[str, Any]],
    relationships: List[Dict[str, Any]]
) -> Dict[str, Any]:
    """Generate comprehensive memory statistics."""

    # Entity type distribution
    entity_types = Counter([e.get("entityType", "unknown") for e in entities])
    axiom_types = Counter([a.get("type", "unknown") for a in axioms])

    # Temporal distribution
    recent_30d = sum(1 for e in entities if is_recent(e, 30))
    recent_90d = sum(1 for e in entities if is_recent(e, 90))

    # Source file distribution
    source_files = Counter([e.get("_source_file", "unknown") for e in entities])

    stats = {
        "total_entities": len(entities),
        "total_axioms": len(axioms),
        "total_relationships": len(relationships),
        "entity_types": dict(entity_types),
        "axiom_types": dict(axiom_types),
        "source_files": dict(source_files),
        "temporal": {
            "recent_30_days": recent_30d,
            "recent_90_days": recent_90d,
            "older_than_90_days": len(entities) - recent_90d
        },
        "generated_at": datetime.now().isoformat()
    }

    return stats

def is_recent(entity: Dict[str, Any], days: int) -> bool:
    """Check if entity was created/updated within N days."""
    cutoff = datetime.now() - timedelta(days=days)
    timestamp = entity.get("timestamp") or entity.get("created_at") or entity.get("createdAt")

    if timestamp:
        try:
            created = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
            return created.replace(tzinfo=None) > cutoff
        except:
            pass

    return False

def build_entity_index(entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Build a lightweight index of all entities."""
    index = []

    for entity in entities:
        index_entry = {
            "id": get_entity_key(entity),
            "name": entity.get("name") or entity.get("title"),
            "type": entity.get("entityType") or entity.get("type"),
            "source_file": entity.get("_source_file"),
            "timestamp": entity.get("timestamp") or entity.get("created_at"),
            "observations_count": len(entity.get("observations", [])),
        }
        index.append(index_entry)

    return index

# =============================================================================
# Main Processing
# =============================================================================

def run_digestion(deep_clean: bool = False, dry_run: bool = False):
    """Run the full memory digestion pipeline."""

    logger.info("=" * 60)
    logger.info("RLM Bloodstream Memory Digestion - Starting")
    logger.info("=" * 60)

    # Load all knowledge
    entities = load_all_jsonl(ENTITIES_DIR)
    axioms = load_all_jsonl(AXIOMS_DIR)
    relationships = load_all_jsonl(RELATIONSHIPS_DIR)

    logger.info(f"Loaded: {len(entities)} entities, {len(axioms)} axioms, {len(relationships)} relationships")

    # Deduplicate entities
    unique_entities, duplicate_counts = deduplicate_entities(entities)

    # Detect orphans and stale entities
    orphans = detect_orphaned_entities(unique_entities, relationships)
    stale = detect_stale_entities(unique_entities, STALE_ENTITY_DAYS)

    # Generate statistics
    stats = generate_statistics(unique_entities, axioms, relationships)

    # Build entity index
    entity_index = build_entity_index(unique_entities)

    # Compile full index
    bloodstream_index = {
        "version": "1.0.0",
        "generated_at": datetime.now().isoformat(),
        "statistics": stats,
        "entities": entity_index,
        "orphaned_entities": orphans,
        "stale_entities": [get_entity_key(e) for e in stale],
        "duplicate_counts": duplicate_counts,
        "health": {
            "total_duplicates_removed": sum(duplicate_counts.values()),
            "orphaned_count": len(orphans),
            "stale_count": len(stale),
            "health_score": calculate_health_score(stats, orphans, stale)
        }
    }

    # Write outputs
    if not dry_run:
        DATA_DIR.mkdir(parents=True, exist_ok=True)
        INDEXES_DIR.mkdir(parents=True, exist_ok=True)

        with open(OUTPUT_INDEX, "w", encoding="utf-8") as f:
            json.dump(bloodstream_index, f, indent=2)
        logger.info(f"Index written: {OUTPUT_INDEX}")

        # Backup to KG indexes
        with open(BACKUP_INDEX, "w", encoding="utf-8") as f:
            json.dump(bloodstream_index, f, indent=2)
        logger.info(f"Backup written: {BACKUP_INDEX}")

        # Deep clean: Archive stale entities
        if deep_clean and stale:
            archive_stale_entities(stale)

    # Print summary report
    print_summary_report(bloodstream_index)

    logger.info("=" * 60)
    logger.info("RLM Bloodstream Memory Digestion - Complete")
    logger.info("=" * 60)

def calculate_health_score(stats: Dict, orphans: List, stale: List) -> float:
    """Calculate overall memory health score (0.0 - 1.0)."""
    total = stats["total_entities"]
    if total == 0:
        return 1.0

    orphan_penalty = len(orphans) / total * 0.3
    stale_penalty = len(stale) / total * 0.2

    health = 1.0 - orphan_penalty - stale_penalty
    return max(0.0, min(1.0, health))

def archive_stale_entities(stale: List[Dict[str, Any]]):
    """Archive stale entities to a separate directory."""
    archive_dir = KG_DIR / "archive" / f"stale_{datetime.now().strftime('%Y%m%d')}"
    archive_dir.mkdir(parents=True, exist_ok=True)

    archive_file = archive_dir / "archived_entities.jsonl"
    with open(archive_file, "w", encoding="utf-8") as f:
        for entity in stale:
            f.write(json.dumps(entity) + "\n")

    logger.info(f"Archived {len(stale)} stale entities to: {archive_file}")

def print_summary_report(index: Dict[str, Any]):
    """Print a human-readable summary report."""
    stats = index["statistics"]
    health = index["health"]

    print("\n" + "=" * 60)
    print("BLOODSTREAM MEMORY HEALTH REPORT")
    print("=" * 60)
    print(f"Generated: {index['generated_at']}")
    print()

    print("📊 TOTALS:")
    print(f"  Entities:      {stats['total_entities']}")
    print(f"  Axioms:        {stats['total_axioms']}")
    print(f"  Relationships: {stats['total_relationships']}")
    print()

    print("🔍 HEALTH METRICS:")
    print(f"  Health Score:  {health['health_score']:.2%}")
    print(f"  Duplicates:    {health['total_duplicates_removed']} removed")
    print(f"  Orphans:       {health['orphaned_count']} entities")
    print(f"  Stale (>180d): {health['stale_count']} entities")
    print()

    print("📅 TEMPORAL DISTRIBUTION:")
    print(f"  Last 30 days:  {stats['temporal']['recent_30_days']}")
    print(f"  Last 90 days:  {stats['temporal']['recent_90_days']}")
    print(f"  Older:         {stats['temporal']['older_than_90_days']}")
    print()

    print("📁 TOP ENTITY TYPES:")
    for entity_type, count in sorted(stats['entity_types'].items(), key=lambda x: x[1], reverse=True)[:5]:
        print(f"  {entity_type:20} {count}")
    print()

    print("=" * 60)
    print()

# =============================================================================
# Main Entry Point
# =============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="RLM Bloodstream Memory Digestion Cron"
    )
    parser.add_argument(
        "--deep-clean",
        action="store_true",
        help="Archive stale entities (>180 days)"
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Run without writing outputs (for testing)"
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Enable debug logging"
    )

    args = parser.parse_args()

    if args.debug:
        logger.setLevel(logging.DEBUG)

    if args.dry_run:
        logger.info("DRY RUN MODE - No files will be written")

    run_digestion(deep_clean=args.deep_clean, dry_run=args.dry_run)

if __name__ == "__main__":
    main()
