"""
AIVA Reflection Loop - PM-031

5-minute consolidation loop for learning and memory updates.
Consolidates recent learnings and updates embeddings.
"""

import os
import json
import logging
import asyncio
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, asdict, field
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


@dataclass
class ReflectionEntry:
    """A single reflection entry."""
    reflection_id: str
    timestamp: str
    learnings: List[str]
    tasks_reviewed: int
    entities_updated: int
    embeddings_generated: int
    duration_ms: float


class ReflectionLoop:
    """
    5-minute reflection loop for AIVA.

    Consolidates learnings, updates PostgreSQL, and generates embeddings.

    Usage:
        loop = ReflectionLoop(memory_bridge)
        await loop.start()  # Runs every 5 minutes
    """

    def __init__(
        self,
        memory_bridge=None,
        interval_seconds: int = 300,  # 5 minutes
        log_dir: str = "logs"
    ):
        """
        Initialize the reflection loop.

        Args:
            memory_bridge: MemoryBridge for storage operations
            interval_seconds: How often to run (default 5 min)
            log_dir: Directory for reflection logs
        """
        self.memory_bridge = memory_bridge
        self.interval_seconds = interval_seconds
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)

        self._running = False
        self._task: Optional[asyncio.Task] = None
        self.reflection_history: List[ReflectionEntry] = []
        self.pending_learnings: List[Dict] = []

        logger.info(f"ReflectionLoop initialized: interval={interval_seconds}s")

    async def start(self) -> None:
        """Start the reflection loop."""
        if self._running:
            logger.warning("Reflection loop already running")
            return

        self._running = True
        self._task = asyncio.create_task(self._loop())
        logger.info("Reflection loop started")

    async def stop(self) -> None:
        """Stop the reflection loop."""
        self._running = False
        if self._task:
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass
        logger.info("Reflection loop stopped")

    async def _loop(self) -> None:
        """Main loop execution."""
        while self._running:
            try:
                await self.reflect()
                await asyncio.sleep(self.interval_seconds)
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"Reflection loop error: {e}")
                await asyncio.sleep(60)  # Wait before retry

    async def reflect(self) -> ReflectionEntry:
        """
        Execute a reflection cycle.

        Returns:
            ReflectionEntry with results
        """
        import time
        start_time = time.time()
        reflection_id = f"ref_{int(time.time() * 1000)}"

        logger.info(f"Starting reflection cycle {reflection_id}")

        learnings = []
        tasks_reviewed = 0
        entities_updated = 0
        embeddings_generated = 0

        try:
            # 1. Gather recent learnings
            learnings = await self._gather_learnings()
            tasks_reviewed = len(self.pending_learnings)

            # 2. Consolidate into entities
            entities_updated = await self._consolidate_entities(learnings)

            # 3. Update Qdrant embeddings
            embeddings_generated = await self._update_embeddings(learnings)

            # 4. Write to PostgreSQL
            await self._persist_reflection(reflection_id, learnings)

            # Clear pending learnings
            self.pending_learnings.clear()

        except Exception as e:
            logger.error(f"Reflection failed: {e}")

        duration = (time.time() - start_time) * 1000

        entry = ReflectionEntry(
            reflection_id=reflection_id,
            timestamp=datetime.utcnow().isoformat(),
            learnings=learnings,
            tasks_reviewed=tasks_reviewed,
            entities_updated=entities_updated,
            embeddings_generated=embeddings_generated,
            duration_ms=duration
        )

        self.reflection_history.append(entry)
        self._log_reflection(entry)

        logger.info(
            f"Reflection {reflection_id} complete: "
            f"{len(learnings)} learnings, {entities_updated} entities, "
            f"{embeddings_generated} embeddings in {duration:.0f}ms"
        )

        return entry

    async def _gather_learnings(self) -> List[str]:
        """Gather learnings from pending items."""
        learnings = []
        for item in self.pending_learnings:
            if "learning" in item:
                learnings.append(item["learning"])
            elif "error" in item:
                learnings.append(f"Error encountered: {item['error']}")
        return learnings

    async def _consolidate_entities(self, learnings: List[str]) -> int:
        """Consolidate learnings into entities."""
        if not self.memory_bridge or not learnings:
            return 0

        try:
            for i, learning in enumerate(learnings):
                self.memory_bridge.store_memory(
                    content={"learning": learning, "type": "reflection"},
                    memory_type="entity",
                    metadata={
                        "entity_type": "learning",
                        "name": f"Learning from reflection",
                        "id": f"learn_{hash(learning) % 100000}"
                    }
                )
            return len(learnings)
        except Exception as e:
            logger.error(f"Entity consolidation failed: {e}")
            return 0

    async def _update_embeddings(self, learnings: List[str]) -> int:
        """Update Qdrant with new embeddings."""
        if not self.memory_bridge or not learnings:
            return 0

        # In production, this would generate embeddings via an embedding model
        # and store them in Qdrant
        return len(learnings)

    async def _persist_reflection(self, reflection_id: str, learnings: List[str]) -> None:
        """Persist reflection to PostgreSQL."""
        if not self.memory_bridge:
            return

        try:
            self.memory_bridge.store_memory(
                content={
                    "reflection_id": reflection_id,
                    "learnings": learnings,
                    "timestamp": datetime.utcnow().isoformat()
                },
                memory_type="entity",
                metadata={
                    "entity_type": "reflection",
                    "name": f"Reflection {reflection_id}"
                }
            )
        except Exception as e:
            logger.error(f"Failed to persist reflection: {e}")

    def _log_reflection(self, entry: ReflectionEntry) -> None:
        """Log reflection to file."""
        log_file = self.log_dir / "reflection_loop.jsonl"
        try:
            with open(log_file, "a") as f:
                f.write(json.dumps(asdict(entry)) + "\n")
        except Exception as e:
            logger.error(f"Failed to log reflection: {e}")

    def add_learning(self, learning: str, context: Optional[Dict] = None) -> None:
        """
        Add a learning for the next reflection cycle.

        Args:
            learning: The learning to add
            context: Additional context
        """
        self.pending_learnings.append({
            "learning": learning,
            "context": context or {},
            "added_at": datetime.utcnow().isoformat()
        })

    def add_error(self, error: str, task_id: Optional[str] = None) -> None:
        """
        Add an error for reflection.

        Args:
            error: The error message
            task_id: Related task ID
        """
        self.pending_learnings.append({
            "error": error,
            "task_id": task_id,
            "added_at": datetime.utcnow().isoformat()
        })

    def get_recent_reflections(self, limit: int = 10) -> List[Dict]:
        """Get recent reflection entries."""
        return [asdict(r) for r in self.reflection_history[-limit:]]

    def get_stats(self) -> Dict:
        """Get reflection loop statistics."""
        if not self.reflection_history:
            return {
                "total_reflections": 0,
                "total_learnings": 0,
                "avg_duration_ms": 0
            }

        return {
            "total_reflections": len(self.reflection_history),
            "total_learnings": sum(len(r.learnings) for r in self.reflection_history),
            "total_entities_updated": sum(r.entities_updated for r in self.reflection_history),
            "avg_duration_ms": sum(r.duration_ms for r in self.reflection_history) / len(self.reflection_history),
            "pending_learnings": len(self.pending_learnings),
            "is_running": self._running
        }


# Singleton instance
_reflection_loop: Optional[ReflectionLoop] = None


def get_reflection_loop(memory_bridge=None) -> ReflectionLoop:
    """Get or create singleton ReflectionLoop."""
    global _reflection_loop
    if _reflection_loop is None:
        _reflection_loop = ReflectionLoop(memory_bridge)
    return _reflection_loop


if __name__ == "__main__":
    # Example usage
    import asyncio

    async def main():
        loop = ReflectionLoop()

        # Add some learnings
        loop.add_learning("Claude Code is effective for code generation")
        loop.add_learning("Triple gate validation prevents errors")
        loop.add_error("Redis connection timeout", "task_123")

        # Run one reflection cycle
        entry = await loop.reflect()

        print(f"\nReflection Complete:")
        print(f"  ID: {entry.reflection_id}")
        print(f"  Learnings: {len(entry.learnings)}")
        print(f"  Duration: {entry.duration_ms:.0f}ms")
        print(f"\nStats: {loop.get_stats()}")

    asyncio.run(main())
