"""
PM-006: Failure Learning Accumulator
Capture failure lessons for transfer between fresh sessions in Genesis.

Acceptance Criteria:
- [x] GIVEN failure WHEN captured THEN lesson extracted
- [x] AND lesson has: attempt_number, error, hypothesized_fix
- [x] AND stored discretely (NO context accumulation)

Dependencies: None
"""

import os
import json
import logging
from datetime import datetime
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, asdict, field
from pathlib import Path
import hashlib

try:
    import redis
except ImportError:
    redis = None

logger = logging.getLogger(__name__)


@dataclass
class FailureLesson:
    """A discrete lesson learned from a failure."""
    lesson_id: str
    task_id: str
    attempt_number: int
    tier: int

    # Error details
    error_type: str
    error_message: str
    error_context: Optional[str] = None

    # Hypothesized fix
    hypothesized_fix: str = ""
    fix_approach: Optional[str] = None

    # Metadata
    model: str = ""
    timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
    duration_ms: Optional[int] = None
    tokens_used: Optional[int] = None

    # Learning categorization
    category: str = "general"  # syntax, logic, api, timeout, etc.
    severity: str = "medium"  # low, medium, high, critical

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)

    def to_json(self) -> str:
        return json.dumps(self.to_dict())

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "FailureLesson":
        return cls(**data)

    @classmethod
    def from_json(cls, json_str: str) -> "FailureLesson":
        return cls.from_dict(json.loads(json_str))

    def to_prompt_context(self) -> str:
        """
        Convert lesson to a discrete prompt context string.
        This is passed to fresh sessions, NOT appended to history.
        """
        context = f"""PREVIOUS ATTEMPT #{self.attempt_number} FAILED:
Error Type: {self.error_type}
Error: {self.error_message}
Hypothesized Fix: {self.hypothesized_fix}"""

        if self.fix_approach:
            context += f"\nSuggested Approach: {self.fix_approach}"

        return context


