"""
Genesis Memory System - Redis Working Memory Store
===================================================
Handles working memory, session state, and real-time streams.
Built on Elestio Redis with streams and pub/sub support.
"""

import redis
import json
import time
from datetime import datetime, timezone
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict


@dataclass
class WorkingMemory:
    """Represents an item in working memory."""
    memory_id: str
    content: str
    importance: float
    created_at: float
    accessed_at: float
    source: str
    agent_id: str


@dataclass
class SessionState:
    """Represents an agent's session state."""
    agent_id: str
    current_task: Optional[str]
    context_window: List[str]
    attention_focus: Optional[str]
    active_entities: List[str]
    conversation_turn: int
    last_activity: float


class RedisStore:
    """
    Redis-backed working memory and session management.

    Features:
    - Working memory with importance-based eviction
    - Session state tracking per agent
    - Real-time memory event streams
    - Agent coordination messaging
    - Embedding cache
    """

    # Key prefixes
    PREFIX_SESSION = "session"
    PREFIX_WORKING_MEMORY = "wm"
    PREFIX_STREAM = "stream"
    PREFIX_CACHE = "cache"
    PREFIX_QUEUE = "queue"

    # Limits
    MAX_WORKING_MEMORY = 100
    MAX_ATTENTION_STACK = 20
    SESSION_TTL = 3600  # 1 hour
    CACHE_TTL = 86400  # 24 hours

    def __init__(
        self,
        host: str = "redis-genesis-u50607.vm.elestio.app",
        port: int = 26379,
        username: str = "default",
        password: str = "e2ZyYYr4oWRdASI2CaLc-",
        decode_responses: bool = True,
        socket_timeout: int = 5
    ):
        self.client = redis.Redis(
            host=host,
            port=port,
            username=username,
            password=password,
            decode_responses=decode_responses,
            socket_timeout=socket_timeout,
            socket_connect_timeout=socket_timeout
        )

        # Verify connection
        self.client.ping()

    # =========================================================================
    # SESSION MANAGEMENT
    # =========================================================================

    def get_session(self, agent_id: str) -> Optional[SessionState]:
        """Get current session state for an agent."""
        key = f"{self.PREFIX_SESSION}:{agent_id}"
        data = self.client.hgetall(key)

        if not data:
            return None

        return SessionState(
            agent_id=agent_id,
            current_task=data.get("current_task"),
            context_window=json.loads(data.get("context_window", "[]")),
            attention_focus=data.get("attention_focus"),
            active_entities=json.loads(data.get("active_entities", "[]")),
            conversation_turn=int(data.get("conversation_turn", 0)),
            last_activity=float(data.get("last_activity", time.time()))
        )

    def update_session(
        self,
        agent_id: str,
        current_task: Optional[str] = None,
        context_window: Optional[List[str]] = None,
        attention_focus: Optional[str] = None,
        active_entities: Optional[List[str]] = None,
        increment_turn: bool = False
    ) -> SessionState:
        """Update agent session state."""
        key = f"{self.PREFIX_SESSION}:{agent_id}"

        # Get existing or create new
        existing = self.get_session(agent_id)

        updates = {
            "last_activity": time.time()
        }

        if current_task is not None:
            updates["current_task"] = current_task

        if context_window is not None:
            updates["context_window"] = json.dumps(context_window[-50:])  # Keep last 50

        if attention_focus is not None:
            updates["attention_focus"] = attention_focus

        if active_entities is not None:
            updates["active_entities"] = json.dumps(active_entities[:20])  # Keep top 20

        if increment_turn:
            turn = (existing.conversation_turn if existing else 0) + 1
            updates["conversation_turn"] = turn

        # Update hash
        self.client.hset(key, mapping=updates)
        self.client.expire(key, self.SESSION_TTL)

        return self.get_session(agent_id)

    def clear_session(self, agent_id: str) -> bool:
        """Clear an agent's session."""
        key = f"{self.PREFIX_SESSION}:{agent_id}"
        return bool(self.client.delete(key))

    # =========================================================================
    # WORKING MEMORY
    # =========================================================================

    def add_to_working_memory(
        self,
        agent_id: str,
        memory_id: str,
        content: str,
        importance: float = 0.5,
        source: str = "observation"
    ) -> bool:
        """
        Add an item to working memory.

        Uses sorted set with importance as score.
        Evicts lowest importance items when over limit.
        """
        key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:memories"
        detail_key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:detail:{memory_id}"

        now = time.time()

        # Store details
        self.client.hset(detail_key, mapping={
            "content": content,
            "importance": importance,
            "created_at": now,
            "accessed_at": now,
            "source": source,
            "agent_id": agent_id
        })
        self.client.expire(detail_key, self.SESSION_TTL)

        # Add to sorted set (score = importance * recency_factor)
        recency_factor = 1.0  # Could decay over time
        score = importance * recency_factor
        self.client.zadd(key, {memory_id: score})

        # Trim to max size (remove lowest scores)
        self.client.zremrangebyrank(key, 0, -(self.MAX_WORKING_MEMORY + 1))
        self.client.expire(key, self.SESSION_TTL)

        # Emit event
        self._emit_memory_event("add", memory_id, agent_id, {
            "tier": "working",
            "importance": importance
        })

        return True

    def get_working_memory(
        self,
        agent_id: str,
        limit: int = 20,
        min_importance: float = 0.0
    ) -> List[WorkingMemory]:
        """Get items from working memory, sorted by importance."""
        key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:memories"

        # Get memory IDs with scores
        items = self.client.zrevrangebyscore(
            key,
            "+inf",
            min_importance,
            start=0,
            num=limit,
            withscores=True
        )

        results = []
        for memory_id, score in items:
            detail_key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:detail:{memory_id}"
            details = self.client.hgetall(detail_key)

            if details:
                # Update access time
                self.client.hset(detail_key, "accessed_at", time.time())

                results.append(WorkingMemory(
                    memory_id=memory_id,
                    content=details.get("content", ""),
                    importance=float(details.get("importance", 0.5)),
                    created_at=float(details.get("created_at", 0)),
                    accessed_at=float(details.get("accessed_at", 0)),
                    source=details.get("source", "unknown"),
                    agent_id=details.get("agent_id", agent_id)
                ))

        return results

    def remove_from_working_memory(self, agent_id: str, memory_id: str) -> bool:
        """Remove an item from working memory."""
        key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:memories"
        detail_key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:detail:{memory_id}"

        removed = self.client.zrem(key, memory_id)
        self.client.delete(detail_key)

        if removed:
            self._emit_memory_event("remove", memory_id, agent_id, {"tier": "working"})

        return bool(removed)

    def search_working_memory(
        self,
        agent_id: str,
        query: str,
        limit: int = 10
    ) -> List[WorkingMemory]:
        """Simple keyword search in working memory."""
        all_memories = self.get_working_memory(agent_id, limit=self.MAX_WORKING_MEMORY)
        query_lower = query.lower()

        matching = [
            m for m in all_memories
            if query_lower in m.content.lower()
        ]

        return sorted(matching, key=lambda m: m.importance, reverse=True)[:limit]

    def get_promotable_memories(
        self,
        agent_id: str,
        access_threshold: int = 3,
        age_threshold_seconds: int = 300
    ) -> List[str]:
        """
        Get memory IDs that should be promoted to episodic tier.

        Criteria:
        - Accessed multiple times (frequently recalled)
        - Old enough to have proven relevance
        """
        key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:memories"
        all_ids = self.client.zrange(key, 0, -1)

        promotable = []
        now = time.time()

        for memory_id in all_ids:
            detail_key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:detail:{memory_id}"
            details = self.client.hgetall(detail_key)

            if not details:
                continue

            created = float(details.get("created_at", now))
            accessed = float(details.get("accessed_at", now))

            # Check age
            age = now - created
            if age < age_threshold_seconds:
                continue

            # Check access frequency (approximated by accessed_at being different from created_at)
            access_count = 1 if accessed > created else 0
            if access_count >= 0:  # Simplified - promote if old enough
                promotable.append(memory_id)

        return promotable

    # =========================================================================
    # ATTENTION STACK
    # =========================================================================

    def push_attention(self, agent_id: str, topic: str) -> int:
        """Push a topic onto the attention stack."""
        key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:attention"

        self.client.lpush(key, topic)
        self.client.ltrim(key, 0, self.MAX_ATTENTION_STACK - 1)
        self.client.expire(key, self.SESSION_TTL)

        return self.client.llen(key)

    def pop_attention(self, agent_id: str) -> Optional[str]:
        """Pop the most recent topic from attention stack."""
        key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:attention"
        return self.client.lpop(key)

    def get_attention_stack(self, agent_id: str) -> List[str]:
        """Get the full attention stack."""
        key = f"{self.PREFIX_WORKING_MEMORY}:{agent_id}:attention"
        return self.client.lrange(key, 0, -1)

    # =========================================================================
    # MEMORY STREAMS
    # =========================================================================

    def _emit_memory_event(
        self,
        event_type: str,
        memory_id: str,
        agent_id: str,
        metadata: Optional[Dict] = None
    ) -> str:
        """Emit a memory event to the stream."""
        key = f"{self.PREFIX_STREAM}:memory:events"

        fields = {
            "event_type": event_type,
            "memory_id": memory_id,
            "agent_id": agent_id,
            "timestamp": time.time()
        }

        if metadata:
            fields["metadata"] = json.dumps(metadata)

        return self.client.xadd(key, fields, maxlen=10000)

    def read_memory_events(
        self,
        last_id: str = "0",
        count: int = 100,
        block: int = 0
    ) -> List[Tuple[str, Dict]]:
        """Read memory events from stream."""
        key = f"{self.PREFIX_STREAM}:memory:events"

        if block > 0:
            # Blocking read
            result = self.client.xread({key: last_id}, count=count, block=block)
            if result:
                return [(eid, data) for eid, data in result[0][1]]
            return []
        else:
            # Non-blocking read
            return self.client.xrange(key, min=f"({last_id}", count=count)

    def subscribe_to_events(self, callback, last_id: str = "$"):
        """Subscribe to memory events (blocking generator)."""
        key = f"{self.PREFIX_STREAM}:memory:events"

        while True:
            result = self.client.xread({key: last_id}, count=10, block=1000)
            if result:
                for stream_key, messages in result:
                    for event_id, data in messages:
                        callback(event_id, data)
                        last_id = event_id

    # =========================================================================
    # AGENT COORDINATION
    # =========================================================================

    def send_agent_message(
        self,
        from_agent: str,
        to_agent: str,
        message_type: str,
        payload: Dict
    ) -> str:
        """Send a message between agents."""
        key = f"{self.PREFIX_STREAM}:agent:coordination"

        return self.client.xadd(key, {
            "from_agent": from_agent,
            "to_agent": to_agent,
            "message_type": message_type,
            "payload": json.dumps(payload),
            "timestamp": time.time()
        }, maxlen=5000)

    def read_agent_messages(
        self,
        agent_id: str,
        last_id: str = "0",
        count: int = 50
    ) -> List[Tuple[str, Dict]]:
        """Read messages for an agent."""
        key = f"{self.PREFIX_STREAM}:agent:coordination"

        messages = self.client.xrange(key, min=f"({last_id}", count=count)

        # Filter for this agent or broadcast
        relevant = []
        for event_id, data in messages:
            to = data.get("to_agent", "")
            if to == agent_id or to == "broadcast" or to == "*":
                relevant.append((event_id, data))

        return relevant

    # =========================================================================
    # CACHING
    # =========================================================================

    def cache_embedding(self, content_hash: str, embedding: List[float]) -> bool:
        """Cache an embedding vector."""
        key = f"{self.PREFIX_CACHE}:embedding:{content_hash}"
        self.client.set(key, json.dumps(embedding), ex=self.CACHE_TTL)
        return True

    def get_cached_embedding(self, content_hash: str) -> Optional[List[float]]:
        """Get cached embedding if exists."""
        key = f"{self.PREFIX_CACHE}:embedding:{content_hash}"
        data = self.client.get(key)

        if data:
            return json.loads(data)
        return None

    def cache_episode(self, episode_id: str, episode_data: Dict) -> bool:
        """Cache an episode for quick access."""
        key = f"{self.PREFIX_CACHE}:episode:{episode_id}"
        self.client.set(key, json.dumps(episode_data), ex=300)  # 5 min cache
        return True

    def get_cached_episode(self, episode_id: str) -> Optional[Dict]:
        """Get cached episode."""
        key = f"{self.PREFIX_CACHE}:episode:{episode_id}"
        data = self.client.get(key)

        if data:
            return json.loads(data)
        return None

    # =========================================================================
    # CONSOLIDATION QUEUE
    # =========================================================================

    def enqueue_for_consolidation(self, episode_id: str) -> int:
        """Add episode to consolidation queue."""
        key = f"{self.PREFIX_QUEUE}:consolidation"
        return self.client.lpush(key, episode_id)

    def dequeue_for_consolidation(self, count: int = 10) -> List[str]:
        """Get episodes pending consolidation."""
        key = f"{self.PREFIX_QUEUE}:consolidation"
        items = []

        for _ in range(count):
            item = self.client.rpop(key)
            if item:
                items.append(item)
            else:
                break

        return items

    # =========================================================================
    # STATISTICS
    # =========================================================================

    def get_stats(self) -> Dict[str, Any]:
        """Get Redis store statistics."""
        info = self.client.info()

        # Count keys by prefix
        session_count = len(self.client.keys(f"{self.PREFIX_SESSION}:*"))
        wm_count = len(self.client.keys(f"{self.PREFIX_WORKING_MEMORY}:*:memories"))
        cache_count = len(self.client.keys(f"{self.PREFIX_CACHE}:*"))

        # Stream lengths
        memory_stream_len = self.client.xlen(f"{self.PREFIX_STREAM}:memory:events")
        agent_stream_len = self.client.xlen(f"{self.PREFIX_STREAM}:agent:coordination")

        return {
            "redis_version": info.get("redis_version"),
            "connected_clients": info.get("connected_clients"),
            "used_memory_human": info.get("used_memory_human"),
            "sessions": session_count,
            "working_memory_agents": wm_count,
            "cached_items": cache_count,
            "memory_events": memory_stream_len,
            "agent_messages": agent_stream_len
        }

    def health_check(self) -> bool:
        """Check Redis connection health."""
        try:
            return self.client.ping()
        except Exception:
            return False

    def close(self):
        """Close Redis connection."""
        self.client.close()


# Singleton instance
_store = None

def get_redis_store() -> RedisStore:
    """Get or create the Redis store singleton."""
    global _store
    if _store is None:
        _store = RedisStore()
    return _store


if __name__ == "__main__":
    # Test the store
    store = get_redis_store()

    print("Redis Working Memory Store Test")
    print("=" * 50)

    # Test session
    session = store.update_session(
        "test_agent",
        current_task="Testing Redis store",
        attention_focus="memory_system"
    )
    print(f"Session created: {session.agent_id}")

    # Test working memory
    store.add_to_working_memory(
        "test_agent",
        "mem_001",
        "This is a test memory in working memory",
        importance=0.8
    )

    memories = store.get_working_memory("test_agent", limit=5)
    print(f"Working memories: {len(memories)}")
    for m in memories:
        print(f"  - {m.memory_id}: {m.content[:40]}...")

    # Test attention
    store.push_attention("test_agent", "current_topic")
    attention = store.get_attention_stack("test_agent")
    print(f"Attention stack: {attention}")

    # Get stats
    stats = store.get_stats()
    print(f"\nStats: {stats}")

    store.close()
