#!/usr/bin/env python3
"""
AIVA Memory Consolidation Engine - Sleep Cycle Mimicry
=======================================================
Production-grade memory consolidation system that mimics biological sleep cycles.

This module implements the core consolidation mechanisms found in biological memory:
- NREM Sleep: Synaptic homeostasis, memory replay, schema integration
- REM Sleep: Memory pruning, interference resolution, creative synthesis

Based on neuroscience research:
- Synaptic Homeostasis Hypothesis (Tononi & Cirelli)
- Two-Stage Memory Model (Marr, McClelland, O'Reilly)
- Memory Reactivation & Replay (Wilson & McNaughton)

Components:
    1. ConsolidationEngine - Main orchestrator for sleep cycles
    2. SynapticHomeostasis - Balance and normalize memory weights
    3. MemoryReplay - Replay important memories for strengthening
    4. SchemaIntegration - Integrate new memories into existing schemas
    5. MemoryPruning - Remove redundant/weak memories
    6. InterferenceResolution - Handle conflicting memories

Author: Genesis AI System
Version: 1.0.0
Date: 2026-01-11
"""

import json
import hashlib
import heapq
import math
import random
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta
from enum import Enum, auto
from pathlib import Path
from typing import (
    Dict, List, Any, Optional, Tuple, Callable,
    Set, Iterator, TypeVar, Generic, Protocol
)
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("aiva.consolidation")


# =============================================================================
# ENUMERATIONS & TYPE DEFINITIONS
# =============================================================================

class SleepPhase(Enum):
    """Biological sleep phases mapped to consolidation operations."""
    AWAKE = "awake"                    # Normal operation
    NREM_LIGHT = "nrem_light"          # Light NREM - Initial consolidation
    NREM_DEEP = "nrem_deep"            # Deep NREM - Synaptic homeostasis
    REM = "rem"                        # REM - Creative synthesis & pruning
    TRANSITION = "transition"          # Phase transitions


class ConsolidationState(Enum):
    """States of the consolidation engine."""
    IDLE = auto()
    COLLECTING = auto()
    REPLAYING = auto()
    INTEGRATING = auto()
    PRUNING = auto()
    RESOLVING = auto()
    COMPLETE = auto()
    ERROR = auto()


class MemoryStrength(Enum):
    """Memory strength classification."""
    WEAK = 0.2
    MODERATE = 0.5
    STRONG = 0.75
    CRITICAL = 1.0


class ConflictType(Enum):
    """Types of memory conflicts."""
    CONTRADICTION = "contradiction"    # Direct factual conflict
    TEMPORAL = "temporal"              # Time-based inconsistency
    SCHEMA = "schema"                  # Schema violation
    SEMANTIC = "semantic"              # Meaning overlap/confusion


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class ConsolidationMemory:
    """Memory representation for consolidation processing."""
    id: str
    content: str
    embedding: Optional[List[float]] = None
    strength: float = 0.5
    access_count: int = 0
    last_accessed: Optional[str] = None
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    domain: str = "general"
    source: str = "unknown"
    schema_id: Optional[str] = None
    relations: List[str] = field(default_factory=list)
    surprise_score: float = 0.0
    replay_count: int = 0
    consolidation_score: float = 0.0
    metadata: Dict[str, Any] = field(default_factory=dict)

    def __lt__(self, other: 'ConsolidationMemory') -> bool:
        """For heap operations - lower consolidation score = lower priority."""
        return self.consolidation_score < other.consolidation_score

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return asdict(self)

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ConsolidationMemory':
        """Create from dictionary."""
        return cls(**data)

    def compute_consolidation_score(self) -> float:
        """
        Compute consolidation priority score.

        Higher score = more important to consolidate.
        Factors:
        - Surprise score (novel information)
        - Access frequency (frequently accessed = important)
        - Recency (recent memories prioritized)
        - Strength (strong memories get stronger)
        """
        # Time decay factor
        created = datetime.fromisoformat(self.created_at)
        age_hours = (datetime.now() - created).total_seconds() / 3600
        recency_factor = math.exp(-age_hours / 168)  # 1-week half-life

        # Access pattern factor
        access_factor = min(1.0, math.log1p(self.access_count) / 5)

        # Combine factors with weights
        self.consolidation_score = (
            self.surprise_score * 0.35 +
            access_factor * 0.25 +
            recency_factor * 0.20 +
            self.strength * 0.20
        )
        return self.consolidation_score


@dataclass
class MemorySchema:
    """Schema representation for memory organization."""
    id: str
    name: str
    domain: str
    prototype: Dict[str, Any]
    member_ids: List[str] = field(default_factory=list)
    coherence_score: float = 1.0
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    last_updated: str = field(default_factory=lambda: datetime.now().isoformat())

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


@dataclass
class ConflictRecord:
    """Record of detected memory conflict."""
    id: str
    memory_a_id: str
    memory_b_id: str
    conflict_type: ConflictType
    severity: float  # 0-1 scale
    resolution: Optional[str] = None
    resolved: bool = False
    detected_at: str = field(default_factory=lambda: datetime.now().isoformat())
    resolved_at: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        d = asdict(self)
        d['conflict_type'] = self.conflict_type.value
        return d


@dataclass
class ConsolidationReport:
    """Report generated after consolidation cycle."""
    cycle_id: str
    phase: SleepPhase
    started_at: str
    completed_at: Optional[str] = None
    duration_seconds: float = 0.0
    memories_processed: int = 0
    memories_strengthened: int = 0
    memories_pruned: int = 0
    schemas_updated: int = 0
    conflicts_resolved: int = 0
    replay_events: int = 0
    homeostasis_adjustments: int = 0
    errors: List[str] = field(default_factory=list)

    def to_dict(self) -> Dict[str, Any]:
        d = asdict(self)
        d['phase'] = self.phase.value
        return d


# =============================================================================
# SYNAPTIC HOMEOSTASIS
# =============================================================================

