"""
AIVA Queen Continual Learning System
=====================================

A production-grade implementation of continual/lifelong learning mechanisms
to prevent catastrophic forgetting and enable sustainable knowledge acquisition.

Components:
1. CatastrophicForgettingPreventer - Master controller for forgetting prevention
2. ElasticWeightConsolidation - EWC implementation for weight importance
3. ProgressiveNetworks - Dynamic network capacity expansion
4. ExperienceReplay - Strategic replay of past experiences
5. TaskInferencer - Infer current task context
6. CapacityManager - Manage and allocate model capacity

Author: Genesis AIVA Queen
Version: 1.0.0
"""

import json
import time
import math
import random
import hashlib
import logging
import threading
from abc import ABC, abstractmethod
from enum import Enum, auto
from dataclasses import dataclass, field, asdict
from typing import (
    Dict, List, Optional, Any, Tuple, Callable, Set,
    Union, TypeVar, Generic, Iterator
)
from collections import defaultdict, deque
from datetime import datetime, timedelta
from pathlib import Path
import heapq
import uuid

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("ContinualLearning")


# ==============================================================================
# Data Models and Enums
# ==============================================================================

class TaskType(Enum):
    """Types of learning tasks."""
    CLASSIFICATION = auto()
    REGRESSION = auto()
    GENERATION = auto()
    REASONING = auto()
    RETRIEVAL = auto()
    CONVERSATION = auto()
    TOOL_USE = auto()
    CODE_GENERATION = auto()
    UNKNOWN = auto()


class ConsolidationStrategy(Enum):
    """Strategies for memory consolidation."""
    EWC = "elastic_weight_consolidation"
    SI = "synaptic_intelligence"
    MAS = "memory_aware_synapses"
    PROGRESSIVE = "progressive_networks"
    REPLAY = "experience_replay"
    HYBRID = "hybrid_approach"


class CapacityState(Enum):
    """States of capacity allocation."""
    AVAILABLE = auto()
    ALLOCATED = auto()
    RESERVED = auto()
    FROZEN = auto()
    DEPRECATED = auto()


@dataclass
class Experience:
    """Represents a single learning experience."""
    experience_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    task_id: str = ""
    task_type: TaskType = TaskType.UNKNOWN
    input_data: Any = None
    output_data: Any = None
    context: Dict[str, Any] = field(default_factory=dict)
    importance: float = 0.5
    timestamp: float = field(default_factory=time.time)
    replay_count: int = 0
    success_rate: float = 1.0
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        result = asdict(self)
        result['task_type'] = self.task_type.name
        return result

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'Experience':
        """Create from dictionary."""
        data = data.copy()
        data['task_type'] = TaskType[data.get('task_type', 'UNKNOWN')]
        return cls(**data)


@dataclass
class TaskDescriptor:
    """Describes a learning task."""
    task_id: str
    task_type: TaskType
    name: str
    description: str = ""
    input_schema: Dict[str, Any] = field(default_factory=dict)
    output_schema: Dict[str, Any] = field(default_factory=dict)
    priority: int = 5
    created_at: float = field(default_factory=time.time)
    sample_count: int = 0
    success_rate: float = 0.0
    is_active: bool = True
    dependencies: List[str] = field(default_factory=list)
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class WeightImportance:
    """Represents weight importance scores for EWC."""
    weight_id: str
    fisher_information: float = 0.0
    optimal_value: float = 0.0
    task_contributions: Dict[str, float] = field(default_factory=dict)
    last_updated: float = field(default_factory=time.time)
    frozen: bool = False


@dataclass
class CapacityBlock:
    """Represents a block of model capacity."""
    block_id: str
    size: int
    state: CapacityState = CapacityState.AVAILABLE
    assigned_task: Optional[str] = None
    utilization: float = 0.0
    created_at: float = field(default_factory=time.time)
    frozen_at: Optional[float] = None
    performance_score: float = 0.0


@dataclass
class NetworkColumn:
    """Represents a column in progressive networks."""
    column_id: str
    task_id: str
    layer_sizes: List[int] = field(default_factory=list)
    lateral_connections: Dict[str, List[float]] = field(default_factory=dict)
    frozen: bool = False
    created_at: float = field(default_factory=time.time)
    performance_history: List[float] = field(default_factory=list)


# ==============================================================================
# Fisher Information Calculator
# ==============================================================================

class FisherInformationCalculator:
    """
    Calculates Fisher Information Matrix for EWC.

    The Fisher Information Matrix estimates parameter importance
    by computing the expected curvature of the loss landscape.
    """

    def __init__(
        self,
        sample_size: int = 100,
        dampening: float = 1e-8,
        normalize: bool = True
    ):
        self.sample_size = sample_size
        self.dampening = dampening
        self.normalize = normalize
        self.fisher_cache: Dict[str, Dict[str, float]] = {}

    def compute_fisher(
        self,
        parameters: Dict[str, float],
        gradient_fn: Callable[[Dict[str, float], Any], Dict[str, float]],
        data_samples: List[Any]
    ) -> Dict[str, float]:
        """
        Compute Fisher Information for each parameter.

        Args:
            parameters: Current parameter values
            gradient_fn: Function to compute gradients
            data_samples: Samples to compute Fisher over

        Returns:
            Fisher Information scores for each parameter
        """
        fisher_scores = {p: 0.0 for p in parameters}

        # Sample from data
        samples = random.sample(
            data_samples,
            min(self.sample_size, len(data_samples))
        ) if data_samples else []

        if not samples:
            logger.warning("No samples for Fisher computation, using uniform importance")
            return {p: 1.0 for p in parameters}

        # Accumulate squared gradients
        for sample in samples:
            try:
                gradients = gradient_fn(parameters, sample)
                for param_name, grad_value in gradients.items():
                    fisher_scores[param_name] += grad_value ** 2
            except Exception as e:
                logger.warning(f"Error computing gradient: {e}")
                continue

        # Average and add dampening
        n_samples = len(samples)
        for param_name in fisher_scores:
            fisher_scores[param_name] = (
                fisher_scores[param_name] / n_samples + self.dampening
            )

        # Normalize if requested
        if self.normalize:
            max_fisher = max(fisher_scores.values()) if fisher_scores else 1.0
            if max_fisher > 0:
                fisher_scores = {
                    p: v / max_fisher for p, v in fisher_scores.items()
                }

        return fisher_scores

    def update_running_fisher(
        self,
        task_id: str,
        new_fisher: Dict[str, float],
        momentum: float = 0.9
    ) -> Dict[str, float]:
        """Update running Fisher estimate with momentum."""
        if task_id not in self.fisher_cache:
            self.fisher_cache[task_id] = new_fisher.copy()
        else:
            for param in new_fisher:
                old_val = self.fisher_cache[task_id].get(param, 0.0)
                self.fisher_cache[task_id][param] = (
                    momentum * old_val + (1 - momentum) * new_fisher[param]
                )
        return self.fisher_cache[task_id]


# ==============================================================================
# Elastic Weight Consolidation
# ==============================================================================

