"""
AIVA Memory Architecture Test Suite

Black-box and white-box tests for 3-tier memory system.

VERIFICATION_STAMP
Story: AIVA-003
Verified By: Claude Code Agent
Verified At: 2026-01-26
Test Coverage: Black-box + White-box
"""

import sys
sys.path.append('/mnt/e/genesis-system')
sys.path.append('/mnt/e/genesis-system/data/genesis-memory')

import pytest
import time
import json
from datetime import datetime, timedelta
from typing import List

from AIVA.memory import (
    MemoryManager,
    WorkingMemory,
    EpisodicMemory,
    SemanticMemory,
    MemoryConsolidator,
    SurpriseScorer
)


# =============================================================================
# BLACK BOX TESTS - Test from outside without implementation knowledge
# =============================================================================

class TestMemoryManagerBlackBox:
    """Black-box tests for unified MemoryManager interface."""

    @pytest.fixture
    def manager(self):
        """Create fresh memory manager for each test."""
        mgr = MemoryManager()
        yield mgr
        mgr.close()

    def test_store_and_recall_basic(self, manager):
        """Test basic store and recall functionality."""
        # Store some content
        result = manager.store(
            content={"message": "Hello AIVA"},
            event_type="greeting",
            session_id="test_session_1"
        )

        # Verify storage succeeded
        assert 'stored_in' in result
        assert 'working' in result['stored_in']

        # Recall recent memories
        recent = manager.recall_recent(limit=10)
        assert len(recent) > 0
        assert any(
            item.get('value', {}).get('message') == "Hello AIVA"
            for item in recent
        )

    def test_high_surprise_consolidation(self, manager):
        """Test that high surprise events consolidate to episodic."""
        # Store high-surprise event (error type)
        result = manager.store(
            content={"error": "Critical failure", "unexpected": True},
            event_type="error",
            session_id="test_session_2"
        )

        # Should consolidate due to high surprise
        assert result.get('consolidated') == True
        assert 'episodic' in result.get('stored_in', [])
        assert result.get('episode_id') is not None

    def test_cross_tier_query(self, manager):
        """Test querying across all memory tiers."""
        # Store test data
        manager.store(
            content={"test": "cross tier query"},
            event_type="test_event"
        )

        # Query across tiers
        results = manager.query("cross tier", query_type="auto", limit=5)

        # Should return results structure
        assert 'working' in results
        assert 'semantic' in results
        assert 'episodic' in results
        assert 'merged' in results

    def test_graceful_degradation(self, manager):
        """Test system continues working if one tier fails."""
        # Query with invalid embedding (should skip semantic)
        results = manager.query(
            "test query",
            query_type="auto",
            embedding=None  # No embedding provided
        )

        # Should still return results from other tiers
        assert results is not None
        assert 'merged' in results


class TestWorkingMemoryBlackBox:
    """Black-box tests for working memory tier."""

    @pytest.fixture
    def wm(self):
        """Create fresh working memory for each test."""
        memory = WorkingMemory()
        memory.clear()
        yield memory
        memory.clear()

    def test_add_and_get(self, wm):
        """Test adding and retrieving items."""
        wm.add("key1", {"data": "value1"})
        result = wm.get("key1")

        assert result is not None
        assert result['data'] == "value1"

    def test_get_nonexistent(self, wm):
        """Test retrieving non-existent item returns None."""
        result = wm.get("nonexistent")
        assert result is None

    def test_remove(self, wm):
        """Test removing items."""
        wm.add("key1", {"data": "value1"})
        removed = wm.remove("key1")

        assert removed == True
        assert wm.get("key1") is None

    def test_search(self, wm):
        """Test searching for items."""
        wm.add("test_key_1", {"message": "hello world"})
        wm.add("test_key_2", {"message": "goodbye world"})

        results = wm.search("world")
        assert len(results) >= 2

    def test_get_context(self, wm):
        """Test getting all context."""
        wm.add("key1", {"data": "value1"})
        wm.add("key2", {"data": "value2"})

        context = wm.get_context()
        assert len(context) >= 2


class TestEpisodicMemoryBlackBox:
    """Black-box tests for episodic memory tier."""

    @pytest.fixture
    def em(self):
        """Create episodic memory instance."""
        memory = EpisodicMemory()
        yield memory
        memory.close()

    def test_store_and_recall(self, em):
        """Test storing and recalling episodes."""
        episode_id = em.store_episode(
            event_type="test_event",
            content={"test": "data"},
            importance=0.8
        )

        assert episode_id is not None

        episode = em.recall(episode_id)
        assert episode is not None
        assert episode['event_type'] == "test_event"
        assert episode['content']['test'] == "data"

    def test_search_by_content(self, em):
        """Test content-based search."""
        em.store_episode(
            event_type="test",
            content={"unique_marker": "searchable_content"}
        )

        results = em.search_by_content("searchable_content")
        assert len(results) > 0

    def test_search_by_time(self, em):
        """Test time-based search."""
        now = datetime.utcnow()
        one_hour_ago = now - timedelta(hours=1)

        em.store_episode(
            event_type="test",
            content={"test": "time search"}
        )

        results = em.search_by_time(
            start_time=one_hour_ago,
            end_time=now + timedelta(hours=1)
        )
        assert len(results) > 0