class SynapticHomeostasis:
    """
    Implements synaptic homeostasis hypothesis.

    During "sleep", synaptic weights are globally downscaled while
    preserving relative differences. This:
    - Prevents saturation
    - Improves signal-to-noise ratio
    - Saves "resources" (storage/compute)
    - Enables new learning capacity

    Based on: Tononi & Cirelli (2006) "Sleep function and synaptic homeostasis"
    """

    def __init__(
        self,
        downscale_factor: float = 0.8,
        minimum_strength: float = 0.1,
        protection_threshold: float = 0.9
    ):
        """
        Initialize homeostasis controller.

        Args:
            downscale_factor: Global scaling factor (0.8 = 20% reduction)
            minimum_strength: Floor for strength values
            protection_threshold: Memories above this are protected
        """
        self.downscale_factor = downscale_factor
        self.minimum_strength = minimum_strength
        self.protection_threshold = protection_threshold
        self.adjustments_made = 0
        self.total_strength_before = 0.0
        self.total_strength_after = 0.0

    def apply_homeostasis(
        self,
        memories: List[ConsolidationMemory],
        preserve_critical: bool = True
    ) -> Tuple[List[ConsolidationMemory], Dict[str, Any]]:
        """
        Apply synaptic homeostasis to memory collection.

        Args:
            memories: List of memories to process
            preserve_critical: Whether to protect high-strength memories

        Returns:
            Tuple of (processed memories, statistics dict)
        """
        if not memories:
            return memories, {"adjustments": 0, "error": "no memories"}

        self.adjustments_made = 0
        self.total_strength_before = sum(m.strength for m in memories)

        processed = []
        for memory in memories:
            # Protect critical memories
            if preserve_critical and memory.strength >= self.protection_threshold:
                processed.append(memory)
                continue

            # Apply downscaling
            old_strength = memory.strength
            new_strength = max(
                self.minimum_strength,
                old_strength * self.downscale_factor
            )

            # Apply non-linear scaling to preserve relative differences
            # Strong memories scale less than weak ones
            strength_factor = 1.0 - (old_strength * 0.2)
            new_strength = old_strength * (self.downscale_factor + strength_factor * 0.1)
            new_strength = max(self.minimum_strength, min(1.0, new_strength))

            if new_strength != old_strength:
                memory.strength = new_strength
                self.adjustments_made += 1

            processed.append(memory)

        self.total_strength_after = sum(m.strength for m in processed)

        stats = {
            "adjustments": self.adjustments_made,
            "total_before": round(self.total_strength_before, 4),
            "total_after": round(self.total_strength_after, 4),
            "reduction_ratio": round(
                self.total_strength_after / max(self.total_strength_before, 0.001),
                4
            ),
            "memories_protected": len([m for m in memories if m.strength >= self.protection_threshold])
        }

        logger.info(f"Homeostasis: {self.adjustments_made} adjustments, "
                   f"reduction ratio: {stats['reduction_ratio']}")

        return processed, stats

    def calculate_homeostatic_pressure(
        self,
        memories: List[ConsolidationMemory]
    ) -> float:
        """
        Calculate current "homeostatic pressure" - need for consolidation.

        Higher pressure = more urgent need for sleep/consolidation.

        Returns:
            Pressure value 0-1 (1 = urgent need for consolidation)
        """
        if not memories:
            return 0.0

        # Factor 1: Total synaptic weight
        total_strength = sum(m.strength for m in memories)
        expected_strength = len(memories) * 0.5  # Expected average
        weight_pressure = min(1.0, total_strength / (expected_strength * 2))

        # Factor 2: Time since last consolidation
        oldest_unreplayed = min(
            (datetime.now() - datetime.fromisoformat(m.created_at)).total_seconds()
            for m in memories if m.replay_count == 0
        ) if any(m.replay_count == 0 for m in memories) else 0
        time_pressure = min(1.0, oldest_unreplayed / (24 * 3600))  # 24h max

        # Factor 3: Memory count pressure
        count_pressure = min(1.0, len(memories) / 10000)  # 10K threshold

        # Combined pressure
        pressure = (
            weight_pressure * 0.4 +
            time_pressure * 0.4 +
            count_pressure * 0.2
        )

        return round(pressure, 4)


# =============================================================================
# MEMORY REPLAY
# =============================================================================