class ElasticWeightConsolidation:
    """
    Implements Elastic Weight Consolidation (EWC) for continual learning.

    EWC prevents catastrophic forgetting by:
    1. Computing Fisher Information to identify important weights
    2. Adding a quadratic penalty for changing important weights
    3. Preserving knowledge from previous tasks
    """

    def __init__(
        self,
        lambda_ewc: float = 1000.0,
        fisher_sample_size: int = 100,
        online_ewc: bool = True,
        gamma: float = 0.9
    ):
        """
        Initialize EWC.

        Args:
            lambda_ewc: Regularization strength
            fisher_sample_size: Samples for Fisher computation
            online_ewc: Use online EWC variant
            gamma: Decay factor for online EWC
        """
        self.lambda_ewc = lambda_ewc
        self.fisher_sample_size = fisher_sample_size
        self.online_ewc = online_ewc
        self.gamma = gamma

        self.fisher_calculator = FisherInformationCalculator(
            sample_size=fisher_sample_size
        )

        # Store weight importance and optimal values per task
        self.weight_importance: Dict[str, WeightImportance] = {}
        self.task_parameters: Dict[str, Dict[str, float]] = {}
        self.consolidated_fisher: Dict[str, float] = {}
        self.consolidated_optimal: Dict[str, float] = {}

        self._lock = threading.RLock()

        logger.info(
            f"EWC initialized: lambda={lambda_ewc}, online={online_ewc}"
        )

    def register_task(
        self,
        task_id: str,
        parameters: Dict[str, float],
        gradient_fn: Callable,
        data_samples: List[Any]
    ) -> Dict[str, float]:
        """
        Register a completed task and compute importance.

        Args:
            task_id: Unique task identifier
            parameters: Final parameter values after training
            gradient_fn: Function to compute gradients
            data_samples: Training samples for Fisher computation

        Returns:
            Fisher Information scores
        """
        with self._lock:
            logger.info(f"Registering task '{task_id}' for EWC consolidation")

            # Compute Fisher Information
            fisher_scores = self.fisher_calculator.compute_fisher(
                parameters, gradient_fn, data_samples
            )

            # Store task parameters
            self.task_parameters[task_id] = parameters.copy()

            # Create weight importance records
            for param_name, fisher_value in fisher_scores.items():
                weight_id = f"{task_id}_{param_name}"

                if param_name not in self.weight_importance:
                    self.weight_importance[param_name] = WeightImportance(
                        weight_id=param_name,
                        fisher_information=fisher_value,
                        optimal_value=parameters.get(param_name, 0.0),
                        task_contributions={task_id: fisher_value}
                    )
                else:
                    # Update existing importance
                    wi = self.weight_importance[param_name]
                    wi.task_contributions[task_id] = fisher_value

                    if self.online_ewc:
                        # Online EWC: decay old Fisher and add new
                        wi.fisher_information = (
                            self.gamma * wi.fisher_information + fisher_value
                        )
                    else:
                        # Standard EWC: accumulate Fisher
                        wi.fisher_information += fisher_value

                    wi.last_updated = time.time()

            # Update consolidated values
            self._update_consolidated()

            logger.info(
                f"Task '{task_id}' registered: {len(fisher_scores)} parameters"
            )
            return fisher_scores

    def _update_consolidated(self):
        """Update consolidated Fisher and optimal values."""
        self.consolidated_fisher = {}
        self.consolidated_optimal = {}

        for param_name, wi in self.weight_importance.items():
            self.consolidated_fisher[param_name] = wi.fisher_information
            self.consolidated_optimal[param_name] = wi.optimal_value

    def compute_ewc_loss(
        self,
        current_parameters: Dict[str, float]
    ) -> float:
        """
        Compute EWC regularization loss.

        Args:
            current_parameters: Current parameter values

        Returns:
            EWC penalty term
        """
        with self._lock:
            ewc_loss = 0.0

            for param_name, current_value in current_parameters.items():
                if param_name in self.consolidated_fisher:
                    fisher = self.consolidated_fisher[param_name]
                    optimal = self.consolidated_optimal.get(param_name, 0.0)

                    # Quadratic penalty weighted by Fisher
                    ewc_loss += fisher * (current_value - optimal) ** 2

            return 0.5 * self.lambda_ewc * ewc_loss

    def get_gradient_modifier(
        self,
        param_name: str,
        current_value: float
    ) -> float:
        """
        Get gradient modification for a parameter.

        Args:
            param_name: Parameter name
            current_value: Current parameter value

        Returns:
            Gradient modification term
        """
        with self._lock:
            if param_name not in self.consolidated_fisher:
                return 0.0

            fisher = self.consolidated_fisher[param_name]
            optimal = self.consolidated_optimal.get(param_name, 0.0)

            return self.lambda_ewc * fisher * (current_value - optimal)

    def freeze_critical_weights(
        self,
        threshold: float = 0.9
    ) -> List[str]:
        """
        Freeze weights with importance above threshold.

        Args:
            threshold: Importance threshold (0-1)

        Returns:
            List of frozen parameter names
        """
        frozen = []
        max_importance = max(
            (wi.fisher_information for wi in self.weight_importance.values()),
            default=1.0
        )

        for param_name, wi in self.weight_importance.items():
            normalized = wi.fisher_information / max_importance if max_importance > 0 else 0
            if normalized >= threshold:
                wi.frozen = True
                frozen.append(param_name)

        logger.info(f"Frozen {len(frozen)} critical weights")
        return frozen

    def get_importance_summary(self) -> Dict[str, Any]:
        """Get summary of weight importance."""
        if not self.weight_importance:
            return {"status": "no_weights_registered"}

        fisher_values = [wi.fisher_information for wi in self.weight_importance.values()]

        return {
            "total_weights": len(self.weight_importance),
            "frozen_weights": sum(1 for wi in self.weight_importance.values() if wi.frozen),
            "mean_importance": sum(fisher_values) / len(fisher_values),
            "max_importance": max(fisher_values),
            "min_importance": min(fisher_values),
            "tasks_registered": len(self.task_parameters)
        }


# ==============================================================================
# Progressive Networks
# ==============================================================================