class LearningAccumulator:
    """
    Accumulate and manage failure lessons.

    Key Principle: Store lessons DISCRETELY for transfer to fresh sessions.
    Do NOT accumulate context - each session starts fresh with only lessons.

    Features:
    - Extract structured lessons from failures
    - Categorize errors for pattern detection
    - Generate hypothesized fixes
    - Format lessons for prompt injection
    """

    REDIS_KEY_PREFIX = "genesis:lessons:"
    MAX_LESSONS_PER_TASK = 40  # Max across all tiers (20+10+10)

    def __init__(self, redis_url: Optional[str] = None):
        """
        Initialize LearningAccumulator.

        Args:
            redis_url: Redis connection URL. Defaults to env REDIS_URL.
        """
        self.redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6379")
        self._redis: Optional[Any] = None
        self._lesson_counter = 0

        # In-memory storage for tasks without Redis
        self._memory_store: Dict[str, List[FailureLesson]] = {}

    @property
    def redis_client(self) -> Optional[Any]:
        """Lazy Redis client initialization."""
        if self._redis is None and redis is not None:
            try:
                self._redis = redis.from_url(self.redis_url)
                self._redis.ping()
            except Exception as e:
                logger.warning(f"Redis connection failed: {e}")
                self._redis = None
        return self._redis

    def _generate_lesson_id(self, task_id: str, attempt: int) -> str:
        """Generate unique lesson ID."""
        timestamp = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
        hash_input = f"{task_id}:{attempt}:{timestamp}"
        short_hash = hashlib.md5(hash_input.encode()).hexdigest()[:8]
        return f"lesson_{task_id[:20]}_{attempt}_{short_hash}"

    def categorize_error(self, error_type: str, error_message: str) -> str:
        """
        Categorize an error for pattern detection.

        Args:
            error_type: Exception type name
            error_message: Error message

        Returns:
            Category string
        """
        error_lower = (error_type + error_message).lower()

        if any(kw in error_lower for kw in ["syntax", "indent", "parse"]):
            return "syntax"
        elif any(kw in error_lower for kw in ["timeout", "timed out"]):
            return "timeout"
        elif any(kw in error_lower for kw in ["api", "rate limit", "quota"]):
            return "api"
        elif any(kw in error_lower for kw in ["import", "module", "not found"]):
            return "import"
        elif any(kw in error_lower for kw in ["type", "attribute"]):
            return "type"
        elif any(kw in error_lower for kw in ["assertion", "test", "expect"]):
            return "test"
        elif any(kw in error_lower for kw in ["memory", "overflow", "context"]):
            return "resource"
        elif any(kw in error_lower for kw in ["permission", "access", "denied"]):
            return "permission"
        return "general"

    def hypothesize_fix(self,
                       error_type: str,
                       error_message: str,
                       category: str) -> str:
        """
        Generate a hypothesized fix based on error.

        Args:
            error_type: Exception type
            error_message: Error message
            category: Error category

        Returns:
            Hypothesized fix string
        """
        fixes = {
            "syntax": "Check indentation, brackets, and quotes. Verify Python syntax.",
            "timeout": "Reduce task complexity. Consider breaking into smaller steps.",
            "api": "Check API credentials and rate limits. Add retry with backoff.",
            "import": "Verify module is installed. Check import path and spelling.",
            "type": "Check variable types. Ensure correct attribute access.",
            "test": "Review test assertions. Check expected vs actual values.",
            "resource": "Reduce context size. Summarize or truncate inputs.",
            "permission": "Check file/API permissions. Verify access credentials.",
            "general": "Review the error message. Check related code sections."
        }

        base_fix = fixes.get(category, fixes["general"])

        # Add specific hints from error message
        if "undefined" in error_message.lower():
            base_fix += " Check variable definitions before use."
        if "null" in error_message.lower() or "none" in error_message.lower():
            base_fix += " Add null/None checks."

        return base_fix

    def capture_failure(self,
                       task_id: str,
                       attempt_number: int,
                       tier: int,
                       error_type: str,
                       error_message: str,
                       model: str = "",
                       error_context: Optional[str] = None,
                       fix_approach: Optional[str] = None,
                       duration_ms: Optional[int] = None,
                       tokens_used: Optional[int] = None) -> FailureLesson:
        """
        Capture a failure as a discrete lesson.

        Args:
            task_id: Task identifier
            attempt_number: Which attempt this was
            tier: Execution tier (1, 2, or 3)
            error_type: Type of error (exception name)
            error_message: Error message
            model: Model that generated the failure
            error_context: Additional context about the error
            fix_approach: Specific fix approach to try
            duration_ms: Duration of failed attempt
            tokens_used: Tokens used in failed attempt

        Returns:
            FailureLesson object
        """
        category = self.categorize_error(error_type, error_message)
        hypothesized_fix = self.hypothesize_fix(error_type, error_message, category)

        lesson = FailureLesson(
            lesson_id=self._generate_lesson_id(task_id, attempt_number),
            task_id=task_id,
            attempt_number=attempt_number,
            tier=tier,
            error_type=error_type,
            error_message=error_message[:500],  # Truncate long messages
            error_context=error_context[:200] if error_context else None,
            hypothesized_fix=hypothesized_fix,
            fix_approach=fix_approach,
            model=model,
            category=category,
            severity=self._assess_severity(error_type, category),
            duration_ms=duration_ms,
            tokens_used=tokens_used
        )

        # Store the lesson
        self._store_lesson(task_id, lesson)

        logger.info(f"Captured lesson {lesson.lesson_id}: {category} error at attempt {attempt_number}")

        return lesson

    def _assess_severity(self, error_type: str, category: str) -> str:
        """Assess error severity."""
        if category in ["syntax", "import"]:
            return "low"  # Easy to fix
        elif category in ["type", "test"]:
            return "medium"
        elif category in ["timeout", "resource"]:
            return "high"
        elif category in ["api", "permission"]:
            return "critical"
        return "medium"

    def _store_lesson(self, task_id: str, lesson: FailureLesson) -> None:
        """Store lesson in Redis or memory."""
        # Store in Redis
        if self.redis_client:
            try:
                key = f"{self.REDIS_KEY_PREFIX}{task_id}"
                self.redis_client.rpush(key, lesson.to_json())
                # Trim to max lessons
                self.redis_client.ltrim(key, -self.MAX_LESSONS_PER_TASK, -1)
            except Exception as e:
                logger.warning(f"Failed to store lesson in Redis: {e}")

        # Also store in memory
        if task_id not in self._memory_store:
            self._memory_store[task_id] = []
        self._memory_store[task_id].append(lesson)

        # Trim memory store
        if len(self._memory_store[task_id]) > self.MAX_LESSONS_PER_TASK:
            self._memory_store[task_id] = self._memory_store[task_id][-self.MAX_LESSONS_PER_TASK:]

    def get_lessons(self, task_id: str) -> List[FailureLesson]:
        """
        Get all lessons for a task.

        Args:
            task_id: Task identifier

        Returns:
            List of FailureLesson objects
        """
        lessons = []

        # Try Redis first
        if self.redis_client:
            try:
                key = f"{self.REDIS_KEY_PREFIX}{task_id}"
                raw_lessons = self.redis_client.lrange(key, 0, -1)
                for raw in raw_lessons:
                    data = raw.decode() if isinstance(raw, bytes) else raw
                    lessons.append(FailureLesson.from_json(data))
                return lessons
            except Exception as e:
                logger.warning(f"Failed to get lessons from Redis: {e}")

        # Fall back to memory
        return self._memory_store.get(task_id, [])

    def get_recent_lessons(self, task_id: str, count: int = 5) -> List[FailureLesson]:
        """Get most recent lessons for a task."""
        lessons = self.get_lessons(task_id)
        return lessons[-count:] if lessons else []

    def format_lessons_for_prompt(self,
                                 task_id: str,
                                 max_lessons: int = 5) -> str:
        """
        Format lessons as discrete prompt context.

        This is injected into fresh session prompts, NOT accumulated in history.

        Args:
            task_id: Task identifier
            max_lessons: Maximum number of lessons to include

        Returns:
            Formatted prompt context string
        """
        lessons = self.get_recent_lessons(task_id, max_lessons)

        if not lessons:
            return ""

        prompt_parts = [
            "## LEARNINGS FROM PREVIOUS ATTEMPTS",
            "The following attempts have been made. Learn from these failures:",
            ""
        ]

        for lesson in lessons:
            prompt_parts.append(lesson.to_prompt_context())
            prompt_parts.append("")

        prompt_parts.extend([
            "## INSTRUCTIONS",
            "- Do NOT repeat the same mistakes",
            "- Apply the hypothesized fixes",
            "- Try a DIFFERENT approach if similar approaches have failed",
            ""
        ])

        return "\n".join(prompt_parts)

    def clear_lessons(self, task_id: str) -> None:
        """Clear all lessons for a task."""
        if self.redis_client:
            try:
                key = f"{self.REDIS_KEY_PREFIX}{task_id}"
                self.redis_client.delete(key)
            except Exception as e:
                logger.warning(f"Failed to clear lessons from Redis: {e}")

        if task_id in self._memory_store:
            del self._memory_store[task_id]

    def get_statistics(self, task_id: str) -> Dict[str, Any]:
        """Get statistics about lessons for a task."""
        lessons = self.get_lessons(task_id)

        if not lessons:
            return {"total_lessons": 0}

        categories = {}
        for lesson in lessons:
            categories[lesson.category] = categories.get(lesson.category, 0) + 1

        return {
            "total_lessons": len(lessons),
            "by_category": categories,
            "by_tier": {
                1: sum(1 for l in lessons if l.tier == 1),
                2: sum(1 for l in lessons if l.tier == 2),
                3: sum(1 for l in lessons if l.tier == 3)
            },
            "most_common_error": max(categories, key=categories.get) if categories else None
        }


# Singleton instance
_accumulator: Optional[LearningAccumulator] = None


def get_learning_accumulator() -> LearningAccumulator:
    """Get or create global LearningAccumulator instance."""
    global _accumulator
    if _accumulator is None:
        _accumulator = LearningAccumulator()
    return _accumulator


if __name__ == "__main__":
    # Test the LearningAccumulator
    logging.basicConfig(level=logging.INFO)

    acc = LearningAccumulator()

    # Capture some test failures
    acc.capture_failure(
        task_id="test-task-001",
        attempt_number=1,
        tier=1,
        error_type="SyntaxError",
        error_message="invalid syntax on line 42",
        model="gemini-flash"
    )

    acc.capture_failure(
        task_id="test-task-001",
        attempt_number=2,
        tier=1,
        error_type="TypeError",
        error_message="'NoneType' object has no attribute 'get'",
        model="gemini-flash"
    )

    # Get formatted prompt
    print("Formatted Lessons for Prompt:")
    print(acc.format_lessons_for_prompt("test-task-001"))

    print("\nStatistics:")
    print(json.dumps(acc.get_statistics("test-task-001"), indent=2))
