#!/usr/bin/env python3
"""
GENESIS CONTINUOUS LEARNING LOOP
=================================
Creates feedback loops that learn from task success/failure patterns
to improve future execution.

Learning Mechanisms:
    1. Pattern Recognition: Identify success/failure patterns
    2. Strategy Adaptation: Adjust approach based on outcomes
    3. Prompt Evolution: Improve prompts that lead to success
    4. Agent Selection Learning: Learn which agents excel at what
    5. Error Prediction: Predict likely failures before they happen

Usage:
    loop = LearningLoop()
    loop.record_outcome(task, result)
    recommendations = loop.get_recommendations(new_task)
"""

import json
import math
import os
import re
from collections import Counter, defaultdict
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
import hashlib


@dataclass
class TaskOutcome:
    """Records the outcome of a task execution."""
    task_id: str
    task_type: str
    task_title: str
    complexity: str
    agent_used: str
    success: bool
    duration: float
    cost: float
    error_type: Optional[str] = None
    error_message: Optional[str] = None
    prompt_hash: Optional[str] = None
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())

    def to_dict(self) -> Dict:
        return asdict(self)


@dataclass
class LearningInsight:
    """An insight derived from learning."""
    insight_type: str
    description: str
    confidence: float
    evidence_count: int
    recommendation: str
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())


@dataclass
class Strategy:
    """An execution strategy derived from learning."""
    strategy_id: str
    name: str
    conditions: Dict[str, Any]  # When to apply
    actions: Dict[str, Any]      # What to do
    success_rate: float = 0.0
    usage_count: int = 0
    last_used: Optional[str] = None


class PatternRecognizer:
    """Recognizes patterns in task execution history."""

    def __init__(self):
        self.patterns: Dict[str, Dict] = {}

    def analyze_outcomes(self, outcomes: List[TaskOutcome]) -> List[Dict]:
        """Analyze outcomes to find patterns."""
        patterns = []

        # Group by task type
        by_type = defaultdict(list)
        for o in outcomes:
            by_type[o.task_type].append(o)

        for task_type, type_outcomes in by_type.items():
            success_rate = sum(1 for o in type_outcomes if o.success) / len(type_outcomes)

            # Agent performance by task type
            by_agent = defaultdict(list)
            for o in type_outcomes:
                by_agent[o.agent_used].append(o)

            best_agent = None
            best_rate = 0
            for agent, agent_outcomes in by_agent.items():
                rate = sum(1 for o in agent_outcomes if o.success) / len(agent_outcomes)
                if rate > best_rate:
                    best_rate = rate
                    best_agent = agent

            patterns.append({
                "pattern": "task_type_performance",
                "task_type": task_type,
                "success_rate": success_rate,
                "best_agent": best_agent,
                "best_agent_rate": best_rate,
                "sample_size": len(type_outcomes)
            })

        # Time-based patterns
        by_hour = defaultdict(list)
        for o in outcomes:
            try:
                dt = datetime.fromisoformat(o.timestamp)
                by_hour[dt.hour].append(o)
            except:
                pass

        if by_hour:
            best_hour = max(by_hour.keys(),
                          key=lambda h: sum(1 for o in by_hour[h] if o.success) / len(by_hour[h]))
            patterns.append({
                "pattern": "temporal_performance",
                "best_hour": best_hour,
                "success_rate": sum(1 for o in by_hour[best_hour] if o.success) / len(by_hour[best_hour]),
                "sample_size": len(by_hour[best_hour])
            })

        # Error patterns
        errors = [o for o in outcomes if not o.success and o.error_type]
        if errors:
            error_counts = Counter(o.error_type for o in errors)
            patterns.append({
                "pattern": "common_errors",
                "error_distribution": dict(error_counts),
                "most_common": error_counts.most_common(3)
            })

        # Complexity vs success
        by_complexity = defaultdict(list)
        for o in outcomes:
            by_complexity[o.complexity].append(o)

        for complexity, comp_outcomes in by_complexity.items():
            patterns.append({
                "pattern": "complexity_performance",
                "complexity": complexity,
                "success_rate": sum(1 for o in comp_outcomes if o.success) / len(comp_outcomes),
                "avg_duration": sum(o.duration for o in comp_outcomes) / len(comp_outcomes),
                "sample_size": len(comp_outcomes)
            })

        return patterns