class MemoryReplay:
    """
    Implements memory replay mechanism.

    During NREM sleep, hippocampal memories are "replayed" to cortex,
    strengthening important memories and enabling integration.

    Based on: Wilson & McNaughton (1994) "Reactivation of hippocampal
    ensemble memories during sleep"
    """

    def __init__(
        self,
        replay_buffer_size: int = 100,
        strengthening_factor: float = 1.15,
        max_replays_per_memory: int = 5
    ):
        """
        Initialize replay system.

        Args:
            replay_buffer_size: Max memories in priority queue
            strengthening_factor: Strength multiplier per replay
            max_replays_per_memory: Cap on replay count
        """
        self.buffer_size = replay_buffer_size
        self.strengthening_factor = strengthening_factor
        self.max_replays = max_replays_per_memory
        self._replay_buffer: List[Tuple[float, ConsolidationMemory]] = []
        self.replay_count = 0

    def add_to_buffer(self, memory: ConsolidationMemory) -> bool:
        """
        Add memory to replay buffer (priority queue).

        Uses consolidation score as priority.
        Returns True if added, False if buffer full and lower priority.
        """
        # Compute priority (negative for max-heap behavior)
        memory.compute_consolidation_score()
        priority = -memory.consolidation_score

        if len(self._replay_buffer) < self.buffer_size:
            heapq.heappush(self._replay_buffer, (priority, memory))
            return True
        elif priority < self._replay_buffer[0][0]:
            heapq.heapreplace(self._replay_buffer, (priority, memory))
            return True
        return False

    def replay_cycle(
        self,
        count: int = 10,
        callback: Optional[Callable[[ConsolidationMemory], None]] = None
    ) -> Tuple[List[ConsolidationMemory], Dict[str, Any]]:
        """
        Execute a replay cycle.

        Pops top-priority memories and strengthens them.

        Args:
            count: Number of memories to replay
            callback: Optional function called for each replayed memory

        Returns:
            Tuple of (replayed memories, statistics)
        """
        replayed = []
        strengthened_count = 0

        for _ in range(min(count, len(self._replay_buffer))):
            if not self._replay_buffer:
                break

            _, memory = heapq.heappop(self._replay_buffer)

            # Check replay limit
            if memory.replay_count >= self.max_replays:
                continue

            # Strengthen memory
            old_strength = memory.strength
            memory.strength = min(1.0, memory.strength * self.strengthening_factor)
            memory.replay_count += 1
            memory.last_accessed = datetime.now().isoformat()
            self.replay_count += 1

            if memory.strength > old_strength:
                strengthened_count += 1

            if callback:
                callback(memory)

            replayed.append(memory)

        stats = {
            "replayed": len(replayed),
            "strengthened": strengthened_count,
            "buffer_remaining": len(self._replay_buffer),
            "total_replays": self.replay_count
        }

        logger.info(f"Replay cycle: {len(replayed)} memories replayed, "
                   f"{strengthened_count} strengthened")

        return replayed, stats

    def get_replay_candidates(
        self,
        memories: List[ConsolidationMemory],
        top_k: int = 50
    ) -> List[ConsolidationMemory]:
        """
        Select top candidates for replay based on consolidation priority.

        Prioritizes:
        - High surprise score (novel/important)
        - Recent memories
        - Frequently accessed
        - Strong existing strength
        """
        # Score and sort
        for m in memories:
            m.compute_consolidation_score()

        sorted_memories = sorted(
            memories,
            key=lambda m: m.consolidation_score,
            reverse=True
        )

        return sorted_memories[:top_k]

    def clear_buffer(self):
        """Clear the replay buffer."""
        self._replay_buffer.clear()
        logger.debug("Replay buffer cleared")


# =============================================================================
# SCHEMA INTEGRATION
# =============================================================================

class SchemaIntegration:
    """
    Integrates new memories into existing knowledge schemas.

    Schemas are generalized knowledge structures that help organize
    and retrieve related memories. New memories are assimilated into
    existing schemas or trigger schema creation/modification.

    Based on: Bartlett (1932), Piaget, and modern schema theory
    """

    def __init__(
        self,
        similarity_threshold: float = 0.7,
        max_schemas: int = 1000,
        min_members_for_schema: int = 3
    ):
        """
        Initialize schema integration system.

        Args:
            similarity_threshold: Minimum similarity for schema membership
            max_schemas: Maximum number of schemas to maintain
            min_members_for_schema: Minimum members to form a schema
        """
        self.similarity_threshold = similarity_threshold
        self.max_schemas = max_schemas
        self.min_members = min_members_for_schema
        self.schemas: Dict[str, MemorySchema] = {}
        self._embedding_cache: Dict[str, List[float]] = {}

    def integrate_memory(
        self,
        memory: ConsolidationMemory,
        existing_memories: List[ConsolidationMemory]
    ) -> Tuple[Optional[str], Dict[str, Any]]:
        """
        Integrate a memory into the schema system.

        Returns:
            Tuple of (schema_id if assigned, statistics dict)
        """
        # Find matching schemas by domain
        domain_schemas = [
            s for s in self.schemas.values()
            if s.domain == memory.domain
        ]

        if not domain_schemas:
            # No domain schemas exist - check if we should create one
            domain_memories = [
                m for m in existing_memories
                if m.domain == memory.domain and m.id != memory.id
            ]

            if len(domain_memories) >= self.min_members - 1:
                # Create new schema
                schema = self._create_schema(memory, domain_memories)
                return schema.id, {
                    "action": "created_schema",
                    "schema_id": schema.id,
                    "members": len(schema.member_ids)
                }
            return None, {"action": "insufficient_memories"}

        # Find best matching schema
        best_schema = None
        best_similarity = 0.0

        for schema in domain_schemas:
            similarity = self._compute_schema_similarity(memory, schema)
            if similarity > best_similarity and similarity >= self.similarity_threshold:
                best_schema = schema
                best_similarity = similarity

        if best_schema:
            # Add to existing schema
            best_schema.member_ids.append(memory.id)
            best_schema.last_updated = datetime.now().isoformat()
            memory.schema_id = best_schema.id

            # Update schema coherence
            best_schema.coherence_score = self._compute_coherence(
                best_schema,
                existing_memories
            )

            return best_schema.id, {
                "action": "added_to_schema",
                "schema_id": best_schema.id,
                "similarity": round(best_similarity, 4),
                "coherence": round(best_schema.coherence_score, 4)
            }

        return None, {"action": "no_matching_schema"}

    def _create_schema(
        self,
        seed_memory: ConsolidationMemory,
        related_memories: List[ConsolidationMemory]
    ) -> MemorySchema:
        """Create a new schema from seed memory and related memories."""
        schema_id = hashlib.sha256(
            f"{seed_memory.domain}:{datetime.now().isoformat()}".encode()
        ).hexdigest()[:12]

        # Build prototype from common features
        prototype = {
            "domain": seed_memory.domain,
            "keywords": self._extract_common_keywords(
                [seed_memory.content] + [m.content for m in related_memories]
            ),
            "avg_strength": sum(m.strength for m in [seed_memory] + related_memories) / (len(related_memories) + 1)
        }

        schema = MemorySchema(
            id=schema_id,
            name=f"{seed_memory.domain}_schema_{schema_id[:6]}",
            domain=seed_memory.domain,
            prototype=prototype,
            member_ids=[seed_memory.id] + [m.id for m in related_memories]
        )

        # Update memory schema references
        seed_memory.schema_id = schema_id
        for m in related_memories:
            m.schema_id = schema_id

        self.schemas[schema_id] = schema

        logger.info(f"Created schema {schema_id} with {len(schema.member_ids)} members")

        return schema

    def _compute_schema_similarity(
        self,
        memory: ConsolidationMemory,
        schema: MemorySchema
    ) -> float:
        """Compute similarity between memory and schema prototype."""
        if memory.embedding and "embedding" in schema.prototype:
            # Vector similarity
            return self._cosine_similarity(
                memory.embedding,
                schema.prototype["embedding"]
            )

        # Fallback to keyword overlap
        memory_keywords = set(memory.content.lower().split())
        schema_keywords = set(schema.prototype.get("keywords", []))

        if not schema_keywords:
            return 0.5  # Default for empty prototype

        overlap = len(memory_keywords & schema_keywords)
        return overlap / max(len(schema_keywords), 1)

    def _compute_coherence(
        self,
        schema: MemorySchema,
        all_memories: List[ConsolidationMemory]
    ) -> float:
        """Compute schema coherence score."""
        member_memories = [
            m for m in all_memories
            if m.id in schema.member_ids
        ]

        if len(member_memories) < 2:
            return 1.0

        # Compute pairwise similarities
        similarities = []
        for i, m1 in enumerate(member_memories):
            for m2 in member_memories[i+1:]:
                sim = self._content_similarity(m1.content, m2.content)
                similarities.append(sim)

        return sum(similarities) / len(similarities) if similarities else 1.0

    def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
        """Compute cosine similarity between two vectors."""
        if len(a) != len(b):
            return 0.0

        dot_product = sum(x * y for x, y in zip(a, b))
        magnitude_a = math.sqrt(sum(x * x for x in a))
        magnitude_b = math.sqrt(sum(x * x for x in b))

        if magnitude_a == 0 or magnitude_b == 0:
            return 0.0

        return dot_product / (magnitude_a * magnitude_b)

    def _content_similarity(self, a: str, b: str) -> float:
        """Simple content similarity using word overlap."""
        words_a = set(a.lower().split())
        words_b = set(b.lower().split())

        if not words_a or not words_b:
            return 0.0

        overlap = len(words_a & words_b)
        union = len(words_a | words_b)

        return overlap / union if union > 0 else 0.0

    def _extract_common_keywords(self, contents: List[str], top_k: int = 10) -> List[str]:
        """Extract common keywords from multiple content strings."""
        word_counts: Dict[str, int] = defaultdict(int)

        for content in contents:
            words = set(content.lower().split())
            for word in words:
                if len(word) > 3:  # Skip short words
                    word_counts[word] += 1

        # Get words appearing in multiple contents
        common = [
            word for word, count in word_counts.items()
            if count >= len(contents) * 0.5
        ]

        return sorted(common, key=lambda w: word_counts[w], reverse=True)[:top_k]

    def get_schema_stats(self) -> Dict[str, Any]:
        """Get schema system statistics."""
        return {
            "total_schemas": len(self.schemas),
            "avg_members": sum(len(s.member_ids) for s in self.schemas.values()) / max(len(self.schemas), 1),
            "domains": list(set(s.domain for s in self.schemas.values())),
            "avg_coherence": sum(s.coherence_score for s in self.schemas.values()) / max(len(self.schemas), 1)
        }