class TestSemanticMemoryBlackBox:
    """Black-box tests for semantic memory tier."""

    @pytest.fixture
    def sm(self):
        """Create semantic memory instance."""
        return SemanticMemory(collection_name="aiva_test_semantic")

    def test_store_and_retrieve(self, sm):
        """Test storing and retrieving with embeddings."""
        # Create dummy embedding (1536 dims)
        embedding = [0.1] * 1536

        point_id = sm.store(
            content="Test knowledge",
            embedding=embedding,
            knowledge_type="test"
        )

        assert point_id is not None

    def test_retrieve_similar(self, sm):
        """Test similarity search."""
        # Store test knowledge
        embedding1 = [0.1] * 1536
        sm.store(
            content="Similar test 1",
            embedding=embedding1,
            knowledge_type="test"
        )

        # Query with similar embedding
        query_embedding = [0.1] * 1536
        results = sm.retrieve_similar(
            query_embedding=query_embedding,
            limit=5,
            score_threshold=0.5
        )

        # Should find similar items
        assert isinstance(results, list)

    def test_forget(self, sm):
        """Test removing knowledge."""
        embedding = [0.1] * 1536
        point_id = sm.store(
            content="Temporary knowledge",
            embedding=embedding
        )

        sm.forget(point_id)
        # Verify deletion (would need retrieve to confirm)


class TestConsolidationBlackBox:
    """Black-box tests for consolidation triggers."""

    @pytest.fixture
    def manager(self):
        """Create memory manager."""
        mgr = MemoryManager()
        yield mgr
        mgr.close()

    def test_surprise_threshold_triggers_consolidation(self, manager):
        """Test that surprise >0.7 triggers consolidation."""
        result = manager.store(
            content={"error": "Unexpected failure", "unexpected": True},
            event_type="error"
        )

        # Should have high surprise and consolidate
        assert result.get('surprise_score', 0) >= 0.7
        assert result.get('consolidated') == True

    def test_normal_events_stay_in_working(self, manager):
        """Test that normal events stay in working memory."""
        result = manager.store(
            content={"status": "normal operation"},
            event_type="normal_operation"
        )

        # Should have low surprise
        assert result.get('surprise_score', 1.0) < 0.7
        assert 'working' in result.get('stored_in', [])


# =============================================================================
# WHITE BOX TESTS - Test internal implementation details
# =============================================================================

class TestWorkingMemoryWhiteBox:
    """White-box tests for working memory internals."""

    @pytest.fixture
    def wm(self):
        """Create working memory."""
        memory = WorkingMemory()
        memory.clear()
        yield memory
        memory.clear()

    def test_capacity_scaling(self, wm):
        """Test capacity scales based on activity."""
        initial_capacity = wm.current_capacity
        assert initial_capacity == 8  # Should start at min

        # Generate activity
        for i in range(20):
            wm.add(f"key_{i}", {"data": f"value_{i}"})
            time.sleep(0.01)

        # Update capacity based on activity
        wm._update_capacity()

        # Capacity should have increased
        assert wm.current_capacity > initial_capacity

    def test_lru_eviction(self, wm):
        """Test LRU eviction when at capacity."""
        # Fill to capacity
        for i in range(wm.max_capacity + 5):
            wm.add(f"key_{i}", {"data": f"value_{i}"})

        # Get stats
        stats = wm.get_stats()

        # Should not exceed max capacity
        assert stats['current_items'] <= wm.max_capacity

    def test_ttl_expiry(self, wm):
        """Test items expire after TTL."""
        # This test would require waiting for TTL
        # For now, just verify TTL is set correctly
        assert wm.ttl_seconds == 3600  # 1 hour


class TestEpisodicMemoryWhiteBox:
    """White-box tests for episodic memory internals."""

    @pytest.fixture
    def em(self):
        """Create episodic memory."""
        memory = EpisodicMemory()
        yield memory
        memory.close()

    def test_schema_creation(self, em):
        """Test database schema is created correctly."""
        # Verify connection exists
        assert em.conn is not None
        assert not em.conn.closed

        # Verify table exists
        with em.conn.cursor() as cursor:
            cursor.execute("""
                SELECT EXISTS (
                    SELECT FROM information_schema.tables
                    WHERE table_name = 'aiva_episodes'
                )
            """)
            exists = cursor.fetchone()[0]
            assert exists == True

    def test_archive_flag(self, em):
        """Test archive flag is set correctly."""
        # Store old episode (would need to manipulate timestamp)
        episode_id = em.store_episode(
            event_type="test",
            content={"test": "archive"}
        )

        # Archive old episodes
        count = em.archive_old_episodes(days=0)  # Archive everything

        # Verify archived
        episode = em.recall(episode_id)
        assert episode.get('is_archived') == True


