"""
AIVA Queen - Advanced Working Memory System
============================================

A cognitive architecture implementation for AIVA's working memory system,
based on Baddeley's multi-component model with modern AI enhancements.

Components:
-----------
1. AttentionMechanism - Focus allocation and relevance filtering
2. WorkingMemoryBuffer - Limited capacity buffer (7+/-2 items per Miller's Law)
3. CentralExecutive - Coordination and control of memory operations
4. PhonologicalLoop - Verbal/text processing and rehearsal
5. VisuospatialSketchpad - Visual and spatial information processing
6. EpisodicBuffer - Cross-modal integration and binding

Author: AIVA Queen Cognitive System
Version: 1.0.0
"""

import time
import math
import json
import hashlib
import threading
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import (
    Dict, List, Any, Optional, Tuple, Callable,
    TypeVar, Generic, Set, Union
)
from enum import Enum, auto
from collections import deque, OrderedDict
from datetime import datetime, timedelta
import heapq
import uuid
from contextlib import contextmanager

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("AIVA.WorkingMemory")


# =============================================================================
# ENUMS AND CONSTANTS
# =============================================================================

class ModalityType(Enum):
    """Types of information modalities supported by the memory system."""
    VERBAL = auto()      # Text, speech, language
    VISUAL = auto()      # Images, diagrams, spatial layouts
    AUDITORY = auto()    # Sound patterns, non-verbal audio
    SEMANTIC = auto()    # Abstract concepts and meanings
    PROCEDURAL = auto()  # Action sequences and procedures
    EPISODIC = auto()    # Event-based memories with context


class AttentionState(Enum):
    """States of the attention mechanism."""
    FOCUSED = auto()     # Single-target focused attention
    DIVIDED = auto()     # Multi-target divided attention
    SELECTIVE = auto()   # Filtering irrelevant information
    SUSTAINED = auto()   # Maintaining attention over time
    SHIFTING = auto()    # Transitioning between targets


class ProcessingPriority(Enum):
    """Priority levels for memory processing."""
    CRITICAL = 100
    HIGH = 75
    NORMAL = 50
    LOW = 25
    BACKGROUND = 10


# Constants based on cognitive science research
MILLER_CAPACITY_MIN = 5   # 7-2 items
MILLER_CAPACITY_MAX = 9   # 7+2 items
DEFAULT_DECAY_RATE = 0.1  # Memory decay rate per second
ATTENTION_REFRESH_RATE = 0.05  # Attention refresh interval
REHEARSAL_BOOST = 0.3     # Strength boost from rehearsal
MAX_CHUNK_SIZE = 4        # Optimal chunking size


# =============================================================================
# DATA CLASSES
# =============================================================================

