#!/usr/bin/env python3
"""
GENESIS ROLLBACK SYSTEM
========================
Tracks changes and enables automatic rollback on verification failure.

Features:
    - Git-based change tracking
    - Automatic snapshots before modifications
    - Selective rollback by file or change set
    - Verification integration
    - Change history with diffs

Usage:
    rollback = RollbackSystem()
    rollback.create_checkpoint("before feature X")
    # ... make changes ...
    if verification_failed:
        rollback.rollback_to_last()
"""

import json
import os
import shutil
import subprocess
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
import hashlib


@dataclass
class FileSnapshot:
    """Snapshot of a file's state."""
    path: str
    content_hash: str
    content: Optional[str] = None  # Only stored for small files
    size: int = 0
    exists: bool = True


@dataclass
class Checkpoint:
    """A checkpoint representing a system state."""
    checkpoint_id: str
    description: str
    created_at: str
    files: Dict[str, FileSnapshot]
    git_commit: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class RollbackResult:
    """Result of a rollback operation."""
    success: bool
    message: str
    files_restored: int = 0
    files_failed: int = 0
    details: List[str] = field(default_factory=list)


class GitIntegration:
    """Git-based change tracking."""

    def __init__(self, repo_path: Path):
        self.repo_path = repo_path

    def get_current_commit(self) -> Optional[str]:
        """Get current commit hash."""
        try:
            result = subprocess.run(
                ["git", "rev-parse", "HEAD"],
                cwd=self.repo_path,
                capture_output=True,
                text=True
            )
            if result.returncode == 0:
                return result.stdout.strip()
        except:
            pass
        return None

    def get_changed_files(self, since_commit: str = None) -> List[str]:
        """Get list of changed files."""
        try:
            cmd = ["git", "diff", "--name-only"]
            if since_commit:
                cmd.append(since_commit)

            result = subprocess.run(
                cmd,
                cwd=self.repo_path,
                capture_output=True,
                text=True
            )
            if result.returncode == 0:
                return [f for f in result.stdout.strip().split('\n') if f]
        except:
            pass
        return []

    def get_file_at_commit(self, file_path: str, commit: str) -> Optional[str]:
        """Get file content at specific commit."""
        try:
            result = subprocess.run(
                ["git", "show", f"{commit}:{file_path}"],
                cwd=self.repo_path,
                capture_output=True,
                text=True
            )
            if result.returncode == 0:
                return result.stdout
        except:
            pass
        return None

    def create_stash(self, message: str = "Rollback checkpoint") -> Optional[str]:
        """Create a git stash."""
        try:
            result = subprocess.run(
                ["git", "stash", "push", "-m", message],
                cwd=self.repo_path,
                capture_output=True,
                text=True
            )
            if result.returncode == 0:
                return "stash@{0}"
        except:
            pass
        return None

    def apply_stash(self, stash_ref: str = "stash@{0}") -> bool:
        """Apply a stash."""
        try:
            result = subprocess.run(
                ["git", "stash", "apply", stash_ref],
                cwd=self.repo_path,
                capture_output=True,
                text=True
            )
            return result.returncode == 0
        except:
            return False

    def checkout_file(self, file_path: str, commit: str) -> bool:
        """Checkout a file from specific commit."""
        try:
            result = subprocess.run(
                ["git", "checkout", commit, "--", file_path],
                cwd=self.repo_path,
                capture_output=True,
                text=True
            )
            return result.returncode == 0
        except:
            return False