class TestSemanticMemoryWhiteBox:
    """White-box tests for semantic memory internals."""

    @pytest.fixture
    def sm(self):
        """Create semantic memory."""
        return SemanticMemory(collection_name="aiva_test_semantic_wb")

    def test_collection_creation(self, sm):
        """Test Qdrant collection is created."""
        # Verify collection exists
        collections = sm.client.get_collections().collections
        collection_names = [c.name for c in collections]

        assert sm.collection_name in collection_names

    def test_vector_dimensions(self, sm):
        """Test vector dimensions are correct."""
        assert sm.vector_size == 1536

        # Try storing wrong dimension (should fail)
        with pytest.raises(ValueError):
            sm.store(
                content="Test",
                embedding=[0.1] * 100,  # Wrong size
                knowledge_type="test"
            )


class TestSurpriseScorerWhiteBox:
    """White-box tests for surprise scoring logic."""

    @pytest.fixture
    def scorer(self):
        """Create surprise scorer."""
        return SurpriseScorer()

    def test_base_scores(self, scorer):
        """Test base scores for event types."""
        # Error should be high surprise
        score = scorer.calculate_surprise("error", {})
        assert score >= 0.7

        # Normal operation should be low surprise
        score = scorer.calculate_surprise("normal_operation", {})
        assert score <= 0.3

    def test_context_adjustment(self, scorer):
        """Test surprise decreases with similar context."""
        context = [
            {'metadata': {'event_type': 'error'}},
            {'metadata': {'event_type': 'error'}},
            {'metadata': {'event_type': 'error'}},
            {'metadata': {'event_type': 'error'}}
        ]

        # Error with many similar recent errors should be less surprising
        score_with_context = scorer.calculate_surprise("error", {}, context)
        score_no_context = scorer.calculate_surprise("error", {})

        assert score_with_context < score_no_context

    def test_explicit_surprise_indicators(self, scorer):
        """Test explicit surprise markers in content."""
        # Content with 'unexpected' flag
        score = scorer.calculate_surprise(
            "normal_operation",
            {"unexpected": True}
        )
        assert score >= 0.8


# =============================================================================
# INTEGRATION TESTS
# =============================================================================

class TestMemoryIntegration:
    """Integration tests across all tiers."""

    @pytest.fixture
    def manager(self):
        """Create memory manager."""
        mgr = MemoryManager()
        yield mgr
        mgr.close()

    def test_full_consolidation_flow(self, manager):
        """Test complete consolidation from working to episodic."""
        # Store high-surprise event
        result = manager.store(
            content={"critical": "system failure"},
            event_type="error",
            session_id="integration_test"
        )

        # Verify stored in working
        assert 'working' in result['stored_in']

        # Verify consolidated to episodic
        assert 'episodic' in result['stored_in']

        # Verify can recall episode
        if result.get('episode_id'):
            episode = manager.recall_episode(result['episode_id'])
            assert episode is not None
            assert episode['content']['critical'] == "system failure"

    def test_cross_tier_deduplication(self, manager):
        """Test that merged results deduplicate across tiers."""
        # Store same content
        content = {"unique": "dedup_test_123"}

        manager.store(content=content, event_type="test")

        # Query across tiers
        results = manager.query("dedup_test_123")

        # Merged results should deduplicate
        merged = results['merged']
        unique_contents = set()
        for item in merged:
            content_str = json.dumps(
                item.get('value') or item.get('content', {}),
                sort_keys=True
            )
            unique_contents.add(content_str)

        # Should not have exact duplicates
        assert len(unique_contents) <= len(merged)


# =============================================================================
# PERFORMANCE TESTS
# =============================================================================

class TestMemoryPerformance:
    """Performance and stress tests."""

    def test_working_memory_batch_operations(self):
        """Test working memory handles batch operations efficiently."""
        wm = WorkingMemory()
        wm.clear()

        start = time.time()

        # Store 100 items
        for i in range(100):
            wm.add(f"perf_key_{i}", {"data": f"value_{i}"})

        elapsed = time.time() - start

        # Should complete in reasonable time (< 5 seconds)
        assert elapsed < 5.0

        wm.clear()

    def test_episodic_search_performance(self):
        """Test episodic search performs adequately."""
        em = EpisodicMemory()

        start = time.time()

        # Search (existing data)
        results = em.search_by_content("test", limit=50)

        elapsed = time.time() - start

        # Should complete in reasonable time (< 3 seconds)
        assert elapsed < 3.0

        em.close()


if __name__ == "__main__":
    # Run tests with pytest
    pytest.main([__file__, "-v", "-s"])