@dataclass
class MemoryItem:
    """Represents a single item in working memory."""
    id: str = field(default_factory=lambda: str(uuid.uuid4()))
    content: Any = None
    modality: ModalityType = ModalityType.SEMANTIC
    activation: float = 1.0
    priority: ProcessingPriority = ProcessingPriority.NORMAL
    created_at: float = field(default_factory=time.time)
    last_accessed: float = field(default_factory=time.time)
    access_count: int = 0
    associations: List[str] = field(default_factory=list)
    metadata: Dict[str, Any] = field(default_factory=dict)
    chunk_id: Optional[str] = None
    is_rehearsed: bool = False

    def refresh(self) -> None:
        """Refresh the memory item's activation."""
        self.last_accessed = time.time()
        self.access_count += 1
        self.activation = min(1.0, self.activation + 0.1)

    def decay(self, rate: float = DEFAULT_DECAY_RATE) -> None:
        """Apply decay to the memory item's activation."""
        elapsed = time.time() - self.last_accessed
        self.activation *= math.exp(-rate * elapsed)

    def get_effective_priority(self) -> float:
        """Calculate effective priority considering activation and base priority."""
        return self.priority.value * self.activation

    def to_dict(self) -> Dict[str, Any]:
        """Serialize the memory item to a dictionary."""
        return {
            "id": self.id,
            "content": self.content,
            "modality": self.modality.name,
            "activation": self.activation,
            "priority": self.priority.name,
            "created_at": self.created_at,
            "last_accessed": self.last_accessed,
            "access_count": self.access_count,
            "associations": self.associations,
            "metadata": self.metadata,
            "chunk_id": self.chunk_id,
            "is_rehearsed": self.is_rehearsed
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'MemoryItem':
        """Deserialize a memory item from a dictionary."""
        return cls(
            id=data["id"],
            content=data["content"],
            modality=ModalityType[data["modality"]],
            activation=data["activation"],
            priority=ProcessingPriority[data["priority"]],
            created_at=data["created_at"],
            last_accessed=data["last_accessed"],
            access_count=data["access_count"],
            associations=data["associations"],
            metadata=data["metadata"],
            chunk_id=data.get("chunk_id"),
            is_rehearsed=data.get("is_rehearsed", False)
        )


@dataclass
class AttentionFocus:
    """Represents the current focus of attention."""
    target_ids: List[str] = field(default_factory=list)
    state: AttentionState = AttentionState.FOCUSED
    intensity: float = 1.0
    duration: float = 0.0
    start_time: float = field(default_factory=time.time)
    filter_criteria: Optional[Dict[str, Any]] = None

    def get_duration(self) -> float:
        """Get the duration of the current focus."""
        return time.time() - self.start_time

    def decay_intensity(self, rate: float = 0.05) -> None:
        """Apply decay to attention intensity over time."""
        self.duration = self.get_duration()
        self.intensity = max(0.1, self.intensity - (rate * self.duration / 60))


@dataclass
class Chunk:
    """Represents a chunked group of memory items."""
    id: str = field(default_factory=lambda: str(uuid.uuid4()))
    item_ids: List[str] = field(default_factory=list)
    label: str = ""
    coherence: float = 1.0
    created_at: float = field(default_factory=time.time)

    def get_size(self) -> int:
        """Get the number of items in the chunk."""
        return len(self.item_ids)


@dataclass
class ProcessingResult:
    """Result of a memory processing operation."""
    success: bool
    operation: str
    item_ids: List[str] = field(default_factory=list)
    message: str = ""
    metadata: Dict[str, Any] = field(default_factory=dict)
    timestamp: float = field(default_factory=time.time)


# =============================================================================
# ATTENTION MECHANISM
# =============================================================================

class AttentionMechanism:
    """
    Manages attention allocation and focus within the working memory system.

    Implements selective, divided, and sustained attention patterns based
    on cognitive science research. Controls which memory items receive
    processing resources.
    """

    def __init__(
        self,
        max_focus_targets: int = 4,
        decay_rate: float = 0.05,
        refresh_interval: float = ATTENTION_REFRESH_RATE
    ):
        self.max_focus_targets = max_focus_targets
        self.decay_rate = decay_rate
        self.refresh_interval = refresh_interval

        self._current_focus: AttentionFocus = AttentionFocus()
        self._attention_history: deque = deque(maxlen=100)
        self._salience_weights: Dict[str, float] = {}
        self._inhibited_items: Set[str] = set()
        self._lock = threading.RLock()

        # Attention metrics
        self._metrics = {
            "focus_shifts": 0,
            "sustained_time": 0.0,
            "items_inhibited": 0,
            "attention_lapses": 0
        }

        logger.info("AttentionMechanism initialized")

    def focus_on(
        self,
        target_ids: List[str],
        state: AttentionState = AttentionState.FOCUSED,
        intensity: float = 1.0,
        filter_criteria: Optional[Dict[str, Any]] = None
    ) -> ProcessingResult:
        """
        Direct attention to specific memory items.

        Args:
            target_ids: IDs of items to focus on
            state: Attention state to enter
            intensity: Initial attention intensity (0.0-1.0)
            filter_criteria: Optional filtering criteria

        Returns:
            ProcessingResult indicating success/failure
        """
        with self._lock:
            # Store previous focus in history
            if self._current_focus.target_ids:
                self._attention_history.append({
                    "focus": self._current_focus,
                    "ended_at": time.time()
                })

            # Limit targets based on state
            max_targets = self.max_focus_targets if state == AttentionState.DIVIDED else 1
            limited_targets = target_ids[:max_targets]

            # Create new focus
            self._current_focus = AttentionFocus(
                target_ids=limited_targets,
                state=state,
                intensity=min(1.0, max(0.0, intensity)),
                filter_criteria=filter_criteria
            )

            self._metrics["focus_shifts"] += 1

            logger.debug(f"Attention focused on {len(limited_targets)} items in {state.name} mode")

            return ProcessingResult(
                success=True,
                operation="focus",
                item_ids=limited_targets,
                message=f"Focused on {len(limited_targets)} items"
            )

    def calculate_salience(
        self,
        item: MemoryItem,
        context: Optional[Dict[str, Any]] = None
    ) -> float:
        """
        Calculate the salience (attention-worthiness) of a memory item.

        Salience is determined by:
        - Recency of access
        - Priority level
        - Relevance to current context
        - Activation strength
        """
        base_salience = item.activation * (item.priority.value / 100)

        # Recency bonus (items accessed recently are more salient)
        recency = time.time() - item.last_accessed
        recency_factor = math.exp(-0.01 * recency)

        # Context relevance (if context provided)
        context_factor = 1.0
        if context and item.metadata:
            matches = sum(
                1 for k, v in context.items()
                if k in item.metadata and item.metadata[k] == v
            )
            context_factor = 1.0 + (matches * 0.2)

        # Check for inhibition
        if item.id in self._inhibited_items:
            return 0.0

        salience = base_salience * recency_factor * context_factor
        self._salience_weights[item.id] = salience

        return salience

    def filter_by_attention(
        self,
        items: List[MemoryItem],
        threshold: float = 0.3
    ) -> List[MemoryItem]:
        """
        Filter items based on current attention state and salience.

        Args:
            items: List of memory items to filter
            threshold: Minimum salience threshold

        Returns:
            Filtered list of attention-worthy items
        """
        with self._lock:
            # Apply filter criteria if set
            filtered = items
            if self._current_focus.filter_criteria:
                criteria = self._current_focus.filter_criteria
                filtered = [
                    item for item in items
                    if all(
                        getattr(item, k, None) == v or item.metadata.get(k) == v
                        for k, v in criteria.items()
                    )
                ]

            # Calculate salience and sort
            salience_items = [
                (item, self.calculate_salience(item))
                for item in filtered
            ]

            # Apply intensity-adjusted threshold
            adj_threshold = threshold * self._current_focus.intensity

            return [
                item for item, salience in salience_items
                if salience >= adj_threshold
            ]

    def inhibit(self, item_ids: List[str], duration: float = 5.0) -> None:
        """
        Temporarily inhibit items from receiving attention.

        Args:
            item_ids: IDs of items to inhibit
            duration: Duration of inhibition in seconds
        """
        with self._lock:
            self._inhibited_items.update(item_ids)
            self._metrics["items_inhibited"] += len(item_ids)

            # Schedule removal of inhibition
            def remove_inhibition():
                time.sleep(duration)
                with self._lock:
                    self._inhibited_items.difference_update(item_ids)

            thread = threading.Thread(target=remove_inhibition, daemon=True)
            thread.start()

    def release_inhibition(self, item_ids: Optional[List[str]] = None) -> None:
        """Release inhibition on specific items or all items."""
        with self._lock:
            if item_ids:
                self._inhibited_items.difference_update(item_ids)
            else:
                self._inhibited_items.clear()

    def sustain(self) -> None:
        """Attempt to sustain current attention focus."""
        with self._lock:
            self._current_focus.decay_intensity(self.decay_rate)

            if self._current_focus.intensity < 0.2:
                self._metrics["attention_lapses"] += 1
                logger.debug("Attention lapse detected")

    def shift(self, new_target_ids: List[str]) -> ProcessingResult:
        """Shift attention to new targets."""
        with self._lock:
            self._current_focus.state = AttentionState.SHIFTING
            return self.focus_on(new_target_ids)

    def get_current_focus(self) -> AttentionFocus:
        """Get the current attention focus."""
        with self._lock:
            return self._current_focus

    def get_focused_ids(self) -> List[str]:
        """Get IDs of currently focused items."""
        with self._lock:
            return self._current_focus.target_ids.copy()

    def get_metrics(self) -> Dict[str, Any]:
        """Get attention metrics."""
        with self._lock:
            self._metrics["sustained_time"] = self._current_focus.get_duration()
            return self._metrics.copy()

    def reset(self) -> None:
        """Reset the attention mechanism."""
        with self._lock:
            self._current_focus = AttentionFocus()
            self._inhibited_items.clear()
            self._salience_weights.clear()
            logger.info("AttentionMechanism reset")


# =============================================================================
# WORKING MEMORY BUFFER
# =============================================================================

class WorkingMemoryBuffer:
    """
    Limited capacity buffer implementing Miller's Law (7 +/- 2 items).

    Manages the core storage of active memory items with automatic
    displacement of least-activated items when capacity is exceeded.
    Supports chunking to effectively increase capacity.
    """

    def __init__(
        self,
        capacity: int = 7,
        decay_rate: float = DEFAULT_DECAY_RATE,
        enable_chunking: bool = True
    ):
        # Enforce Miller's Law bounds
        self.capacity = max(MILLER_CAPACITY_MIN, min(MILLER_CAPACITY_MAX, capacity))
        self.decay_rate = decay_rate
        self.enable_chunking = enable_chunking

        self._items: Dict[str, MemoryItem] = OrderedDict()
        self._chunks: Dict[str, Chunk] = {}
        self._lock = threading.RLock()

        # Performance metrics
        self._metrics = {
            "items_added": 0,
            "items_displaced": 0,
            "chunks_created": 0,
            "retrieval_count": 0,
            "avg_activation": 0.0
        }

        logger.info(f"WorkingMemoryBuffer initialized with capacity={self.capacity}")

    def add(
        self,
        content: Any,
        modality: ModalityType = ModalityType.SEMANTIC,
        priority: ProcessingPriority = ProcessingPriority.NORMAL,
        metadata: Optional[Dict[str, Any]] = None
    ) -> ProcessingResult:
        """
        Add an item to the working memory buffer.

        If capacity is exceeded, the least-activated item is displaced.

        Args:
            content: The content to store
            modality: Type of information
            priority: Processing priority
            metadata: Optional metadata

        Returns:
            ProcessingResult with the new item's ID
        """
        with self._lock:
            # Create new memory item
            item = MemoryItem(
                content=content,
                modality=modality,
                priority=priority,
                metadata=metadata or {}
            )

            displaced_ids = []

            # Check capacity (counting chunks as single items)
            effective_count = self._get_effective_count()
            while effective_count >= self.capacity:
                displaced = self._displace_weakest()
                if displaced:
                    displaced_ids.append(displaced.id)
                    effective_count = self._get_effective_count()
                else:
                    break

            # Add the new item
            self._items[item.id] = item
            self._metrics["items_added"] += 1

            logger.debug(f"Added item {item.id[:8]} to buffer (effective count: {self._get_effective_count()})")

            return ProcessingResult(
                success=True,
                operation="add",
                item_ids=[item.id],
                message=f"Added item, displaced {len(displaced_ids)} items",
                metadata={"displaced_ids": displaced_ids}
            )

    def retrieve(self, item_id: str) -> Optional[MemoryItem]:
        """
        Retrieve an item from the buffer by ID, refreshing its activation.

        Args:
            item_id: ID of the item to retrieve

        Returns:
            The memory item if found, None otherwise
        """
        with self._lock:
            item = self._items.get(item_id)
            if item:
                item.refresh()
                self._metrics["retrieval_count"] += 1

                # Move to end of OrderedDict (most recently used)
                self._items.move_to_end(item_id)

            return item

    def retrieve_all(self, apply_decay: bool = True) -> List[MemoryItem]:
        """
        Retrieve all items from the buffer.

        Args:
            apply_decay: Whether to apply decay to activation levels

        Returns:
            List of all memory items
        """
        with self._lock:
            if apply_decay:
                for item in self._items.values():
                    item.decay(self.decay_rate)

            return list(self._items.values())

    def update(self, item_id: str, updates: Dict[str, Any]) -> ProcessingResult:
        """
        Update properties of an existing memory item.

        Args:
            item_id: ID of the item to update
            updates: Dictionary of properties to update

        Returns:
            ProcessingResult indicating success/failure
        """
        with self._lock:
            item = self._items.get(item_id)
            if not item:
                return ProcessingResult(
                    success=False,
                    operation="update",
                    message=f"Item {item_id} not found"
                )

            for key, value in updates.items():
                if hasattr(item, key):
                    setattr(item, key, value)
                else:
                    item.metadata[key] = value

            item.refresh()

            return ProcessingResult(
                success=True,
                operation="update",
                item_ids=[item_id],
                message="Item updated successfully"
            )

    def remove(self, item_id: str) -> ProcessingResult:
        """
        Remove an item from the buffer.

        Args:
            item_id: ID of the item to remove

        Returns:
            ProcessingResult indicating success/failure
        """
        with self._lock:
            if item_id in self._items:
                del self._items[item_id]

                # Remove from any chunks
                for chunk in self._chunks.values():
                    if item_id in chunk.item_ids:
                        chunk.item_ids.remove(item_id)

                return ProcessingResult(
                    success=True,
                    operation="remove",
                    item_ids=[item_id],
                    message="Item removed"
                )

            return ProcessingResult(
                success=False,
                operation="remove",
                message=f"Item {item_id} not found"
            )

    def chunk(
        self,
        item_ids: List[str],
        label: str = ""
    ) -> ProcessingResult:
        """
        Group items into a chunk (effectively increasing capacity).

        Args:
            item_ids: IDs of items to chunk together
            label: Optional label for the chunk

        Returns:
            ProcessingResult with the chunk ID
        """
        if not self.enable_chunking:
            return ProcessingResult(
                success=False,
                operation="chunk",
                message="Chunking is disabled"
            )

        with self._lock:
            # Verify all items exist
            valid_ids = [id for id in item_ids if id in self._items]
            if len(valid_ids) < 2:
                return ProcessingResult(
                    success=False,
                    operation="chunk",
                    message="Need at least 2 valid items to chunk"
                )

            # Limit chunk size
            valid_ids = valid_ids[:MAX_CHUNK_SIZE]

            # Create chunk
            chunk = Chunk(
                item_ids=valid_ids,
                label=label or f"chunk_{len(self._chunks)}"
            )

            # Mark items as chunked
            for item_id in valid_ids:
                self._items[item_id].chunk_id = chunk.id

            self._chunks[chunk.id] = chunk
            self._metrics["chunks_created"] += 1

            logger.debug(f"Created chunk {chunk.id[:8]} with {len(valid_ids)} items")

            return ProcessingResult(
                success=True,
                operation="chunk",
                item_ids=[chunk.id],
                message=f"Created chunk with {len(valid_ids)} items"
            )

    def unchunk(self, chunk_id: str) -> ProcessingResult:
        """
        Dissolve a chunk back into individual items.

        Args:
            chunk_id: ID of the chunk to dissolve

        Returns:
            ProcessingResult with the freed item IDs
        """
        with self._lock:
            chunk = self._chunks.get(chunk_id)
            if not chunk:
                return ProcessingResult(
                    success=False,
                    operation="unchunk",
                    message=f"Chunk {chunk_id} not found"
                )

            # Unmark items
            for item_id in chunk.item_ids:
                if item_id in self._items:
                    self._items[item_id].chunk_id = None

            freed_ids = chunk.item_ids.copy()
            del self._chunks[chunk_id]

            return ProcessingResult(
                success=True,
                operation="unchunk",
                item_ids=freed_ids,
                message=f"Unchunked {len(freed_ids)} items"
            )

    def _get_effective_count(self) -> int:
        """
        Calculate effective item count (chunks count as 1 item).
        """
        chunked_ids = set()
        for chunk in self._chunks.values():
            chunked_ids.update(chunk.item_ids)

        unchunked_count = sum(
            1 for id in self._items.keys()
            if id not in chunked_ids
        )

        return unchunked_count + len(self._chunks)

    def _displace_weakest(self) -> Optional[MemoryItem]:
        """
        Remove and return the weakest (lowest activation) item.
        """
        if not self._items:
            return None

        # Find weakest non-chunked item
        weakest_id = None
        weakest_priority = float('inf')

        for item_id, item in self._items.items():
            if item.chunk_id is None:  # Don't displace chunked items
                effective_priority = item.get_effective_priority()
                if effective_priority < weakest_priority:
                    weakest_priority = effective_priority
                    weakest_id = item_id

        if weakest_id:
            displaced = self._items.pop(weakest_id)
            self._metrics["items_displaced"] += 1
            logger.debug(f"Displaced item {weakest_id[:8]} (priority: {weakest_priority:.2f})")
            return displaced

        return None

    def get_capacity_status(self) -> Dict[str, Any]:
        """Get current capacity status."""
        with self._lock:
            effective = self._get_effective_count()
            return {
                "capacity": self.capacity,
                "total_items": len(self._items),
                "effective_count": effective,
                "chunks": len(self._chunks),
                "available_slots": self.capacity - effective,
                "utilization": effective / self.capacity
            }

    def get_metrics(self) -> Dict[str, Any]:
        """Get buffer performance metrics."""
        with self._lock:
            if self._items:
                self._metrics["avg_activation"] = sum(
                    i.activation for i in self._items.values()
                ) / len(self._items)
            return self._metrics.copy()

    def clear(self) -> None:
        """Clear all items from the buffer."""
        with self._lock:
            self._items.clear()
            self._chunks.clear()
            logger.info("WorkingMemoryBuffer cleared")


# =============================================================================
# PHONOLOGICAL LOOP
# =============================================================================

class PhonologicalLoop:
    """
    Processes and maintains verbal/textual information through rehearsal.

    Implements the articulatory rehearsal process that maintains
    verbal information in working memory through repetition.
    Based on the phonological loop component of Baddeley's model.
    """

    def __init__(
        self,
        rehearsal_capacity: int = 6,
        decay_time: float = 2.0,  # seconds before decay without rehearsal
        word_length_effect: bool = True
    ):
        self.rehearsal_capacity = rehearsal_capacity
        self.decay_time = decay_time
        self.word_length_effect = word_length_effect

        self._store: deque = deque(maxlen=rehearsal_capacity)
        self._rehearsal_queue: List[Tuple[float, str]] = []  # (priority, item_id)
        self._subvocal_rate: float = 1.5  # items per second
        self._last_rehearsal: Dict[str, float] = {}
        self._lock = threading.RLock()

        # Metrics
        self._metrics = {
            "items_processed": 0,
            "rehearsals_performed": 0,
            "decay_events": 0,
            "avg_word_length": 0.0
        }

        logger.info("PhonologicalLoop initialized")

    def encode(
        self,
        text: str,
        item_id: str,
        priority: float = 0.5
    ) -> ProcessingResult:
        """
        Encode verbal/textual information into the phonological store.

        Args:
            text: Text content to encode
            item_id: Associated memory item ID
            priority: Rehearsal priority (0.0-1.0)

        Returns:
            ProcessingResult indicating success/failure
        """
        with self._lock:
            # Calculate word length effect on capacity
            if self.word_length_effect:
                avg_word_length = self._calculate_word_length(text)
                effective_capacity = max(
                    3,
                    int(self.rehearsal_capacity * (4 / max(avg_word_length, 1)))
                )
            else:
                effective_capacity = self.rehearsal_capacity

            # Create phonological representation
            phonological_rep = {
                "id": item_id,
                "text": text,
                "encoded_at": time.time(),
                "phonemes": self._extract_phonemes(text),
                "syllable_count": self._count_syllables(text)
            }

            # Add to store (may displace oldest)
            if len(self._store) >= effective_capacity:
                oldest = self._store.popleft()
                if oldest["id"] in self._last_rehearsal:
                    del self._last_rehearsal[oldest["id"]]

            self._store.append(phonological_rep)

            # Add to rehearsal queue
            heapq.heappush(self._rehearsal_queue, (-priority, item_id))
            self._last_rehearsal[item_id] = time.time()

            self._metrics["items_processed"] += 1

            return ProcessingResult(
                success=True,
                operation="encode_phonological",
                item_ids=[item_id],
                message=f"Encoded text ({len(text)} chars, {phonological_rep['syllable_count']} syllables)"
            )

    def rehearse(self, item_id: Optional[str] = None) -> ProcessingResult:
        """
        Perform subvocal rehearsal to maintain items in the loop.

        Args:
            item_id: Specific item to rehearse, or None for automatic selection

        Returns:
            ProcessingResult indicating what was rehearsed
        """
        with self._lock:
            if item_id:
                # Rehearse specific item
                for rep in self._store:
                    if rep["id"] == item_id:
                        self._last_rehearsal[item_id] = time.time()
                        self._metrics["rehearsals_performed"] += 1
                        return ProcessingResult(
                            success=True,
                            operation="rehearse",
                            item_ids=[item_id],
                            message="Rehearsed specific item"
                        )
                return ProcessingResult(
                    success=False,
                    operation="rehearse",
                    message=f"Item {item_id} not in phonological store"
                )

            # Automatic rehearsal - rehearse items at risk of decay
            rehearsed = []
            current_time = time.time()

            for rep in self._store:
                item_id = rep["id"]
                last = self._last_rehearsal.get(item_id, rep["encoded_at"])

                if current_time - last >= self.decay_time * 0.8:
                    self._last_rehearsal[item_id] = current_time
                    rehearsed.append(item_id)
                    self._metrics["rehearsals_performed"] += 1

            return ProcessingResult(
                success=True,
                operation="rehearse",
                item_ids=rehearsed,
                message=f"Rehearsed {len(rehearsed)} items"
            )

    def check_decay(self) -> List[str]:
        """
        Check for and process decayed items.

        Returns:
            List of item IDs that have decayed
        """
        with self._lock:
            current_time = time.time()
            decayed = []

            for rep in list(self._store):
                item_id = rep["id"]
                last = self._last_rehearsal.get(item_id, rep["encoded_at"])

                if current_time - last >= self.decay_time:
                    decayed.append(item_id)
                    self._store.remove(rep)
                    if item_id in self._last_rehearsal:
                        del self._last_rehearsal[item_id]
                    self._metrics["decay_events"] += 1

            return decayed

    def retrieve(self, item_id: str) -> Optional[Dict[str, Any]]:
        """Retrieve a phonological representation by item ID."""
        with self._lock:
            for rep in self._store:
                if rep["id"] == item_id:
                    self._last_rehearsal[item_id] = time.time()
                    return rep
            return None

    def _extract_phonemes(self, text: str) -> List[str]:
        """
        Extract simplified phoneme representation from text.
        (Simplified - full implementation would use phonetic library)
        """
        # Simple consonant-vowel pattern extraction
        vowels = set("aeiouAEIOU")
        phonemes = []

        for char in text.lower():
            if char.isalpha():
                if char in vowels:
                    phonemes.append("V")
                else:
                    phonemes.append("C")

        return phonemes

    def _count_syllables(self, text: str) -> int:
        """Estimate syllable count for word length effect calculation."""
        vowels = "aeiouAEIOU"
        count = 0
        prev_vowel = False

        for char in text:
            is_vowel = char in vowels
            if is_vowel and not prev_vowel:
                count += 1
            prev_vowel = is_vowel

        return max(1, count)

    def _calculate_word_length(self, text: str) -> float:
        """Calculate average word length for capacity adjustment."""
        words = text.split()
        if not words:
            return 0.0

        avg = sum(len(w) for w in words) / len(words)
        self._metrics["avg_word_length"] = avg
        return avg

    def get_store_contents(self) -> List[Dict[str, Any]]:
        """Get all items in the phonological store."""
        with self._lock:
            return list(self._store)

    def get_metrics(self) -> Dict[str, Any]:
        """Get phonological loop metrics."""
        with self._lock:
            return self._metrics.copy()

    def clear(self) -> None:
        """Clear the phonological store."""
        with self._lock:
            self._store.clear()
            self._rehearsal_queue.clear()
            self._last_rehearsal.clear()
            logger.info("PhonologicalLoop cleared")


# =============================================================================
# VISUOSPATIAL SKETCHPAD
# =============================================================================

class VisuospatialSketchpad:
    """
    Processes and maintains visual and spatial information.

    Handles imagery, spatial relationships, and visual pattern
    recognition. Implements the visual cache and inner scribe
    components from Baddeley's model.
    """

    def __init__(
        self,
        visual_capacity: int = 4,
        spatial_resolution: int = 100,
        decay_rate: float = 0.15
    ):
        self.visual_capacity = visual_capacity
        self.spatial_resolution = spatial_resolution
        self.decay_rate = decay_rate

        self._visual_cache: Dict[str, Dict[str, Any]] = {}  # Static visual info
        self._spatial_index: Dict[str, Tuple[float, float, float]] = {}  # 3D coordinates
        self._movement_traces: deque = deque(maxlen=50)  # Movement sequences
        self._lock = threading.RLock()

        # Metrics
        self._metrics = {
            "images_processed": 0,
            "spatial_operations": 0,
            "movement_traces": 0,
            "mental_rotations": 0
        }

        logger.info("VisuospatialSketchpad initialized")

    def encode_visual(
        self,
        item_id: str,
        visual_data: Dict[str, Any],
        spatial_location: Optional[Tuple[float, float, float]] = None
    ) -> ProcessingResult:
        """
        Encode visual information into the visual cache.

        Args:
            item_id: Associated memory item ID
            visual_data: Visual representation (features, colors, shapes)
            spatial_location: Optional 3D spatial coordinates

        Returns:
            ProcessingResult indicating success/failure
        """
        with self._lock:
            # Check capacity
            if len(self._visual_cache) >= self.visual_capacity:
                # Remove oldest item
                oldest_id = next(iter(self._visual_cache))
                del self._visual_cache[oldest_id]
                if oldest_id in self._spatial_index:
                    del self._spatial_index[oldest_id]

            # Create visual representation
            visual_rep = {
                "id": item_id,
                "data": visual_data,
                "encoded_at": time.time(),
                "last_accessed": time.time(),
                "activation": 1.0,
                "features": self._extract_visual_features(visual_data)
            }

            self._visual_cache[item_id] = visual_rep

            if spatial_location:
                self._spatial_index[item_id] = spatial_location

            self._metrics["images_processed"] += 1

            return ProcessingResult(
                success=True,
                operation="encode_visual",
                item_ids=[item_id],
                message="Visual information encoded"
            )

    def encode_spatial(
        self,
        item_id: str,
        location: Tuple[float, float, float]
    ) -> ProcessingResult:
        """
        Encode spatial location for an item.

        Args:
            item_id: Associated memory item ID
            location: 3D spatial coordinates (x, y, z)

        Returns:
            ProcessingResult indicating success/failure
        """
        with self._lock:
            # Normalize to resolution
            normalized = tuple(
                round(coord * self.spatial_resolution) / self.spatial_resolution
                for coord in location
            )

            self._spatial_index[item_id] = normalized
            self._metrics["spatial_operations"] += 1

            return ProcessingResult(
                success=True,
                operation="encode_spatial",
                item_ids=[item_id],
                message=f"Spatial location encoded: {normalized}"
            )

    def trace_movement(
        self,
        item_id: str,
        path: List[Tuple[float, float, float]]
    ) -> ProcessingResult:
        """
        Record a movement trace (inner scribe function).

        Args:
            item_id: Associated memory item ID
            path: Sequence of 3D coordinates representing movement

        Returns:
            ProcessingResult indicating success/failure
        """
        with self._lock:
            trace = {
                "id": item_id,
                "path": path,
                "recorded_at": time.time(),
                "duration": len(path) * 0.1  # Estimated duration
            }

            self._movement_traces.append(trace)
            self._metrics["movement_traces"] += 1

            return ProcessingResult(
                success=True,
                operation="trace_movement",
                item_ids=[item_id],
                message=f"Movement trace recorded ({len(path)} points)"
            )

    def mental_rotation(
        self,
        item_id: str,
        rotation: Tuple[float, float, float]
    ) -> ProcessingResult:
        """
        Apply mental rotation to a visual representation.

        Args:
            item_id: ID of item to rotate
            rotation: Rotation angles (x, y, z) in degrees

        Returns:
            ProcessingResult indicating success/failure
        """
        with self._lock:
            if item_id not in self._visual_cache:
                return ProcessingResult(
                    success=False,
                    operation="mental_rotation",
                    message=f"Item {item_id} not in visual cache"
                )

            visual_rep = self._visual_cache[item_id]

            # Store rotation as transformation
            if "transformations" not in visual_rep:
                visual_rep["transformations"] = []

            visual_rep["transformations"].append({
                "type": "rotation",
                "values": rotation,
                "applied_at": time.time()
            })

            visual_rep["last_accessed"] = time.time()
            self._metrics["mental_rotations"] += 1

            return ProcessingResult(
                success=True,
                operation="mental_rotation",
                item_ids=[item_id],
                message=f"Applied rotation {rotation}"
            )

    def find_nearby(
        self,
        location: Tuple[float, float, float],
        radius: float = 0.2
    ) -> List[str]:
        """
        Find items within a spatial radius of a location.

        Args:
            location: Center point
            radius: Search radius

        Returns:
            List of item IDs within the radius
        """
        with self._lock:
            nearby = []

            for item_id, item_loc in self._spatial_index.items():
                distance = math.sqrt(sum(
                    (a - b) ** 2
                    for a, b in zip(location, item_loc)
                ))
                if distance <= radius:
                    nearby.append(item_id)

            return nearby

    def get_spatial_map(self) -> Dict[str, Tuple[float, float, float]]:
        """Get the current spatial index."""
        with self._lock:
            return self._spatial_index.copy()

    def _extract_visual_features(self, visual_data: Dict[str, Any]) -> Dict[str, Any]:
        """Extract key visual features from visual data."""
        features = {
            "has_color": "color" in visual_data or "colors" in visual_data,
            "has_shape": "shape" in visual_data or "shapes" in visual_data,
            "has_size": "size" in visual_data or "dimensions" in visual_data,
            "complexity": len(str(visual_data))
        }
        return features

    def apply_decay(self) -> List[str]:
        """Apply decay to visual representations."""
        with self._lock:
            current_time = time.time()
            decayed = []

            for item_id, visual_rep in list(self._visual_cache.items()):
                elapsed = current_time - visual_rep["last_accessed"]
                visual_rep["activation"] *= math.exp(-self.decay_rate * elapsed)

                if visual_rep["activation"] < 0.1:
                    del self._visual_cache[item_id]
                    if item_id in self._spatial_index:
                        del self._spatial_index[item_id]
                    decayed.append(item_id)

            return decayed

    def get_metrics(self) -> Dict[str, Any]:
        """Get visuospatial metrics."""
        with self._lock:
            metrics = self._metrics.copy()
            metrics["visual_cache_size"] = len(self._visual_cache)
            metrics["spatial_index_size"] = len(self._spatial_index)
            return metrics

    def clear(self) -> None:
        """Clear the visuospatial sketchpad."""
        with self._lock:
            self._visual_cache.clear()
            self._spatial_index.clear()
            self._movement_traces.clear()
            logger.info("VisuospatialSketchpad cleared")


# =============================================================================
# EPISODIC BUFFER
# =============================================================================

class EpisodicBuffer:
    """
    Integrates information across modalities into coherent episodes.

    Serves as a binding interface between working memory subsystems
    and long-term memory. Maintains temporary multi-modal representations.
    """

    def __init__(
        self,
        capacity: int = 4,
        binding_threshold: float = 0.5,
        temporal_window: float = 30.0  # seconds
    ):
        self.capacity = capacity
        self.binding_threshold = binding_threshold
        self.temporal_window = temporal_window

        self._episodes: OrderedDict[str, Dict[str, Any]] = OrderedDict()
        self._binding_strength: Dict[Tuple[str, str], float] = {}
        self._temporal_context: deque = deque(maxlen=100)
        self._lock = threading.RLock()

        # Metrics
        self._metrics = {
            "episodes_created": 0,
            "bindings_formed": 0,
            "integrations": 0,
            "retrievals": 0
        }

        logger.info("EpisodicBuffer initialized")

    def create_episode(
        self,
        components: Dict[str, str],  # modality -> item_id mapping
        context: Optional[Dict[str, Any]] = None,
        label: str = ""
    ) -> ProcessingResult:
        """
        Create a new episode binding multiple memory components.

        Args:
            components: Mapping of modality names to item IDs
            context: Optional contextual information
            label: Optional episode label

        Returns:
            ProcessingResult with the episode ID
        """
        with self._lock:
            episode_id = str(uuid.uuid4())

            # Check capacity
            while len(self._episodes) >= self.capacity:
                # Remove oldest episode
                oldest_id = next(iter(self._episodes))
                self._remove_episode(oldest_id)

            episode = {
                "id": episode_id,
                "components": components,
                "context": context or {},
                "label": label or f"episode_{len(self._episodes)}",
                "created_at": time.time(),
                "last_accessed": time.time(),
                "access_count": 0,
                "coherence": self._calculate_coherence(components)
            }

            self._episodes[episode_id] = episode

            # Create bindings between components
            component_ids = list(components.values())
            for i, id1 in enumerate(component_ids):
                for id2 in component_ids[i+1:]:
                    binding_key = tuple(sorted([id1, id2]))
                    self._binding_strength[binding_key] = self.binding_threshold
                    self._metrics["bindings_formed"] += 1

            # Add to temporal context
            self._temporal_context.append({
                "episode_id": episode_id,
                "timestamp": time.time()
            })

            self._metrics["episodes_created"] += 1

            logger.debug(f"Created episode {episode_id[:8]} with {len(components)} components")

            return ProcessingResult(
                success=True,
                operation="create_episode",
                item_ids=[episode_id],
                message=f"Episode created with {len(components)} bound components",
                metadata={"coherence": episode["coherence"]}
            )

    def integrate(
        self,
        item_id: str,
        modality: str,
        episode_id: Optional[str] = None
    ) -> ProcessingResult:
        """
        Integrate a new item into an existing or new episode.

        Args:
            item_id: ID of item to integrate
            modality: Modality of the item
            episode_id: Optional specific episode to integrate into

        Returns:
            ProcessingResult indicating the integration result
        """
        with self._lock:
            if episode_id:
                if episode_id not in self._episodes:
                    return ProcessingResult(
                        success=False,
                        operation="integrate",
                        message=f"Episode {episode_id} not found"
                    )
                episode = self._episodes[episode_id]
            else:
                # Find most recent compatible episode within temporal window
                current_time = time.time()
                compatible_episode = None

                for ep_id, episode in reversed(self._episodes.items()):
                    if current_time - episode["created_at"] <= self.temporal_window:
                        compatible_episode = episode
                        break

                if not compatible_episode:
                    # Create new episode
                    return self.create_episode({modality: item_id})

                episode = compatible_episode

            # Add component to episode
            episode["components"][modality] = item_id
            episode["last_accessed"] = time.time()
            episode["coherence"] = self._calculate_coherence(episode["components"])

            # Update bindings
            for existing_id in episode["components"].values():
                if existing_id != item_id:
                    binding_key = tuple(sorted([existing_id, item_id]))
                    current = self._binding_strength.get(binding_key, 0.0)
                    self._binding_strength[binding_key] = min(1.0, current + 0.1)

            self._metrics["integrations"] += 1

            return ProcessingResult(
                success=True,
                operation="integrate",
                item_ids=[episode["id"]],
                message=f"Integrated {modality} into episode"
            )

    def retrieve_episode(self, episode_id: str) -> Optional[Dict[str, Any]]:
        """Retrieve an episode by ID."""
        with self._lock:
            episode = self._episodes.get(episode_id)
            if episode:
                episode["last_accessed"] = time.time()
                episode["access_count"] += 1
                self._episodes.move_to_end(episode_id)
                self._metrics["retrievals"] += 1
            return episode

    def find_by_component(self, item_id: str) -> List[Dict[str, Any]]:
        """Find all episodes containing a specific component."""
        with self._lock:
            matching = []
            for episode in self._episodes.values():
                if item_id in episode["components"].values():
                    matching.append(episode)
            return matching

    def get_binding_strength(self, id1: str, id2: str) -> float:
        """Get the binding strength between two items."""
        with self._lock:
            binding_key = tuple(sorted([id1, id2]))
            return self._binding_strength.get(binding_key, 0.0)

    def strengthen_binding(self, id1: str, id2: str, amount: float = 0.1) -> None:
        """Strengthen the binding between two items."""
        with self._lock:
            binding_key = tuple(sorted([id1, id2]))
            current = self._binding_strength.get(binding_key, 0.0)
            self._binding_strength[binding_key] = min(1.0, current + amount)

    def _calculate_coherence(self, components: Dict[str, str]) -> float:
        """Calculate the coherence of an episode based on bindings."""
        if len(components) < 2:
            return 1.0

        component_ids = list(components.values())
        total_strength = 0.0
        pair_count = 0

        for i, id1 in enumerate(component_ids):
            for id2 in component_ids[i+1:]:
                binding_key = tuple(sorted([id1, id2]))
                total_strength += self._binding_strength.get(binding_key, 0.0)
                pair_count += 1

        return total_strength / pair_count if pair_count > 0 else 0.0

    def _remove_episode(self, episode_id: str) -> None:
        """Remove an episode and its bindings."""
        if episode_id in self._episodes:
            episode = self._episodes[episode_id]
            component_ids = list(episode["components"].values())

            # Remove associated bindings
            keys_to_remove = []
            for key in self._binding_strength:
                if any(comp_id in key for comp_id in component_ids):
                    keys_to_remove.append(key)

            for key in keys_to_remove:
                del self._binding_strength[key]

            del self._episodes[episode_id]

    def get_recent_episodes(self, limit: int = 5) -> List[Dict[str, Any]]:
        """Get the most recent episodes."""
        with self._lock:
            return list(self._episodes.values())[-limit:]

    def get_metrics(self) -> Dict[str, Any]:
        """Get episodic buffer metrics."""
        with self._lock:
            metrics = self._metrics.copy()
            metrics["active_episodes"] = len(self._episodes)
            metrics["active_bindings"] = len(self._binding_strength)
            return metrics

    def clear(self) -> None:
        """Clear the episodic buffer."""
        with self._lock:
            self._episodes.clear()
            self._binding_strength.clear()
            self._temporal_context.clear()
            logger.info("EpisodicBuffer cleared")


# =============================================================================
# CENTRAL EXECUTIVE
# =============================================================================

class CentralExecutive:
    """
    Coordinates and controls all working memory operations.

    Acts as the supervisory system that manages attention,
    coordinates between subsystems, and controls information flow.
    Implements executive functions like task switching, inhibition,
    and updating.
    """

    def __init__(
        self,
        buffer: WorkingMemoryBuffer,
        attention: AttentionMechanism,
        phonological: PhonologicalLoop,
        visuospatial: VisuospatialSketchpad,
        episodic: EpisodicBuffer
    ):
        self.buffer = buffer
        self.attention = attention
        self.phonological = phonological
        self.visuospatial = visuospatial
        self.episodic = episodic

        self._task_stack: deque = deque(maxlen=10)
        self._current_task: Optional[Dict[str, Any]] = None
        self._goal_stack: List[Dict[str, Any]] = []
        self._inhibition_set: Set[str] = set()
        self._processing_queue: List[Tuple[float, str, Callable]] = []
        self._lock = threading.RLock()

        # Executive metrics
        self._metrics = {
            "task_switches": 0,
            "operations_executed": 0,
            "inhibitions_applied": 0,
            "coordination_events": 0
        }

        # Background processing thread
        self._running = True
        self._process_thread = threading.Thread(
            target=self._background_processor,
            daemon=True
        )
        self._process_thread.start()

        logger.info("CentralExecutive initialized")

    def encode(
        self,
        content: Any,
        modality: ModalityType = ModalityType.SEMANTIC,
        priority: ProcessingPriority = ProcessingPriority.NORMAL,
        metadata: Optional[Dict[str, Any]] = None
    ) -> ProcessingResult:
        """
        Encode new information into working memory with full coordination.

        Routes information to appropriate subsystems based on modality.

        Args:
            content: Content to encode
            modality: Type of information
            priority: Processing priority
            metadata: Optional metadata

        Returns:
            ProcessingResult with encoding results
        """
        with self._lock:
            # Add to main buffer
            result = self.buffer.add(content, modality, priority, metadata)

            if not result.success:
                return result

            item_id = result.item_ids[0]

            # Route to appropriate subsystem
            if modality in (ModalityType.VERBAL, ModalityType.AUDITORY):
                if isinstance(content, str):
                    self.phonological.encode(content, item_id, priority.value / 100)

            elif modality == ModalityType.VISUAL:
                if isinstance(content, dict):
                    self.visuospatial.encode_visual(item_id, content)

            # Update attention
            self.attention.focus_on([item_id], intensity=priority.value / 100)

            self._metrics["operations_executed"] += 1

            return ProcessingResult(
                success=True,
                operation="encode",
                item_ids=[item_id],
                message=f"Encoded {modality.name} content",
                metadata={"modality": modality.name}
            )

    def retrieve(
        self,
        query: Optional[str] = None,
        modality: Optional[ModalityType] = None,
        top_k: int = 5
    ) -> List[MemoryItem]:
        """
        Retrieve items from working memory with attention filtering.

        Args:
            query: Optional text query for filtering
            modality: Optional modality filter
            top_k: Maximum items to return

        Returns:
            List of matching memory items
        """
        with self._lock:
            # Get all items from buffer
            items = self.buffer.retrieve_all(apply_decay=True)

            # Filter by modality
            if modality:
                items = [i for i in items if i.modality == modality]

            # Filter by query
            if query and isinstance(query, str):
                query_lower = query.lower()
                items = [
                    i for i in items
                    if isinstance(i.content, str) and query_lower in i.content.lower()
                ]

            # Apply attention filter
            filtered = self.attention.filter_by_attention(items)

            # Sort by effective priority and return top_k
            sorted_items = sorted(
                filtered,
                key=lambda x: x.get_effective_priority(),
                reverse=True
            )

            self._metrics["operations_executed"] += 1

            return sorted_items[:top_k]

    def update_item(
        self,
        item_id: str,
        updates: Dict[str, Any]
    ) -> ProcessingResult:
        """
        Update an existing memory item.

        Args:
            item_id: ID of item to update
            updates: Properties to update

        Returns:
            ProcessingResult indicating success/failure
        """
        with self._lock:
            result = self.buffer.update(item_id, updates)

            if result.success:
                # Re-focus attention on updated item
                self.attention.focus_on([item_id])
                self._metrics["operations_executed"] += 1

            return result

    def switch_task(
        self,
        new_task: Dict[str, Any]
    ) -> ProcessingResult:
        """
        Switch to a new task, saving current task state.

        Args:
            new_task: New task specification

        Returns:
            ProcessingResult indicating switch result
        """
        with self._lock:
            # Save current task
            if self._current_task:
                self._current_task["suspended_at"] = time.time()
                self._current_task["attention_state"] = self.attention.get_current_focus()
                self._task_stack.append(self._current_task)

            # Set new task
            self._current_task = {
                "task": new_task,
                "started_at": time.time(),
                "status": "active"
            }

            # Clear attention for new task
            self.attention.reset()

            self._metrics["task_switches"] += 1

            logger.debug(f"Switched to new task: {new_task.get('name', 'unnamed')}")

            return ProcessingResult(
                success=True,
                operation="switch_task",
                message="Task switched successfully"
            )

    def resume_previous_task(self) -> ProcessingResult:
        """Resume the most recently suspended task."""
        with self._lock:
            if not self._task_stack:
                return ProcessingResult(
                    success=False,
                    operation="resume_task",
                    message="No suspended tasks to resume"
                )

            # Save current task if any
            if self._current_task:
                self._current_task["suspended_at"] = time.time()

            # Restore previous task
            self._current_task = self._task_stack.pop()
            self._current_task["resumed_at"] = time.time()
            self._current_task["status"] = "resumed"

            # Restore attention state
            if "attention_state" in self._current_task:
                focus = self._current_task["attention_state"]
                self.attention.focus_on(
                    focus.target_ids,
                    focus.state,
                    focus.intensity
                )

            self._metrics["task_switches"] += 1

            return ProcessingResult(
                success=True,
                operation="resume_task",
                message="Previous task resumed"
            )

    def set_goal(
        self,
        goal: Dict[str, Any],
        priority: ProcessingPriority = ProcessingPriority.NORMAL
    ) -> ProcessingResult:
        """
        Set a goal that guides processing.

        Args:
            goal: Goal specification
            priority: Goal priority

        Returns:
            ProcessingResult indicating success
        """
        with self._lock:
            goal_entry = {
                "goal": goal,
                "priority": priority,
                "created_at": time.time(),
                "status": "active"
            }

            # Insert based on priority
            insert_idx = 0
            for i, existing in enumerate(self._goal_stack):
                if existing["priority"].value < priority.value:
                    insert_idx = i
                    break
                insert_idx = i + 1

            self._goal_stack.insert(insert_idx, goal_entry)

            return ProcessingResult(
                success=True,
                operation="set_goal",
                message=f"Goal set with priority {priority.name}"
            )

    def inhibit(
        self,
        item_ids: List[str],
        duration: float = 5.0
    ) -> ProcessingResult:
        """
        Inhibit specific items from processing.

        Args:
            item_ids: IDs of items to inhibit
            duration: Inhibition duration in seconds

        Returns:
            ProcessingResult indicating success
        """
        with self._lock:
            self.attention.inhibit(item_ids, duration)
            self._inhibition_set.update(item_ids)
            self._metrics["inhibitions_applied"] += len(item_ids)

            # Schedule removal
            def clear_inhibition():
                time.sleep(duration)
                with self._lock:
                    self._inhibition_set.difference_update(item_ids)

            threading.Thread(target=clear_inhibition, daemon=True).start()

            return ProcessingResult(
                success=True,
                operation="inhibit",
                item_ids=item_ids,
                message=f"Inhibited {len(item_ids)} items for {duration}s"
            )

    def coordinate_subsystems(self) -> ProcessingResult:
        """
        Coordinate processing across all subsystems.

        Performs maintenance operations:
        - Apply decay
        - Check rehearsal needs
        - Update bindings
        - Garbage collection
        """
        with self._lock:
            # Decay and cleanup in phonological loop
            decayed_phonological = self.phonological.check_decay()
            self.phonological.rehearse()

            # Decay in visuospatial
            decayed_visual = self.visuospatial.apply_decay()

            # Sustain attention
            self.attention.sustain()

            # Create episode from active items
            focused_ids = self.attention.get_focused_ids()
            if len(focused_ids) >= 2:
                items = [self.buffer.retrieve(id) for id in focused_ids]
                valid_items = [i for i in items if i is not None]

                if valid_items:
                    components = {
                        item.modality.name.lower(): item.id
                        for item in valid_items
                    }
                    self.episodic.integrate(
                        list(components.values())[0],
                        list(components.keys())[0]
                    )

            self._metrics["coordination_events"] += 1

            return ProcessingResult(
                success=True,
                operation="coordinate",
                message=f"Coordination complete. Decayed: {len(decayed_phonological)} phonological, {len(decayed_visual)} visual"
            )

    def _background_processor(self) -> None:
        """Background thread for continuous processing."""
        while self._running:
            try:
                # Process queue items
                with self._lock:
                    if self._processing_queue:
                        _, name, func = heapq.heappop(self._processing_queue)
                        try:
                            func()
                        except Exception as e:
                            logger.error(f"Background processing error in {name}: {e}")

                # Periodic coordination
                self.coordinate_subsystems()

                time.sleep(0.1)  # 100ms cycle

            except Exception as e:
                logger.error(f"Background processor error: {e}")
                time.sleep(1.0)

    def schedule_operation(
        self,
        name: str,
        operation: Callable,
        priority: float = 0.5
    ) -> None:
        """Schedule an operation for background processing."""
        with self._lock:
            heapq.heappush(
                self._processing_queue,
                (-priority, name, operation)
            )

    def get_state(self) -> Dict[str, Any]:
        """Get the current state of the central executive."""
        with self._lock:
            return {
                "current_task": self._current_task,
                "task_stack_depth": len(self._task_stack),
                "goal_count": len(self._goal_stack),
                "inhibited_items": len(self._inhibition_set),
                "pending_operations": len(self._processing_queue),
                "attention": {
                    "state": self.attention.get_current_focus().state.name,
                    "intensity": self.attention.get_current_focus().intensity,
                    "focused_count": len(self.attention.get_focused_ids())
                }
            }

    def get_metrics(self) -> Dict[str, Any]:
        """Get executive metrics plus subsystem metrics."""
        with self._lock:
            metrics = {
                "executive": self._metrics.copy(),
                "buffer": self.buffer.get_metrics(),
                "attention": self.attention.get_metrics(),
                "phonological": self.phonological.get_metrics(),
                "visuospatial": self.visuospatial.get_metrics(),
                "episodic": self.episodic.get_metrics()
            }
            return metrics

    def shutdown(self) -> None:
        """Shutdown the central executive."""
        self._running = False
        logger.info("CentralExecutive shutdown initiated")

    def clear_all(self) -> None:
        """Clear all subsystems."""
        with self._lock:
            self.buffer.clear()
            self.attention.reset()
            self.phonological.clear()
            self.visuospatial.clear()
            self.episodic.clear()
            self._task_stack.clear()
            self._current_task = None
            self._goal_stack.clear()
            self._inhibition_set.clear()
            self._processing_queue.clear()
            logger.info("All subsystems cleared")


# =============================================================================
# ADVANCED WORKING MEMORY SYSTEM (FACADE)
# =============================================================================

class AdvancedWorkingMemory:
    """
    Complete advanced working memory system for AIVA Queen.

    Provides a unified interface to the cognitive architecture
    including attention, storage, verbal processing, visual processing,
    and episodic binding.

    Usage:
        memory = AdvancedWorkingMemory()

        # Store information
        memory.store("The meeting is at 3pm", modality=ModalityType.VERBAL)

        # Retrieve with filtering
        items = memory.retrieve(query="meeting")

        # Create multi-modal episode
        memory.create_episode({
            "verbal": "Project presentation",
            "visual": {"slides": 10, "charts": 3}
        })
    """

    def __init__(
        self,
        buffer_capacity: int = 7,
        enable_chunking: bool = True,
        decay_rate: float = DEFAULT_DECAY_RATE
    ):
        """
        Initialize the advanced working memory system.

        Args:
            buffer_capacity: Base capacity (7+/-2 per Miller's Law)
            enable_chunking: Enable chunking for capacity expansion
            decay_rate: Memory decay rate
        """
        # Initialize subsystems
        self._buffer = WorkingMemoryBuffer(
            capacity=buffer_capacity,
            decay_rate=decay_rate,
            enable_chunking=enable_chunking
        )

        self._attention = AttentionMechanism()
        self._phonological = PhonologicalLoop()
        self._visuospatial = VisuospatialSketchpad()
        self._episodic = EpisodicBuffer()

        # Initialize central executive (coordinator)
        self._executive = CentralExecutive(
            buffer=self._buffer,
            attention=self._attention,
            phonological=self._phonological,
            visuospatial=self._visuospatial,
            episodic=self._episodic
        )

        self._lock = threading.RLock()

        logger.info("AdvancedWorkingMemory system initialized")

    # =========================================================================
    # PRIMARY INTERFACE
    # =========================================================================

    def store(
        self,
        content: Any,
        modality: ModalityType = ModalityType.SEMANTIC,
        priority: ProcessingPriority = ProcessingPriority.NORMAL,
        metadata: Optional[Dict[str, Any]] = None
    ) -> str:
        """
        Store information in working memory.

        Args:
            content: Content to store
            modality: Type of information
            priority: Processing priority
            metadata: Optional metadata

        Returns:
            ID of the stored item
        """
        result = self._executive.encode(content, modality, priority, metadata)
        return result.item_ids[0] if result.success else ""

    def retrieve(
        self,
        query: Optional[str] = None,
        modality: Optional[ModalityType] = None,
        top_k: int = 5
    ) -> List[MemoryItem]:
        """
        Retrieve items from working memory.

        Args:
            query: Optional search query
            modality: Optional modality filter
            top_k: Maximum items to return

        Returns:
            List of matching memory items
        """
        return self._executive.retrieve(query, modality, top_k)

    def update(self, item_id: str, updates: Dict[str, Any]) -> bool:
        """
        Update an existing memory item.

        Args:
            item_id: ID of item to update
            updates: Properties to update

        Returns:
            True if successful, False otherwise
        """
        result = self._executive.update_item(item_id, updates)
        return result.success

    def remove(self, item_id: str) -> bool:
        """
        Remove an item from working memory.

        Args:
            item_id: ID of item to remove

        Returns:
            True if successful, False otherwise
        """
        result = self._buffer.remove(item_id)
        return result.success

    def focus(self, item_ids: List[str]) -> None:
        """Direct attention to specific items."""
        self._attention.focus_on(item_ids)

    def chunk(self, item_ids: List[str], label: str = "") -> Optional[str]:
        """
        Chunk items together to save capacity.

        Args:
            item_ids: IDs of items to chunk
            label: Optional chunk label

        Returns:
            Chunk ID if successful, None otherwise
        """
        result = self._buffer.chunk(item_ids, label)
        return result.item_ids[0] if result.success else None

    def create_episode(
        self,
        components: Dict[str, Any],
        label: str = ""
    ) -> Optional[str]:
        """
        Create a multi-modal episode.

        Args:
            components: Mapping of modality names to content
            label: Optional episode label

        Returns:
            Episode ID if successful, None otherwise
        """
        # First, store each component
        stored_components: Dict[str, str] = {}

        for modality_name, content in components.items():
            try:
                modality = ModalityType[modality_name.upper()]
            except KeyError:
                modality = ModalityType.SEMANTIC

            item_id = self.store(content, modality=modality)
            if item_id:
                stored_components[modality_name] = item_id

        # Create episode
        result = self._episodic.create_episode(stored_components, label=label)
        return result.item_ids[0] if result.success else None

    # =========================================================================
    # TASK MANAGEMENT
    # =========================================================================

    def switch_task(self, task_name: str, task_data: Optional[Dict] = None) -> bool:
        """Switch to a new task."""
        result = self._executive.switch_task({
            "name": task_name,
            "data": task_data or {}
        })
        return result.success

    def resume_task(self) -> bool:
        """Resume the previous task."""
        result = self._executive.resume_previous_task()
        return result.success

    def set_goal(
        self,
        goal_description: str,
        priority: ProcessingPriority = ProcessingPriority.NORMAL
    ) -> bool:
        """Set a processing goal."""
        result = self._executive.set_goal(
            {"description": goal_description},
            priority
        )
        return result.success

    # =========================================================================
    # SPECIALIZED PROCESSING
    # =========================================================================

    def process_verbal(
        self,
        text: str,
        priority: float = 0.5
    ) -> str:
        """
        Process verbal/textual information through the phonological loop.

        Args:
            text: Text to process
            priority: Rehearsal priority

        Returns:
            Item ID
        """
        item_id = self.store(text, modality=ModalityType.VERBAL)
        self._phonological.encode(text, item_id, priority)
        return item_id

    def process_visual(
        self,
        visual_data: Dict[str, Any],
        location: Optional[Tuple[float, float, float]] = None
    ) -> str:
        """
        Process visual information through the visuospatial sketchpad.

        Args:
            visual_data: Visual representation
            location: Optional spatial location

        Returns:
            Item ID
        """
        item_id = self.store(visual_data, modality=ModalityType.VISUAL)
        self._visuospatial.encode_visual(item_id, visual_data, location)
        return item_id

    def rehearse(self, item_id: Optional[str] = None) -> bool:
        """
        Trigger rehearsal to maintain items in memory.

        Args:
            item_id: Specific item to rehearse, or None for automatic

        Returns:
            True if rehearsal occurred
        """
        result = self._phonological.rehearse(item_id)
        return result.success

    def find_nearby(
        self,
        location: Tuple[float, float, float],
        radius: float = 0.2
    ) -> List[str]:
        """Find items near a spatial location."""
        return self._visuospatial.find_nearby(location, radius)

    # =========================================================================
    # ATTENTION CONTROL
    # =========================================================================

    def inhibit(self, item_ids: List[str], duration: float = 5.0) -> None:
        """Temporarily inhibit items from processing."""
        self._executive.inhibit(item_ids, duration)

    def get_focused_items(self) -> List[MemoryItem]:
        """Get currently focused items."""
        focused_ids = self._attention.get_focused_ids()
        return [
            item for item in self.retrieve(top_k=100)
            if item.id in focused_ids
        ]

    # =========================================================================
    # STATUS AND METRICS
    # =========================================================================

    def get_capacity_status(self) -> Dict[str, Any]:
        """Get current capacity status."""
        return self._buffer.get_capacity_status()

    def get_state(self) -> Dict[str, Any]:
        """Get complete system state."""
        return {
            "executive": self._executive.get_state(),
            "capacity": self._buffer.get_capacity_status(),
            "attention": {
                "state": self._attention.get_current_focus().state.name,
                "focused_count": len(self._attention.get_focused_ids())
            },
            "phonological_store_size": len(self._phonological.get_store_contents()),
            "visuospatial_cache_size": len(self._visuospatial._visual_cache),
            "active_episodes": len(self._episodic._episodes)
        }

    def get_metrics(self) -> Dict[str, Any]:
        """Get comprehensive system metrics."""
        return self._executive.get_metrics()

    # =========================================================================
    # SERIALIZATION
    # =========================================================================

    def export_state(self) -> Dict[str, Any]:
        """Export the complete memory state for persistence."""
        with self._lock:
            return {
                "buffer_items": [
                    item.to_dict() for item in self._buffer.retrieve_all(apply_decay=False)
                ],
                "chunks": [
                    {
                        "id": chunk.id,
                        "item_ids": chunk.item_ids,
                        "label": chunk.label,
                        "coherence": chunk.coherence
                    }
                    for chunk in self._buffer._chunks.values()
                ],
                "episodes": [
                    {
                        "id": ep["id"],
                        "components": ep["components"],
                        "context": ep["context"],
                        "label": ep["label"],
                        "coherence": ep["coherence"]
                    }
                    for ep in self._episodic._episodes.values()
                ],
                "exported_at": time.time()
            }

    def import_state(self, state: Dict[str, Any]) -> bool:
        """Import a previously exported memory state."""
        with self._lock:
            try:
                # Clear current state
                self.clear()

                # Restore buffer items
                for item_dict in state.get("buffer_items", []):
                    item = MemoryItem.from_dict(item_dict)
                    self._buffer._items[item.id] = item

                # Restore chunks
                for chunk_data in state.get("chunks", []):
                    chunk = Chunk(
                        id=chunk_data["id"],
                        item_ids=chunk_data["item_ids"],
                        label=chunk_data["label"],
                        coherence=chunk_data["coherence"]
                    )
                    self._buffer._chunks[chunk.id] = chunk

                # Restore episodes
                for ep_data in state.get("episodes", []):
                    self._episodic._episodes[ep_data["id"]] = {
                        "id": ep_data["id"],
                        "components": ep_data["components"],
                        "context": ep_data["context"],
                        "label": ep_data["label"],
                        "coherence": ep_data["coherence"],
                        "created_at": time.time(),
                        "last_accessed": time.time(),
                        "access_count": 0
                    }

                logger.info("Memory state imported successfully")
                return True

            except Exception as e:
                logger.error(f"Failed to import memory state: {e}")
                return False

    def clear(self) -> None:
        """Clear all memory contents."""
        self._executive.clear_all()

    def shutdown(self) -> None:
        """Shutdown the memory system."""
        self._executive.shutdown()
        logger.info("AdvancedWorkingMemory shutdown complete")


# =============================================================================
# EXAMPLE USAGE AND TESTING
# =============================================================================

if __name__ == "__main__":
    print("=" * 70)
    print("AIVA Queen - Advanced Working Memory System Test")
    print("=" * 70)

    # Initialize the system
    memory = AdvancedWorkingMemory(
        buffer_capacity=7,
        enable_chunking=True
    )

    print("\n[1] Storing verbal information...")
    id1 = memory.process_verbal("The quarterly meeting is scheduled for 3pm tomorrow")
    id2 = memory.process_verbal("Project Alpha deadline is next Friday")
    id3 = memory.store(
        "Budget approval pending from finance department",
        modality=ModalityType.VERBAL,
        priority=ProcessingPriority.HIGH
    )
    print(f"    Stored items: {id1[:8]}, {id2[:8]}, {id3[:8]}")

    print("\n[2] Storing visual information...")
    visual_id = memory.process_visual(
        {
            "type": "chart",
            "chart_type": "bar",
            "data_points": 12,
            "colors": ["blue", "green", "red"]
        },
        location=(0.5, 0.5, 0.0)
    )
    print(f"    Stored visual: {visual_id[:8]}")

    print("\n[3] Checking capacity status...")
    status = memory.get_capacity_status()
    print(f"    Capacity: {status['effective_count']}/{status['capacity']}")
    print(f"    Utilization: {status['utilization']:.1%}")

    print("\n[4] Testing chunking...")
    chunk_id = memory.chunk([id1, id2], label="meeting_info")
    if chunk_id:
        print(f"    Created chunk: {chunk_id[:8]}")
        status = memory.get_capacity_status()
        print(f"    New effective count: {status['effective_count']}")

    print("\n[5] Retrieving information...")
    items = memory.retrieve(query="meeting")
    print(f"    Found {len(items)} items matching 'meeting':")
    for item in items:
        content_preview = str(item.content)[:50] + "..." if len(str(item.content)) > 50 else str(item.content)
        print(f"      - {content_preview}")

    print("\n[6] Creating multi-modal episode...")
    episode_id = memory.create_episode(
        {
            "verbal": "Project presentation slides ready",
            "visual": {"slides": 15, "animations": 5}
        },
        label="presentation_prep"
    )
    if episode_id:
        print(f"    Episode created: {episode_id[:8]}")

    print("\n[7] Task switching test...")
    memory.switch_task("email_review")
    memory.store("Reply to John about budget", priority=ProcessingPriority.HIGH)
    print("    Switched to 'email_review' task")

    memory.resume_task()
    print("    Resumed previous task")

    print("\n[8] Rehearsal test...")
    memory.rehearse()
    print("    Performed rehearsal cycle")

    print("\n[9] System state...")
    state = memory.get_state()
    print(f"    Executive state: {state['executive']['current_task'] is not None}")
    print(f"    Attention: {state['attention']['state']}")
    print(f"    Phonological store: {state['phonological_store_size']} items")
    print(f"    Active episodes: {state['active_episodes']}")

    print("\n[10] Metrics...")
    metrics = memory.get_metrics()
    print(f"    Buffer - Items added: {metrics['buffer']['items_added']}")
    print(f"    Attention - Focus shifts: {metrics['attention']['focus_shifts']}")
    print(f"    Phonological - Rehearsals: {metrics['phonological']['rehearsals_performed']}")
    print(f"    Episodes created: {metrics['episodic']['episodes_created']}")

    print("\n[11] Export/Import test...")
    exported = memory.export_state()
    print(f"    Exported {len(exported['buffer_items'])} items")

    memory.clear()
    print("    Cleared memory")

    memory.import_state(exported)
    print("    Imported state")

    items = memory.retrieve(top_k=10)
    print(f"    Restored items: {len(items)}")

    print("\n[12] Shutting down...")
    memory.shutdown()
    print("    Shutdown complete")

    print("\n" + "=" * 70)
    print("Advanced Working Memory System Test Complete")
    print("=" * 70)