class PromptEvolver:
    """Evolves prompts based on outcome success."""

    def __init__(self, storage_path: Path = None):
        self.storage_path = storage_path or Path("data/prompt_evolution.json")
        self.prompt_registry: Dict[str, Dict] = {}
        self._load()

    def _load(self):
        if self.storage_path.exists():
            try:
                self.prompt_registry = json.loads(self.storage_path.read_text())
            except:
                pass

    def _save(self):
        self.storage_path.parent.mkdir(parents=True, exist_ok=True)
        self.storage_path.write_text(json.dumps(self.prompt_registry, indent=2))

    def hash_prompt(self, prompt: str) -> str:
        """Create a hash for a prompt."""
        return hashlib.sha256(prompt.encode()).hexdigest()[:16]

    def record_prompt(self, prompt: str, task_type: str, success: bool):
        """Record a prompt's performance."""
        prompt_hash = self.hash_prompt(prompt)

        if prompt_hash not in self.prompt_registry:
            self.prompt_registry[prompt_hash] = {
                "prompt_sample": prompt[:500],  # Store first 500 chars
                "task_type": task_type,
                "successes": 0,
                "failures": 0,
                "first_seen": datetime.now().isoformat()
            }

        if success:
            self.prompt_registry[prompt_hash]["successes"] += 1
        else:
            self.prompt_registry[prompt_hash]["failures"] += 1

        self._save()

    def get_successful_patterns(self, task_type: str = None) -> List[Dict]:
        """Get patterns from successful prompts."""
        successful = []

        for prompt_hash, data in self.prompt_registry.items():
            total = data["successes"] + data["failures"]
            if total >= 3:  # Minimum sample size
                success_rate = data["successes"] / total
                if success_rate >= 0.7:  # 70% success threshold
                    if task_type is None or data["task_type"] == task_type:
                        successful.append({
                            "prompt_sample": data["prompt_sample"],
                            "success_rate": success_rate,
                            "sample_size": total
                        })

        return sorted(successful, key=lambda x: -x["success_rate"])

    def suggest_improvements(self, prompt: str, task_type: str) -> List[str]:
        """Suggest improvements based on successful patterns."""
        suggestions = []

        successful = self.get_successful_patterns(task_type)
        if not successful:
            return ["No successful patterns found for this task type yet."]

        # Extract common elements from successful prompts
        best = successful[0]
        suggestions.append(f"Consider patterns from prompts with {best['success_rate']:.0%} success rate")

        return suggestions


class AgentSelector:
    """Learns which agents perform best for different tasks."""

    def __init__(self):
        self.agent_performance: Dict[str, Dict[str, Dict]] = defaultdict(
            lambda: defaultdict(lambda: {"successes": 0, "failures": 0, "total_duration": 0})
        )

    def record(self, agent: str, task_type: str, success: bool, duration: float):
        """Record an agent's performance."""
        perf = self.agent_performance[task_type][agent]
        if success:
            perf["successes"] += 1
        else:
            perf["failures"] += 1
        perf["total_duration"] += duration

    def get_best_agent(self, task_type: str) -> Tuple[Optional[str], float]:
        """Get best agent for a task type based on history."""
        if task_type not in self.agent_performance:
            return None, 0.0

        best_agent = None
        best_score = 0.0

        for agent, perf in self.agent_performance[task_type].items():
            total = perf["successes"] + perf["failures"]
            if total >= 2:  # Minimum samples
                success_rate = perf["successes"] / total
                avg_duration = perf["total_duration"] / total

                # Score: success rate with speed bonus
                score = success_rate * (1 + 1 / (1 + avg_duration))

                if score > best_score:
                    best_score = score
                    best_agent = agent

        return best_agent, best_score

    def get_recommendations(self) -> Dict[str, str]:
        """Get agent recommendations for all task types."""
        recommendations = {}
        for task_type in self.agent_performance:
            agent, score = self.get_best_agent(task_type)
            if agent:
                recommendations[task_type] = agent
        return recommendations


class ErrorPredictor:
    """Predicts likely errors before they happen."""

    def __init__(self):
        self.error_patterns: Dict[str, List[Dict]] = defaultdict(list)

    def learn_from_error(self, outcome: TaskOutcome):
        """Learn from a failed outcome."""
        if outcome.success or not outcome.error_type:
            return

        self.error_patterns[outcome.task_type].append({
            "complexity": outcome.complexity,
            "agent": outcome.agent_used,
            "error_type": outcome.error_type,
            "error_message": outcome.error_message
        })

    def predict_errors(self, task_type: str, complexity: str, agent: str) -> List[Dict]:
        """Predict likely errors for a task configuration."""
        predictions = []

        if task_type not in self.error_patterns:
            return predictions

        # Find similar historical errors
        for pattern in self.error_patterns[task_type]:
            similarity = 0
            if pattern["complexity"] == complexity:
                similarity += 0.5
            if pattern["agent"] == agent:
                similarity += 0.5

            if similarity >= 0.5:
                predictions.append({
                    "error_type": pattern["error_type"],
                    "likelihood": similarity,
                    "mitigation": self._suggest_mitigation(pattern["error_type"])
                })

        return predictions

    def _suggest_mitigation(self, error_type: str) -> str:
        """Suggest mitigation for an error type."""
        mitigations = {
            "timeout": "Increase timeout or break task into smaller parts",
            "rate_limit": "Add delay between requests or use different agent",
            "context_overflow": "Reduce context or use agent with larger context window",
            "parse_error": "Add more specific format instructions in prompt",
            "api_error": "Implement retry with exponential backoff",
            "validation_error": "Add input validation before execution"
        }
        return mitigations.get(error_type, "Review error logs and adjust approach")


