#!/usr/bin/env python3
"""
GENESIS CONTEXT WINDOW OPTIMIZER
================================
Smart context compression that maximizes relevant information per token.

Strategies:
    1. Semantic Compression: Keep meaning, reduce tokens
    2. Priority Ordering: Most relevant content first
    3. Chunking: Break large content into digestible pieces
    4. Summarization: Compress history while preserving key points
    5. Dynamic Allocation: Adjust context budget per task type

Usage:
    optimizer = ContextOptimizer(max_tokens=100000)
    optimized = optimizer.optimize(content, task_type="code_review")
"""

import hashlib
import json
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
from collections import OrderedDict


@dataclass
class ContentBlock:
    """A block of content with metadata."""
    content: str
    block_type: str  # code, text, data, system, user
    priority: float  # 0-1, higher = more important
    tokens: int
    source: Optional[str] = None
    timestamp: Optional[str] = None
    hash: str = field(default="")

    def __post_init__(self):
        if not self.hash:
            self.hash = hashlib.md5(self.content.encode()).hexdigest()[:8]


@dataclass
class ContextBudget:
    """Budget allocation for different content types."""
    total: int
    system: int
    user: int
    code: int
    data: int
    history: int
    reserved: int  # Buffer for output


class TokenEstimator:
    """Estimates token count for text."""

    def __init__(self, chars_per_token: float = 4.0):
        self.chars_per_token = chars_per_token
        self.code_multiplier = 1.2  # Code is often less dense

    def estimate(self, text: str, content_type: str = "text") -> int:
        """Estimate token count."""
        if not text:
            return 0

        base = len(text) / self.chars_per_token

        if content_type == "code":
            return int(base * self.code_multiplier)

        return int(base)

    def estimate_blocks(self, blocks: List[ContentBlock]) -> int:
        """Estimate total tokens for blocks."""
        return sum(b.tokens for b in blocks)


class SemanticCompressor:
    """Compresses content while preserving semantic meaning."""

    def __init__(self):
        self.compression_patterns = [
            # Remove excessive whitespace
            (r'\n\s*\n\s*\n', '\n\n'),
            # Compress indentation
            (r'    ', '  '),
            # Remove trailing whitespace
            (r' +\n', '\n'),
            # Remove empty lines in code
            (r'(\n\s*){3,}', '\n\n'),
        ]

        # Code-specific patterns
        self.code_patterns = [
            # Remove excessive comments (keep docstrings)
            (r'#[^\n]*\n\s*#[^\n]*\n', '# ...\n'),
            # Compress empty methods
            (r'def \w+\([^)]*\):\s*pass', 'def ...: pass'),
        ]

    def compress(self, content: str, content_type: str = "text",
                target_ratio: float = 0.8) -> str:
        """Compress content to target ratio."""
        original_len = len(content)
        compressed = content

        # Apply general patterns
        for pattern, replacement in self.compression_patterns:
            compressed = re.sub(pattern, replacement, compressed)

        # Apply code patterns
        if content_type == "code":
            for pattern, replacement in self.code_patterns:
                compressed = re.sub(pattern, replacement, compressed)

        # If still too long, truncate intelligently
        target_len = int(original_len * target_ratio)
        if len(compressed) > target_len:
            compressed = self._smart_truncate(compressed, target_len, content_type)

        return compressed

    def _smart_truncate(self, content: str, target_len: int,
                       content_type: str) -> str:
        """Truncate content intelligently."""
        if len(content) <= target_len:
            return content

        # For code, try to keep complete functions/classes
        if content_type == "code":
            lines = content.split('\n')
            result_lines = []
            current_len = 0

            for line in lines:
                if current_len + len(line) + 1 > target_len:
                    result_lines.append("# ... (truncated)")
                    break
                result_lines.append(line)
                current_len += len(line) + 1

            return '\n'.join(result_lines)

        # For text, keep sentences
        sentences = re.split(r'(?<=[.!?])\s+', content)
        result = []
        current_len = 0

        for sentence in sentences:
            if current_len + len(sentence) > target_len:
                result.append("...")
                break
            result.append(sentence)
            current_len += len(sentence) + 1

        return ' '.join(result)