# =============================================================================
# MEMORY PRUNING
# =============================================================================

class MemoryPruning:
    """
    Implements memory pruning during consolidation.

    Removes weak, redundant, or outdated memories to:
    - Free resources
    - Reduce interference
    - Improve retrieval efficiency

    Based on: Forgetting as optimization (Richards & Frankland, 2017)
    """

    def __init__(
        self,
        strength_threshold: float = 0.15,
        age_threshold_days: int = 30,
        redundancy_threshold: float = 0.9
    ):
        """
        Initialize pruning system.

        Args:
            strength_threshold: Memories below this strength may be pruned
            age_threshold_days: Consider pruning memories older than this
            redundancy_threshold: Similarity threshold for redundancy detection
        """
        self.strength_threshold = strength_threshold
        self.age_threshold = timedelta(days=age_threshold_days)
        self.redundancy_threshold = redundancy_threshold
        self.pruned_count = 0
        self.pruned_ids: Set[str] = set()

    def identify_candidates(
        self,
        memories: List[ConsolidationMemory]
    ) -> Tuple[List[ConsolidationMemory], List[ConsolidationMemory]]:
        """
        Identify memories that are candidates for pruning.

        Returns:
            Tuple of (keep list, prune candidates list)
        """
        now = datetime.now()
        keep = []
        candidates = []

        for memory in memories:
            # Never prune critical memories
            if memory.strength >= MemoryStrength.CRITICAL.value:
                keep.append(memory)
                continue

            # Check age + weakness combination
            age = now - datetime.fromisoformat(memory.created_at)
            is_old = age > self.age_threshold
            is_weak = memory.strength < self.strength_threshold

            # Check for low access (never accessed after creation)
            is_neglected = memory.access_count == 0 and age.days > 7

            if (is_old and is_weak) or is_neglected:
                candidates.append(memory)
            else:
                keep.append(memory)

        return keep, candidates

    def detect_redundancy(
        self,
        memories: List[ConsolidationMemory]
    ) -> List[Tuple[str, str]]:
        """
        Detect redundant memory pairs.

        Returns list of (keep_id, prune_id) tuples.
        """
        redundant_pairs = []

        for i, m1 in enumerate(memories):
            for m2 in memories[i+1:]:
                similarity = self._compute_similarity(m1, m2)

                if similarity >= self.redundancy_threshold:
                    # Keep the stronger/more accessed one
                    if m1.strength > m2.strength or m1.access_count > m2.access_count:
                        redundant_pairs.append((m1.id, m2.id))
                    else:
                        redundant_pairs.append((m2.id, m1.id))

        return redundant_pairs

    def prune_memories(
        self,
        memories: List[ConsolidationMemory],
        aggressive: bool = False
    ) -> Tuple[List[ConsolidationMemory], Dict[str, Any]]:
        """
        Execute memory pruning.

        Args:
            memories: List of memories to process
            aggressive: If True, apply more aggressive pruning

        Returns:
            Tuple of (surviving memories, statistics)
        """
        if aggressive:
            # Temporarily lower thresholds
            old_strength = self.strength_threshold
            self.strength_threshold = 0.25

        keep, candidates = self.identify_candidates(memories)

        # Detect redundancy in candidates
        redundant = self.detect_redundancy(candidates)
        prune_ids = set(pair[1] for pair in redundant)

        # Add weak candidates that aren't the "keeper" in a redundant pair
        for candidate in candidates:
            if candidate.id not in prune_ids:
                # Give weak memories one more chance
                if candidate.access_count == 0 and candidate.strength < 0.2:
                    prune_ids.add(candidate.id)

        # Filter survivors
        survivors = keep + [c for c in candidates if c.id not in prune_ids]
        pruned = [c for c in candidates if c.id in prune_ids]

        self.pruned_count += len(pruned)
        self.pruned_ids.update(prune_ids)

        if aggressive:
            self.strength_threshold = old_strength

        stats = {
            "initial_count": len(memories),
            "survivors": len(survivors),
            "pruned": len(pruned),
            "redundant_pairs": len(redundant),
            "pruning_rate": round(len(pruned) / max(len(memories), 1), 4)
        }

        logger.info(f"Pruning: {len(pruned)}/{len(memories)} memories removed "
                   f"({stats['pruning_rate']*100:.1f}%)")

        return survivors, stats

    def _compute_similarity(
        self,
        m1: ConsolidationMemory,
        m2: ConsolidationMemory
    ) -> float:
        """Compute similarity between two memories."""
        # Use embeddings if available
        if m1.embedding and m2.embedding:
            return self._cosine_similarity(m1.embedding, m2.embedding)

        # Fallback to content similarity
        words1 = set(m1.content.lower().split())
        words2 = set(m2.content.lower().split())

        if not words1 or not words2:
            return 0.0

        overlap = len(words1 & words2)
        union = len(words1 | words2)

        return overlap / union if union > 0 else 0.0

    def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
        """Compute cosine similarity."""
        if len(a) != len(b):
            return 0.0

        dot = sum(x * y for x, y in zip(a, b))
        mag_a = math.sqrt(sum(x * x for x in a))
        mag_b = math.sqrt(sum(x * x for x in b))

        return dot / (mag_a * mag_b) if mag_a and mag_b else 0.0