class LearningLoop:
    """
    Main learning loop that coordinates all learning components.

    Implements continuous learning through:
    1. Recording all task outcomes
    2. Recognizing patterns
    3. Evolving strategies
    4. Generating recommendations
    """

    def __init__(self, storage_dir: Path = None):
        self.storage_dir = storage_dir or Path(__file__).parent.parent / "data" / "learning"
        self.storage_dir.mkdir(parents=True, exist_ok=True)

        self.outcomes: List[TaskOutcome] = []
        self.strategies: Dict[str, Strategy] = {}
        self.insights: List[LearningInsight] = []

        self.pattern_recognizer = PatternRecognizer()
        self.prompt_evolver = PromptEvolver(self.storage_dir / "prompts.json")
        self.agent_selector = AgentSelector()
        self.error_predictor = ErrorPredictor()

        self._load_history()

    def _load_history(self):
        """Load historical outcomes."""
        history_path = self.storage_dir / "outcomes.jsonl"
        if history_path.exists():
            for line in history_path.read_text().strip().split('\n'):
                if line:
                    try:
                        data = json.loads(line)
                        self.outcomes.append(TaskOutcome(**data))
                    except:
                        pass

    def _save_outcome(self, outcome: TaskOutcome):
        """Append outcome to history."""
        history_path = self.storage_dir / "outcomes.jsonl"
        with open(history_path, 'a') as f:
            f.write(json.dumps(outcome.to_dict()) + '\n')

    def record_outcome(
        self,
        task_id: str,
        task_type: str,
        task_title: str,
        complexity: str,
        agent_used: str,
        success: bool,
        duration: float,
        cost: float,
        error_type: Optional[str] = None,
        error_message: Optional[str] = None,
        prompt: Optional[str] = None
    ):
        """Record a task outcome and learn from it."""
        outcome = TaskOutcome(
            task_id=task_id,
            task_type=task_type,
            task_title=task_title,
            complexity=complexity,
            agent_used=agent_used,
            success=success,
            duration=duration,
            cost=cost,
            error_type=error_type,
            error_message=error_message,
            prompt_hash=self.prompt_evolver.hash_prompt(prompt) if prompt else None
        )

        self.outcomes.append(outcome)
        self._save_outcome(outcome)

        # Update learners
        self.agent_selector.record(agent_used, task_type, success, duration)

        if prompt:
            self.prompt_evolver.record_prompt(prompt, task_type, success)

        if not success and error_type:
            self.error_predictor.learn_from_error(outcome)

        # Check for new insights periodically
        if len(self.outcomes) % 10 == 0:
            self._generate_insights()

    def _generate_insights(self):
        """Generate new insights from recent patterns."""
        recent = self.outcomes[-100:]  # Last 100 outcomes
        patterns = self.pattern_recognizer.analyze_outcomes(recent)

        for pattern in patterns:
            if pattern["pattern"] == "task_type_performance":
                if pattern["success_rate"] < 0.5 and pattern["sample_size"] >= 5:
                    self.insights.append(LearningInsight(
                        insight_type="low_success_rate",
                        description=f"Task type '{pattern['task_type']}' has low success rate",
                        confidence=min(pattern["sample_size"] / 10, 1.0),
                        evidence_count=pattern["sample_size"],
                        recommendation=f"Consider using {pattern['best_agent']} for these tasks"
                    ))

    def get_recommendations(
        self,
        task_type: str,
        complexity: str = "moderate"
    ) -> Dict:
        """Get recommendations for a new task."""
        recommendations = {
            "agent": None,
            "agent_confidence": 0.0,
            "predicted_errors": [],
            "success_probability": 0.5,
            "estimated_duration": None,
            "prompt_suggestions": [],
            "warnings": []
        }

        # Agent recommendation
        agent, confidence = self.agent_selector.get_best_agent(task_type)
        if agent:
            recommendations["agent"] = agent
            recommendations["agent_confidence"] = confidence

        # Error predictions
        if agent:
            predictions = self.error_predictor.predict_errors(task_type, complexity, agent)
            recommendations["predicted_errors"] = predictions

            if predictions:
                high_risk = [p for p in predictions if p["likelihood"] >= 0.7]
                if high_risk:
                    recommendations["warnings"].append(
                        f"High risk of {high_risk[0]['error_type']}"
                    )

        # Success probability based on history
        type_outcomes = [o for o in self.outcomes if o.task_type == task_type]
        if type_outcomes:
            recommendations["success_probability"] = (
                sum(1 for o in type_outcomes if o.success) / len(type_outcomes)
            )
            recommendations["estimated_duration"] = (
                sum(o.duration for o in type_outcomes) / len(type_outcomes)
            )

        # Prompt suggestions
        recommendations["prompt_suggestions"] = self.prompt_evolver.suggest_improvements(
            "", task_type
        )

        return recommendations

    def get_insights(self, limit: int = 10) -> List[LearningInsight]:
        """Get recent insights."""
        return sorted(
            self.insights,
            key=lambda x: x.confidence,
            reverse=True
        )[:limit]

    def get_summary(self) -> Dict:
        """Get learning loop summary."""
        if not self.outcomes:
            return {"status": "no_data", "message": "No outcomes recorded yet"}

        recent = self.outcomes[-100:]
        success_rate = sum(1 for o in recent if o.success) / len(recent)

        return {
            "total_outcomes": len(self.outcomes),
            "recent_success_rate": success_rate,
            "insights_generated": len(self.insights),
            "agent_recommendations": self.agent_selector.get_recommendations(),
            "successful_patterns": len(self.prompt_evolver.get_successful_patterns()),
            "learning_status": "active" if len(self.outcomes) > 10 else "warming_up"
        }