class PriorityRanker:
    """Ranks content blocks by relevance."""

    def __init__(self):
        # Type priority defaults
        self.type_priorities = {
            "system": 0.9,
            "user": 0.8,
            "code": 0.7,
            "error": 0.85,
            "data": 0.5,
            "history": 0.4,
            "metadata": 0.3
        }

        # Keywords that increase priority
        self.priority_keywords = [
            "error", "exception", "critical", "important",
            "required", "must", "todo", "fixme", "bug"
        ]

    def rank(self, blocks: List[ContentBlock],
            task_type: Optional[str] = None) -> List[ContentBlock]:
        """Rank blocks by priority."""
        for block in blocks:
            score = self.type_priorities.get(block.block_type, 0.5)

            # Keyword boost
            content_lower = block.content.lower()
            keyword_matches = sum(
                1 for kw in self.priority_keywords
                if kw in content_lower
            )
            score += keyword_matches * 0.05

            # Task-type relevance
            if task_type:
                if task_type == "code_review" and block.block_type == "code":
                    score += 0.2
                elif task_type == "debugging" and "error" in content_lower:
                    score += 0.3
                elif task_type == "research" and block.block_type == "data":
                    score += 0.2

            # Recency boost (if timestamp)
            if block.timestamp:
                try:
                    dt = datetime.fromisoformat(block.timestamp)
                    age_hours = (datetime.now() - dt).total_seconds() / 3600
                    if age_hours < 1:
                        score += 0.1
                    elif age_hours < 24:
                        score += 0.05
                except:
                    pass

            block.priority = min(score, 1.0)

        return sorted(blocks, key=lambda b: -b.priority)