# =============================================================================
# INTERFERENCE RESOLUTION
# =============================================================================

class InterferenceResolution:
    """
    Resolves conflicts and interference between memories.

    Memory interference occurs when:
    - New memories contradict old ones
    - Similar memories compete for retrieval
    - Temporal sequences are inconsistent

    Resolution strategies:
    - Temporal ordering (most recent wins)
    - Confidence weighting
    - Source reliability
    - Schema consistency
    """

    def __init__(
        self,
        conflict_detection_threshold: float = 0.8,
        temporal_weight: float = 0.3,
        source_reliability: Dict[str, float] = None
    ):
        """
        Initialize interference resolution system.

        Args:
            conflict_detection_threshold: Similarity threshold for conflict detection
            temporal_weight: Weight given to temporal ordering
            source_reliability: Dict mapping source names to reliability scores
        """
        self.conflict_threshold = conflict_detection_threshold
        self.temporal_weight = temporal_weight
        self.source_reliability = source_reliability or {}
        self.conflicts: List[ConflictRecord] = []
        self.resolutions: Dict[str, str] = {}  # conflict_id -> resolution

    def detect_conflicts(
        self,
        memories: List[ConsolidationMemory]
    ) -> List[ConflictRecord]:
        """
        Detect potential conflicts between memories.

        Returns list of detected conflicts.
        """
        detected = []

        for i, m1 in enumerate(memories):
            for m2 in memories[i+1:]:
                conflict = self._check_conflict(m1, m2)
                if conflict:
                    detected.append(conflict)

        self.conflicts.extend(detected)
        return detected

    def _check_conflict(
        self,
        m1: ConsolidationMemory,
        m2: ConsolidationMemory
    ) -> Optional[ConflictRecord]:
        """Check if two memories are in conflict."""
        # Same domain required for meaningful conflict
        if m1.domain != m2.domain:
            return None

        # Detect contradiction (high similarity but different conclusions)
        similarity = self._content_similarity(m1.content, m2.content)

        if similarity < 0.5:
            return None  # Too different to conflict

        # Check for negation patterns
        has_negation = self._detect_negation_conflict(m1.content, m2.content)

        if has_negation:
            conflict_id = hashlib.sha256(
                f"{m1.id}:{m2.id}".encode()
            ).hexdigest()[:12]

            return ConflictRecord(
                id=conflict_id,
                memory_a_id=m1.id,
                memory_b_id=m2.id,
                conflict_type=ConflictType.CONTRADICTION,
                severity=similarity
            )

        # Check for temporal inconsistency
        if self._detect_temporal_conflict(m1, m2):
            conflict_id = hashlib.sha256(
                f"{m1.id}:{m2.id}:temporal".encode()
            ).hexdigest()[:12]

            return ConflictRecord(
                id=conflict_id,
                memory_a_id=m1.id,
                memory_b_id=m2.id,
                conflict_type=ConflictType.TEMPORAL,
                severity=0.5
            )

        return None

    def resolve_conflicts(
        self,
        memories: List[ConsolidationMemory],
        conflicts: List[ConflictRecord]
    ) -> Tuple[List[ConsolidationMemory], Dict[str, Any]]:
        """
        Resolve detected conflicts.

        Returns:
            Tuple of (processed memories, resolution statistics)
        """
        memory_map = {m.id: m for m in memories}
        resolutions = []

        for conflict in conflicts:
            if conflict.resolved:
                continue

            m_a = memory_map.get(conflict.memory_a_id)
            m_b = memory_map.get(conflict.memory_b_id)

            if not m_a or not m_b:
                continue

            resolution = self._resolve_conflict(m_a, m_b, conflict)
            conflict.resolution = resolution["action"]
            conflict.resolved = True
            conflict.resolved_at = datetime.now().isoformat()
            resolutions.append(resolution)
            self.resolutions[conflict.id] = resolution["action"]

            # Apply resolution
            if resolution["action"] == "keep_a_weaken_b":
                m_b.strength *= 0.5
            elif resolution["action"] == "keep_b_weaken_a":
                m_a.strength *= 0.5
            elif resolution["action"] == "merge":
                # Merge content into stronger memory
                if m_a.strength >= m_b.strength:
                    m_a.content = f"{m_a.content} [Updated: {m_b.content}]"
                    m_a.strength = max(m_a.strength, m_b.strength)
                    m_b.strength = 0.0  # Mark for pruning
                else:
                    m_b.content = f"{m_b.content} [Updated: {m_a.content}]"
                    m_b.strength = max(m_a.strength, m_b.strength)
                    m_a.strength = 0.0

        stats = {
            "conflicts_processed": len(conflicts),
            "resolutions": len(resolutions),
            "resolution_types": self._count_resolution_types(resolutions)
        }

        logger.info(f"Resolved {len(resolutions)} conflicts")

        return list(memory_map.values()), stats

    def _resolve_conflict(
        self,
        m_a: ConsolidationMemory,
        m_b: ConsolidationMemory,
        conflict: ConflictRecord
    ) -> Dict[str, Any]:
        """Determine resolution for a specific conflict."""
        # Factor 1: Temporal recency
        time_a = datetime.fromisoformat(m_a.created_at)
        time_b = datetime.fromisoformat(m_b.created_at)
        recency_score_a = 1.0 if time_a > time_b else 0.0

        # Factor 2: Source reliability
        source_a_rel = self.source_reliability.get(m_a.source, 0.5)
        source_b_rel = self.source_reliability.get(m_b.source, 0.5)

        # Factor 3: Memory strength
        strength_score_a = m_a.strength / max(m_a.strength + m_b.strength, 0.01)

        # Factor 4: Access frequency
        access_score_a = m_a.access_count / max(m_a.access_count + m_b.access_count, 1)

        # Compute weighted score
        score_a = (
            recency_score_a * self.temporal_weight +
            source_a_rel * 0.3 +
            strength_score_a * 0.2 +
            access_score_a * 0.2
        )

        # Determine action
        if score_a > 0.6:
            return {"action": "keep_a_weaken_b", "confidence": score_a}
        elif score_a < 0.4:
            return {"action": "keep_b_weaken_a", "confidence": 1 - score_a}
        else:
            return {"action": "merge", "confidence": 0.5}

    def _detect_negation_conflict(self, content_a: str, content_b: str) -> bool:
        """Detect if contents contain negation conflict."""
        negation_words = {"not", "no", "never", "none", "isn't", "aren't", "wasn't", "weren't", "don't", "doesn't", "didn't", "won't", "wouldn't", "can't", "couldn't"}

        words_a = set(content_a.lower().split())
        words_b = set(content_b.lower().split())

        # Check if one has negation the other doesn't
        has_neg_a = bool(words_a & negation_words)
        has_neg_b = bool(words_b & negation_words)

        return has_neg_a != has_neg_b

    def _detect_temporal_conflict(
        self,
        m_a: ConsolidationMemory,
        m_b: ConsolidationMemory
    ) -> bool:
        """Detect temporal inconsistency between memories."""
        temporal_keywords = ["before", "after", "then", "first", "last", "earlier", "later"]

        content_a = m_a.content.lower()
        content_b = m_b.content.lower()

        # Simple check - both discuss temporal ordering
        has_temporal_a = any(kw in content_a for kw in temporal_keywords)
        has_temporal_b = any(kw in content_b for kw in temporal_keywords)

        return has_temporal_a and has_temporal_b

    def _content_similarity(self, a: str, b: str) -> float:
        """Compute content similarity."""
        words_a = set(a.lower().split())
        words_b = set(b.lower().split())

        if not words_a or not words_b:
            return 0.0

        overlap = len(words_a & words_b)
        union = len(words_a | words_b)

        return overlap / union if union > 0 else 0.0

    def _count_resolution_types(self, resolutions: List[Dict]) -> Dict[str, int]:
        """Count resolution types."""
        counts: Dict[str, int] = defaultdict(int)
        for r in resolutions:
            counts[r["action"]] += 1
        return dict(counts)