class ProgressiveNetworks:
    """
    Implements Progressive Neural Networks for continual learning.

    Progressive Networks:
    1. Add new network columns for each new task
    2. Freeze previous columns to prevent forgetting
    3. Add lateral connections for knowledge transfer
    """

    def __init__(
        self,
        base_layer_sizes: List[int] = None,
        max_columns: int = 10,
        lateral_connection_type: str = "adapter",
        compression_threshold: float = 0.8
    ):
        """
        Initialize Progressive Networks.

        Args:
            base_layer_sizes: Default layer sizes for new columns
            max_columns: Maximum number of columns
            lateral_connection_type: Type of lateral connections
            compression_threshold: Threshold for column compression
        """
        self.base_layer_sizes = base_layer_sizes or [256, 128, 64]
        self.max_columns = max_columns
        self.lateral_connection_type = lateral_connection_type
        self.compression_threshold = compression_threshold

        self.columns: Dict[str, NetworkColumn] = {}
        self.column_order: List[str] = []
        self.task_to_column: Dict[str, str] = {}

        self._lock = threading.RLock()

        logger.info(
            f"ProgressiveNetworks initialized: max_columns={max_columns}"
        )

    def add_column(
        self,
        task_id: str,
        layer_sizes: Optional[List[int]] = None
    ) -> NetworkColumn:
        """
        Add a new column for a task.

        Args:
            task_id: Task identifier
            layer_sizes: Layer sizes for this column

        Returns:
            Created NetworkColumn
        """
        with self._lock:
            if len(self.columns) >= self.max_columns:
                logger.warning("Max columns reached, compressing old columns")
                self._compress_columns()

            column_id = f"col_{len(self.columns)}_{task_id[:8]}"
            sizes = layer_sizes or self.base_layer_sizes.copy()

            # Create lateral connections to previous columns
            lateral_connections = {}
            for prev_col_id in self.column_order:
                prev_col = self.columns[prev_col_id]
                # Initialize adapter weights (simplified as scalars)
                lateral_connections[prev_col_id] = [
                    random.gauss(0, 0.1) for _ in range(len(sizes))
                ]

            column = NetworkColumn(
                column_id=column_id,
                task_id=task_id,
                layer_sizes=sizes,
                lateral_connections=lateral_connections
            )

            self.columns[column_id] = column
            self.column_order.append(column_id)
            self.task_to_column[task_id] = column_id

            logger.info(
                f"Added column '{column_id}' for task '{task_id}' "
                f"with {len(lateral_connections)} lateral connections"
            )
            return column

    def freeze_column(self, task_id: str) -> bool:
        """Freeze a column after training."""
        with self._lock:
            if task_id not in self.task_to_column:
                return False

            column_id = self.task_to_column[task_id]
            self.columns[column_id].frozen = True
            logger.info(f"Frozen column '{column_id}'")
            return True

    def _compress_columns(self):
        """Compress old columns to make room for new ones."""
        # Find columns to compress (oldest with lowest performance)
        candidates = [
            (col_id, col) for col_id, col in self.columns.items()
            if col.frozen and col.performance_history
        ]

        if not candidates:
            logger.warning("No compressible columns found")
            return

        # Sort by average performance (ascending)
        candidates.sort(
            key=lambda x: sum(x[1].performance_history) / len(x[1].performance_history)
        )

        # Compress lowest performing column
        col_id, col = candidates[0]

        # Reduce layer sizes by half
        col.layer_sizes = [s // 2 for s in col.layer_sizes]
        col.layer_sizes = [max(s, 8) for s in col.layer_sizes]  # Min size 8

        logger.info(f"Compressed column '{col_id}'")

    def forward(
        self,
        task_id: str,
        input_data: Any,
        layer_fn: Callable[[Any, int, List[float]], Any]
    ) -> Any:
        """
        Forward pass through progressive network.

        Args:
            task_id: Current task
            input_data: Input to process
            layer_fn: Function to apply at each layer

        Returns:
            Output from network
        """
        with self._lock:
            if task_id not in self.task_to_column:
                # Create new column for unknown task
                self.add_column(task_id)

            column_id = self.task_to_column[task_id]
            column = self.columns[column_id]

            current = input_data

            for layer_idx, layer_size in enumerate(column.layer_sizes):
                # Gather lateral inputs from previous columns
                lateral_inputs = []
                for prev_col_id, weights in column.lateral_connections.items():
                    if layer_idx < len(weights):
                        lateral_inputs.append(weights[layer_idx])

                # Apply layer function
                current = layer_fn(current, layer_size, lateral_inputs)

            return current

    def update_performance(self, task_id: str, score: float):
        """Update performance history for a task's column."""
        with self._lock:
            if task_id in self.task_to_column:
                column_id = self.task_to_column[task_id]
                self.columns[column_id].performance_history.append(score)
                # Keep last 100 scores
                if len(self.columns[column_id].performance_history) > 100:
                    self.columns[column_id].performance_history.pop(0)

    def get_network_summary(self) -> Dict[str, Any]:
        """Get summary of progressive network state."""
        return {
            "total_columns": len(self.columns),
            "frozen_columns": sum(1 for c in self.columns.values() if c.frozen),
            "active_columns": sum(1 for c in self.columns.values() if not c.frozen),
            "total_parameters": sum(
                sum(c.layer_sizes) for c in self.columns.values()
            ),
            "column_order": self.column_order,
            "tasks": list(self.task_to_column.keys())
        }


# ==============================================================================
# Experience Replay
# ==============================================================================

class ExperienceReplay:
    """
    Implements Experience Replay for continual learning.

    Maintains a buffer of past experiences and strategically replays
    them during training to prevent forgetting.
    """

    def __init__(
        self,
        buffer_size: int = 10000,
        sampling_strategy: str = "priority",
        priority_exponent: float = 0.6,
        importance_sampling: bool = True,
        per_task_quota: Optional[int] = None
    ):
        """
        Initialize Experience Replay.

        Args:
            buffer_size: Maximum buffer size
            sampling_strategy: How to sample ("uniform", "priority", "reservoir")
            priority_exponent: Priority sampling exponent
            importance_sampling: Use importance sampling weights
            per_task_quota: Max experiences per task (None = no limit)
        """
        self.buffer_size = buffer_size
        self.sampling_strategy = sampling_strategy
        self.priority_exponent = priority_exponent
        self.importance_sampling = importance_sampling
        self.per_task_quota = per_task_quota

        # Priority queue for priority sampling
        self.priority_buffer: List[Tuple[float, Experience]] = []

        # Task-specific buffers
        self.task_buffers: Dict[str, List[Experience]] = defaultdict(list)

        # Statistics
        self.total_experiences = 0
        self.total_replays = 0
        self.replay_stats: Dict[str, int] = defaultdict(int)

        self._lock = threading.RLock()

        logger.info(
            f"ExperienceReplay initialized: size={buffer_size}, "
            f"strategy={sampling_strategy}"
        )

    def add_experience(self, experience: Experience) -> bool:
        """
        Add experience to replay buffer.

        Args:
            experience: Experience to add

        Returns:
            True if added successfully
        """
        with self._lock:
            task_id = experience.task_id

            # Check per-task quota
            if self.per_task_quota:
                if len(self.task_buffers[task_id]) >= self.per_task_quota:
                    # Remove oldest from this task
                    removed = self.task_buffers[task_id].pop(0)
                    self._remove_from_priority_buffer(removed.experience_id)

            # Check total buffer size
            while len(self.priority_buffer) >= self.buffer_size:
                # Remove lowest priority
                if self.priority_buffer:
                    _, removed_exp = heapq.heappop(self.priority_buffer)
                    self.task_buffers[removed_exp.task_id] = [
                        e for e in self.task_buffers[removed_exp.task_id]
                        if e.experience_id != removed_exp.experience_id
                    ]

            # Compute priority (negative for min-heap, we want max priority)
            priority = -self._compute_priority(experience)

            heapq.heappush(self.priority_buffer, (priority, experience))
            self.task_buffers[task_id].append(experience)
            self.total_experiences += 1

            return True

    def _remove_from_priority_buffer(self, experience_id: str):
        """Remove experience from priority buffer by ID."""
        self.priority_buffer = [
            (p, e) for p, e in self.priority_buffer
            if e.experience_id != experience_id
        ]
        heapq.heapify(self.priority_buffer)

    def _compute_priority(self, experience: Experience) -> float:
        """Compute priority score for an experience."""
        # Base priority from importance
        priority = experience.importance

        # Boost for lower success rate (harder examples)
        priority *= (2.0 - experience.success_rate)

        # Decay based on age
        age_hours = (time.time() - experience.timestamp) / 3600
        age_decay = math.exp(-0.01 * age_hours)
        priority *= age_decay

        # Boost for less replayed experiences
        replay_boost = 1.0 / (1.0 + experience.replay_count)
        priority *= replay_boost

        return priority ** self.priority_exponent

    def sample(
        self,
        batch_size: int,
        task_id: Optional[str] = None
    ) -> Tuple[List[Experience], List[float]]:
        """
        Sample experiences from buffer.

        Args:
            batch_size: Number of experiences to sample
            task_id: Optional task filter

        Returns:
            Tuple of (experiences, importance_weights)
        """
        with self._lock:
            if not self.priority_buffer:
                return [], []

            # Get candidate pool
            if task_id:
                candidates = [
                    (p, e) for p, e in self.priority_buffer
                    if e.task_id == task_id
                ]
            else:
                candidates = self.priority_buffer.copy()

            if not candidates:
                return [], []

            # Sample based on strategy
            if self.sampling_strategy == "uniform":
                sampled = random.sample(
                    candidates,
                    min(batch_size, len(candidates))
                )
            elif self.sampling_strategy == "priority":
                # Weighted sampling by priority
                weights = [abs(p) for p, _ in candidates]
                total_weight = sum(weights)
                if total_weight == 0:
                    sampled = random.sample(
                        candidates,
                        min(batch_size, len(candidates))
                    )
                else:
                    probs = [w / total_weight for w in weights]
                    indices = random.choices(
                        range(len(candidates)),
                        weights=probs,
                        k=min(batch_size, len(candidates))
                    )
                    sampled = [candidates[i] for i in indices]
            else:  # reservoir
                sampled = random.sample(
                    candidates,
                    min(batch_size, len(candidates))
                )

            experiences = [e for _, e in sampled]

            # Update replay counts
            for exp in experiences:
                exp.replay_count += 1
                self.replay_stats[exp.task_id] += 1
            self.total_replays += len(experiences)

            # Compute importance sampling weights
            if self.importance_sampling:
                N = len(self.priority_buffer)
                weights = []
                for priority, exp in sampled:
                    p = abs(priority)
                    # IS weight = 1 / (N * P(sample))
                    total_p = sum(abs(pp) for pp, _ in candidates)
                    if total_p > 0:
                        prob = p / total_p
                        weight = 1.0 / (N * prob + 1e-8)
                    else:
                        weight = 1.0
                    weights.append(weight)

                # Normalize weights
                max_weight = max(weights) if weights else 1.0
                weights = [w / max_weight for w in weights]
            else:
                weights = [1.0] * len(experiences)

            return experiences, weights

    def sample_balanced(
        self,
        batch_size: int,
        tasks: Optional[List[str]] = None
    ) -> Tuple[List[Experience], List[float]]:
        """
        Sample balanced across tasks.

        Args:
            batch_size: Total batch size
            tasks: Tasks to include (None = all)

        Returns:
            Tuple of (experiences, weights)
        """
        with self._lock:
            target_tasks = tasks or list(self.task_buffers.keys())

            if not target_tasks:
                return [], []

            per_task = max(1, batch_size // len(target_tasks))
            remainder = batch_size % len(target_tasks)

            all_experiences = []
            all_weights = []

            for i, task_id in enumerate(target_tasks):
                n = per_task + (1 if i < remainder else 0)
                exps, weights = self.sample(n, task_id)
                all_experiences.extend(exps)
                all_weights.extend(weights)

            return all_experiences, all_weights

    def get_statistics(self) -> Dict[str, Any]:
        """Get replay buffer statistics."""
        with self._lock:
            return {
                "buffer_size": len(self.priority_buffer),
                "max_size": self.buffer_size,
                "utilization": len(self.priority_buffer) / self.buffer_size,
                "total_experiences": self.total_experiences,
                "total_replays": self.total_replays,
                "tasks": len(self.task_buffers),
                "replay_stats": dict(self.replay_stats),
                "avg_replay_count": (
                    sum(e.replay_count for _, e in self.priority_buffer) /
                    len(self.priority_buffer) if self.priority_buffer else 0
                )
            }


# ==============================================================================
# Task Inferencer
# ==============================================================================

class TaskInferencer:
    """
    Infers the current task from context and input.

    Uses multiple signals to determine which task is being performed,
    enabling appropriate knowledge retrieval and capacity allocation.
    """

    def __init__(
        self,
        similarity_threshold: float = 0.7,
        max_task_history: int = 100,
        use_context_window: bool = True,
        context_window_size: int = 5
    ):
        """
        Initialize Task Inferencer.

        Args:
            similarity_threshold: Threshold for task matching
            max_task_history: History size for pattern detection
            use_context_window: Use sliding context window
            context_window_size: Size of context window
        """
        self.similarity_threshold = similarity_threshold
        self.max_task_history = max_task_history
        self.use_context_window = use_context_window
        self.context_window_size = context_window_size

        self.registered_tasks: Dict[str, TaskDescriptor] = {}
        self.task_signatures: Dict[str, Dict[str, Any]] = {}
        self.inference_history: deque = deque(maxlen=max_task_history)
        self.context_window: deque = deque(maxlen=context_window_size)

        self._lock = threading.RLock()

        logger.info("TaskInferencer initialized")

    def register_task(self, task: TaskDescriptor):
        """Register a task with its descriptor."""
        with self._lock:
            self.registered_tasks[task.task_id] = task
            self.task_signatures[task.task_id] = self._compute_signature(task)
            logger.info(f"Registered task '{task.task_id}'")

    def _compute_signature(self, task: TaskDescriptor) -> Dict[str, Any]:
        """Compute task signature for matching."""
        # Create signature from task properties
        signature = {
            "task_type": task.task_type,
            "input_keys": set(task.input_schema.keys()),
            "output_keys": set(task.output_schema.keys()),
            "keywords": self._extract_keywords(task.description),
            "dependencies": set(task.dependencies)
        }
        return signature

    def _extract_keywords(self, text: str) -> Set[str]:
        """Extract keywords from text."""
        # Simple keyword extraction
        words = text.lower().split()
        # Filter common words
        stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been'}
        return {w for w in words if len(w) > 3 and w not in stopwords}

    def infer_task(
        self,
        input_data: Any,
        context: Optional[Dict[str, Any]] = None
    ) -> Tuple[str, float, TaskDescriptor]:
        """
        Infer the current task from input and context.

        Args:
            input_data: Input to process
            context: Additional context

        Returns:
            Tuple of (task_id, confidence, task_descriptor)
        """
        with self._lock:
            if not self.registered_tasks:
                return self._create_unknown_task()

            # Update context window
            if self.use_context_window:
                self.context_window.append({
                    "input": input_data,
                    "context": context,
                    "timestamp": time.time()
                })

            # Compute input signature
            input_sig = self._compute_input_signature(input_data, context)

            # Score each registered task
            scores = []
            for task_id, task_sig in self.task_signatures.items():
                score = self._compute_similarity(input_sig, task_sig)
                scores.append((task_id, score))

            # Consider context window for pattern matching
            if self.use_context_window and len(self.context_window) > 1:
                pattern_boost = self._compute_pattern_boost()
                for i, (task_id, score) in enumerate(scores):
                    if task_id in pattern_boost:
                        scores[i] = (task_id, score + 0.2 * pattern_boost[task_id])

            # Get best match
            scores.sort(key=lambda x: x[1], reverse=True)
            best_task_id, best_score = scores[0]

            # Record inference
            self.inference_history.append({
                "task_id": best_task_id,
                "confidence": best_score,
                "timestamp": time.time()
            })

            if best_score >= self.similarity_threshold:
                task = self.registered_tasks[best_task_id]
                return best_task_id, best_score, task
            else:
                return self._create_unknown_task()

    def _compute_input_signature(
        self,
        input_data: Any,
        context: Optional[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """Compute signature from input data."""
        signature = {
            "input_type": type(input_data).__name__,
            "input_keys": set(),
            "keywords": set(),
            "context_keys": set()
        }

        if isinstance(input_data, dict):
            signature["input_keys"] = set(input_data.keys())
            # Extract text for keywords
            for v in input_data.values():
                if isinstance(v, str):
                    signature["keywords"].update(self._extract_keywords(v))
        elif isinstance(input_data, str):
            signature["keywords"] = self._extract_keywords(input_data)

        if context:
            signature["context_keys"] = set(context.keys())

        return signature

    def _compute_similarity(
        self,
        input_sig: Dict[str, Any],
        task_sig: Dict[str, Any]
    ) -> float:
        """Compute similarity between input and task signatures."""
        score = 0.0
        weights = {"type": 0.3, "keys": 0.3, "keywords": 0.4}

        # Key overlap
        if input_sig["input_keys"] and task_sig["input_keys"]:
            overlap = len(input_sig["input_keys"] & task_sig["input_keys"])
            total = len(input_sig["input_keys"] | task_sig["input_keys"])
            score += weights["keys"] * (overlap / total if total > 0 else 0)

        # Keyword overlap
        if input_sig["keywords"] and task_sig["keywords"]:
            overlap = len(input_sig["keywords"] & task_sig["keywords"])
            total = len(task_sig["keywords"])
            score += weights["keywords"] * (overlap / total if total > 0 else 0)

        return min(score, 1.0)

    def _compute_pattern_boost(self) -> Dict[str, float]:
        """Compute boost from recent task patterns."""
        boost = defaultdict(float)

        # Check recent inference history
        recent = list(self.inference_history)[-10:]
        for entry in recent:
            boost[entry["task_id"]] += 0.1

        return dict(boost)

    def _create_unknown_task(self) -> Tuple[str, float, TaskDescriptor]:
        """Create unknown task tuple."""
        unknown_task = TaskDescriptor(
            task_id="unknown",
            task_type=TaskType.UNKNOWN,
            name="Unknown Task",
            description="Unidentified task"
        )
        return "unknown", 0.0, unknown_task

    def get_task_distribution(self) -> Dict[str, float]:
        """Get distribution of inferred tasks."""
        if not self.inference_history:
            return {}

        counts = defaultdict(int)
        for entry in self.inference_history:
            counts[entry["task_id"]] += 1

        total = sum(counts.values())
        return {k: v / total for k, v in counts.items()}


# ==============================================================================
# Capacity Manager
# ==============================================================================

class CapacityManager:
    """
    Manages model capacity for continual learning.

    Allocates, reserves, and manages capacity blocks to ensure
    sufficient resources for new tasks while preserving old ones.
    """

    def __init__(
        self,
        total_capacity: int = 100000,
        block_size: int = 1000,
        reserve_ratio: float = 0.2,
        growth_factor: float = 1.5,
        min_task_capacity: int = 5000
    ):
        """
        Initialize Capacity Manager.

        Args:
            total_capacity: Total available capacity
            block_size: Size of each capacity block
            reserve_ratio: Ratio to keep in reserve
            growth_factor: Factor for capacity growth
            min_task_capacity: Minimum capacity per task
        """
        self.total_capacity = total_capacity
        self.block_size = block_size
        self.reserve_ratio = reserve_ratio
        self.growth_factor = growth_factor
        self.min_task_capacity = min_task_capacity

        # Initialize capacity blocks
        self.blocks: Dict[str, CapacityBlock] = {}
        self._initialize_blocks()

        # Task allocations
        self.task_allocations: Dict[str, List[str]] = defaultdict(list)

        # Statistics
        self.allocation_history: List[Dict[str, Any]] = []

        self._lock = threading.RLock()

        logger.info(
            f"CapacityManager initialized: total={total_capacity}, "
            f"blocks={len(self.blocks)}"
        )

    def _initialize_blocks(self):
        """Initialize capacity blocks."""
        num_blocks = self.total_capacity // self.block_size
        reserve_blocks = int(num_blocks * self.reserve_ratio)

        for i in range(num_blocks):
            block_id = f"block_{i:04d}"
            state = CapacityState.RESERVED if i < reserve_blocks else CapacityState.AVAILABLE
            self.blocks[block_id] = CapacityBlock(
                block_id=block_id,
                size=self.block_size,
                state=state
            )

    def allocate(
        self,
        task_id: str,
        requested_capacity: Optional[int] = None
    ) -> List[CapacityBlock]:
        """
        Allocate capacity for a task.

        Args:
            task_id: Task to allocate for
            requested_capacity: Requested capacity (None = minimum)

        Returns:
            List of allocated blocks
        """
        with self._lock:
            capacity_needed = requested_capacity or self.min_task_capacity
            blocks_needed = math.ceil(capacity_needed / self.block_size)

            # Find available blocks
            available = [
                b for b in self.blocks.values()
                if b.state == CapacityState.AVAILABLE
            ]

            if len(available) < blocks_needed:
                # Try to free reserved blocks
                reserved = [
                    b for b in self.blocks.values()
                    if b.state == CapacityState.RESERVED
                ]
                available.extend(reserved[:blocks_needed - len(available)])

            if len(available) < blocks_needed:
                # Try to grow capacity
                if not self._grow_capacity(blocks_needed - len(available)):
                    logger.warning(
                        f"Insufficient capacity for task '{task_id}': "
                        f"need {blocks_needed}, have {len(available)}"
                    )
                    # Allocate what we can
                    blocks_needed = len(available)

            # Allocate blocks
            allocated = available[:blocks_needed]
            for block in allocated:
                block.state = CapacityState.ALLOCATED
                block.assigned_task = task_id
                self.task_allocations[task_id].append(block.block_id)

            # Record allocation
            self.allocation_history.append({
                "task_id": task_id,
                "blocks": len(allocated),
                "capacity": len(allocated) * self.block_size,
                "timestamp": time.time()
            })

            logger.info(
                f"Allocated {len(allocated)} blocks ({len(allocated) * self.block_size} capacity) "
                f"for task '{task_id}'"
            )
            return allocated

    def _grow_capacity(self, blocks_needed: int) -> bool:
        """Attempt to grow total capacity."""
        new_blocks = int(blocks_needed * self.growth_factor)
        new_capacity = new_blocks * self.block_size

        # Check if growth is allowed
        max_growth = self.total_capacity * 0.5  # Max 50% growth
        if new_capacity > max_growth:
            return False

        # Add new blocks
        current_count = len(self.blocks)
        for i in range(new_blocks):
            block_id = f"block_{current_count + i:04d}"
            self.blocks[block_id] = CapacityBlock(
                block_id=block_id,
                size=self.block_size,
                state=CapacityState.AVAILABLE
            )

        self.total_capacity += new_capacity
        logger.info(f"Grew capacity by {new_capacity}, total: {self.total_capacity}")
        return True

    def freeze_task(self, task_id: str) -> int:
        """Freeze capacity for a task."""
        with self._lock:
            frozen_count = 0
            for block_id in self.task_allocations.get(task_id, []):
                if block_id in self.blocks:
                    self.blocks[block_id].state = CapacityState.FROZEN
                    self.blocks[block_id].frozen_at = time.time()
                    frozen_count += 1

            logger.info(f"Froze {frozen_count} blocks for task '{task_id}'")
            return frozen_count

    def release_task(self, task_id: str) -> int:
        """Release capacity from a task."""
        with self._lock:
            released_count = 0
            for block_id in self.task_allocations.get(task_id, []):
                if block_id in self.blocks:
                    block = self.blocks[block_id]
                    if block.state != CapacityState.FROZEN:
                        block.state = CapacityState.AVAILABLE
                        block.assigned_task = None
                        released_count += 1

            if task_id in self.task_allocations:
                del self.task_allocations[task_id]

            logger.info(f"Released {released_count} blocks from task '{task_id}'")
            return released_count

    def get_utilization(self) -> Dict[str, Any]:
        """Get capacity utilization statistics."""
        with self._lock:
            states = defaultdict(int)
            for block in self.blocks.values():
                states[block.state.name] += 1

            total = len(self.blocks)
            return {
                "total_blocks": total,
                "total_capacity": self.total_capacity,
                "states": dict(states),
                "utilization": {
                    state: count / total if total > 0 else 0
                    for state, count in states.items()
                },
                "tasks_with_allocation": len(self.task_allocations),
                "average_blocks_per_task": (
                    sum(len(blocks) for blocks in self.task_allocations.values()) /
                    len(self.task_allocations) if self.task_allocations else 0
                )
            }

    def defragment(self) -> int:
        """Defragment capacity by consolidating allocations."""
        with self._lock:
            # Sort blocks by state and task
            blocks_by_task: Dict[str, List[CapacityBlock]] = defaultdict(list)
            available_blocks: List[CapacityBlock] = []

            for block in self.blocks.values():
                if block.assigned_task:
                    blocks_by_task[block.assigned_task].append(block)
                elif block.state == CapacityState.AVAILABLE:
                    available_blocks.append(block)

            # Consolidate fragmented allocations
            moves = 0
            # (Simplified - just count potential moves)
            for task_id, blocks in blocks_by_task.items():
                # Check for non-contiguous blocks (simplified)
                if len(blocks) > 1:
                    moves += 1

            logger.info(f"Defragmentation identified {moves} potential consolidations")
            return moves


# ==============================================================================
# Catastrophic Forgetting Preventer
# ==============================================================================

class CatastrophicForgettingPreventer:
    """
    Master controller for preventing catastrophic forgetting.

    Coordinates multiple strategies:
    - Elastic Weight Consolidation
    - Progressive Networks
    - Experience Replay
    - Task Inference
    - Capacity Management
    """

    def __init__(
        self,
        strategy: ConsolidationStrategy = ConsolidationStrategy.HYBRID,
        ewc_lambda: float = 1000.0,
        replay_buffer_size: int = 10000,
        progressive_max_columns: int = 10,
        capacity_total: int = 100000
    ):
        """
        Initialize Catastrophic Forgetting Preventer.

        Args:
            strategy: Primary consolidation strategy
            ewc_lambda: EWC regularization strength
            replay_buffer_size: Experience replay buffer size
            progressive_max_columns: Max progressive network columns
            capacity_total: Total capacity to manage
        """
        self.strategy = strategy

        # Initialize components
        self.ewc = ElasticWeightConsolidation(lambda_ewc=ewc_lambda)
        self.progressive = ProgressiveNetworks(max_columns=progressive_max_columns)
        self.replay = ExperienceReplay(buffer_size=replay_buffer_size)
        self.task_inferencer = TaskInferencer()
        self.capacity_manager = CapacityManager(total_capacity=capacity_total)

        # Tracking
        self.active_tasks: Set[str] = set()
        self.completed_tasks: Set[str] = set()
        self.task_performance: Dict[str, List[float]] = defaultdict(list)
        self.forgetting_events: List[Dict[str, Any]] = []

        # Thresholds
        self.forgetting_threshold = 0.1  # 10% performance drop
        self.consolidation_interval = 100  # Consolidate every N experiences

        self._lock = threading.RLock()
        self._experience_count = 0

        logger.info(
            f"CatastrophicForgettingPreventer initialized: strategy={strategy.name}"
        )

    def start_task(
        self,
        task_id: str,
        task_descriptor: TaskDescriptor
    ) -> Dict[str, Any]:
        """
        Start a new learning task.

        Args:
            task_id: Task identifier
            task_descriptor: Task description

        Returns:
            Task initialization info
        """
        with self._lock:
            logger.info(f"Starting task '{task_id}'")

            # Register with task inferencer
            self.task_inferencer.register_task(task_descriptor)

            # Allocate capacity
            blocks = self.capacity_manager.allocate(task_id)

            # Add progressive network column if using that strategy
            if self.strategy in [
                ConsolidationStrategy.PROGRESSIVE,
                ConsolidationStrategy.HYBRID
            ]:
                self.progressive.add_column(task_id)

            self.active_tasks.add(task_id)

            return {
                "task_id": task_id,
                "allocated_capacity": len(blocks) * self.capacity_manager.block_size,
                "strategy": self.strategy.name,
                "active_tasks": len(self.active_tasks)
            }

    def process_experience(
        self,
        experience: Experience,
        parameters: Optional[Dict[str, float]] = None,
        gradient_fn: Optional[Callable] = None
    ) -> Dict[str, Any]:
        """
        Process a learning experience.

        Args:
            experience: Experience to process
            parameters: Current model parameters
            gradient_fn: Gradient computation function

        Returns:
            Processing results
        """
        with self._lock:
            self._experience_count += 1

            # Add to replay buffer
            self.replay.add_experience(experience)

            # Track performance
            self.task_performance[experience.task_id].append(experience.success_rate)

            # Check for forgetting
            forgetting_detected = self._check_forgetting(experience.task_id)

            # Periodic consolidation
            ewc_loss = 0.0
            if self._experience_count % self.consolidation_interval == 0:
                if parameters and gradient_fn:
                    ewc_loss = self.ewc.compute_ewc_loss(parameters)

            # Update progressive network performance
            if experience.task_id in self.progressive.task_to_column:
                self.progressive.update_performance(
                    experience.task_id,
                    experience.success_rate
                )

            return {
                "experience_id": experience.experience_id,
                "forgetting_detected": forgetting_detected,
                "ewc_loss": ewc_loss,
                "replay_buffer_size": len(self.replay.priority_buffer),
                "total_experiences": self._experience_count
            }

    def _check_forgetting(self, current_task: str) -> bool:
        """Check if forgetting is occurring in other tasks."""
        forgetting_detected = False

        for task_id in self.completed_tasks:
            if task_id == current_task:
                continue

            perf_history = self.task_performance.get(task_id, [])
            if len(perf_history) < 10:
                continue

            # Compare recent to historical performance
            historical = sum(perf_history[:-5]) / len(perf_history[:-5])
            recent = sum(perf_history[-5:]) / 5

            if historical - recent > self.forgetting_threshold:
                forgetting_detected = True
                self.forgetting_events.append({
                    "task_id": task_id,
                    "current_task": current_task,
                    "historical_perf": historical,
                    "recent_perf": recent,
                    "drop": historical - recent,
                    "timestamp": time.time()
                })
                logger.warning(
                    f"Forgetting detected in task '{task_id}': "
                    f"{historical:.3f} -> {recent:.3f}"
                )

        return forgetting_detected

    def complete_task(
        self,
        task_id: str,
        final_parameters: Dict[str, float],
        gradient_fn: Callable,
        training_data: List[Any]
    ) -> Dict[str, Any]:
        """
        Complete a task and consolidate knowledge.

        Args:
            task_id: Task identifier
            final_parameters: Final model parameters
            gradient_fn: Gradient computation function
            training_data: Training data samples

        Returns:
            Consolidation results
        """
        with self._lock:
            logger.info(f"Completing task '{task_id}'")

            # Register with EWC
            if self.strategy in [ConsolidationStrategy.EWC, ConsolidationStrategy.HYBRID]:
                fisher_scores = self.ewc.register_task(
                    task_id, final_parameters, gradient_fn, training_data
                )
            else:
                fisher_scores = {}

            # Freeze progressive network column
            if self.strategy in [
                ConsolidationStrategy.PROGRESSIVE,
                ConsolidationStrategy.HYBRID
            ]:
                self.progressive.freeze_column(task_id)

            # Freeze capacity
            frozen_blocks = self.capacity_manager.freeze_task(task_id)

            # Move from active to completed
            self.active_tasks.discard(task_id)
            self.completed_tasks.add(task_id)

            return {
                "task_id": task_id,
                "fisher_parameters": len(fisher_scores),
                "frozen_capacity": frozen_blocks * self.capacity_manager.block_size,
                "completed_tasks": len(self.completed_tasks),
                "active_tasks": len(self.active_tasks)
            }

    def get_replay_batch(
        self,
        batch_size: int = 32,
        balanced: bool = True
    ) -> Tuple[List[Experience], List[float]]:
        """
        Get a batch of experiences for replay.

        Args:
            batch_size: Batch size
            balanced: Balance across tasks

        Returns:
            Tuple of (experiences, importance_weights)
        """
        if balanced:
            return self.replay.sample_balanced(
                batch_size,
                list(self.completed_tasks)
            )
        else:
            return self.replay.sample(batch_size)

    def infer_current_task(
        self,
        input_data: Any,
        context: Optional[Dict[str, Any]] = None
    ) -> Tuple[str, float, TaskDescriptor]:
        """
        Infer the current task from input.

        Args:
            input_data: Input data
            context: Optional context

        Returns:
            Tuple of (task_id, confidence, task_descriptor)
        """
        return self.task_inferencer.infer_task(input_data, context)

    def get_protection_status(self) -> Dict[str, Any]:
        """Get comprehensive protection status."""
        with self._lock:
            return {
                "strategy": self.strategy.name,
                "active_tasks": list(self.active_tasks),
                "completed_tasks": list(self.completed_tasks),
                "total_experiences": self._experience_count,
                "forgetting_events": len(self.forgetting_events),
                "ewc_summary": self.ewc.get_importance_summary(),
                "progressive_summary": self.progressive.get_network_summary(),
                "replay_stats": self.replay.get_statistics(),
                "capacity_utilization": self.capacity_manager.get_utilization(),
                "task_distribution": self.task_inferencer.get_task_distribution()
            }

    def mitigate_forgetting(
        self,
        parameters: Dict[str, float],
        learning_rate: float = 0.01
    ) -> Dict[str, float]:
        """
        Apply forgetting mitigation to parameters.

        Args:
            parameters: Current parameters
            learning_rate: Learning rate for mitigation

        Returns:
            Mitigated parameters
        """
        mitigated = parameters.copy()

        # Apply EWC gradient modification
        for param_name, value in mitigated.items():
            modifier = self.ewc.get_gradient_modifier(param_name, value)
            mitigated[param_name] = value - learning_rate * modifier

        return mitigated

    def save_state(self, path: str):
        """Save preventer state to file."""
        state = {
            "strategy": self.strategy.name,
            "active_tasks": list(self.active_tasks),
            "completed_tasks": list(self.completed_tasks),
            "task_performance": dict(self.task_performance),
            "forgetting_events": self.forgetting_events,
            "experience_count": self._experience_count,
            "ewc_task_parameters": self.ewc.task_parameters,
            "ewc_consolidated_fisher": self.ewc.consolidated_fisher,
            "ewc_consolidated_optimal": self.ewc.consolidated_optimal,
            "progressive_columns": {
                k: {
                    "column_id": v.column_id,
                    "task_id": v.task_id,
                    "layer_sizes": v.layer_sizes,
                    "frozen": v.frozen
                }
                for k, v in self.progressive.columns.items()
            },
            "capacity_allocations": dict(self.capacity_manager.task_allocations)
        }

        with open(path, 'w') as f:
            json.dump(state, f, indent=2)

        logger.info(f"State saved to {path}")

    def load_state(self, path: str):
        """Load preventer state from file."""
        with open(path, 'r') as f:
            state = json.load(f)

        self.strategy = ConsolidationStrategy[state["strategy"]]
        self.active_tasks = set(state["active_tasks"])
        self.completed_tasks = set(state["completed_tasks"])
        self.task_performance = defaultdict(list, state["task_performance"])
        self.forgetting_events = state["forgetting_events"]
        self._experience_count = state["experience_count"]

        # Restore EWC state
        self.ewc.task_parameters = state["ewc_task_parameters"]
        self.ewc.consolidated_fisher = state["ewc_consolidated_fisher"]
        self.ewc.consolidated_optimal = state["ewc_consolidated_optimal"]

        logger.info(f"State loaded from {path}")


# ==============================================================================
# Main Entry Point
# ==============================================================================

def create_continual_learner(
    strategy: str = "hybrid",
    ewc_lambda: float = 1000.0,
    replay_buffer_size: int = 10000,
    progressive_max_columns: int = 10,
    capacity_total: int = 100000
) -> CatastrophicForgettingPreventer:
    """
    Factory function to create a configured continual learner.

    Args:
        strategy: Consolidation strategy name
        ewc_lambda: EWC regularization strength
        replay_buffer_size: Replay buffer size
        progressive_max_columns: Max progressive network columns
        capacity_total: Total capacity

    Returns:
        Configured CatastrophicForgettingPreventer
    """
    strategy_enum = ConsolidationStrategy[strategy.upper()]

    return CatastrophicForgettingPreventer(
        strategy=strategy_enum,
        ewc_lambda=ewc_lambda,
        replay_buffer_size=replay_buffer_size,
        progressive_max_columns=progressive_max_columns,
        capacity_total=capacity_total
    )


# ==============================================================================
# Example Usage and Testing
# ==============================================================================

if __name__ == "__main__":
    print("=" * 70)
    print("AIVA Queen Continual Learning System - Test Suite")
    print("=" * 70)

    # Create continual learner
    learner = create_continual_learner(
        strategy="hybrid",
        ewc_lambda=1000.0,
        replay_buffer_size=1000,
        progressive_max_columns=5,
        capacity_total=50000
    )

    # Define sample tasks
    tasks = [
        TaskDescriptor(
            task_id="classification_task",
            task_type=TaskType.CLASSIFICATION,
            name="Image Classification",
            description="Classify images into categories",
            input_schema={"image": "tensor"},
            output_schema={"class": "int", "confidence": "float"}
        ),
        TaskDescriptor(
            task_id="generation_task",
            task_type=TaskType.GENERATION,
            name="Text Generation",
            description="Generate coherent text from prompts",
            input_schema={"prompt": "string"},
            output_schema={"text": "string"}
        ),
        TaskDescriptor(
            task_id="reasoning_task",
            task_type=TaskType.REASONING,
            name="Logical Reasoning",
            description="Perform logical reasoning over facts",
            input_schema={"facts": "list", "query": "string"},
            output_schema={"answer": "string", "explanation": "string"}
        )
    ]

    # Test Task 1: Classification
    print("\n--- Starting Classification Task ---")
    result = learner.start_task("classification_task", tasks[0])
    print(f"Task started: {result}")

    # Simulate training experiences
    for i in range(50):
        exp = Experience(
            task_id="classification_task",
            task_type=TaskType.CLASSIFICATION,
            input_data={"image": f"image_{i}"},
            output_data={"class": i % 10, "confidence": 0.9},
            importance=random.uniform(0.3, 1.0),
            success_rate=random.uniform(0.7, 1.0)
        )

        # Simulated parameters and gradient function
        params = {f"w_{j}": random.gauss(0, 1) for j in range(10)}
        grad_fn = lambda p, d: {k: random.gauss(0, 0.1) for k in p}

        result = learner.process_experience(exp, params, grad_fn)

    print(f"Processed 50 classification experiences")

    # Complete classification task
    final_params = {f"w_{j}": random.gauss(0, 1) for j in range(10)}
    training_data = [{"image": f"img_{i}"} for i in range(20)]

    result = learner.complete_task(
        "classification_task",
        final_params,
        lambda p, d: {k: random.gauss(0, 0.1) for k in p},
        training_data
    )
    print(f"Task completed: {result}")

    # Test Task 2: Generation
    print("\n--- Starting Generation Task ---")
    result = learner.start_task("generation_task", tasks[1])
    print(f"Task started: {result}")

    for i in range(30):
        exp = Experience(
            task_id="generation_task",
            task_type=TaskType.GENERATION,
            input_data={"prompt": f"prompt_{i}"},
            output_data={"text": f"generated_text_{i}"},
            importance=random.uniform(0.3, 1.0),
            success_rate=random.uniform(0.6, 0.95)
        )
        result = learner.process_experience(exp)

    print(f"Processed 30 generation experiences")

    # Get replay batch
    print("\n--- Testing Experience Replay ---")
    experiences, weights = learner.get_replay_batch(batch_size=10, balanced=True)
    print(f"Sampled {len(experiences)} experiences with weights")
    for exp, weight in zip(experiences[:3], weights[:3]):
        print(f"  - Task: {exp.task_id}, Importance: {exp.importance:.3f}, Weight: {weight:.3f}")

    # Test task inference
    print("\n--- Testing Task Inference ---")
    task_id, confidence, descriptor = learner.infer_current_task(
        {"image": "test_image"},
        {"context": "visual_processing"}
    )
    print(f"Inferred task: {task_id} (confidence: {confidence:.3f})")

    # Get protection status
    print("\n--- Protection Status ---")
    status = learner.get_protection_status()
    print(f"Strategy: {status['strategy']}")
    print(f"Active tasks: {status['active_tasks']}")
    print(f"Completed tasks: {status['completed_tasks']}")
    print(f"Total experiences: {status['total_experiences']}")
    print(f"Forgetting events: {status['forgetting_events']}")
    print(f"Replay buffer size: {status['replay_stats']['buffer_size']}")
    print(f"Progressive network columns: {status['progressive_summary']['total_columns']}")

    # Test capacity management
    print("\n--- Capacity Utilization ---")
    capacity = status['capacity_utilization']
    print(f"Total capacity: {capacity['total_capacity']}")
    print(f"States: {capacity['states']}")
    print(f"Utilization: {capacity['utilization']}")

    # Test EWC
    print("\n--- EWC Summary ---")
    ewc_summary = status['ewc_summary']
    print(f"Total weights: {ewc_summary.get('total_weights', 0)}")
    print(f"Tasks registered: {ewc_summary.get('tasks_registered', 0)}")

    # Save and load state
    print("\n--- Testing State Persistence ---")
    state_path = "/tmp/continual_learner_state.json"
    learner.save_state(state_path)
    print(f"State saved to {state_path}")

    # Create new learner and load state
    new_learner = create_continual_learner()
    new_learner.load_state(state_path)
    print("State loaded successfully")

    new_status = new_learner.get_protection_status()
    print(f"Restored completed tasks: {new_status['completed_tasks']}")
    print(f"Restored experience count: {new_status['total_experiences']}")

    print("\n" + "=" * 70)
    print("All tests completed successfully!")
    print("=" * 70)