def main():
    """Test the learning loop."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Learning Loop")
    parser.add_argument("--summary", action="store_true", help="Show learning summary")
    parser.add_argument("--insights", action="store_true", help="Show insights")
    parser.add_argument("--recommend", help="Get recommendations for task type")
    parser.add_argument("--demo", action="store_true", help="Run demo with sample data")
    args = parser.parse_args()

    loop = LearningLoop()

    if args.summary:
        print("Learning Loop Summary:")
        print(json.dumps(loop.get_summary(), indent=2))
        return

    if args.insights:
        print("Recent Insights:")
        for insight in loop.get_insights():
            print(f"\n[{insight.insight_type}] {insight.description}")
            print(f"  Confidence: {insight.confidence:.0%}")
            print(f"  Recommendation: {insight.recommendation}")
        return

    if args.recommend:
        print(f"Recommendations for '{args.recommend}':")
        print(json.dumps(loop.get_recommendations(args.recommend), indent=2))
        return

    if args.demo:
        print("Running demo with sample outcomes...")

        # Simulate some outcomes
        sample_data = [
            ("code_generation", "gemini-flash", True, 2.5),
            ("code_generation", "gemini-flash", True, 3.0),
            ("code_generation", "claude-sonnet", True, 4.5),
            ("architecture", "claude-opus", True, 10.0),
            ("architecture", "gemini-flash", False, 5.0),
            ("architecture", "claude-opus", True, 8.0),
            ("monitoring", "aiva-qwen", True, 0.5),
            ("monitoring", "aiva-qwen", True, 0.3),
            ("research", "gemini-pro", True, 15.0),
            ("research", "gemini-pro", True, 12.0),
        ]

        for i, (task_type, agent, success, duration) in enumerate(sample_data):
            loop.record_outcome(
                task_id=f"demo-{i:03d}",
                task_type=task_type,
                task_title=f"Demo {task_type} task",
                complexity="moderate",
                agent_used=agent,
                success=success,
                duration=duration,
                cost=duration * 0.001
            )

        print("\nSummary after demo:")
        print(json.dumps(loop.get_summary(), indent=2))

        print("\nRecommendations for 'architecture':")
        print(json.dumps(loop.get_recommendations("architecture"), indent=2))
        return

    # Default: show status
    print("Genesis Learning Loop")
    print("=" * 40)
    summary = loop.get_summary()
    print(f"Status: {summary.get('learning_status', 'unknown')}")
    print(f"Total outcomes: {summary.get('total_outcomes', 0)}")
    if summary.get('recent_success_rate'):
        print(f"Recent success rate: {summary['recent_success_rate']:.0%}")


if __name__ == "__main__":
    main()