# =============================================================================
# CONSOLIDATION ENGINE - MAIN ORCHESTRATOR
# =============================================================================

class ConsolidationEngine:
    """
    Main orchestrator for memory consolidation sleep cycles.

    Coordinates all consolidation components through biologically-inspired
    sleep phases:

    1. NREM Light - Initial collection and replay buffer building
    2. NREM Deep - Synaptic homeostasis and memory strengthening
    3. REM - Schema integration, pruning, and interference resolution

    The engine manages state, tracks metrics, and ensures robust
    consolidation with proper error handling.
    """

    def __init__(
        self,
        data_path: str = "/mnt/e/genesis-system/data/consolidation",
        cycle_interval_hours: float = 8.0,
        nrem_duration_ratio: float = 0.7,
        rem_duration_ratio: float = 0.3
    ):
        """
        Initialize the consolidation engine.

        Args:
            data_path: Path for persisting consolidation state
            cycle_interval_hours: Time between consolidation cycles
            nrem_duration_ratio: Proportion of cycle for NREM phases
            rem_duration_ratio: Proportion of cycle for REM phase
        """
        self.data_path = Path(data_path)
        self.data_path.mkdir(parents=True, exist_ok=True)

        self.cycle_interval = timedelta(hours=cycle_interval_hours)
        self.nrem_ratio = nrem_duration_ratio
        self.rem_ratio = rem_duration_ratio

        # Initialize components
        self.homeostasis = SynapticHomeostasis()
        self.replay = MemoryReplay()
        self.schema = SchemaIntegration()
        self.pruning = MemoryPruning()
        self.interference = InterferenceResolution()

        # State tracking
        self.state = ConsolidationState.IDLE
        self.current_phase = SleepPhase.AWAKE
        self.last_consolidation: Optional[datetime] = None
        self.cycle_count = 0
        self.reports: List[ConsolidationReport] = []

        # Thread safety
        self._lock = threading.Lock()
        self._running = False

        # Load previous state
        self._load_state()

        logger.info("ConsolidationEngine initialized")

    def consolidate(
        self,
        memories: List[ConsolidationMemory],
        force: bool = False
    ) -> ConsolidationReport:
        """
        Execute a full consolidation cycle.

        This is the main entry point for memory consolidation.

        Args:
            memories: List of memories to consolidate
            force: If True, run even if interval hasn't elapsed

        Returns:
            ConsolidationReport with cycle results
        """
        with self._lock:
            # Check if consolidation needed
            if not force and not self._should_consolidate():
                logger.info("Consolidation not needed yet")
                return self._create_empty_report("skipped - interval not elapsed")

            # Initialize report
            cycle_id = hashlib.sha256(
                f"cycle:{datetime.now().isoformat()}".encode()
            ).hexdigest()[:16]

            report = ConsolidationReport(
                cycle_id=cycle_id,
                phase=SleepPhase.AWAKE,
                started_at=datetime.now().isoformat(),
                memories_processed=len(memories)
            )

            try:
                self.state = ConsolidationState.COLLECTING
                self._running = True

                # ============================================================
                # PHASE 1: NREM LIGHT - Collection & Initial Processing
                # ============================================================
                report.phase = SleepPhase.NREM_LIGHT
                self.current_phase = SleepPhase.NREM_LIGHT
                logger.info(f"[NREM-LIGHT] Starting consolidation cycle {cycle_id}")

                # Build replay buffer from high-priority memories
                candidates = self.replay.get_replay_candidates(memories, top_k=100)
                for m in candidates:
                    self.replay.add_to_buffer(m)

                # ============================================================
                # PHASE 2: NREM DEEP - Homeostasis & Replay
                # ============================================================
                report.phase = SleepPhase.NREM_DEEP
                self.current_phase = SleepPhase.NREM_DEEP
                self.state = ConsolidationState.REPLAYING
                logger.info("[NREM-DEEP] Applying synaptic homeostasis")

                # Apply homeostasis
                memories, homeostasis_stats = self.homeostasis.apply_homeostasis(
                    memories,
                    preserve_critical=True
                )
                report.homeostasis_adjustments = homeostasis_stats["adjustments"]

                # Execute replay cycles
                logger.info("[NREM-DEEP] Running memory replay")
                total_replayed = 0
                for _ in range(3):  # Multiple replay passes
                    replayed, replay_stats = self.replay.replay_cycle(count=20)
                    total_replayed += len(replayed)
                    report.memories_strengthened += replay_stats["strengthened"]

                report.replay_events = total_replayed

                # ============================================================
                # PHASE 3: REM - Integration, Pruning, Resolution
                # ============================================================
                report.phase = SleepPhase.REM
                self.current_phase = SleepPhase.REM
                logger.info("[REM] Schema integration and pruning")

                # Schema integration
                self.state = ConsolidationState.INTEGRATING
                schemas_updated = 0
                for memory in memories:
                    schema_id, _ = self.schema.integrate_memory(memory, memories)
                    if schema_id:
                        schemas_updated += 1
                report.schemas_updated = schemas_updated

                # Conflict detection and resolution
                self.state = ConsolidationState.RESOLVING
                conflicts = self.interference.detect_conflicts(memories)
                if conflicts:
                    memories, resolution_stats = self.interference.resolve_conflicts(
                        memories, conflicts
                    )
                    report.conflicts_resolved = resolution_stats["resolutions"]

                # Memory pruning
                self.state = ConsolidationState.PRUNING
                memories, prune_stats = self.pruning.prune_memories(memories)
                report.memories_pruned = prune_stats["pruned"]

                # ============================================================
                # COMPLETION
                # ============================================================
                self.current_phase = SleepPhase.AWAKE
                self.state = ConsolidationState.COMPLETE

                report.completed_at = datetime.now().isoformat()
                report.duration_seconds = (
                    datetime.fromisoformat(report.completed_at) -
                    datetime.fromisoformat(report.started_at)
                ).total_seconds()

                # Update engine state
                self.last_consolidation = datetime.now()
                self.cycle_count += 1
                self.reports.append(report)

                # Persist state
                self._save_state()

                logger.info(f"[COMPLETE] Cycle {cycle_id} finished in "
                           f"{report.duration_seconds:.2f}s")

            except Exception as e:
                report.errors.append(str(e))
                self.state = ConsolidationState.ERROR
                logger.error(f"Consolidation error: {e}")

            finally:
                self._running = False
                self.current_phase = SleepPhase.AWAKE

            return report

    def get_homeostatic_pressure(
        self,
        memories: List[ConsolidationMemory]
    ) -> float:
        """Get current homeostatic pressure indicating need for consolidation."""
        return self.homeostasis.calculate_homeostatic_pressure(memories)

    def get_status(self) -> Dict[str, Any]:
        """Get current engine status."""
        return {
            "state": self.state.name,
            "current_phase": self.current_phase.value,
            "running": self._running,
            "cycle_count": self.cycle_count,
            "last_consolidation": self.last_consolidation.isoformat() if self.last_consolidation else None,
            "next_consolidation": (
                (self.last_consolidation + self.cycle_interval).isoformat()
                if self.last_consolidation else "immediate"
            ),
            "schema_stats": self.schema.get_schema_stats(),
            "pruned_total": self.pruning.pruned_count,
            "conflicts_total": len(self.interference.conflicts)
        }

    def get_recent_reports(self, count: int = 5) -> List[Dict[str, Any]]:
        """Get recent consolidation reports."""
        return [r.to_dict() for r in self.reports[-count:]]

    def _should_consolidate(self) -> bool:
        """Check if enough time has passed since last consolidation."""
        if self.last_consolidation is None:
            return True

        elapsed = datetime.now() - self.last_consolidation
        return elapsed >= self.cycle_interval

    def _create_empty_report(self, reason: str) -> ConsolidationReport:
        """Create an empty report for skipped consolidation."""
        return ConsolidationReport(
            cycle_id="skipped",
            phase=SleepPhase.AWAKE,
            started_at=datetime.now().isoformat(),
            completed_at=datetime.now().isoformat(),
            errors=[reason]
        )

    def _save_state(self):
        """Persist engine state to disk."""
        state_file = self.data_path / "consolidation_state.json"

        state_data = {
            "cycle_count": self.cycle_count,
            "last_consolidation": self.last_consolidation.isoformat() if self.last_consolidation else None,
            "schemas": {sid: s.to_dict() for sid, s in self.schema.schemas.items()},
            "recent_reports": [r.to_dict() for r in self.reports[-10:]],
            "saved_at": datetime.now().isoformat()
        }

        try:
            with open(state_file, 'w') as f:
                json.dump(state_data, f, indent=2)
            logger.debug(f"State saved to {state_file}")
        except Exception as e:
            logger.error(f"Failed to save state: {e}")

    def _load_state(self):
        """Load engine state from disk."""
        state_file = self.data_path / "consolidation_state.json"

        if not state_file.exists():
            return

        try:
            with open(state_file) as f:
                state_data = json.load(f)

            self.cycle_count = state_data.get("cycle_count", 0)

            if state_data.get("last_consolidation"):
                self.last_consolidation = datetime.fromisoformat(
                    state_data["last_consolidation"]
                )

            # Restore schemas
            for sid, sdata in state_data.get("schemas", {}).items():
                self.schema.schemas[sid] = MemorySchema(**sdata)

            logger.info(f"State loaded: {self.cycle_count} previous cycles, "
                       f"{len(self.schema.schemas)} schemas")

        except Exception as e:
            logger.error(f"Failed to load state: {e}")