class RollbackSystem:
    """
    Manages checkpoints and rollbacks for the Genesis system.
    """

    def __init__(self, genesis_root: Path = None, max_checkpoints: int = 20):
        self.genesis_root = genesis_root or Path(__file__).parent.parent
        self.checkpoints_dir = self.genesis_root / "data" / "checkpoints"
        self.checkpoints_dir.mkdir(parents=True, exist_ok=True)

        self.max_checkpoints = max_checkpoints
        self.checkpoints: List[Checkpoint] = []
        self.git = GitIntegration(self.genesis_root)

        # Files to track by default
        self.tracked_patterns = [
            "core/*.py",
            "skills/*.py",
            "loop/*.py",
            "*.json",
        ]

        self._load_checkpoints()

    def _load_checkpoints(self):
        """Load existing checkpoints."""
        index_path = self.checkpoints_dir / "index.json"
        if index_path.exists():
            try:
                data = json.loads(index_path.read_text())
                for cp_data in data.get("checkpoints", []):
                    self.checkpoints.append(Checkpoint(
                        checkpoint_id=cp_data["checkpoint_id"],
                        description=cp_data["description"],
                        created_at=cp_data["created_at"],
                        files={},  # Load on demand
                        git_commit=cp_data.get("git_commit"),
                        metadata=cp_data.get("metadata", {})
                    ))
            except:
                pass

    def _save_index(self):
        """Save checkpoint index."""
        index_path = self.checkpoints_dir / "index.json"
        data = {
            "checkpoints": [
                {
                    "checkpoint_id": cp.checkpoint_id,
                    "description": cp.description,
                    "created_at": cp.created_at,
                    "git_commit": cp.git_commit,
                    "metadata": cp.metadata
                }
                for cp in self.checkpoints
            ]
        }
        index_path.write_text(json.dumps(data, indent=2))

    def _hash_content(self, content: str) -> str:
        """Hash file content."""
        return hashlib.sha256(content.encode()).hexdigest()

    def _get_tracked_files(self) -> List[Path]:
        """Get list of files to track."""
        files = []
        for pattern in self.tracked_patterns:
            files.extend(self.genesis_root.glob(pattern))
        return [f for f in files if f.is_file()]

    def create_checkpoint(
        self,
        description: str,
        files: List[Path] = None,
        include_content: bool = True
    ) -> Checkpoint:
        """
        Create a new checkpoint.

        Args:
            description: Human-readable description
            files: Specific files to track (default: tracked_patterns)
            include_content: Store full content for small files

        Returns:
            Created checkpoint
        """
        checkpoint_id = datetime.now().strftime("%Y%m%d_%H%M%S")

        files_to_track = files or self._get_tracked_files()
        file_snapshots = {}

        for file_path in files_to_track:
            try:
                content = file_path.read_text()
                size = len(content)

                snapshot = FileSnapshot(
                    path=str(file_path.relative_to(self.genesis_root)),
                    content_hash=self._hash_content(content),
                    content=content if (include_content and size < 100000) else None,
                    size=size,
                    exists=True
                )
                file_snapshots[snapshot.path] = snapshot

            except Exception:
                pass

        checkpoint = Checkpoint(
            checkpoint_id=checkpoint_id,
            description=description,
            created_at=datetime.now().isoformat(),
            files=file_snapshots,
            git_commit=self.git.get_current_commit()
        )

        # Save checkpoint data
        cp_path = self.checkpoints_dir / f"{checkpoint_id}.json"
        cp_data = {
            "checkpoint_id": checkpoint_id,
            "description": description,
            "created_at": checkpoint.created_at,
            "git_commit": checkpoint.git_commit,
            "files": {
                path: {
                    "path": snap.path,
                    "content_hash": snap.content_hash,
                    "content": snap.content,
                    "size": snap.size,
                    "exists": snap.exists
                }
                for path, snap in file_snapshots.items()
            }
        }
        cp_path.write_text(json.dumps(cp_data, indent=2))

        self.checkpoints.append(checkpoint)

        # Cleanup old checkpoints
        while len(self.checkpoints) > self.max_checkpoints:
            old = self.checkpoints.pop(0)
            old_path = self.checkpoints_dir / f"{old.checkpoint_id}.json"
            if old_path.exists():
                old_path.unlink()

        self._save_index()

        return checkpoint

    def _load_checkpoint_files(self, checkpoint: Checkpoint) -> Dict[str, FileSnapshot]:
        """Load full file data for a checkpoint."""
        cp_path = self.checkpoints_dir / f"{checkpoint.checkpoint_id}.json"
        if not cp_path.exists():
            return {}

        data = json.loads(cp_path.read_text())
        files = {}

        for path, snap_data in data.get("files", {}).items():
            files[path] = FileSnapshot(
                path=snap_data["path"],
                content_hash=snap_data["content_hash"],
                content=snap_data.get("content"),
                size=snap_data["size"],
                exists=snap_data["exists"]
            )

        return files

    def rollback_to(self, checkpoint_id: str) -> RollbackResult:
        """
        Rollback to a specific checkpoint.

        Args:
            checkpoint_id: ID of checkpoint to rollback to

        Returns:
            RollbackResult with status and details
        """
        checkpoint = None
        for cp in self.checkpoints:
            if cp.checkpoint_id == checkpoint_id:
                checkpoint = cp
                break

        if not checkpoint:
            return RollbackResult(
                success=False,
                message=f"Checkpoint {checkpoint_id} not found"
            )

        # Load full file data
        files = self._load_checkpoint_files(checkpoint)
        if not files:
            return RollbackResult(
                success=False,
                message="Could not load checkpoint data"
            )

        restored = 0
        failed = 0
        details = []

        for path, snapshot in files.items():
            full_path = self.genesis_root / path

            try:
                if snapshot.content:
                    # Restore from stored content
                    full_path.parent.mkdir(parents=True, exist_ok=True)
                    full_path.write_text(snapshot.content)
                    restored += 1
                    details.append(f"Restored: {path}")

                elif checkpoint.git_commit:
                    # Try to restore from git
                    if self.git.checkout_file(path, checkpoint.git_commit):
                        restored += 1
                        details.append(f"Restored from git: {path}")
                    else:
                        failed += 1
                        details.append(f"Could not restore: {path}")
                else:
                    failed += 1
                    details.append(f"No content available: {path}")

            except Exception as e:
                failed += 1
                details.append(f"Error restoring {path}: {e}")

        return RollbackResult(
            success=failed == 0,
            message=f"Restored {restored} files" + (f", {failed} failed" if failed else ""),
            files_restored=restored,
            files_failed=failed,
            details=details
        )

    def rollback_to_last(self) -> RollbackResult:
        """Rollback to the most recent checkpoint."""
        if not self.checkpoints:
            return RollbackResult(
                success=False,
                message="No checkpoints available"
            )

        return self.rollback_to(self.checkpoints[-1].checkpoint_id)

    def rollback_file(self, file_path: str, checkpoint_id: str = None) -> RollbackResult:
        """Rollback a single file."""
        if checkpoint_id:
            checkpoint = None
            for cp in self.checkpoints:
                if cp.checkpoint_id == checkpoint_id:
                    checkpoint = cp
                    break
        else:
            checkpoint = self.checkpoints[-1] if self.checkpoints else None

        if not checkpoint:
            return RollbackResult(success=False, message="No checkpoint found")

        files = self._load_checkpoint_files(checkpoint)

        # Normalize path
        rel_path = str(Path(file_path).relative_to(self.genesis_root) if Path(file_path).is_absolute() else file_path)

        if rel_path not in files:
            return RollbackResult(
                success=False,
                message=f"File {rel_path} not found in checkpoint"
            )

        snapshot = files[rel_path]
        full_path = self.genesis_root / rel_path

        try:
            if snapshot.content:
                full_path.write_text(snapshot.content)
            elif checkpoint.git_commit:
                self.git.checkout_file(rel_path, checkpoint.git_commit)
            else:
                return RollbackResult(
                    success=False,
                    message="No content available for file"
                )

            return RollbackResult(
                success=True,
                message=f"Restored {rel_path}",
                files_restored=1
            )

        except Exception as e:
            return RollbackResult(
                success=False,
                message=f"Error: {e}",
                files_failed=1
            )

    def get_diff(self, checkpoint_id: str) -> Dict[str, str]:
        """Get diff between checkpoint and current state."""
        checkpoint = None
        for cp in self.checkpoints:
            if cp.checkpoint_id == checkpoint_id:
                checkpoint = cp
                break

        if not checkpoint:
            return {}

        files = self._load_checkpoint_files(checkpoint)
        diffs = {}

        for path, snapshot in files.items():
            full_path = self.genesis_root / path

            if not full_path.exists():
                diffs[path] = "DELETED"
            elif snapshot.content:
                current = full_path.read_text()
                if self._hash_content(current) != snapshot.content_hash:
                    diffs[path] = "MODIFIED"
            else:
                current = full_path.read_text()
                if self._hash_content(current) != snapshot.content_hash:
                    diffs[path] = "MODIFIED (no content stored)"

        return diffs

    def list_checkpoints(self) -> List[Dict]:
        """List all checkpoints."""
        return [
            {
                "id": cp.checkpoint_id,
                "description": cp.description,
                "created_at": cp.created_at,
                "git_commit": cp.git_commit[:8] if cp.git_commit else None
            }
            for cp in reversed(self.checkpoints)
        ]