class ContextChunker:
    """Breaks large content into chunks."""

    def __init__(self, max_chunk_tokens: int = 2000):
        self.max_chunk_tokens = max_chunk_tokens
        self.estimator = TokenEstimator()

    def chunk(self, content: str, content_type: str = "text") -> List[str]:
        """Break content into chunks."""
        tokens = self.estimator.estimate(content, content_type)

        if tokens <= self.max_chunk_tokens:
            return [content]

        # Calculate number of chunks needed
        num_chunks = (tokens // self.max_chunk_tokens) + 1
        chunk_size = len(content) // num_chunks

        if content_type == "code":
            return self._chunk_code(content, chunk_size)
        else:
            return self._chunk_text(content, chunk_size)

    def _chunk_code(self, code: str, chunk_size: int) -> List[str]:
        """Chunk code at function/class boundaries."""
        chunks = []
        lines = code.split('\n')
        current_chunk = []
        current_size = 0

        for line in lines:
            # Start new chunk at class/function definitions
            if current_size > chunk_size and (
                line.startswith('def ') or
                line.startswith('class ') or
                line.startswith('async def ')
            ):
                if current_chunk:
                    chunks.append('\n'.join(current_chunk))
                current_chunk = [line]
                current_size = len(line)
            else:
                current_chunk.append(line)
                current_size += len(line)

        if current_chunk:
            chunks.append('\n'.join(current_chunk))

        return chunks

    def _chunk_text(self, text: str, chunk_size: int) -> List[str]:
        """Chunk text at paragraph boundaries."""
        chunks = []
        paragraphs = re.split(r'\n\n+', text)
        current_chunk = []
        current_size = 0

        for para in paragraphs:
            if current_size + len(para) > chunk_size and current_chunk:
                chunks.append('\n\n'.join(current_chunk))
                current_chunk = [para]
                current_size = len(para)
            else:
                current_chunk.append(para)
                current_size += len(para)

        if current_chunk:
            chunks.append('\n\n'.join(current_chunk))

        return chunks


class HistorySummarizer:
    """Summarizes conversation history."""

    def __init__(self):
        self.max_recent = 5  # Keep last N messages full
        self.summary_token_ratio = 0.2  # Target 20% of original

    def summarize(self, messages: List[Dict],
                 token_budget: int) -> List[Dict]:
        """Summarize older messages while keeping recent ones."""
        if len(messages) <= self.max_recent:
            return messages

        recent = messages[-self.max_recent:]
        older = messages[:-self.max_recent]

        # Create summary of older messages
        summary_parts = []
        for msg in older:
            role = msg.get("role", "user")
            content = msg.get("content", "")

            # Extract key points
            if len(content) > 200:
                # Keep first sentence and any important markers
                first_sentence = re.split(r'[.!?]', content)[0]
                summary_parts.append(f"[{role}]: {first_sentence}...")
            else:
                summary_parts.append(f"[{role}]: {content}")

        summary_message = {
            "role": "system",
            "content": f"[Conversation summary - {len(older)} earlier messages]:\n" +
                      "\n".join(summary_parts[-10:])  # Keep last 10 summaries
        }

        return [summary_message] + recent


class ContextOptimizer:
    """
    Main context optimizer that coordinates all optimization strategies.

    Features:
    - Dynamic budget allocation based on task type
    - Priority-based content selection
    - Semantic compression
    - Smart chunking
    - History summarization
    """

    def __init__(self, max_tokens: int = 100000):
        self.max_tokens = max_tokens
        self.estimator = TokenEstimator()
        self.compressor = SemanticCompressor()
        self.ranker = PriorityRanker()
        self.chunker = ContextChunker()
        self.summarizer = HistorySummarizer()

        # Budget allocation by task type
        self.budgets = {
            "default": ContextBudget(
                total=max_tokens,
                system=int(max_tokens * 0.1),
                user=int(max_tokens * 0.3),
                code=int(max_tokens * 0.3),
                data=int(max_tokens * 0.1),
                history=int(max_tokens * 0.1),
                reserved=int(max_tokens * 0.1)
            ),
            "code_review": ContextBudget(
                total=max_tokens,
                system=int(max_tokens * 0.05),
                user=int(max_tokens * 0.15),
                code=int(max_tokens * 0.5),
                data=int(max_tokens * 0.05),
                history=int(max_tokens * 0.15),
                reserved=int(max_tokens * 0.1)
            ),
            "research": ContextBudget(
                total=max_tokens,
                system=int(max_tokens * 0.1),
                user=int(max_tokens * 0.2),
                code=int(max_tokens * 0.1),
                data=int(max_tokens * 0.4),
                history=int(max_tokens * 0.1),
                reserved=int(max_tokens * 0.1)
            ),
            "debugging": ContextBudget(
                total=max_tokens,
                system=int(max_tokens * 0.1),
                user=int(max_tokens * 0.2),
                code=int(max_tokens * 0.35),
                data=int(max_tokens * 0.2),
                history=int(max_tokens * 0.05),
                reserved=int(max_tokens * 0.1)
            )
        }

    def optimize(
        self,
        blocks: List[ContentBlock],
        task_type: str = "default",
        preserve_order: bool = False
    ) -> List[ContentBlock]:
        """
        Optimize content blocks to fit within context budget.

        Args:
            blocks: Content blocks to optimize
            task_type: Type of task (affects budget allocation)
            preserve_order: If True, maintain original order

        Returns:
            Optimized list of content blocks
        """
        budget = self.budgets.get(task_type, self.budgets["default"])

        # Rank by priority
        ranked = self.ranker.rank(blocks, task_type)

        # Group by type
        by_type = {}
        for block in ranked:
            if block.block_type not in by_type:
                by_type[block.block_type] = []
            by_type[block.block_type].append(block)

        # Allocate budget per type
        optimized = []
        used_tokens = 0

        type_budgets = {
            "system": budget.system,
            "user": budget.user,
            "code": budget.code,
            "data": budget.data,
            "history": budget.history
        }

        for block_type, type_blocks in by_type.items():
            type_budget = type_budgets.get(block_type, budget.user)
            type_used = 0

            for block in type_blocks:
                if type_used + block.tokens > type_budget:
                    # Try to compress
                    ratio = (type_budget - type_used) / block.tokens
                    if ratio > 0.3:  # Only compress if we can keep 30%+
                        compressed = self.compressor.compress(
                            block.content,
                            block.block_type,
                            ratio
                        )
                        block.content = compressed
                        block.tokens = self.estimator.estimate(
                            compressed, block.block_type
                        )
                    else:
                        continue  # Skip block

                optimized.append(block)
                type_used += block.tokens
                used_tokens += block.tokens

        # Restore order if requested
        if preserve_order:
            order_map = {b.hash: i for i, b in enumerate(blocks)}
            optimized.sort(key=lambda b: order_map.get(b.hash, float('inf')))

        return optimized

    def optimize_messages(
        self,
        messages: List[Dict],
        task_type: str = "default"
    ) -> List[Dict]:
        """Optimize a list of chat messages."""
        # Convert to blocks
        blocks = []
        for msg in messages:
            role = msg.get("role", "user")
            content = msg.get("content", "")

            block_type = "system" if role == "system" else "user"
            if "```" in content:
                block_type = "code"

            blocks.append(ContentBlock(
                content=content,
                block_type=block_type,
                priority=0.5,
                tokens=self.estimator.estimate(content),
                source=role
            ))

        # Optimize
        optimized_blocks = self.optimize(blocks, task_type, preserve_order=True)

        # Convert back to messages
        return [
            {"role": b.source or "user", "content": b.content}
            for b in optimized_blocks
        ]

    def get_budget_status(self, task_type: str = "default") -> Dict:
        """Get current budget allocation status."""
        budget = self.budgets.get(task_type, self.budgets["default"])
        return {
            "total": budget.total,
            "allocations": {
                "system": budget.system,
                "user": budget.user,
                "code": budget.code,
                "data": budget.data,
                "history": budget.history,
                "reserved": budget.reserved
            }
        }

    def estimate_fit(self, content: str, content_type: str = "text") -> Dict:
        """Estimate if content fits and how much compression needed."""
        tokens = self.estimator.estimate(content, content_type)
        available = self.max_tokens - self.budgets["default"].reserved

        return {
            "tokens": tokens,
            "available": available,
            "fits": tokens <= available,
            "compression_needed": max(0, tokens - available) / tokens if tokens > available else 0,
            "percentage_used": tokens / available
        }


def main():
    """Test the context optimizer."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Context Optimizer")
    parser.add_argument("--file", help="File to analyze")
    parser.add_argument("--max-tokens", type=int, default=100000,
                       help="Max context tokens")
    parser.add_argument("--task-type", default="default",
                       choices=["default", "code_review", "research", "debugging"])
    parser.add_argument("--demo", action="store_true", help="Run demo")
    args = parser.parse_args()

    optimizer = ContextOptimizer(max_tokens=args.max_tokens)

    if args.file:
        path = Path(args.file)
        if path.exists():
            content = path.read_text()
            result = optimizer.estimate_fit(content, "code" if path.suffix == ".py" else "text")
            print(f"File: {path}")
            print(f"Tokens: {result['tokens']}")
            print(f"Available: {result['available']}")
            print(f"Fits: {'Yes' if result['fits'] else 'No'}")
            if result['compression_needed'] > 0:
                print(f"Compression needed: {result['compression_needed']:.0%}")
        return

    if args.demo:
        print("Context Optimizer Demo")
        print("=" * 40)

        # Demo blocks
        blocks = [
            ContentBlock(
                content="You are a helpful AI assistant for code review.",
                block_type="system",
                priority=0.9,
                tokens=15
            ),
            ContentBlock(
                content="Please review this Python code for bugs and improvements.",
                block_type="user",
                priority=0.8,
                tokens=12
            ),
            ContentBlock(
                content="""
def calculate(x, y):
    result = x / y
    return result

def process(data):
    for item in data:
        print(item)
""",
                block_type="code",
                priority=0.7,
                tokens=50
            ),
            ContentBlock(
                content="Earlier in the conversation, we discussed error handling.",
                block_type="history",
                priority=0.4,
                tokens=10
            )
        ]

        print("\nOriginal blocks:")
        for b in blocks:
            print(f"  [{b.block_type}] {b.tokens} tokens - priority {b.priority}")

        optimized = optimizer.optimize(blocks, task_type="code_review")

        print("\nOptimized blocks:")
        for b in optimized:
            print(f"  [{b.block_type}] {b.tokens} tokens - priority {b.priority:.2f}")

        print("\nBudget status for 'code_review':")
        print(json.dumps(optimizer.get_budget_status("code_review"), indent=2))
        return

    # Default: show configuration
    print("Context Optimizer Configuration")
    print("=" * 40)
    print(f"Max tokens: {optimizer.max_tokens}")
    print("\nTask type budgets:")
    for task_type in optimizer.budgets:
        budget = optimizer.get_budget_status(task_type)
        print(f"\n  {task_type}:")
        for alloc_type, tokens in budget["allocations"].items():
            pct = tokens / budget["total"] * 100
            print(f"    {alloc_type}: {tokens} ({pct:.0f}%)")


if __name__ == "__main__":
    main()