# =============================================================================
# CONVENIENCE FUNCTIONS
# =============================================================================

def create_consolidation_engine(
    data_path: str = "/mnt/e/genesis-system/data/consolidation",
    **kwargs
) -> ConsolidationEngine:
    """Factory function to create a configured consolidation engine."""
    return ConsolidationEngine(data_path=data_path, **kwargs)


def run_consolidation_cycle(
    memories: List[Dict[str, Any]],
    force: bool = True
) -> Dict[str, Any]:
    """
    Convenience function to run a single consolidation cycle.

    Args:
        memories: List of memory dicts to consolidate
        force: If True, run regardless of timing

    Returns:
        Consolidation report as dict
    """
    # Convert dicts to ConsolidationMemory objects
    memory_objects = [
        ConsolidationMemory.from_dict(m) if isinstance(m, dict)
        else m for m in memories
    ]

    engine = create_consolidation_engine()
    report = engine.consolidate(memory_objects, force=force)

    return report.to_dict()


# =============================================================================
# CLI INTERFACE
# =============================================================================

if __name__ == "__main__":
    import sys

    def print_usage():
        print("""
AIVA Memory Consolidation Engine
=================================

Usage:
    python mem_04_consolidation.py status       Show engine status
    python mem_04_consolidation.py test         Run test consolidation
    python mem_04_consolidation.py reports      Show recent reports
    python mem_04_consolidation.py pressure     Check homeostatic pressure

Examples:
    python mem_04_consolidation.py status
    python mem_04_consolidation.py test --count 100
        """)

    if len(sys.argv) < 2:
        print_usage()
        sys.exit(0)

    command = sys.argv[1]
    engine = create_consolidation_engine()

    if command == "status":
        status = engine.get_status()
        print("\n" + "="*60)
        print("CONSOLIDATION ENGINE STATUS")
        print("="*60)
        print(json.dumps(status, indent=2))

    elif command == "test":
        # Generate test memories
        count = 50
        if len(sys.argv) > 3 and sys.argv[2] == "--count":
            count = int(sys.argv[3])

        print(f"\nGenerating {count} test memories...")

        test_memories = []
        domains = ["technical", "learning", "error", "decision", "observation"]

        for i in range(count):
            mem = ConsolidationMemory(
                id=hashlib.sha256(f"test_{i}".encode()).hexdigest()[:12],
                content=f"Test memory {i} about {random.choice(domains)} topic with some additional context.",
                domain=random.choice(domains),
                source="test_generator",
                strength=random.uniform(0.1, 0.9),
                access_count=random.randint(0, 10),
                surprise_score=random.uniform(0.2, 0.8)
            )
            test_memories.append(mem)

        print(f"Running consolidation cycle...")
        report = engine.consolidate(test_memories, force=True)

        print("\n" + "="*60)
        print("CONSOLIDATION REPORT")
        print("="*60)
        print(json.dumps(report.to_dict(), indent=2))

    elif command == "reports":
        reports = engine.get_recent_reports(5)
        print("\n" + "="*60)
        print("RECENT CONSOLIDATION REPORTS")
        print("="*60)
        if reports:
            print(json.dumps(reports, indent=2))
        else:
            print("No consolidation reports yet.")

    elif command == "pressure":
        # Generate sample memories for pressure calculation
        sample_memories = [
            ConsolidationMemory(
                id=f"sample_{i}",
                content=f"Sample memory {i}",
                strength=random.uniform(0.3, 0.8),
                replay_count=random.randint(0, 2)
            )
            for i in range(100)
        ]

        pressure = engine.get_homeostatic_pressure(sample_memories)
        print(f"\nHomeostatic Pressure: {pressure:.4f}")
        print(f"Interpretation: {'HIGH - consolidation recommended' if pressure > 0.6 else 'MODERATE' if pressure > 0.3 else 'LOW'}")

    else:
        print(f"Unknown command: {command}")
        print_usage()
        sys.exit(1)