def main():
    """CLI for rollback system."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Rollback System")
    parser.add_argument("command", choices=["list", "create", "rollback", "diff", "status"])
    parser.add_argument("--id", help="Checkpoint ID")
    parser.add_argument("--description", "-d", help="Checkpoint description")
    parser.add_argument("--file", help="Specific file for rollback")
    args = parser.parse_args()

    rollback = RollbackSystem()

    if args.command == "list":
        print("Checkpoints:")
        print("=" * 60)
        for cp in rollback.list_checkpoints():
            print(f"  {cp['id']} - {cp['description']}")
            print(f"    Created: {cp['created_at']}")
            if cp['git_commit']:
                print(f"    Git: {cp['git_commit']}")
            print()

    elif args.command == "create":
        desc = args.description or f"Manual checkpoint at {datetime.now()}"
        cp = rollback.create_checkpoint(desc)
        print(f"Created checkpoint: {cp.checkpoint_id}")
        print(f"Files tracked: {len(cp.files)}")

    elif args.command == "rollback":
        if args.file:
            result = rollback.rollback_file(args.file, args.id)
        elif args.id:
            result = rollback.rollback_to(args.id)
        else:
            result = rollback.rollback_to_last()

        status = "✓" if result.success else "✗"
        print(f"{status} {result.message}")
        for detail in result.details[:10]:
            print(f"  {detail}")

    elif args.command == "diff":
        if not args.id:
            print("--id required for diff")
            return
        diffs = rollback.get_diff(args.id)
        if diffs:
            print("Changes since checkpoint:")
            for path, status in diffs.items():
                print(f"  {status}: {path}")
        else:
            print("No changes detected")

    elif args.command == "status":
        print("Rollback System Status")
        print("=" * 40)
        print(f"Checkpoints: {len(rollback.checkpoints)}")
        print(f"Storage: {rollback.checkpoints_dir}")
        if rollback.checkpoints:
            last = rollback.checkpoints[-1]
            print(f"Last checkpoint: {last.checkpoint_id}")


if __name__ == "__main__":
    main()
